From 63a4cae56cf896abd12ecf8dbd50f7f5fb9549e1 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Fri, 16 May 2025 16:47:07 +0100 Subject: [PATCH] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to `svdupq_lane` (#135633) --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 68 ++++++++++++++- .../Transforms/LegalizeForLLVMExport.cpp | 86 ++++++++++--------- .../Dialect/ArmSVE/legalize-for-llvm.mlir | 41 +++++++++ mlir/test/Target/LLVMIR/arm-sve.mlir | 32 +++++++ 4 files changed, 185 insertions(+), 42 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index cdcf4d8752e8..3a990f8464ef 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -61,6 +61,13 @@ class Scalable1DVectorOfLength elementTypes> : ShapedCont "a 1-D scalable vector with length " # length, "::mlir::VectorType">; +def SVEVector : AnyTypeOf<[ + Scalable1DVectorOfLength<2, [I64, F64]>, + Scalable1DVectorOfLength<4, [I32, F32]>, + Scalable1DVectorOfLength<8, [I16, F16, BF16]>, + Scalable1DVectorOfLength<16, [I8]>], + "an SVE vector with element size <= 64-bit">; + //===----------------------------------------------------------------------===// // ArmSVE op definitions //===----------------------------------------------------------------------===// @@ -72,14 +79,22 @@ class ArmSVE_IntrOp traits = [], list overloadedOperands = [], list overloadedResults = [], - int numResults = 1> : + int numResults = 1, + list immArgPositions = [], + list immArgAttrNames = []> : LLVM_IntrOpBase overloadedResults=*/overloadedResults, /*list overloadedOperands=*/overloadedOperands, /*list traits=*/traits, - /*int numResults=*/numResults>; + /*int numResults=*/numResults, + /*bit requiresAccessGroup=*/0, + /*bit requiresAliasAnalysis=*/0, + /*bit requiresFastmath=*/0, + /*bit requiresOpBundles=*/0, + /*list immArgPositions=*/immArgPositions, + /*list immArgAttrNames=*/immArgAttrNames>; class ArmSVE_IntrBinaryOverloadedOp traits = []>: @@ -509,6 +524,45 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned", def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">; +def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> { + let summary = "Broadcast indexed 128-bit segment to vector"; + + let description = [{ + This operation fills each 128-bit segment of a vector with the elements + from the indexed 128-bit segment of the source vector. If the VL is + 128 bits the operation is a NOP. If the index exceeds the number of + 128-bit segments in a vector the result is an all-zeroes vector. + + Example: + ```mlir + // VL == 256 + // %X = [A B C D x x x x] + %Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32> + // Y = [A B C D A B C D] + + // %U = [x x x x x x x x A B C D E F G H] + %V = arm_sve.dupq_lane %U[1] : vector<[8]xf16> + // %V = [A B C D E F H A B C D E F H] + ``` + + Note: The semantics of the operation match those of the `svdupq_lane` instrinsics. + [Source](https://developer.arm.com/architectures/instruction-sets/intrinsics/#q=svdupq_lane) + }]; + + let arguments = (ins SVEVector:$src, + I64Attr:$lane); + let results = (outs SVEVector:$dst); + + let builders = [ + OpBuilder<(ins "Value":$src, "int64_t":$lane), [{ + build($_builder, $_state, src.getType(), src, lane); + }]>]; + + let assemblyFormat = [{ + $src `[` $lane `]` attr-dict `:` type($dst) + }]; +} + def UmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"ummla">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; @@ -610,4 +664,14 @@ def WhileLTIntrOp : /*overloadedResults=*/[0]>, Arguments<(ins I64:$base, I64:$n)>; +def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane", + /*traits=*/[], + /*overloadedOperands=*/[0], + /*overloadedResults=*/[], + /*numResults=*/1, + /*immArgPositions*/[1], + /*immArgAttrNames*/["lane"]>, + Arguments<(ins Arg, "v">:$v, + Arg:$lane)>; + #endif // ARMSVE_OPS diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index 2bdb640699d0..536373b82c67 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -24,6 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern; using SmmlaOpLowering = OneToOneConvertToLLVMPattern; using UdotOpLowering = OneToOneConvertToLLVMPattern; using UmmlaOpLowering = OneToOneConvertToLLVMPattern; +using DupQLaneLowering = + OneToOneConvertToLLVMPattern; using ScalableMaskedAddIOpLowering = OneToOneConvertToLLVMPattern; @@ -188,24 +190,25 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( // Populate conversion patterns // clang-format off - patterns.add(converter); + SdotOpLowering>(converter); // Add vector.create_mask conversion with a high benefit as it produces much // nicer code than the generic lowering. patterns.add(converter, /*benefit=*/4096); @@ -215,41 +218,44 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( void mlir::configureArmSVELegalizeForExportTarget( LLVMConversionTarget &target) { // clang-format off - target.addLegalOp(); - target.addIllegalOp(); + target.addIllegalOp(); + ZipX4Op, + SdotOp>(); // clang-format on } diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir index bdb69a95a52d..650b3e72d4ec 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -271,3 +271,44 @@ func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[ %0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1> return %0 : vector<[8]xi1> } + +// ----- + +// CHECK-LABEL: @arm_sve_dupq_lane( +// CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[16]xi8> +// CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[8]xi16> +// CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[8]xf16> +// CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[8]xbf16> +// CHECK-SAME: %[[A4:[a-z0-9]+]]: vector<[4]xi32> +// CHECK-SAME: %[[A5:[a-z0-9]+]]: vector<[4]xf32> +// CHECK-SAME: %[[A6:[a-z0-9]+]]: vector<[2]xi64> +// CHECK-SAME: %[[A7:[a-z0-9]+]]: vector<[2]xf64> +// CHECK-SAME: -> !llvm.struct<(vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>)> { +func.func @arm_sve_dupq_lane( + %v16i8: vector<[16]xi8>, %v8i16: vector<[8]xi16>, + %v8f16: vector<[8]xf16>, %v8bf16: vector<[8]xbf16>, + %v4i32: vector<[4]xi32>, %v4f32: vector<[4]xf32>, + %v2i64: vector<[2]xi64>, %v2f64: vector<[2]xf64>) + -> (vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, + vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>) { +// CHECK: "arm_sve.intr.dupq_lane"(%[[A0]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + %0 = arm_sve.dupq_lane %v16i8[0] : vector<[16]xi8> +// CHECK: "arm_sve.intr.dupq_lane"(%[[A1]]) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16> + %1 = arm_sve.dupq_lane %v8i16[1] : vector<[8]xi16> +// CHECK: "arm_sve.intr.dupq_lane"(%[[A2]]) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16> + %2 = arm_sve.dupq_lane %v8f16[2] : vector<[8]xf16> +// CHECK: "arm_sve.intr.dupq_lane"(%[[A3]]) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16> + %3 = arm_sve.dupq_lane %v8bf16[3] : vector<[8]xbf16> +// CHECK: "arm_sve.intr.dupq_lane"(%[[A4]]) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32> + %4 = arm_sve.dupq_lane %v4i32[4] : vector<[4]xi32> +// CHECK: "arm_sve.intr.dupq_lane"(%[[A5]]) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32> + %5 = arm_sve.dupq_lane %v4f32[5] : vector<[4]xf32> +// CHECK: "arm_sve.intr.dupq_lane"(%[[A6]]) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64> + %6 = arm_sve.dupq_lane %v2i64[6] : vector<[2]xi64> +// CHECK: "arm_sve.intr.dupq_lane"(%[[A7]]) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64> + %7 = arm_sve.dupq_lane %v2f64[7] : vector<[2]xf64> + + return %0, %1, %2, %3, %4, %5, %6, %7 + : vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, + vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64> +} diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir index ed5a1fc7ba2e..14c68b21fd86 100644 --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -390,3 +390,35 @@ llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[ "arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1> llvm.return } + +// CHECK-LABEL: @arm_sve_dupq_lane +// CHECK-SAME: %[[V0:[0-9]+]] +// CHECK-SAME: %[[V1:[0-9]+]] +// CHECK-SAME: %[[V2:[0-9]+]] +// CHECK-SAME: %[[V3:[0-9]+]] +// CHECK-SAME: %[[V4:[0-9]+]] +// CHECK-SAME: %[[V5:[0-9]+]] +// CHECK-SAME: %[[V6:[0-9]+]] +// CHECK-SAME: %[[V7:[0-9]+]] +llvm.func @arm_sve_dupq_lane(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>, + %nxv8f16: vector<[8]xf16>, %nxv8bf16: vector<[8]xbf16>, + %nxv4i32: vector<[4]xi32>, %nxv4f32: vector<[4]xf32>, + %nxv2i64: vector<[2]xi64>, %nxv2f64: vector<[2]xf64>) { + // CHECK: call @llvm.aarch64.sve.dupq.lane.nxv16i8( %[[V0]], i64 0) + %0 = "arm_sve.intr.dupq_lane"(%nxv16i8) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + // CHECK: call @llvm.aarch64.sve.dupq.lane.nxv8i16( %[[V1]], i64 1) + %1 = "arm_sve.intr.dupq_lane"(%nxv8i16) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16> + // CHECK: call @llvm.aarch64.sve.dupq.lane.nxv8f16( %[[V2]], i64 2) + %2 = "arm_sve.intr.dupq_lane"(%nxv8f16) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16> + // CHECK: call @llvm.aarch64.sve.dupq.lane.nxv8bf16( %[[V3]], i64 3) + %3 = "arm_sve.intr.dupq_lane"(%nxv8bf16) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16> + // CHECK: call @llvm.aarch64.sve.dupq.lane.nxv4i32( %[[V4]], i64 4) + %4 = "arm_sve.intr.dupq_lane"(%nxv4i32) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32> + // CHECK: call @llvm.aarch64.sve.dupq.lane.nxv4f32( %[[V5]], i64 5) + %5 = "arm_sve.intr.dupq_lane"(%nxv4f32) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32> + // CHECK: call @llvm.aarch64.sve.dupq.lane.nxv2i64( %[[V6]], i64 6) + %6 = "arm_sve.intr.dupq_lane"(%nxv2i64) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64> + // CHECK: call @llvm.aarch64.sve.dupq.lane.nxv2f64( %[[V7]], i64 7) + %7 = "arm_sve.intr.dupq_lane"(%nxv2f64) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64> + llvm.return +}