[mlir][amx] Simplify intrinsic generation (#140559)
Replaces separate amx named intrinsic operations with direct calls to LLVM intrinsic functions. The existing amx tests are updated and expanded. The separate conversion step translating amx intrinsics into LLVM IR is eliminated. Instead, this step is now performed by the existing llvm dialect infrastructure. Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581/7
This commit is contained in:
@@ -69,6 +69,15 @@ SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
|
||||
/// function is used to combine multiple values into a single value.
|
||||
Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
|
||||
Type dstType);
|
||||
|
||||
/// Performs the index computation to get to the element at `indices` of the
|
||||
/// memory pointed to by `memRefDesc`, using the layout map of `type`.
|
||||
/// The indices are linearized as:
|
||||
/// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
|
||||
Value getStridedElementPtr(
|
||||
OpBuilder &builder, Location loc, const LLVMTypeConverter &converter,
|
||||
MemRefType type, Value memRefDesc, ValueRange indices,
|
||||
LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none);
|
||||
} // namespace LLVM
|
||||
|
||||
/// Base class for operation conversions targeting the LLVM IR dialect. It
|
||||
@@ -107,8 +116,8 @@ protected:
|
||||
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
Type resultType, int64_t value);
|
||||
|
||||
// This is a strided getElementPtr variant that linearizes subscripts as:
|
||||
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
|
||||
/// Convenience wrapper for the corresponding helper utility.
|
||||
/// This is a strided getElementPtr variant with linearized subscripts.
|
||||
Value getStridedElementPtr(
|
||||
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
|
||||
Value memRefDesc, ValueRange indices,
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#define AMX
|
||||
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/Dialect/AMX/AMXInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
@@ -47,8 +48,6 @@ def AMX_Dialect : Dialect {
|
||||
|
||||
This `AMX` dialect provides a bridge between MLIR concepts such as
|
||||
vectors and memrefs and the lower level LLVM IR support of AMX.
|
||||
The dialect is split into user-facing AMX ops (AMX_Op) and
|
||||
backend-facing intrinsic ops (AMX_IntrOp).
|
||||
|
||||
Note that since configuration changes (implicit at dialect level) are
|
||||
costly, it is highly recommended to use the AMX dialect on same-shaped
|
||||
@@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>;
|
||||
class AMX_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<AMX_Dialect, mnemonic, traits> {}
|
||||
|
||||
// The "internal" intrinsics are meant for compiler usage.
|
||||
class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
|
||||
LLVM_IntrOpBase<AMX_Dialect, mnemonic,
|
||||
"x86_" # !subst(".", "_", mnemonic) # "_internal",
|
||||
[], [], traits, numResults>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Op definitions (user facing).
|
||||
// AMX Op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//
|
||||
// Tile reset.
|
||||
//
|
||||
|
||||
def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
|
||||
def TileZeroOp : AMX_Op<"tile_zero", [Pure,
|
||||
AMXIntrinsicOpInterface
|
||||
]> {
|
||||
let summary = "tile zero operation";
|
||||
let description = [{
|
||||
Zeroes the destination tile, with the shape defined by the 2-dim
|
||||
@@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
|
||||
TileType getTileType() {
|
||||
return ::llvm::cast<TileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
return "llvm.x86.tilezero.internal";
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "attr-dict `:` qualified(type($res))";
|
||||
let hasVerifier = 1;
|
||||
@@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
|
||||
// Tile memory operations.
|
||||
//
|
||||
|
||||
def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
|
||||
def TileLoadOp : AMX_Op<"tile_load", [Pure,
|
||||
AMXIntrinsicOpInterface
|
||||
]> {
|
||||
let summary = "tile load operation";
|
||||
let description = [{
|
||||
Loads a tile from memory defined by a base and indices, with the
|
||||
@@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
|
||||
TileType getTileType() {
|
||||
return ::llvm::cast<TileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
return "llvm.x86.tileloadd64.internal";
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
|
||||
"type($base) `into` qualified(type($res))";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TileStoreOp : AMX_Op<"tile_store"> {
|
||||
def TileStoreOp : AMX_Op<"tile_store", [
|
||||
AMXIntrinsicOpInterface
|
||||
]> {
|
||||
let summary = "tile store operation";
|
||||
let description = [{
|
||||
Stores a tile to memory defined by a base and indices, with the
|
||||
@@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> {
|
||||
TileType getTileType() {
|
||||
return ::llvm::cast<TileType>(getVal().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
return "llvm.x86.tilestored64.internal";
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
|
||||
"type($base) `,` qualified(type($val))";
|
||||
@@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
|
||||
// Tile arithmetic operations.
|
||||
//
|
||||
|
||||
def TileMulFOp : AMX_Op<"tile_mulf", [
|
||||
Pure, AllTypesMatch<["acc", "res"]>]> {
|
||||
def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
|
||||
AMXIntrinsicOpInterface,
|
||||
AllTypesMatch<["acc", "res"]>
|
||||
]> {
|
||||
let summary = "tile multiplication operation (floating-point)";
|
||||
let description = [{
|
||||
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
|
||||
@@ -270,6 +295,19 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
|
||||
TileType getTileType() {
|
||||
return ::llvm::cast<TileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.tdp";
|
||||
auto elementType =
|
||||
getLhsTileType().getElementType();
|
||||
intr += elementType.isF16() ? "fp16" : "bf16";
|
||||
intr += "ps.internal";
|
||||
return intr;
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
|
||||
"qualified(type($lhs)) `,` qualified(type($rhs))"
|
||||
@@ -277,8 +315,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TileMulIOp : AMX_Op<"tile_muli", [
|
||||
Pure, AllTypesMatch<["acc", "res"]>]> {
|
||||
def TileMulIOp : AMX_Op<"tile_muli", [Pure,
|
||||
AMXIntrinsicOpInterface,
|
||||
AllTypesMatch<["acc", "res"]>
|
||||
]> {
|
||||
let summary = "tile multiplication operation (integer)";
|
||||
let description = [{
|
||||
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
|
||||
@@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [
|
||||
TileType getTileType() {
|
||||
return ::llvm::cast<TileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.tdpb";
|
||||
intr += getIsZextLhs() ? "u" : "s";
|
||||
intr += getIsZextRhs() ? "u" : "s";
|
||||
intr += "d.internal";
|
||||
return intr;
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
|
||||
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX IntrOp definitions (LLVM compiler facing).
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//
|
||||
// Tile reset. Parameters define the tile size.
|
||||
//
|
||||
|
||||
def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
|
||||
Arguments<(ins AnyInteger, AnyInteger)>;
|
||||
|
||||
//
|
||||
// Tile memory operations. Parameters define the tile size,
|
||||
// base address, and stride between consecutive rows for the
|
||||
// memory operation.
|
||||
//
|
||||
|
||||
def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
|
||||
Arguments<(ins AnyInteger,
|
||||
AnyInteger, LLVM_AnyPointer, AnyInteger)>;
|
||||
|
||||
def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
|
||||
Arguments<(ins AnyInteger,
|
||||
AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
|
||||
|
||||
//
|
||||
// Tile multiplication operations (series of dot products). Parameters
|
||||
// define the tile sizes and source and destination tiles for the
|
||||
// operation. Note that the prefix "tdp" stands for tile dot product.
|
||||
//
|
||||
|
||||
// Dot product of bf16 tiles into f32 tile.
|
||||
def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
|
||||
Arguments<(ins AnyInteger,
|
||||
AnyInteger,
|
||||
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
// Dot product of f16 tiles into f32 tile.
|
||||
def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
|
||||
Arguments<(ins AnyInteger,
|
||||
AnyInteger,
|
||||
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
|
||||
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
|
||||
Arguments<(ins AnyInteger,
|
||||
AnyInteger,
|
||||
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
// Dot product of i8 tiles into i32 tile (with sign/zero extension).
|
||||
def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
|
||||
Arguments<(ins AnyInteger,
|
||||
AnyInteger,
|
||||
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
// Dot product of i8 tiles into i32 tile (with zero/sign extension).
|
||||
def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
|
||||
Arguments<(ins AnyInteger,
|
||||
AnyInteger,
|
||||
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
// Dot product of i8 tiles into i32 tile (with zero/zero extension).
|
||||
def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
|
||||
Arguments<(ins AnyInteger,
|
||||
AnyInteger,
|
||||
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
#endif // AMX
|
||||
|
||||
@@ -14,11 +14,15 @@
|
||||
#define MLIR_DIALECT_AMX_AMXDIALECT_H_
|
||||
|
||||
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
/// Include the generated interface declarations.
|
||||
#include "mlir/Dialect/AMX/AMXInterfaces.h.inc"
|
||||
|
||||
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
|
||||
31
mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
Normal file
31
mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
Normal file
@@ -0,0 +1,31 @@
|
||||
//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines interfaces for the AMX dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef AMX_INTERFACES
|
||||
#define AMX_INTERFACES
|
||||
|
||||
include "mlir/IR/Interfaces.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Intrinsic Interface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AMXIntrinsicOpInterface
|
||||
: OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> {
|
||||
let description = [{
|
||||
A wrapper interface for operations representing AMX LLVM intrinsics.
|
||||
}];
|
||||
let cppNamespace = "::mlir::amx";
|
||||
}
|
||||
|
||||
#endif // AMX_INTERFACES
|
||||
@@ -1,6 +1,5 @@
|
||||
add_mlir_dialect(AMX amx)
|
||||
add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS AMX.td)
|
||||
mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
|
||||
add_public_tablegen_target(MLIRAMXConversionsIncGen)
|
||||
add_mlir_interface(AMXInterfaces)
|
||||
add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
|
||||
|
||||
@@ -25,9 +25,6 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
|
||||
/// intrinsics.
|
||||
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
|
||||
|
||||
/// Register LLVM conversion interface for AMX dialect.
|
||||
void registerConvertAMXToLLVMInterface(DialectRegistry ®istry);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
|
||||
|
||||
@@ -32,7 +32,6 @@
|
||||
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
|
||||
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Dialect/AMX/Transforms.h"
|
||||
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
|
||||
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
|
||||
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
|
||||
@@ -84,7 +83,6 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
|
||||
registerConvertOpenMPToLLVMInterface(registry);
|
||||
registerConvertSCFToEmitCInterface(registry);
|
||||
ub::registerConvertUBToLLVMInterface(registry);
|
||||
registerConvertAMXToLLVMInterface(registry);
|
||||
gpu::registerConvertGpuToLLVMInterface(registry);
|
||||
NVVM::registerConvertGpuToNVVMInterface(registry);
|
||||
vector::registerConvertVectorToLLVMInterface(registry);
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
//===- AMXToLLVMIRTranslation.h - AMX to LLVM IR ----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This provides registration calls for AMX dialect to LLVM IR translation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
|
||||
#define MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
class MLIRContext;
|
||||
|
||||
/// Register the AMX dialect and the translation from it to the LLVM IR
|
||||
/// in the given registry;
|
||||
void registerAMXDialectTranslation(DialectRegistry ®istry);
|
||||
|
||||
/// Register the AMX dialect and the translation from it in the registry
|
||||
/// associated with the given context.
|
||||
void registerAMXDialectTranslation(MLIRContext &context);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
|
||||
@@ -14,7 +14,6 @@
|
||||
#ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H
|
||||
#define MLIR_TARGET_LLVMIR_DIALECT_ALL_H
|
||||
|
||||
#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
|
||||
@@ -37,7 +36,6 @@ class DialectRegistry;
|
||||
/// corresponding translation interfaces.
|
||||
static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) {
|
||||
registerArmNeonDialectTranslation(registry);
|
||||
registerAMXDialectTranslation(registry);
|
||||
registerArmSMEDialectTranslation(registry);
|
||||
registerArmSVEDialectTranslation(registry);
|
||||
registerBuiltinDialectTranslation(registry);
|
||||
|
||||
@@ -62,49 +62,8 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
|
||||
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
|
||||
Value memRefDesc, ValueRange indices,
|
||||
LLVM::GEPNoWrapFlags noWrapFlags) const {
|
||||
|
||||
auto [strides, offset] = type.getStridesAndOffset();
|
||||
|
||||
MemRefDescriptor memRefDescriptor(memRefDesc);
|
||||
// Use a canonical representation of the start address so that later
|
||||
// optimizations have a longer sequence of instructions to CSE.
|
||||
// If we don't do that we would sprinkle the memref.offset in various
|
||||
// position of the different address computations.
|
||||
Value base =
|
||||
memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
|
||||
|
||||
LLVM::IntegerOverflowFlags intOverflowFlags =
|
||||
LLVM::IntegerOverflowFlags::none;
|
||||
if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
|
||||
intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
|
||||
}
|
||||
if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
|
||||
intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
|
||||
}
|
||||
|
||||
Type indexType = getIndexType();
|
||||
Value index;
|
||||
for (int i = 0, e = indices.size(); i < e; ++i) {
|
||||
Value increment = indices[i];
|
||||
if (strides[i] != 1) { // Skip if stride is 1.
|
||||
Value stride =
|
||||
ShapedType::isDynamic(strides[i])
|
||||
? memRefDescriptor.stride(rewriter, loc, i)
|
||||
: createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
|
||||
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride,
|
||||
intOverflowFlags);
|
||||
}
|
||||
index = index ? rewriter.create<LLVM::AddOp>(loc, index, increment,
|
||||
intOverflowFlags)
|
||||
: increment;
|
||||
}
|
||||
|
||||
Type elementPtrType = memRefDescriptor.getElementPtrType();
|
||||
return index ? rewriter.create<LLVM::GEPOp>(
|
||||
loc, elementPtrType,
|
||||
getTypeConverter()->convertType(type.getElementType()),
|
||||
base, index, noWrapFlags)
|
||||
: base;
|
||||
return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type,
|
||||
memRefDesc, indices, noWrapFlags);
|
||||
}
|
||||
|
||||
// Check if the MemRefType `type` is supported by the lowering. We currently
|
||||
@@ -524,3 +483,52 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
|
||||
const LLVMTypeConverter &converter,
|
||||
MemRefType type, Value memRefDesc,
|
||||
ValueRange indices,
|
||||
LLVM::GEPNoWrapFlags noWrapFlags) {
|
||||
auto [strides, offset] = type.getStridesAndOffset();
|
||||
|
||||
MemRefDescriptor memRefDescriptor(memRefDesc);
|
||||
// Use a canonical representation of the start address so that later
|
||||
// optimizations have a longer sequence of instructions to CSE.
|
||||
// If we don't do that we would sprinkle the memref.offset in various
|
||||
// position of the different address computations.
|
||||
Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type);
|
||||
|
||||
LLVM::IntegerOverflowFlags intOverflowFlags =
|
||||
LLVM::IntegerOverflowFlags::none;
|
||||
if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
|
||||
intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
|
||||
}
|
||||
if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
|
||||
intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
|
||||
}
|
||||
|
||||
Type indexType = converter.getIndexType();
|
||||
Value index;
|
||||
for (int i = 0, e = indices.size(); i < e; ++i) {
|
||||
Value increment = indices[i];
|
||||
if (strides[i] != 1) { // Skip if stride is 1.
|
||||
Value stride =
|
||||
ShapedType::isDynamic(strides[i])
|
||||
? memRefDescriptor.stride(builder, loc, i)
|
||||
: builder.create<LLVM::ConstantOp>(
|
||||
loc, indexType, builder.getIndexAttr(strides[i]));
|
||||
increment =
|
||||
builder.create<LLVM::MulOp>(loc, increment, stride, intOverflowFlags);
|
||||
}
|
||||
index = index ? builder.create<LLVM::AddOp>(loc, index, increment,
|
||||
intOverflowFlags)
|
||||
: increment;
|
||||
}
|
||||
|
||||
Type elementPtrType = memRefDescriptor.getElementPtrType();
|
||||
return index ? builder.create<LLVM::GEPOp>(
|
||||
loc, elementPtrType,
|
||||
converter.convertType(type.getElementType()), base, index,
|
||||
noWrapFlags)
|
||||
: base;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/AMX/AMXDialect.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
@@ -21,6 +23,8 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
|
||||
|
||||
#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
|
||||
|
||||
void amx::AMXDialect::initialize() {
|
||||
@@ -60,24 +64,127 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
|
||||
/// dimension directly translates into the number of rows of the tiles.
|
||||
/// The second dimensions needs to be scaled by the number of bytes.
|
||||
static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
|
||||
RewriterBase &rewriter) {
|
||||
Type llvmInt16Type = rewriter.getIntegerType(16);
|
||||
unsigned width = tType.getElementType().getIntOrFloatBitWidth();
|
||||
assert(llvm::isPowerOf2_64(width) && width >= 8);
|
||||
unsigned bytes = width >> 3;
|
||||
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
|
||||
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
|
||||
return SmallVector<Value>{
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
|
||||
}
|
||||
|
||||
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
|
||||
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
|
||||
static Value getStride(Location loc, MemRefType mType, Value base,
|
||||
RewriterBase &rewriter) {
|
||||
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
|
||||
int64_t preLast = mType.getRank() - 2;
|
||||
Type llvmInt64Type = rewriter.getIntegerType(64);
|
||||
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
|
||||
assert(llvm::isPowerOf2_64(width) && width >= 8);
|
||||
unsigned bytes = width >> 3;
|
||||
auto [strides, offset] = mType.getStridesAndOffset();
|
||||
if (strides[preLast] == ShapedType::kDynamic) {
|
||||
// Dynamic stride needs code to compute the stride at runtime.
|
||||
MemRefDescriptor memrefDescriptor(base);
|
||||
auto attr = rewriter.getI64IntegerAttr(bytes);
|
||||
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
|
||||
return rewriter
|
||||
.create<LLVM::MulOp>(loc, llvmInt64Type, scale,
|
||||
memrefDescriptor.stride(rewriter, loc, preLast))
|
||||
.getResult();
|
||||
}
|
||||
// Use direct constant for static stride.
|
||||
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
LogicalResult amx::TileZeroOp::verify() {
|
||||
return verifyTileSize(*this, getTileType());
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
return getTileSizes(getLoc(), getTileType(), rewriter);
|
||||
}
|
||||
|
||||
LogicalResult amx::TileLoadOp::verify() {
|
||||
unsigned rank = getMemRefType().getRank();
|
||||
MemRefType memrefTy = getMemRefType();
|
||||
unsigned rank = memrefTy.getRank();
|
||||
if (rank < 2)
|
||||
return emitOpError("requires at least 2D memref");
|
||||
if (getIndices().size() != rank)
|
||||
return emitOpError("requires ") << rank << " indices";
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset;
|
||||
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
|
||||
strides.back() != 1)
|
||||
return emitOpError("requires memref with unit innermost stride");
|
||||
return verifyTileSize(*this, getTileType());
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
SmallVector<Value> intrinsicOperands;
|
||||
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
|
||||
intrinsicOperands.push_back(
|
||||
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
|
||||
adaptor.getBase(), adaptor.getIndices()));
|
||||
intrinsicOperands.push_back(
|
||||
getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
LogicalResult amx::TileStoreOp::verify() {
|
||||
unsigned rank = getMemRefType().getRank();
|
||||
MemRefType memrefTy = getMemRefType();
|
||||
unsigned rank = memrefTy.getRank();
|
||||
if (rank < 2)
|
||||
return emitOpError("requires at least 2D memref");
|
||||
if (getIndices().size() != rank)
|
||||
return emitOpError("requires ") << rank << " indices";
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset;
|
||||
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
|
||||
strides.back() != 1)
|
||||
return emitOpError("requires memref with unit innermost stride");
|
||||
return verifyTileSize(*this, getTileType());
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
SmallVector<Value> intrinsicOperands;
|
||||
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
|
||||
intrinsicOperands.push_back(
|
||||
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
|
||||
adaptor.getBase(), adaptor.getIndices()));
|
||||
intrinsicOperands.push_back(
|
||||
getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
|
||||
intrinsicOperands.push_back(adaptor.getVal());
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
LogicalResult amx::TileMulFOp::verify() {
|
||||
amx::TileType aType = getLhsTileType();
|
||||
amx::TileType bType = getRhsTileType();
|
||||
@@ -95,6 +202,25 @@ LogicalResult amx::TileMulFOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
amx::TileType aType = getLhsTileType();
|
||||
amx::TileType bType = getRhsTileType();
|
||||
SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
|
||||
SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
|
||||
|
||||
SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
|
||||
tsza[1], adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs()};
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
LogicalResult amx::TileMulIOp::verify() {
|
||||
amx::TileType aType = getLhsTileType();
|
||||
amx::TileType bType = getRhsTileType();
|
||||
@@ -112,6 +238,25 @@ LogicalResult amx::TileMulIOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
amx::TileType aType = getLhsTileType();
|
||||
amx::TileType bType = getRhsTileType();
|
||||
SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
|
||||
SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
|
||||
|
||||
SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
|
||||
tsza[1], adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs()};
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
Type amx::TileType::parse(AsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return nullptr;
|
||||
|
||||
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRAMXDialect
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMDialect
|
||||
MLIRSideEffectInterfaces
|
||||
)
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
add_mlir_dialect_library(MLIRAMXTransforms
|
||||
LegalizeForLLVMExport.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRAMXConversionsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAMXDialect
|
||||
MLIRIR
|
||||
|
||||
@@ -21,224 +21,42 @@ using namespace mlir::amx;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
|
||||
/// dimension directly translates into the number of rows of the tiles.
|
||||
/// The second dimensions needs to be scaled by the number of bytes.
|
||||
std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
amx::TileType tType, Location loc) {
|
||||
Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
|
||||
unsigned width = tType.getElementType().getIntOrFloatBitWidth();
|
||||
assert(llvm::isPowerOf2_64(width) && width >= 8);
|
||||
unsigned bytes = width >> 3;
|
||||
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
|
||||
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
|
||||
return std::make_pair(
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
|
||||
}
|
||||
/// Generic one-to-one conversion of simply mappable operations into calls
|
||||
/// to their respective LLVM intrinsics.
|
||||
struct AMXIntrinsicOpConversion
|
||||
: public OpInterfaceConversionPattern<amx::AMXIntrinsicOp> {
|
||||
using OpInterfaceConversionPattern<
|
||||
amx::AMXIntrinsicOp>::OpInterfaceConversionPattern;
|
||||
|
||||
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
|
||||
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
|
||||
/// Returns failure if proper stride couldn't be found.
|
||||
FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
MemRefType mType, Value base, Location loc) {
|
||||
if (mType.getRank() < 2)
|
||||
return failure();
|
||||
int64_t preLast = mType.getRank() - 2;
|
||||
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
|
||||
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
|
||||
assert(llvm::isPowerOf2_64(width) && width >= 8);
|
||||
unsigned bytes = width >> 3;
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1)
|
||||
return failure();
|
||||
if (strides[preLast] == ShapedType::kDynamic) {
|
||||
// Dynamic stride needs code to compute the stride at runtime.
|
||||
MemRefDescriptor memrefDescriptor(base);
|
||||
auto attr = rewriter.getI64IntegerAttr(bytes);
|
||||
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
|
||||
return rewriter
|
||||
.create<LLVM::MulOp>(loc, llvmInt64Type, scale,
|
||||
memrefDescriptor.stride(rewriter, loc, preLast))
|
||||
.getResult();
|
||||
}
|
||||
// Use direct constant for static stride.
|
||||
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
|
||||
using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
amx::TileType tType = op.getTileType();
|
||||
// Determine m x n tile sizes.
|
||||
std::pair<Value, Value> tsz =
|
||||
getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
|
||||
// Replace operation with intrinsic.
|
||||
Type resType = typeConverter->convertType(tType);
|
||||
rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
|
||||
tsz.second);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
|
||||
using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
|
||||
AMXIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
|
||||
benefit),
|
||||
typeConverter(typeConverter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MemRefType mType = op.getMemRefType();
|
||||
amx::TileType tType = op.getTileType();
|
||||
// Determine m x n tile sizes.
|
||||
std::pair<Value, Value> tsz =
|
||||
getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
|
||||
// Determine stride.
|
||||
auto stride = getStride(rewriter, *getTypeConverter(), mType,
|
||||
adaptor.getBase(), op.getLoc());
|
||||
if (failed(stride))
|
||||
return failure();
|
||||
// Replace operation with intrinsic.
|
||||
Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
|
||||
adaptor.getBase(), adaptor.getIndices());
|
||||
Type resType = typeConverter->convertType(tType);
|
||||
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
|
||||
op, resType, tsz.first, tsz.second, ptr, stride.value());
|
||||
return success();
|
||||
return LLVM::detail::intrinsicRewrite(
|
||||
op, rewriter.getStringAttr(op.getIntrinsicName()),
|
||||
op.getIntrinsicOperands(operands, typeConverter, rewriter),
|
||||
typeConverter, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
|
||||
using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MemRefType mType = op.getMemRefType();
|
||||
amx::TileType tType = op.getTileType();
|
||||
// Determine m x n tile sizes.
|
||||
std::pair<Value, Value> tsz =
|
||||
getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
|
||||
// Determine stride.
|
||||
auto stride = getStride(rewriter, *getTypeConverter(), mType,
|
||||
adaptor.getBase(), op.getLoc());
|
||||
if (failed(stride))
|
||||
return failure();
|
||||
// Replace operation with intrinsic.
|
||||
Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
|
||||
adaptor.getBase(), adaptor.getIndices());
|
||||
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
|
||||
op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
|
||||
using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
amx::TileType aType = op.getLhsTileType();
|
||||
amx::TileType bType = op.getRhsTileType();
|
||||
amx::TileType cType = op.getTileType();
|
||||
// Determine m x n x k tile sizes.
|
||||
std::pair<Value, Value> tsza =
|
||||
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
|
||||
std::pair<Value, Value> tszb =
|
||||
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
|
||||
// Replace operation with intrinsic.
|
||||
Type resType = typeConverter->convertType(cType);
|
||||
if (aType.getElementType().isBF16())
|
||||
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
|
||||
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
else if (aType.getElementType().isF16())
|
||||
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
|
||||
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
else
|
||||
llvm_unreachable("Unexpected element type for amx.mulf");
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
|
||||
using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
amx::TileType aType = op.getLhsTileType();
|
||||
amx::TileType bType = op.getRhsTileType();
|
||||
amx::TileType cType = op.getTileType();
|
||||
// Determine m x n x k tile sizes.
|
||||
std::pair<Value, Value> tsza =
|
||||
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
|
||||
std::pair<Value, Value> tszb =
|
||||
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
|
||||
// Replace operation with intrinsic.
|
||||
Type resType = typeConverter->convertType(cType);
|
||||
bool zexta = op.getIsZextLhs();
|
||||
bool zextb = op.getIsZextRhs();
|
||||
if (zexta && zextb)
|
||||
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
|
||||
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
else if (zexta && !zextb)
|
||||
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>(
|
||||
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
else if (!zexta && zextb)
|
||||
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>(
|
||||
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>(
|
||||
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
return success();
|
||||
}
|
||||
private:
|
||||
const LLVMTypeConverter &typeConverter;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateAMXLegalizeForLLVMExportPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
|
||||
TileMulFConversion, TileMulIConversion>(converter);
|
||||
patterns.add<AMXIntrinsicOpConversion>(converter);
|
||||
converter.addConversion([&](amx::TileType type) {
|
||||
return LLVM::LLVMX86AMXType::get(&converter.getContext());
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
|
||||
target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
|
||||
x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
|
||||
x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
|
||||
target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
|
||||
TileMulFOp>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Implement the interface to convert AMX to LLVM.
|
||||
struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
||||
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
||||
|
||||
void populateConvertToLLVMConversionPatterns(
|
||||
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) const final {
|
||||
populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
|
||||
dialect->addInterfaces<AMXToLLVMDialectInterface>();
|
||||
});
|
||||
target.addIllegalDialect<AMXDialect>();
|
||||
}
|
||||
|
||||
@@ -51,7 +51,6 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
|
||||
MLIRArmNeonToLLVMIRTranslation
|
||||
MLIRArmSMEToLLVMIRTranslation
|
||||
MLIRArmSVEToLLVMIRTranslation
|
||||
MLIRAMXToLLVMIRTranslation
|
||||
MLIRBuiltinToLLVMIRTranslation
|
||||
MLIRGPUToLLVMIRTranslation
|
||||
MLIRLLVMToLLVMIRTranslation
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
//===- AMXToLLVMIRTranslation.cpp - Translate AMX to LLVM IR --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a translation between the AMX dialect and LLVM IR.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
|
||||
#include "mlir/Dialect/AMX/AMXDialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
||||
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/IntrinsicsX86.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::LLVM;
|
||||
|
||||
namespace {
|
||||
/// Implementation of the dialect interface that converts operations belonging
|
||||
/// to the AMX dialect to LLVM IR.
|
||||
class AMXDialectLLVMIRTranslationInterface
|
||||
: public LLVMTranslationDialectInterface {
|
||||
public:
|
||||
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
|
||||
|
||||
/// Translates the given operation to LLVM IR using the provided IR builder
|
||||
/// and saving the state in `moduleTranslation`.
|
||||
LogicalResult
|
||||
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) const final {
|
||||
Operation &opInst = *op;
|
||||
#include "mlir/Dialect/AMX/AMXConversions.inc"
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::registerAMXDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<amx::AMXDialect>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
|
||||
dialect->addInterfaces<AMXDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerAMXDialectTranslation(MLIRContext &context) {
|
||||
DialectRegistry registry;
|
||||
registerAMXDialectTranslation(registry);
|
||||
context.appendDialectRegistry(registry);
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
add_mlir_translation_library(MLIRAMXToLLVMIRTranslation
|
||||
AMXToLLVMIRTranslation.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRAMXConversionsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRAMXDialect
|
||||
MLIRLLVMDialect
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
add_subdirectory(ArmNeon)
|
||||
add_subdirectory(ArmSME)
|
||||
add_subdirectory(ArmSVE)
|
||||
add_subdirectory(AMX)
|
||||
add_subdirectory(Builtin)
|
||||
add_subdirectory(GPU)
|
||||
add_subdirectory(LLVMIR)
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: muli(
|
||||
// CHECK: amx.tilezero
|
||||
// CHECK: amx.tileloadd64
|
||||
// CHECK: amx.tileloadd64
|
||||
// CHECK: amx.tdpbuud
|
||||
// CHECK: amx.tilestored64
|
||||
// CHECK: amx.tdpbssd
|
||||
// CHECK: amx.tilestored64
|
||||
// CHECK: amx.tdpbusd
|
||||
// CHECK: amx.tilestored64
|
||||
// CHECK: amx.tdpbsud
|
||||
// CHECK: amx.tilestored64
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tdpbuud.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tdpbssd.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tdpbusd.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tdpbsud.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
|
||||
func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_zero : !amx.tile<16x64xi8>
|
||||
@@ -29,11 +29,11 @@ func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mulbf16(
|
||||
// CHECK: amx.tilezero
|
||||
// CHECK: amx.tileloadd64
|
||||
// CHECK: amx.tileloadd64
|
||||
// CHECK: amx.tdpbf16ps
|
||||
// CHECK: amx.tilestored64
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tdpbf16ps.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
|
||||
func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_zero : !amx.tile<16x32xbf16>
|
||||
@@ -45,11 +45,11 @@ func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mulfp16(
|
||||
// CHECK: amx.tilezero
|
||||
// CHECK: amx.tileloadd64
|
||||
// CHECK: amx.tileloadd64
|
||||
// CHECK: amx.tdpfp16ps
|
||||
// CHECK: amx.tilestored64
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tdpfp16ps.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
|
||||
func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_zero : !amx.tile<16x32xf16>
|
||||
@@ -62,21 +62,21 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
|
||||
|
||||
// CHECK-LABEL: strides(
|
||||
// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
|
||||
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
|
||||
// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
|
||||
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
|
||||
// CHECK: llvm.mlir.constant(2 : i64) : i64
|
||||
// CHECK: llvm.extractvalue %{{.+}}[4, 0]
|
||||
// CHECK: %[[STRIDE_1:.+]] = llvm.mul
|
||||
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
|
||||
// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
|
||||
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
|
||||
// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
|
||||
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
|
||||
// CHECK: llvm.mlir.constant(2 : i64) : i64
|
||||
// CHECK: llvm.extractvalue %{{.+}}[4, 0]
|
||||
// CHECK: %[[STRIDE_2:.+]] = llvm.mul
|
||||
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
|
||||
func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
|
||||
|
||||
@@ -1,13 +1,90 @@
|
||||
// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
|
||||
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-amx" --convert-to-llvm -reconcile-unrealized-casts \
|
||||
// RUN: | mlir-translate --mlir-to-llvmir \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: define void @target(ptr %0)
|
||||
// CHECK: %[[c:.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 16)
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal(i16 16, i16 16, ptr %0, i64 32, x86_amx %[[c]]
|
||||
llvm.func @target(%ptr: !llvm.ptr) {
|
||||
%c = llvm.mlir.constant(16 : i16) : i16
|
||||
%s = llvm.mlir.constant(32 : i64) : i64
|
||||
%0 = "amx.tilezero"(%c, %c) : (i16, i16) -> !llvm.array<16 x vector<16xbf16>>
|
||||
"amx.tilestored64"(%c, %c, %ptr, %s, %0) : (i16, i16, !llvm.ptr, i64, !llvm.array<16 x vector<16xbf16>>) -> ()
|
||||
llvm.return
|
||||
// CHECK-LABEL: define void @amx_tile_zero
|
||||
func.func @amx_tile_zero(%out: memref<?x?xf32>, %idx: index)
|
||||
{
|
||||
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
%zero = amx.tile_zero : !amx.tile<16x16xf32>
|
||||
amx.tile_store %out[%idx, %idx], %zero : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define void @amx_tile_load_store
|
||||
func.func @amx_tile_load_store(%base: memref<?x?xi8>, %out: memref<?x?xi8>,
|
||||
%idx: index)
|
||||
{
|
||||
// CHECK: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
%val = amx.tile_load %base[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
amx.tile_store %out[%idx, %idx], %val : memref<?x?xi8>, !amx.tile<16x64xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define void @amx_tile_mulf_bf16
|
||||
func.func @amx_tile_mulf_bf16(
|
||||
%matA: memref<?x?xbf16>, %matB: memref<?x?xbf16>, %idx: index,
|
||||
%out: memref<?x?xf32>)
|
||||
{
|
||||
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
|
||||
%acc = amx.tile_zero : !amx.tile<16x16xf32>
|
||||
// CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
%tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
|
||||
%tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
|
||||
// CHECK: call x86_amx @llvm.x86.tdpbf16ps.internal
|
||||
%tRes = amx.tile_mulf %tA, %tB, %acc
|
||||
: !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define void @amx_tile_mulf_f16
|
||||
func.func @amx_tile_mulf_f16(
|
||||
%matA: memref<?x?xf16>, %matB: memref<?x?xf16>, %idx: index,
|
||||
%out: memref<?x?xf32>)
|
||||
{
|
||||
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
|
||||
%acc = amx.tile_zero : !amx.tile<16x16xf32>
|
||||
// CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
%tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xf16> into !amx.tile<16x32xf16>
|
||||
%tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xf16> into !amx.tile<16x32xf16>
|
||||
// CHECK: call x86_amx @llvm.x86.tdpfp16ps.internal
|
||||
%tRes = amx.tile_mulf %tA, %tB, %acc
|
||||
: !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define void @amx_tile_muli
|
||||
func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
|
||||
%matC: memref<?x?xi32>, %idx: index, %out: memref<?x?xi8>)
|
||||
{
|
||||
%c0 = arith.constant 0 : index
|
||||
%c16 = arith.constant 16 : index
|
||||
// CHECK-COUNT-3: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
%tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
%tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
%acc = amx.tile_load %matC[%idx, %idx] : memref<?x?xi32> into !amx.tile<16x16xi32>
|
||||
// CHECK: call x86_amx @llvm.x86.tdpbuud.internal
|
||||
// CHECK: call x86_amx @llvm.x86.tdpbssd.internal
|
||||
// CHECK: call x86_amx @llvm.x86.tdpbusd.internal
|
||||
// CHECK: call x86_amx @llvm.x86.tdpbsud.internal
|
||||
%res = amx.tile_muli %tA zext, %tB zext, %acc
|
||||
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
%res1 = amx.tile_muli %tA, %tB, %acc
|
||||
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
%res2 = amx.tile_muli %tA zext, %tB, %acc
|
||||
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
%res3 = amx.tile_muli %tA, %tB zext, %acc
|
||||
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
// CHECK-COUNT-4: call void @llvm.x86.tilestored64.internal
|
||||
amx.tile_store %out[%c0, %c0], %res : memref<?x?xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %out[%c0, %c16], %res1 : memref<?x?xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %out[%c16, %c0], %res2 : memref<?x?xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user