[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svdupq_lane (#135633)
This commit is contained in:
@@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
|
||||
"a 1-D scalable vector with length " # length,
|
||||
"::mlir::VectorType">;
|
||||
|
||||
def SVEVector : AnyTypeOf<[
|
||||
Scalable1DVectorOfLength<2, [I64, F64]>,
|
||||
Scalable1DVectorOfLength<4, [I32, F32]>,
|
||||
Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
|
||||
Scalable1DVectorOfLength<16, [I8]>],
|
||||
"an SVE vector with element size <= 64-bit">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArmSVE op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
|
||||
list<Trait> traits = [],
|
||||
list<int> overloadedOperands = [],
|
||||
list<int> overloadedResults = [],
|
||||
int numResults = 1> :
|
||||
int numResults = 1,
|
||||
list<int> immArgPositions = [],
|
||||
list<string> immArgAttrNames = []> :
|
||||
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
|
||||
/*string opName=*/"intr." # mnemonic,
|
||||
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
|
||||
/*list<int> overloadedResults=*/overloadedResults,
|
||||
/*list<int> overloadedOperands=*/overloadedOperands,
|
||||
/*list<Trait> traits=*/traits,
|
||||
/*int numResults=*/numResults>;
|
||||
/*int numResults=*/numResults,
|
||||
/*bit requiresAccessGroup=*/0,
|
||||
/*bit requiresAliasAnalysis=*/0,
|
||||
/*bit requiresFastmath=*/0,
|
||||
/*bit requiresOpBundles=*/0,
|
||||
/*list<int> immArgPositions=*/immArgPositions,
|
||||
/*list<string> immArgAttrNames=*/immArgAttrNames>;
|
||||
|
||||
class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
|
||||
list<Trait> traits = []>:
|
||||
@@ -509,6 +524,45 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
|
||||
|
||||
def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
|
||||
|
||||
def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
|
||||
let summary = "Broadcast indexed 128-bit segment to vector";
|
||||
|
||||
let description = [{
|
||||
This operation fills each 128-bit segment of a vector with the elements
|
||||
from the indexed 128-bit segment of the source vector. If the VL is
|
||||
128 bits the operation is a NOP. If the index exceeds the number of
|
||||
128-bit segments in a vector the result is an all-zeroes vector.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// VL == 256
|
||||
// %X = [A B C D x x x x]
|
||||
%Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
|
||||
// Y = [A B C D A B C D]
|
||||
|
||||
// %U = [x x x x x x x x A B C D E F G H]
|
||||
%V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
|
||||
// %V = [A B C D E F H A B C D E F H]
|
||||
```
|
||||
|
||||
Note: The semantics of the operation match those of the `svdupq_lane` instrinsics.
|
||||
[Source](https://developer.arm.com/architectures/instruction-sets/intrinsics/#q=svdupq_lane)
|
||||
}];
|
||||
|
||||
let arguments = (ins SVEVector:$src,
|
||||
I64Attr:$lane);
|
||||
let results = (outs SVEVector:$dst);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
|
||||
build($_builder, $_state, src.getType(), src, lane);
|
||||
}]>];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$src `[` $lane `]` attr-dict `:` type($dst)
|
||||
}];
|
||||
}
|
||||
|
||||
def UmmlaIntrOp :
|
||||
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
|
||||
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
|
||||
@@ -610,4 +664,14 @@ def WhileLTIntrOp :
|
||||
/*overloadedResults=*/[0]>,
|
||||
Arguments<(ins I64:$base, I64:$n)>;
|
||||
|
||||
def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
|
||||
/*traits=*/[],
|
||||
/*overloadedOperands=*/[0],
|
||||
/*overloadedResults=*/[],
|
||||
/*numResults=*/1,
|
||||
/*immArgPositions*/[1],
|
||||
/*immArgAttrNames*/["lane"]>,
|
||||
Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
|
||||
Arg<I64Attr, "lane">:$lane)>;
|
||||
|
||||
#endif // ARMSVE_OPS
|
||||
|
||||
@@ -24,6 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
|
||||
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
|
||||
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
|
||||
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
|
||||
using DupQLaneLowering =
|
||||
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
|
||||
using ScalableMaskedAddIOpLowering =
|
||||
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
|
||||
ScalableMaskedAddIIntrOp>;
|
||||
@@ -188,24 +190,25 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
|
||||
// Populate conversion patterns
|
||||
|
||||
// clang-format off
|
||||
patterns.add<SdotOpLowering,
|
||||
patterns.add<ConvertFromSvboolOpLowering,
|
||||
ConvertToSvboolOpLowering,
|
||||
DupQLaneLowering,
|
||||
PselOpLowering,
|
||||
ScalableMaskedAddFOpLowering,
|
||||
ScalableMaskedAddIOpLowering,
|
||||
ScalableMaskedDivFOpLowering,
|
||||
ScalableMaskedMulFOpLowering,
|
||||
ScalableMaskedMulIOpLowering,
|
||||
ScalableMaskedSDivIOpLowering,
|
||||
ScalableMaskedSubFOpLowering,
|
||||
ScalableMaskedSubIOpLowering,
|
||||
ScalableMaskedUDivIOpLowering,
|
||||
SmmlaOpLowering,
|
||||
UdotOpLowering,
|
||||
UmmlaOpLowering,
|
||||
ScalableMaskedAddIOpLowering,
|
||||
ScalableMaskedAddFOpLowering,
|
||||
ScalableMaskedSubIOpLowering,
|
||||
ScalableMaskedSubFOpLowering,
|
||||
ScalableMaskedMulIOpLowering,
|
||||
ScalableMaskedMulFOpLowering,
|
||||
ScalableMaskedSDivIOpLowering,
|
||||
ScalableMaskedUDivIOpLowering,
|
||||
ScalableMaskedDivFOpLowering,
|
||||
ConvertToSvboolOpLowering,
|
||||
ConvertFromSvboolOpLowering,
|
||||
ZipX2OpLowering,
|
||||
ZipX4OpLowering,
|
||||
PselOpLowering>(converter);
|
||||
SdotOpLowering>(converter);
|
||||
// Add vector.create_mask conversion with a high benefit as it produces much
|
||||
// nicer code than the generic lowering.
|
||||
patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
|
||||
@@ -215,41 +218,44 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
|
||||
void mlir::configureArmSVELegalizeForExportTarget(
|
||||
LLVMConversionTarget &target) {
|
||||
// clang-format off
|
||||
target.addLegalOp<SdotIntrOp,
|
||||
target.addLegalOp<ConvertFromSvboolIntrOp,
|
||||
ConvertToSvboolIntrOp,
|
||||
DupQLaneIntrOp,
|
||||
PselIntrOp,
|
||||
ScalableMaskedAddFIntrOp,
|
||||
ScalableMaskedAddIIntrOp,
|
||||
ScalableMaskedDivFIntrOp,
|
||||
ScalableMaskedMulFIntrOp,
|
||||
ScalableMaskedMulIIntrOp,
|
||||
ScalableMaskedSDivIIntrOp,
|
||||
ScalableMaskedSubFIntrOp,
|
||||
ScalableMaskedSubIIntrOp,
|
||||
ScalableMaskedUDivIIntrOp,
|
||||
SmmlaIntrOp,
|
||||
UdotIntrOp,
|
||||
UmmlaIntrOp,
|
||||
ScalableMaskedAddIIntrOp,
|
||||
ScalableMaskedAddFIntrOp,
|
||||
ScalableMaskedSubIIntrOp,
|
||||
ScalableMaskedSubFIntrOp,
|
||||
ScalableMaskedMulIIntrOp,
|
||||
ScalableMaskedMulFIntrOp,
|
||||
ScalableMaskedSDivIIntrOp,
|
||||
ScalableMaskedUDivIIntrOp,
|
||||
ScalableMaskedDivFIntrOp,
|
||||
ConvertToSvboolIntrOp,
|
||||
ConvertFromSvboolIntrOp,
|
||||
WhileLTIntrOp,
|
||||
ZipX2IntrOp,
|
||||
ZipX4IntrOp,
|
||||
PselIntrOp,
|
||||
WhileLTIntrOp>();
|
||||
target.addIllegalOp<SdotOp,
|
||||
SdotIntrOp>();
|
||||
target.addIllegalOp<ConvertFromSvboolOp,
|
||||
ConvertToSvboolOp,
|
||||
DupQLaneOp,
|
||||
PselOp,
|
||||
ScalableMaskedAddFOp,
|
||||
ScalableMaskedAddIOp,
|
||||
ScalableMaskedDivFOp,
|
||||
ScalableMaskedMulFOp,
|
||||
ScalableMaskedMulIOp,
|
||||
ScalableMaskedSDivIOp,
|
||||
ScalableMaskedSubFOp,
|
||||
ScalableMaskedSubIOp,
|
||||
ScalableMaskedUDivIOp,
|
||||
SmmlaOp,
|
||||
UdotOp,
|
||||
UmmlaOp,
|
||||
ScalableMaskedAddIOp,
|
||||
ScalableMaskedAddFOp,
|
||||
ScalableMaskedSubIOp,
|
||||
ScalableMaskedSubFOp,
|
||||
ScalableMaskedMulIOp,
|
||||
ScalableMaskedMulFOp,
|
||||
ScalableMaskedSDivIOp,
|
||||
ScalableMaskedUDivIOp,
|
||||
ScalableMaskedDivFOp,
|
||||
ConvertToSvboolOp,
|
||||
ConvertFromSvboolOp,
|
||||
ZipX2Op,
|
||||
ZipX4Op>();
|
||||
ZipX4Op,
|
||||
SdotOp>();
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -271,3 +271,44 @@ func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[
|
||||
%0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
|
||||
return %0 : vector<[8]xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @arm_sve_dupq_lane(
|
||||
// CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[16]xi8>
|
||||
// CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[8]xi16>
|
||||
// CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[8]xf16>
|
||||
// CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[8]xbf16>
|
||||
// CHECK-SAME: %[[A4:[a-z0-9]+]]: vector<[4]xi32>
|
||||
// CHECK-SAME: %[[A5:[a-z0-9]+]]: vector<[4]xf32>
|
||||
// CHECK-SAME: %[[A6:[a-z0-9]+]]: vector<[2]xi64>
|
||||
// CHECK-SAME: %[[A7:[a-z0-9]+]]: vector<[2]xf64>
|
||||
// CHECK-SAME: -> !llvm.struct<(vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>)> {
|
||||
func.func @arm_sve_dupq_lane(
|
||||
%v16i8: vector<[16]xi8>, %v8i16: vector<[8]xi16>,
|
||||
%v8f16: vector<[8]xf16>, %v8bf16: vector<[8]xbf16>,
|
||||
%v4i32: vector<[4]xi32>, %v4f32: vector<[4]xf32>,
|
||||
%v2i64: vector<[2]xi64>, %v2f64: vector<[2]xf64>)
|
||||
-> (vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
|
||||
vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>) {
|
||||
// CHECK: "arm_sve.intr.dupq_lane"(%[[A0]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
|
||||
%0 = arm_sve.dupq_lane %v16i8[0] : vector<[16]xi8>
|
||||
// CHECK: "arm_sve.intr.dupq_lane"(%[[A1]]) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
|
||||
%1 = arm_sve.dupq_lane %v8i16[1] : vector<[8]xi16>
|
||||
// CHECK: "arm_sve.intr.dupq_lane"(%[[A2]]) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
|
||||
%2 = arm_sve.dupq_lane %v8f16[2] : vector<[8]xf16>
|
||||
// CHECK: "arm_sve.intr.dupq_lane"(%[[A3]]) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
|
||||
%3 = arm_sve.dupq_lane %v8bf16[3] : vector<[8]xbf16>
|
||||
// CHECK: "arm_sve.intr.dupq_lane"(%[[A4]]) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
|
||||
%4 = arm_sve.dupq_lane %v4i32[4] : vector<[4]xi32>
|
||||
// CHECK: "arm_sve.intr.dupq_lane"(%[[A5]]) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
|
||||
%5 = arm_sve.dupq_lane %v4f32[5] : vector<[4]xf32>
|
||||
// CHECK: "arm_sve.intr.dupq_lane"(%[[A6]]) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
|
||||
%6 = arm_sve.dupq_lane %v2i64[6] : vector<[2]xi64>
|
||||
// CHECK: "arm_sve.intr.dupq_lane"(%[[A7]]) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
|
||||
%7 = arm_sve.dupq_lane %v2f64[7] : vector<[2]xf64>
|
||||
|
||||
return %0, %1, %2, %3, %4, %5, %6, %7
|
||||
: vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
|
||||
vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>
|
||||
}
|
||||
|
||||
@@ -390,3 +390,35 @@ llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[
|
||||
"arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @arm_sve_dupq_lane
|
||||
// CHECK-SAME: <vscale x 16 x i8> %[[V0:[0-9]+]]
|
||||
// CHECK-SAME: <vscale x 8 x i16> %[[V1:[0-9]+]]
|
||||
// CHECK-SAME: <vscale x 8 x half> %[[V2:[0-9]+]]
|
||||
// CHECK-SAME: <vscale x 8 x bfloat> %[[V3:[0-9]+]]
|
||||
// CHECK-SAME: <vscale x 4 x i32> %[[V4:[0-9]+]]
|
||||
// CHECK-SAME: <vscale x 4 x float> %[[V5:[0-9]+]]
|
||||
// CHECK-SAME: <vscale x 2 x i64> %[[V6:[0-9]+]]
|
||||
// CHECK-SAME: <vscale x 2 x double> %[[V7:[0-9]+]]
|
||||
llvm.func @arm_sve_dupq_lane(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>,
|
||||
%nxv8f16: vector<[8]xf16>, %nxv8bf16: vector<[8]xbf16>,
|
||||
%nxv4i32: vector<[4]xi32>, %nxv4f32: vector<[4]xf32>,
|
||||
%nxv2i64: vector<[2]xi64>, %nxv2f64: vector<[2]xf64>) {
|
||||
// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sve.dupq.lane.nxv16i8(<vscale x 16 x i8> %[[V0]], i64 0)
|
||||
%0 = "arm_sve.intr.dupq_lane"(%nxv16i8) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
|
||||
// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sve.dupq.lane.nxv8i16(<vscale x 8 x i16> %[[V1]], i64 1)
|
||||
%1 = "arm_sve.intr.dupq_lane"(%nxv8i16) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
|
||||
// CHECK: call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %[[V2]], i64 2)
|
||||
%2 = "arm_sve.intr.dupq_lane"(%nxv8f16) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
|
||||
// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sve.dupq.lane.nxv8bf16(<vscale x 8 x bfloat> %[[V3]], i64 3)
|
||||
%3 = "arm_sve.intr.dupq_lane"(%nxv8bf16) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
|
||||
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.dupq.lane.nxv4i32(<vscale x 4 x i32> %[[V4]], i64 4)
|
||||
%4 = "arm_sve.intr.dupq_lane"(%nxv4i32) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
|
||||
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.dupq.lane.nxv4f32(<vscale x 4 x float> %[[V5]], i64 5)
|
||||
%5 = "arm_sve.intr.dupq_lane"(%nxv4f32) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
|
||||
// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %[[V6]], i64 6)
|
||||
%6 = "arm_sve.intr.dupq_lane"(%nxv2i64) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
|
||||
// CHECK: call <vscale x 2 x double> @llvm.aarch64.sve.dupq.lane.nxv2f64(<vscale x 2 x double> %[[V7]], i64 7)
|
||||
%7 = "arm_sve.intr.dupq_lane"(%nxv2f64) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user