This patch is intended to resolve #109481 and improve the usability of the AMX dialect. In LLVM IR, AMX intrinsics use `x86_amx` which is one of the primitive types. This type is supposed to be used for AMX intrinsic calls and no other operations. AMX dialect of MLIR uses regular 2D vector types, which are then lowered to arrays of vectors in the LLVMIR dialect. This creates an inconsistency in the types used in the LLVMIR dialect and LLVMIR. Translation of AMX intrinsic calls to LLVM IR doesn't require result types to match and that is where tile loads and mul operation results get `x86_amx` type. This works in very simple cases when mul and tile store operations directly consume the result of another AMX intrinsic call, but it doesn't work when an argument is a block argument (phi node). In addition to translation problems, this inconsistency between types used in MLIR and LLVM IR makes MLIR verification and transformation quite problematic. Both `amx.tileload` and `vector::transfer_read` can load values of the same type, but only one of them can be used in AMX operations. In general, by looking at a type of value, we cannot determine if it can only be used for AMX operations or contrary can be used in other operations but AMX ones. To remove this inconsistency and make AMX operations more explicit in their limitations, I propose to add `LLVMX86AMXType` type to the LLVMIR dialect to match `x86_amx` type in LLVM IR, and introduce `amx::TileType` to be used by AMX operations in MLIR. This resolves translation problems for AMX usage with phi nodes and provides proper type verification in MLIR for AMX operations. P.S. This patch also adds missing FP16 support. It's trivial but unrelated to type system changes, so let me know if I should submit it separately. --------- Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
246 lines
10 KiB
C++
246 lines
10 KiB
C++
//===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/AMX/Transforms.h"
|
|
|
|
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Dialect/AMX/AMXDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::amx;
|
|
|
|
namespace {
|
|
|
|
/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
|
|
/// dimension directly translates into the number of rows of the tiles.
|
|
/// The second dimensions needs to be scaled by the number of bytes.
|
|
std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
|
|
const LLVMTypeConverter &typeConverter,
|
|
amx::TileType tType, Location loc) {
|
|
Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
|
|
unsigned width = tType.getElementType().getIntOrFloatBitWidth();
|
|
assert(llvm::isPowerOf2_64(width) && width >= 8);
|
|
unsigned bytes = width >> 3;
|
|
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
|
|
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
|
|
return std::make_pair(
|
|
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
|
|
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
|
|
}
|
|
|
|
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
|
|
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
|
|
/// Returns failure if proper stride couldn't be found.
|
|
FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
|
|
const LLVMTypeConverter &typeConverter,
|
|
MemRefType mType, Value base, Location loc) {
|
|
if (mType.getRank() < 2)
|
|
return failure();
|
|
int64_t preLast = mType.getRank() - 2;
|
|
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
|
|
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
|
|
assert(llvm::isPowerOf2_64(width) && width >= 8);
|
|
unsigned bytes = width >> 3;
|
|
int64_t offset;
|
|
SmallVector<int64_t, 4> strides;
|
|
if (failed(getStridesAndOffset(mType, strides, offset)) ||
|
|
strides.back() != 1)
|
|
return failure();
|
|
if (strides[preLast] == ShapedType::kDynamic) {
|
|
// Dynamic stride needs code to compute the stride at runtime.
|
|
MemRefDescriptor memrefDescriptor(base);
|
|
auto attr = rewriter.getI64IntegerAttr(bytes);
|
|
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
|
|
return rewriter
|
|
.create<LLVM::MulOp>(loc, llvmInt64Type, scale,
|
|
memrefDescriptor.stride(rewriter, loc, preLast))
|
|
.getResult();
|
|
}
|
|
// Use direct constant for static stride.
|
|
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
|
|
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
|
|
.getResult();
|
|
}
|
|
|
|
struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
|
|
using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
|
|
LogicalResult
|
|
matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
amx::TileType tType = op.getTileType();
|
|
// Determine m x n tile sizes.
|
|
std::pair<Value, Value> tsz =
|
|
getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
|
|
// Replace operation with intrinsic.
|
|
Type resType = typeConverter->convertType(tType);
|
|
rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
|
|
tsz.second);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
|
|
using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
MemRefType mType = op.getMemRefType();
|
|
amx::TileType tType = op.getTileType();
|
|
// Determine m x n tile sizes.
|
|
std::pair<Value, Value> tsz =
|
|
getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
|
|
// Determine stride.
|
|
auto stride = getStride(rewriter, *getTypeConverter(), mType,
|
|
adaptor.getBase(), op.getLoc());
|
|
if (failed(stride))
|
|
return failure();
|
|
// Replace operation with intrinsic.
|
|
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
|
|
adaptor.getIndices(), rewriter);
|
|
Type resType = typeConverter->convertType(tType);
|
|
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
|
|
op, resType, tsz.first, tsz.second, ptr, stride.value());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
|
|
using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
MemRefType mType = op.getMemRefType();
|
|
amx::TileType tType = op.getTileType();
|
|
// Determine m x n tile sizes.
|
|
std::pair<Value, Value> tsz =
|
|
getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
|
|
// Determine stride.
|
|
auto stride = getStride(rewriter, *getTypeConverter(), mType,
|
|
adaptor.getBase(), op.getLoc());
|
|
if (failed(stride))
|
|
return failure();
|
|
// Replace operation with intrinsic.
|
|
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
|
|
adaptor.getIndices(), rewriter);
|
|
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
|
|
op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
|
|
using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
|
|
LogicalResult
|
|
matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
amx::TileType aType = op.getLhsTileType();
|
|
amx::TileType bType = op.getRhsTileType();
|
|
amx::TileType cType = op.getTileType();
|
|
// Determine m x n x k tile sizes.
|
|
std::pair<Value, Value> tsza =
|
|
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
|
|
std::pair<Value, Value> tszb =
|
|
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
|
|
// Replace operation with intrinsic.
|
|
Type resType = typeConverter->convertType(cType);
|
|
if (aType.getElementType().isBF16())
|
|
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
|
|
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
|
adaptor.getLhs(), adaptor.getRhs());
|
|
else if (aType.getElementType().isF16())
|
|
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
|
|
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
|
adaptor.getLhs(), adaptor.getRhs());
|
|
else
|
|
llvm_unreachable("Unexpected element type for amx.mulf");
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
|
|
using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
|
|
LogicalResult
|
|
matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
amx::TileType aType = op.getLhsTileType();
|
|
amx::TileType bType = op.getRhsTileType();
|
|
amx::TileType cType = op.getTileType();
|
|
// Determine m x n x k tile sizes.
|
|
std::pair<Value, Value> tsza =
|
|
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
|
|
std::pair<Value, Value> tszb =
|
|
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
|
|
// Replace operation with intrinsic.
|
|
Type resType = typeConverter->convertType(cType);
|
|
bool zexta = op.getIsZextLhs();
|
|
bool zextb = op.getIsZextRhs();
|
|
if (zexta && zextb)
|
|
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
|
|
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
|
adaptor.getLhs(), adaptor.getRhs());
|
|
else if (zexta && !zextb)
|
|
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>(
|
|
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
|
adaptor.getLhs(), adaptor.getRhs());
|
|
else if (!zexta && zextb)
|
|
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>(
|
|
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
|
adaptor.getLhs(), adaptor.getRhs());
|
|
else
|
|
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>(
|
|
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
|
|
adaptor.getLhs(), adaptor.getRhs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateAMXLegalizeForLLVMExportPatterns(
|
|
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
|
patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
|
|
TileMulFConversion, TileMulIConversion>(converter);
|
|
converter.addConversion([&](amx::TileType type) {
|
|
return LLVM::LLVMX86AMXType::get(&converter.getContext());
|
|
});
|
|
}
|
|
|
|
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
|
|
target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
|
|
x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
|
|
x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
|
|
target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
|
|
TileMulFOp>();
|
|
}
|
|
|
|
namespace {
|
|
/// Implement the interface to convert AMX to LLVM.
|
|
struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
|
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
|
|
|
void populateConvertToLLVMConversionPatterns(
|
|
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) const final {
|
|
populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
|
|
dialect->addInterfaces<AMXToLLVMDialectInterface>();
|
|
});
|
|
}
|