[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (#135634)

This commit is contained in:
Momchil Velikov
2025-05-16 17:12:35 +01:00
committed by GitHub
parent 7fe1b43122
commit e9c9c33fa4
5 changed files with 96 additions and 38 deletions

View File

@@ -147,11 +147,9 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
}
def SdotOp : ArmSVE_Op<"sdot",
[Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
def SdotOp : ArmSVE_Op<"sdot", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Vector-vector dot product and accumulate op";
let description = [{
SDOT: Signed integer addition of dot product.
@@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
def SmmlaOp : ArmSVE_Op<"smmla",
[Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
def SmmlaOp : ArmSVE_Op<"smmla", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Matrix-matrix multiply and accumulate op";
let description = [{
SMMLA: Signed integer matrix multiply-accumulate.
@@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
def UdotOp : ArmSVE_Op<"udot",
[Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
def UdotOp : ArmSVE_Op<"udot", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Vector-vector dot product and accumulate op";
let description = [{
UDOT: Unsigned integer addition of dot product.
@@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
def UmmlaOp : ArmSVE_Op<"ummla",
[Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
def UmmlaOp : ArmSVE_Op<"ummla", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Matrix-matrix multiply and accumulate op";
let description = [{
UMMLA: Unsigned integer matrix multiply-accumulate.
@@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Matrix-matrix multiply and accumulate op";
let description = [{
USMMLA: Unsigned by signed integer matrix multiply-accumulate.
The unsigned by signed integer matrix multiply-accumulate operation
multiplies the 2×8 matrix of unsigned 8-bit integer values held
the first source vector by the 8×2 matrix of signed 8-bit integer
values in the second source vector. The resulting 2×2 widened 32-bit
integer matrix product is then added to the 32-bit integer matrix
accumulator.
Source:
https://developer.arm.com/documentation/100987/0000
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
ScalableVectorOfLengthAndType<[16], [I8]>:$src2
);
let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
"VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
[Pure, SvboolTypeConstraint<"result", "source">]>
{
[Pure,
SvboolTypeConstraint<"result", "source">]> {
let summary = "Convert a svbool type to a SVE predicate type";
let description = [{
Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
@@ -313,8 +333,8 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
}
def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
[Pure, SvboolTypeConstraint<"source", "result">]>
{
[Pure,
SvboolTypeConstraint<"source", "result">]> {
let summary = "Convert a SVE predicate type to a svbool type";
let description = [{
Converts SVE predicate types (or vectors of predicate types, e.g.
@@ -356,10 +376,9 @@ def ZipInputVectorType : AnyTypeOf<[
Scalable1DVectorOfLength<16, [I8]>],
"an SVE vector with element size <= 64-bit">;
def ZipX2Op : ArmSVE_Op<"zip.x2", [
Pure,
AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
> {
def ZipX2Op : ArmSVE_Op<"zip.x2", [Pure,
AllTypesMatch<["sourceV1", "sourceV2",
"resultV1", "resultV2"]>]> {
let summary = "Multi-vector two-way zip op";
let description = [{
@@ -400,12 +419,11 @@ def ZipX2Op : ArmSVE_Op<"zip.x2", [
}];
}
def ZipX4Op : ArmSVE_Op<"zip.x4", [
Pure,
AllTypesMatch<[
"sourceV1", "sourceV2", "sourceV3", "sourceV4",
"resultV1", "resultV2", "resultV3", "resultV4"]>]
> {
def ZipX4Op
: ArmSVE_Op<"zip.x4",
[Pure,
AllTypesMatch<["sourceV1", "sourceV2", "sourceV3", "sourceV4",
"resultV1", "resultV2", "resultV3", "resultV4"]>]> {
let summary = "Multi-vector four-way zip op";
let description = [{
@@ -463,10 +481,7 @@ def ZipX4Op : ArmSVE_Op<"zip.x4", [
}];
}
def PselOp : ArmSVE_Op<"psel", [
Pure,
AllTypesMatch<["p1", "result"]>,
]> {
def PselOp : ArmSVE_Op<"psel", [Pure, AllTypesMatch<["p1", "result"]>]> {
let summary = "Predicate select";
let description = [{
@@ -571,6 +586,10 @@ def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def UsmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

View File

@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
@@ -206,6 +207,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
UsmmlaOpLowering,
ZipX2OpLowering,
ZipX4OpLowering,
SdotOpLowering>(converter);
@@ -234,6 +236,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
UsmmlaIntrOp,
WhileLTIntrOp,
ZipX2IntrOp,
ZipX4IntrOp,
@@ -254,6 +257,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
UsmmlaOp,
ZipX2Op,
ZipX4Op,
SdotOp>();

View File

@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
// -----
func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: arm_sve.intr.usmmla
%0 = arm_sve.usmmla %c, %a, %b :
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
// -----
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,

View File

@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
// -----
func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>) -> vector<[4]xi32> {
// CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
%0 = arm_sve.usmmla %c, %a, %b :
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
// -----
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,

View File

@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
%arg1: vector<[16]xi8>,
%arg2: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
%0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
(vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
-> vector<[4]xi32>
llvm.return %0 : vector<[4]xi32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,