[MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (#111197)

This patch is intended to resolve #109481 and improve the usability of
the AMX dialect.

In LLVM IR, AMX intrinsics use `x86_amx` which is one of the primitive
types. This type is supposed to be used for AMX intrinsic calls and no
other operations. AMX dialect of MLIR uses regular 2D vector types,
which are then lowered to arrays of vectors in the LLVMIR dialect. This
creates an inconsistency in the types used in the LLVMIR dialect and
LLVMIR. Translation of AMX intrinsic calls to LLVM IR doesn't require
result types to match and that is where tile loads and mul operation
results get `x86_amx` type. This works in very simple cases when mul and
tile store operations directly consume the result of another AMX
intrinsic call, but it doesn't work when an argument is a block argument
(phi node).

In addition to translation problems, this inconsistency between types
used in MLIR and LLVM IR makes MLIR verification and transformation
quite problematic. Both `amx.tileload` and `vector::transfer_read` can
load values of the same type, but only one of them can be used in AMX
operations. In general, by looking at a type of value, we cannot
determine if it can only be used for AMX operations or contrary can be
used in other operations but AMX ones.

To remove this inconsistency and make AMX operations more explicit in
their limitations, I propose to add `LLVMX86AMXType` type to the LLVMIR
dialect to match `x86_amx` type in LLVM IR, and introduce
`amx::TileType` to be used by AMX operations in MLIR. This resolves
translation problems for AMX usage with phi nodes and provides proper
type verification in MLIR for AMX operations.

P.S. This patch also adds missing FP16 support. It's trivial but
unrelated to type system changes, so let me know if I should submit it
separately.

---------

Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
This commit is contained in:
Ilya Enkovich
2024-11-06 08:30:53 -06:00
committed by GitHub
parent 44ab3805b5
commit 2f743ac52e
14 changed files with 329 additions and 140 deletions

View File

@@ -30,6 +30,8 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypes.td"
//===----------------------------------------------------------------------===//
// AMX dialect definition.
@@ -55,8 +57,77 @@ def AMX_Dialect : Dialect {
For details, see the Intel documentation:
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
}];
let useDefaultTypePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
// AMX Tile definition.
//===----------------------------------------------------------------------===//
class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
: TypeDef<AMX_Dialect, typeName, traits> {
let mnemonic = typeMnemonic;
}
def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
let cppFunctionName = "isValidTileTypeElementType";
}
def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
let summary = "AMX 2D tile to be used by AMX opertaions.";
let description = [{
This type is used to represent values in AMX tile registers. All AMX operations
work on AMX tiles and these tiles cannot be used in other operations directly.
LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
element type for IR verification and lowering to LLVMIR dialect.
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
AMX_TileTypeElementType:$elementType
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType), [{
return $_get(elementType.getContext(), shape, elementType);
}]>
];
let extraClassDeclaration = [{
/// Returns if this type is ranked (always true).
bool hasRank() const { return true; }
/// Clone this tile type with the given shape and element type. If the
/// provided shape is `std::nullopt`, the current shape of the type is used.
TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return get(shape.value_or(getShape()), elementType);
}
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
class AMXTileOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
"::mlir::amx::TileType">;
def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
def AMXTileF32 : AMXTileOf<[F32]>;
def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
def AMXTileI32 : AMXTileOf<[I32]>;
def AMXTileI8 : AMXTileOf<[I8]>;
//===----------------------------------------------------------------------===//
// AMX Op and IntrOp definitions.
//===----------------------------------------------------------------------===//
@@ -88,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
Example:
```mlir
%0 = amx.tile_zero : vector<16x16xbf16>
%0 = amx.tile_zero : !amx.tile<16x16xbf16>
```
}];
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let results = (outs AnyAMXTile:$res);
let extraClassDeclaration = [{
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "attr-dict `:` type($res)";
let assemblyFormat = "attr-dict `:` qualified(type($res))";
let hasVerifier = 1;
}
@@ -117,23 +187,22 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
Example:
```mlir
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let results = (outs AnyAMXTile:$res);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
"type($base) `into` type($res)";
"type($base) `into` qualified(type($res))";
let hasVerifier = 1;
}
@@ -148,22 +217,22 @@ def TileStoreOp : AMX_Op<"tile_store"> {
Example:
```mlir
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
AnyAMXTile:$val);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getVal().getType());
TileType getTileType() {
return ::llvm::cast<TileType>(getVal().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
"type($base) `,` type($val)";
"type($base) `,` qualified(type($val))";
let hasVerifier = 1;
}
@@ -184,26 +253,27 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
```mlir
%0 = amx.tile_mulf %a, %b, %c
: vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
: !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
```
}];
let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
VectorOfRankAndType<[2], [F32, BF16]>:$acc);
let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
let arguments = (ins AMXTileF16OrBF16:$lhs,
AMXTileF16OrBF16:$rhs,
AMXTileF32:$acc);
let results = (outs AMXTileF32:$res);
let extraClassDeclaration = [{
VectorType getLhsVectorType() {
return ::llvm::cast<VectorType>(getLhs().getType());
TileType getLhsTileType() {
return ::llvm::cast<TileType>(getLhs().getType());
}
VectorType getRhsVectorType() {
return ::llvm::cast<VectorType>(getRhs().getType());
TileType getRhsTileType() {
return ::llvm::cast<TileType>(getRhs().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
"qualified(type($lhs)) `,` qualified(type($rhs))"
" `,` qualified(type($acc)) ";
let hasVerifier = 1;
}
@@ -223,29 +293,29 @@ def TileMulIOp : AMX_Op<"tile_muli", [
```mlir
%0 = amx.tile_muli %a zext, %b zext, %c
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
```
}];
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
VectorOfRankAndType<[2], [I32, I8]>:$acc,
let arguments = (ins AMXTileI8:$lhs,
AMXTileI8:$rhs,
AMXTileI32:$acc,
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
let results = (outs AMXTileI32:$res);
let extraClassDeclaration = [{
VectorType getLhsVectorType() {
return ::llvm::cast<VectorType>(getLhs().getType());
TileType getLhsTileType() {
return ::llvm::cast<TileType>(getLhs().getType());
}
VectorType getRhsVectorType() {
return ::llvm::cast<VectorType>(getRhs().getType());
TileType getRhsTileType() {
return ::llvm::cast<TileType>(getRhs().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
let hasVerifier = 1;
}
@@ -286,6 +356,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
// Dot product of f16 tiles into f32 tile.
def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
Arguments<(ins AnyInteger,

View File

@@ -21,6 +21,9 @@
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/AMX/AMXTypes.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.h.inc"

View File

@@ -14,16 +14,20 @@ namespace mlir {
class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;
class DialectRegistry;
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
/// intrinsics.
void populateAMXLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Configure the target to support lowering AMX ops to ops that map to LLVM
/// intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
/// Register LLVM conversion interface for AMX dialect.
void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
} // namespace mlir
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H

View File

@@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
}];
}
//===----------------------------------------------------------------------===//
// LLVMX86AMXType
//===----------------------------------------------------------------------===//
def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
let summary = "LLVM x86_amx type.";
let description = [{
The x86_amx type represents a value held in an AMX tile register on an x86
machine. Can only be used in AMX intrinsics calls.
}];
}
#endif // LLVMTYPES_TD

View File

@@ -24,6 +24,7 @@
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
registerConvertNVVMToLLVMInterface(registry);
registerConvertOpenMPToLLVMInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
registerConvertAMXToLLVMInterface(registry);
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);

View File

@@ -13,14 +13,22 @@
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
void amx::AMXDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AMX/AMX.cpp.inc"
@@ -28,7 +36,7 @@ void amx::AMXDialect::initialize() {
}
/// Verify that AMX supports the implied tile shape.
static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
const unsigned kMaxRows = 16;
const unsigned kBitsPerRow = 64 * 8;
unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
@@ -40,8 +48,8 @@ static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
}
/// Verify that AMX supports the multiplication.
static LogicalResult verifyMultShape(Operation *op, VectorType atp,
VectorType btp, VectorType ctp,
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
amx::TileType btp, amx::TileType ctp,
unsigned scale) {
unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
@@ -53,27 +61,27 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
}
LogicalResult amx::TileZeroOp::verify() {
return verifyTileSize(*this, getVectorType());
return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileLoadOp::verify() {
unsigned rank = getMemRefType().getRank();
if (getIndices().size() != rank)
return emitOpError("requires ") << rank << " indices";
return verifyTileSize(*this, getVectorType());
return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileStoreOp::verify() {
unsigned rank = getMemRefType().getRank();
if (getIndices().size() != rank)
return emitOpError("requires ") << rank << " indices";
return verifyTileSize(*this, getVectorType());
return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileMulFOp::verify() {
VectorType aType = getLhsVectorType();
VectorType bType = getRhsVectorType();
VectorType cType = getVectorType();
amx::TileType aType = getLhsTileType();
amx::TileType bType = getRhsTileType();
amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
@@ -82,15 +90,15 @@ LogicalResult amx::TileMulFOp::verify() {
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
return emitOpError("unsupported type combination");
return success();
}
LogicalResult amx::TileMulIOp::verify() {
VectorType aType = getLhsVectorType();
VectorType bType = getRhsVectorType();
VectorType cType = getVectorType();
amx::TileType aType = getLhsTileType();
amx::TileType bType = getRhsTileType();
amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
@@ -104,5 +112,34 @@ LogicalResult amx::TileMulIOp::verify() {
return success();
}
Type amx::TileType::parse(AsmParser &parser) {
if (parser.parseLess())
return nullptr;
SmallVector<int64_t, 2> shape;
if (parser.parseDimensionList(shape, false, true))
return nullptr;
Type elementType;
if (parser.parseType(elementType))
return nullptr;
if (parser.parseGreater())
return nullptr;
return TileType::get(shape, elementType);
}
void amx::TileType::print(AsmPrinter &os) const {
os << "<";
os.printDimensionList(getShape());
os << 'x';
os.printType(getElementType());
os << '>';
}
#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"

View File

@@ -8,6 +8,7 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
@@ -25,13 +26,13 @@ namespace {
/// The second dimensions needs to be scaled by the number of bytes.
std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter,
VectorType vType, Location loc) {
amx::TileType tType, Location loc) {
Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
unsigned width = vType.getElementType().getIntOrFloatBitWidth();
unsigned width = tType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
return std::make_pair(
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
@@ -76,12 +77,12 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
LogicalResult
matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType vType = op.getVectorType();
amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
// Replace operation with intrinsic.
Type resType = typeConverter->convertType(vType);
Type resType = typeConverter->convertType(tType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
tsz.second);
return success();
@@ -95,10 +96,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType mType = op.getMemRefType();
VectorType vType = op.getVectorType();
amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
// Determine stride.
auto stride = getStride(rewriter, *getTypeConverter(), mType,
adaptor.getBase(), op.getLoc());
@@ -107,7 +108,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
// Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Type resType = typeConverter->convertType(vType);
Type resType = typeConverter->convertType(tType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
op, resType, tsz.first, tsz.second, ptr, stride.value());
return success();
@@ -121,10 +122,10 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType mType = op.getMemRefType();
VectorType vType = op.getVectorType();
amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
// Determine stride.
auto stride = getStride(rewriter, *getTypeConverter(), mType,
adaptor.getBase(), op.getLoc());
@@ -144,9 +145,9 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
LogicalResult
matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType aType = op.getLhsVectorType();
VectorType bType = op.getRhsVectorType();
VectorType cType = op.getVectorType();
amx::TileType aType = op.getLhsTileType();
amx::TileType bType = op.getRhsTileType();
amx::TileType cType = op.getTileType();
// Determine m x n x k tile sizes.
std::pair<Value, Value> tsza =
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -154,9 +155,16 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
// Replace operation with intrinsic.
Type resType = typeConverter->convertType(cType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs());
if (aType.getElementType().isBF16())
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs());
else if (aType.getElementType().isF16())
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs());
else
llvm_unreachable("Unexpected element type for amx.mulf");
return success();
}
};
@@ -166,9 +174,9 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
LogicalResult
matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType aType = op.getLhsVectorType();
VectorType bType = op.getRhsVectorType();
VectorType cType = op.getVectorType();
amx::TileType aType = op.getLhsTileType();
amx::TileType bType = op.getRhsTileType();
amx::TileType cType = op.getTileType();
// Determine m x n x k tile sizes.
std::pair<Value, Value> tsza =
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -201,15 +209,37 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
} // namespace
void mlir::populateAMXLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
TileMulFConversion, TileMulIConversion>(converter);
converter.addConversion([&](amx::TileType type) {
return LLVM::LLVMX86AMXType::get(&converter.getContext());
});
}
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
x86_amx_tdpbusd, x86_amx_tdpbuud>();
x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
TileMulFOp>();
}
namespace {
/// Implement the interface to convert AMX to LLVM.
struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
}
};
} // namespace
void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
dialect->addInterfaces<AMXToLLVMDialectInterface>();
});
}

View File

@@ -45,6 +45,7 @@ static StringRef getTypeKeyword(Type type) {
.Case<LLVMArrayType>([&](Type) { return "array"; })
.Case<LLVMStructType>([&](Type) { return "struct"; })
.Case<LLVMTargetExtType>([&](Type) { return "target"; })
.Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; })
.Default([](Type) -> StringRef {
llvm_unreachable("unexpected 'llvm' type kind");
});
@@ -317,6 +318,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
.Case("array", [&] { return LLVMArrayType::parse(parser); })
.Case("struct", [&] { return parseStructType(parser); })
.Case("target", [&] { return LLVMTargetExtType::parse(parser); })
.Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
.Default([&] {
parser.emitError(keyLoc) << "unknown LLVM type: " << key;
return Type();

View File

@@ -780,7 +780,8 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
LLVMFixedVectorType,
LLVMScalableVectorType,
LLVMTargetExtType,
LLVMVoidType
LLVMVoidType,
LLVMX86AMXType
>(type)) {
// clang-format on
return true;
@@ -842,7 +843,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
LLVMMetadataType,
LLVMPPCFP128Type,
LLVMTokenType,
LLVMVoidType
LLVMVoidType,
LLVMX86AMXType
>([](Type) { return true; })
// clang-format on
.Default([](Type) { return false; });

View File

@@ -63,6 +63,8 @@ private:
return Float128Type::get(&context);
if (type->isX86_FP80Ty())
return Float80Type::get(&context);
if (type->isX86_AMXTy())
return LLVM::LLVMX86AMXType::get(&context);
if (type->isPPC_FP128Ty())
return LLVM::LLVMPPCFP128Type::get(&context);
if (type->isLabelTy())

View File

@@ -67,6 +67,9 @@ public:
.Case([this](LLVM::LLVMMetadataType) {
return llvm::Type::getMetadataTy(context);
})
.Case([this](LLVM::LLVMX86AMXType) {
return llvm::Type::getX86_AMXTy(context);
})
.Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
LLVM::LLVMPointerType, LLVM::LLVMStructType,
LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType,

View File

@@ -4,21 +4,21 @@
func.func @rowheight() {
// expected-error@+1 {{'amx.tile_zero' op bad row height: 17}}
%0 = amx.tile_zero : vector<17x16xbf16>
%0 = amx.tile_zero : !amx.tile<17x16xbf16>
}
// -----
func.func @colwidth() {
// expected-error@+1 {{'amx.tile_zero' op bad column width: 65}}
%0 = amx.tile_zero : vector<16x65xi8>
%0 = amx.tile_zero : !amx.tile<16x65xi8>
}
// -----
func.func @col4bytemultiple() {
// expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
%0 = amx.tile_zero : vector<16x5xi8>
%0 = amx.tile_zero : !amx.tile<16x5xi8>
}
// -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
func.func @memtilesize(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into vector<16x17xf32>
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
}
// -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
func.func @memindexsize(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into vector<16x16xf32>
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
}
// -----
func.func @multsize() {
%0 = amx.tile_zero : vector<8x8xbf16>
%1 = amx.tile_zero : vector<8x8xbf16>
%2 = amx.tile_zero : vector<4x4xf32>
%0 = amx.tile_zero : !amx.tile<8x8xbf16>
%1 = amx.tile_zero : !amx.tile<8x8xbf16>
%2 = amx.tile_zero : !amx.tile<4x4xf32>
// expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
%3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
%3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
}

View File

@@ -14,33 +14,49 @@
// CHECK: amx.tilestored64
func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_zero : vector<16x64xi8>
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
%5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
%6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
%7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, vector<16x16xi32>
%1 = amx.tile_zero : !amx.tile<16x64xi8>
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
%5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !amx.tile<16x16xi32>
%6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !amx.tile<16x16xi32>
%7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, !amx.tile<16x16xi32>
return
}
// CHECK-LABEL: mulf(
// CHECK-LABEL: mulbf16(
// CHECK: amx.tilezero
// CHECK: amx.tileloadd64
// CHECK: amx.tileloadd64
// CHECK: amx.tdpbf16ps
// CHECK: amx.tilestored64
func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_zero : vector<16x32xbf16>
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
%4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
%1 = amx.tile_zero : !amx.tile<16x32xbf16>
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
%4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
// CHECK-LABEL: mulfp16(
// CHECK: amx.tilezero
// CHECK: amx.tileloadd64
// CHECK: amx.tileloadd64
// CHECK: amx.tdpfp16ps
// CHECK: amx.tilestored64
func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_zero : !amx.tile<16x32xf16>
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !amx.tile<16x32xf16>
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
%4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
@@ -63,11 +79,11 @@ func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
%3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
%3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into !amx.tile<16x32xbf16>
amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, !amx.tile<16x32xbf16>
amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, !amx.tile<16x32xbf16>
return
}

View File

@@ -1,49 +1,49 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
// CHECK-LABEL: tzero
// CHECK: amx.tile_zero : vector<16x16xbf16>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, vector<16x16xbf16>
// CHECK: amx.tile_zero : !amx.tile<16x16xbf16>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16>
func.func @tzero(%arg0: memref<?x?xbf16>) {
%0 = arith.constant 0 : index
%1 = amx.tile_zero : vector<16x16xbf16>
amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, vector<16x16xbf16>
%1 = amx.tile_zero : !amx.tile<16x16xbf16>
amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !amx.tile<16x16xbf16>
return
}
// CHECK-LABEL: tmulf
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into vector<16x32xbf16>
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into vector<16x16xf32>
// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, vector<16x16xf32>
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !amx.tile<16x16xf32>
// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !amx.tile<16x16xf32>
func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
%3 = amx.tile_mulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, vector<16x16xf32>
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
%3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
// CHECK-LABEL: tmuli
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !amx.tile<16x16xi32>
// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !amx.tile<16x16xi32>
// Verify the parsing/printing of the sign-extension annotation.
// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
// Verify the various `zext` combinations.
%5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
%6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
%7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
%5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
%6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
%7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
return
}