[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (#135634)
This commit is contained in:
@@ -147,11 +147,9 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
|
||||
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
|
||||
}
|
||||
|
||||
def SdotOp : ArmSVE_Op<"sdot",
|
||||
[Pure,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>,
|
||||
]> {
|
||||
def SdotOp : ArmSVE_Op<"sdot", [Pure,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>]> {
|
||||
let summary = "Vector-vector dot product and accumulate op";
|
||||
let description = [{
|
||||
SDOT: Signed integer addition of dot product.
|
||||
@@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot",
|
||||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
||||
}
|
||||
|
||||
def SmmlaOp : ArmSVE_Op<"smmla",
|
||||
[Pure,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>,
|
||||
]> {
|
||||
def SmmlaOp : ArmSVE_Op<"smmla", [Pure,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>]> {
|
||||
let summary = "Matrix-matrix multiply and accumulate op";
|
||||
let description = [{
|
||||
SMMLA: Signed integer matrix multiply-accumulate.
|
||||
@@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla",
|
||||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
||||
}
|
||||
|
||||
def UdotOp : ArmSVE_Op<"udot",
|
||||
[Pure,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>,
|
||||
]> {
|
||||
def UdotOp : ArmSVE_Op<"udot", [Pure,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>]> {
|
||||
let summary = "Vector-vector dot product and accumulate op";
|
||||
let description = [{
|
||||
UDOT: Unsigned integer addition of dot product.
|
||||
@@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot",
|
||||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
||||
}
|
||||
|
||||
def UmmlaOp : ArmSVE_Op<"ummla",
|
||||
[Pure,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>,
|
||||
]> {
|
||||
def UmmlaOp : ArmSVE_Op<"ummla", [Pure,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>]> {
|
||||
let summary = "Matrix-matrix multiply and accumulate op";
|
||||
let description = [{
|
||||
UMMLA: Unsigned integer matrix multiply-accumulate.
|
||||
@@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla",
|
||||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
||||
}
|
||||
|
||||
def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>]> {
|
||||
let summary = "Matrix-matrix multiply and accumulate op";
|
||||
let description = [{
|
||||
USMMLA: Unsigned by signed integer matrix multiply-accumulate.
|
||||
|
||||
The unsigned by signed integer matrix multiply-accumulate operation
|
||||
multiplies the 2×8 matrix of unsigned 8-bit integer values held
|
||||
the first source vector by the 8×2 matrix of signed 8-bit integer
|
||||
values in the second source vector. The resulting 2×2 widened 32-bit
|
||||
integer matrix product is then added to the 32-bit integer matrix
|
||||
accumulator.
|
||||
|
||||
Source:
|
||||
https://developer.arm.com/documentation/100987/0000
|
||||
}];
|
||||
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
|
||||
let arguments = (ins
|
||||
ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
|
||||
ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
|
||||
ScalableVectorOfLengthAndType<[16], [I8]>:$src2
|
||||
);
|
||||
let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
|
||||
let assemblyFormat =
|
||||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
||||
}
|
||||
|
||||
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
|
||||
"expected corresponding svbool type widened to [16]xi1",
|
||||
lhsArg, rhsArg,
|
||||
"VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
|
||||
|
||||
def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
|
||||
[Pure, SvboolTypeConstraint<"result", "source">]>
|
||||
{
|
||||
[Pure,
|
||||
SvboolTypeConstraint<"result", "source">]> {
|
||||
let summary = "Convert a svbool type to a SVE predicate type";
|
||||
let description = [{
|
||||
Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
|
||||
@@ -313,8 +333,8 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
|
||||
}
|
||||
|
||||
def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
|
||||
[Pure, SvboolTypeConstraint<"source", "result">]>
|
||||
{
|
||||
[Pure,
|
||||
SvboolTypeConstraint<"source", "result">]> {
|
||||
let summary = "Convert a SVE predicate type to a svbool type";
|
||||
let description = [{
|
||||
Converts SVE predicate types (or vectors of predicate types, e.g.
|
||||
@@ -356,10 +376,9 @@ def ZipInputVectorType : AnyTypeOf<[
|
||||
Scalable1DVectorOfLength<16, [I8]>],
|
||||
"an SVE vector with element size <= 64-bit">;
|
||||
|
||||
def ZipX2Op : ArmSVE_Op<"zip.x2", [
|
||||
Pure,
|
||||
AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
|
||||
> {
|
||||
def ZipX2Op : ArmSVE_Op<"zip.x2", [Pure,
|
||||
AllTypesMatch<["sourceV1", "sourceV2",
|
||||
"resultV1", "resultV2"]>]> {
|
||||
let summary = "Multi-vector two-way zip op";
|
||||
|
||||
let description = [{
|
||||
@@ -400,12 +419,11 @@ def ZipX2Op : ArmSVE_Op<"zip.x2", [
|
||||
}];
|
||||
}
|
||||
|
||||
def ZipX4Op : ArmSVE_Op<"zip.x4", [
|
||||
Pure,
|
||||
AllTypesMatch<[
|
||||
"sourceV1", "sourceV2", "sourceV3", "sourceV4",
|
||||
"resultV1", "resultV2", "resultV3", "resultV4"]>]
|
||||
> {
|
||||
def ZipX4Op
|
||||
: ArmSVE_Op<"zip.x4",
|
||||
[Pure,
|
||||
AllTypesMatch<["sourceV1", "sourceV2", "sourceV3", "sourceV4",
|
||||
"resultV1", "resultV2", "resultV3", "resultV4"]>]> {
|
||||
let summary = "Multi-vector four-way zip op";
|
||||
|
||||
let description = [{
|
||||
@@ -463,10 +481,7 @@ def ZipX4Op : ArmSVE_Op<"zip.x4", [
|
||||
}];
|
||||
}
|
||||
|
||||
def PselOp : ArmSVE_Op<"psel", [
|
||||
Pure,
|
||||
AllTypesMatch<["p1", "result"]>,
|
||||
]> {
|
||||
def PselOp : ArmSVE_Op<"psel", [Pure, AllTypesMatch<["p1", "result"]>]> {
|
||||
let summary = "Predicate select";
|
||||
|
||||
let description = [{
|
||||
@@ -571,6 +586,10 @@ def SmmlaIntrOp :
|
||||
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
|
||||
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
|
||||
|
||||
def UsmmlaIntrOp :
|
||||
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
|
||||
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
|
||||
|
||||
def SdotIntrOp :
|
||||
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
|
||||
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
|
||||
|
||||
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
|
||||
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
|
||||
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
|
||||
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
|
||||
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
|
||||
using DupQLaneLowering =
|
||||
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
|
||||
using ScalableMaskedAddIOpLowering =
|
||||
@@ -206,6 +207,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
|
||||
SmmlaOpLowering,
|
||||
UdotOpLowering,
|
||||
UmmlaOpLowering,
|
||||
UsmmlaOpLowering,
|
||||
ZipX2OpLowering,
|
||||
ZipX4OpLowering,
|
||||
SdotOpLowering>(converter);
|
||||
@@ -234,6 +236,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
|
||||
SmmlaIntrOp,
|
||||
UdotIntrOp,
|
||||
UmmlaIntrOp,
|
||||
UsmmlaIntrOp,
|
||||
WhileLTIntrOp,
|
||||
ZipX2IntrOp,
|
||||
ZipX4IntrOp,
|
||||
@@ -254,6 +257,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
|
||||
SmmlaOp,
|
||||
UdotOp,
|
||||
UmmlaOp,
|
||||
UsmmlaOp,
|
||||
ZipX2Op,
|
||||
ZipX4Op,
|
||||
SdotOp>();
|
||||
|
||||
@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
|
||||
%b: vector<[16]xi8>,
|
||||
%c: vector<[4]xi32>)
|
||||
-> vector<[4]xi32> {
|
||||
// CHECK: arm_sve.intr.usmmla
|
||||
%0 = arm_sve.usmmla %c, %a, %b :
|
||||
vector<[16]xi8> to vector<[4]xi32>
|
||||
return %0 : vector<[4]xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
|
||||
%b: vector<[4]xi32>,
|
||||
%c: vector<[4]xi32>,
|
||||
|
||||
@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
|
||||
%b: vector<[16]xi8>,
|
||||
%c: vector<[4]xi32>) -> vector<[4]xi32> {
|
||||
// CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
|
||||
%0 = arm_sve.usmmla %c, %a, %b :
|
||||
vector<[16]xi8> to vector<[4]xi32>
|
||||
return %0 : vector<[4]xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
|
||||
%b: vector<[4]xi32>,
|
||||
%c: vector<[4]xi32>,
|
||||
|
||||
@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
|
||||
llvm.return %0 : vector<[4]xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
|
||||
llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
|
||||
%arg1: vector<[16]xi8>,
|
||||
%arg2: vector<[4]xi32>)
|
||||
-> vector<[4]xi32> {
|
||||
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
|
||||
%0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
|
||||
(vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
|
||||
-> vector<[4]xi32>
|
||||
llvm.return %0 : vector<[4]xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
|
||||
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
|
||||
%arg1: vector<[4]xi32>,
|
||||
|
||||
Reference in New Issue
Block a user