[MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to bfmmla (#145064)
This commit is contained in:
@@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
|
|||||||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def BfmmlaOp : ArmSVE_IntrOp<"bfmmla", [Pure,
|
||||||
|
AllTypesMatch<["src1", "src2"]>,
|
||||||
|
AllTypesMatch<["acc", "res"]>,
|
||||||
|
]> {
|
||||||
|
let summary = "BFloat16 matrix multiply-accumulate";
|
||||||
|
let description = [{
|
||||||
|
BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
|
||||||
|
|
||||||
|
This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit
|
||||||
|
segment of the first source vector by the 4x2 BFloat16 matrix in the
|
||||||
|
corresponding segment of the second source vector, then accumulates
|
||||||
|
this intermediate result with the 2x2 Float32 matrix in the corresponding
|
||||||
|
segment of the accumulator vector, yielding the final 2x2 Float32
|
||||||
|
segment of the result.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://developer.arm.com/documentation/100987/0000
|
||||||
|
}];
|
||||||
|
// Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>)
|
||||||
|
let arguments = (ins
|
||||||
|
ScalableVectorOfLengthAndType<[4], [F32]>:$acc,
|
||||||
|
ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
|
||||||
|
ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
|
||||||
|
);
|
||||||
|
let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$res);
|
||||||
|
let assemblyFormat =
|
||||||
|
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
|
||||||
|
}
|
||||||
|
|
||||||
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
|
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
|
||||||
"expected corresponding svbool type widened to [16]xi1",
|
"expected corresponding svbool type widened to [16]xi1",
|
||||||
lhsArg, rhsArg,
|
lhsArg, rhsArg,
|
||||||
|
|||||||
@@ -220,7 +220,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
|
|||||||
void mlir::configureArmSVELegalizeForExportTarget(
|
void mlir::configureArmSVELegalizeForExportTarget(
|
||||||
LLVMConversionTarget &target) {
|
LLVMConversionTarget &target) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
target.addLegalOp<ConvertFromSvboolIntrOp,
|
target.addLegalOp<BfmmlaOp,
|
||||||
|
ConvertFromSvboolIntrOp,
|
||||||
ConvertToSvboolIntrOp,
|
ConvertToSvboolIntrOp,
|
||||||
DupQLaneIntrOp,
|
DupQLaneIntrOp,
|
||||||
PselIntrOp,
|
PselIntrOp,
|
||||||
|
|||||||
@@ -72,3 +72,63 @@ func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) {
|
|||||||
arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
|
arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @bfmmla_invalid_element_type_lhs_rhs(%acc: vector<[4]xf32>,
|
||||||
|
%lhs: vector<[8]xf16>,
|
||||||
|
%rhs: vector<[8]xf16>) -> vector<[4]xf32> {
|
||||||
|
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<[8]xf16>'}}
|
||||||
|
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xf16> to vector<[4]xf32>
|
||||||
|
return %0 : vector<[4]xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @bfmmla_invalid_dimension_lhs_rhs(%acc: vector<[4]xf32>,
|
||||||
|
%lhs: vector<[4]xbf16>,
|
||||||
|
%rhs: vector<[4]xbf16>) -> vector<[4]xf32> {
|
||||||
|
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<[4]xbf16>}}
|
||||||
|
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[4]xbf16> to vector<[4]xf32>
|
||||||
|
return %0 : vector<[4]xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @bfmmla_fixed_dimension_lhs_rhs(%acc: vector<[4]xf32>,
|
||||||
|
%lhs: vector<8xbf16>,
|
||||||
|
%rhs: vector<8xbf16>) -> vector<[4]xf32> {
|
||||||
|
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<8xbf16>}}
|
||||||
|
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<8xbf16> to vector<[4]xf32>
|
||||||
|
return %0 : vector<[4]xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @bfmmla_invalid_element_type_acc(%acc: vector<[4]xi32>,
|
||||||
|
%lhs: vector<[8]xbf16>,
|
||||||
|
%rhs: vector<[8]xbf16>) -> vector<[4]xi32> {
|
||||||
|
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<[4]xi32>'}}
|
||||||
|
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<[4]xi32>
|
||||||
|
return %0 : vector<[4]xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @bfmmla_invalid_dimension_acc(%acc: vector<[8]xf32>,
|
||||||
|
%lhs: vector<[8]xbf16>,
|
||||||
|
%rhs: vector<[8]xbf16>) -> vector<[8]xf32> {
|
||||||
|
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<[8]xf32>'}}
|
||||||
|
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<[8]xf32>
|
||||||
|
return %0 : vector<[8]xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @bfmmla_fixed_dimension_acc(%acc: vector<4xf32>,
|
||||||
|
%lhs: vector<[8]xbf16>,
|
||||||
|
%rhs: vector<[8]xbf16>) -> vector<4xf32> {
|
||||||
|
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<4xf32>'}}
|
||||||
|
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<4xf32>
|
||||||
|
return %0 : vector<4xf32>
|
||||||
|
}
|
||||||
|
|||||||
@@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
|
||||||
|
%b: vector<[8]xbf16>,
|
||||||
|
%c: vector<[4]xf32>) -> vector<[4]xf32> {
|
||||||
|
// CHECK: arm_sve.intr.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
|
||||||
|
%0 = arm_sve.intr.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
|
||||||
|
return %0 : vector<[4]xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
|
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
|
||||||
%b: vector<[4]xi32>,
|
%b: vector<[4]xi32>,
|
||||||
%c: vector<[4]xi32>,
|
%c: vector<[4]xi32>,
|
||||||
|
|||||||
@@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
|
|||||||
llvm.return %0 : vector<[4]xi32>
|
llvm.return %0 : vector<[4]xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_bfmmla
|
||||||
|
llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>,
|
||||||
|
%arg1: vector<[8]xbf16>,
|
||||||
|
%arg2: vector<[4]xf32>)
|
||||||
|
-> vector<[4]xf32> {
|
||||||
|
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.bfmmla(<vscale x 4 x float>
|
||||||
|
%0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) :
|
||||||
|
(vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>)
|
||||||
|
-> vector<[4]xf32>
|
||||||
|
llvm.return %0 : vector<[4]xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
|
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
|
||||||
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
|
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
|
||||||
%arg1: vector<[4]xi32>,
|
%arg1: vector<[4]xi32>,
|
||||||
|
|||||||
Reference in New Issue
Block a user