The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Caveats include: - This clang-tidy script probably has more problems. - This only touches C++ code, so nothing that is being generated. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This first patch was created with the following steps. The intention is to only do automated changes at first, so I waste less time if it's reverted, and so the first mass change is more clear as an example to other teams that will need to follow similar steps. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. 4. Some changes have been deleted for the following reasons: - Some files had a variable also named cast - Some files had not included a header file that defines the cast functions - Some files are definitions of the classes that have the casting methods, so the code still refers to the method instead of the function without adding a prefix or removing the method declaration at the same time. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc git restore mlir/lib/IR mlir/lib/Dialect/DLTI/DLTI.cpp\ mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp\ mlir/lib/**/IR/\ mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp\ mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp\ mlir/test/lib/Dialect/Test/TestTypes.cpp\ mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp\ mlir/test/lib/Dialect/Test/TestAttributes.cpp\ mlir/unittests/TableGen/EnumsGenTest.cpp\ mlir/test/python/lib/PythonTestCAPI.cpp\ mlir/include/mlir/IR/ ``` Differential Revision: https://reviews.llvm.org/D150123
566 lines
24 KiB
C++
566 lines
24 KiB
C++
//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
|
|
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include <optional>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::amdgpu;
|
|
|
|
static Value createI32Constant(ConversionPatternRewriter &rewriter,
|
|
Location loc, int32_t value) {
|
|
Type llvmI32 = rewriter.getI32Type();
|
|
return rewriter.createOrFold<LLVM::ConstantOp>(loc, llvmI32, value);
|
|
}
|
|
|
|
namespace {
|
|
/// Define lowering patterns for raw buffer ops
|
|
template <typename GpuOp, typename Intrinsic>
|
|
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
|
|
RawBufferOpLowering(LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
static constexpr uint32_t maxVectorOpWidth = 128;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = gpuOp.getLoc();
|
|
Value memref = adaptor.getMemref();
|
|
Value unconvertedMemref = gpuOp.getMemref();
|
|
MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
|
|
|
|
if (chipset.majorVersion < 9)
|
|
return gpuOp.emitOpError("Raw buffer ops require GCN or higher");
|
|
|
|
Value storeData = adaptor.getODSOperands(0)[0];
|
|
if (storeData == memref) // no write component to this op
|
|
storeData = Value();
|
|
Type wantedDataType;
|
|
if (storeData)
|
|
wantedDataType = storeData.getType();
|
|
else
|
|
wantedDataType = gpuOp.getODSResults(0)[0].getType();
|
|
|
|
Value atomicCmpData = Value();
|
|
// Operand index 1 of a load is the indices, trying to read them can crash.
|
|
if (storeData) {
|
|
Value maybeCmpData = adaptor.getODSOperands(1)[0];
|
|
if (maybeCmpData != memref)
|
|
atomicCmpData = maybeCmpData;
|
|
}
|
|
|
|
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
|
|
|
|
Type i32 = rewriter.getI32Type();
|
|
Type llvmI32 = this->typeConverter->convertType(i32);
|
|
|
|
int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
|
|
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
|
|
|
|
// If we want to load a vector<NxT> with total size <= 32
|
|
// bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
|
|
// and the total load size is >= 32, use a vector load of N / (bitsize(T) /
|
|
// 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
|
|
// so bitcast any floats to integers.
|
|
Type llvmBufferValType = llvmWantedDataType;
|
|
if (atomicCmpData) {
|
|
if (isa<VectorType>(wantedDataType))
|
|
return gpuOp.emitOpError("vector compare-and-swap does not exist");
|
|
if (auto floatType = dyn_cast<FloatType>(wantedDataType))
|
|
llvmBufferValType = this->getTypeConverter()->convertType(
|
|
rewriter.getIntegerType(floatType.getWidth()));
|
|
}
|
|
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
|
|
uint32_t elemBits = dataVector.getElementTypeBitWidth();
|
|
uint32_t totalBits = elemBits * dataVector.getNumElements();
|
|
if (totalBits > maxVectorOpWidth)
|
|
return gpuOp.emitOpError(
|
|
"Total width of loads or stores must be no more than " +
|
|
Twine(maxVectorOpWidth) + " bits, but we call for " +
|
|
Twine(totalBits) +
|
|
" bits. This should've been caught in validation");
|
|
if (elemBits < 32) {
|
|
if (totalBits > 32) {
|
|
if (totalBits % 32 != 0)
|
|
return gpuOp.emitOpError("Load or store of more than 32-bits that "
|
|
"doesn't fit into words. Can't happen\n");
|
|
llvmBufferValType = this->typeConverter->convertType(
|
|
VectorType::get(totalBits / 32, i32));
|
|
} else {
|
|
llvmBufferValType = this->typeConverter->convertType(
|
|
rewriter.getIntegerType(totalBits));
|
|
}
|
|
}
|
|
}
|
|
|
|
SmallVector<Value, 6> args;
|
|
if (storeData) {
|
|
if (llvmBufferValType != llvmWantedDataType) {
|
|
Value castForStore =
|
|
rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
|
|
args.push_back(castForStore);
|
|
} else {
|
|
args.push_back(storeData);
|
|
}
|
|
}
|
|
|
|
if (atomicCmpData) {
|
|
if (llvmBufferValType != llvmWantedDataType) {
|
|
Value castForCmp = rewriter.create<LLVM::BitcastOp>(
|
|
loc, llvmBufferValType, atomicCmpData);
|
|
args.push_back(castForCmp);
|
|
} else {
|
|
args.push_back(atomicCmpData);
|
|
}
|
|
}
|
|
|
|
// Construct buffer descriptor from memref, attributes
|
|
int64_t offset = 0;
|
|
SmallVector<int64_t, 5> strides;
|
|
if (failed(getStridesAndOffset(memrefType, strides, offset)))
|
|
return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
|
|
|
|
// Resource descriptor
|
|
// bits 0-47: base address
|
|
// bits 48-61: stride (0 for raw buffers)
|
|
// bit 62: texture cache coherency (always 0)
|
|
// bit 63: enable swizzles (always off for raw buffers)
|
|
// bits 64-95 (word 2): Number of records, units of stride
|
|
// bits 96-127 (word 3): See below
|
|
|
|
Type llvm4xI32 = this->typeConverter->convertType(VectorType::get(4, i32));
|
|
MemRefDescriptor memrefDescriptor(memref);
|
|
Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type());
|
|
Value c32I64 = rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmI64, rewriter.getI64IntegerAttr(32));
|
|
|
|
Value resource = rewriter.create<LLVM::UndefOp>(loc, llvm4xI32);
|
|
|
|
Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
|
|
Value ptrAsInt = rewriter.create<LLVM::PtrToIntOp>(loc, llvmI64, ptr);
|
|
Value lowHalf = rewriter.create<LLVM::TruncOp>(loc, llvmI32, ptrAsInt);
|
|
resource = rewriter.create<LLVM::InsertElementOp>(
|
|
loc, llvm4xI32, resource, lowHalf,
|
|
this->createIndexConstant(rewriter, loc, 0));
|
|
|
|
// Bits 48-63 are used both for the stride of the buffer and (on gfx10) for
|
|
// enabling swizzling. Prevent the high bits of pointers from accidentally
|
|
// setting those flags.
|
|
Value highHalfShifted = rewriter.create<LLVM::TruncOp>(
|
|
loc, llvmI32, rewriter.create<LLVM::LShrOp>(loc, ptrAsInt, c32I64));
|
|
Value highHalfTruncated = rewriter.create<LLVM::AndOp>(
|
|
loc, llvmI32, highHalfShifted,
|
|
createI32Constant(rewriter, loc, 0x0000ffff));
|
|
resource = rewriter.create<LLVM::InsertElementOp>(
|
|
loc, llvm4xI32, resource, highHalfTruncated,
|
|
this->createIndexConstant(rewriter, loc, 1));
|
|
|
|
Value numRecords;
|
|
if (memrefType.hasStaticShape()) {
|
|
numRecords = createI32Constant(
|
|
rewriter, loc,
|
|
static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
|
|
} else {
|
|
Value maxIndex;
|
|
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
|
|
Value size = memrefDescriptor.size(rewriter, loc, i);
|
|
Value stride = memrefDescriptor.stride(rewriter, loc, i);
|
|
stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
|
|
Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
|
|
maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
|
|
maxThisDim)
|
|
: maxThisDim;
|
|
}
|
|
numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
|
|
}
|
|
resource = rewriter.create<LLVM::InsertElementOp>(
|
|
loc, llvm4xI32, resource, numRecords,
|
|
this->createIndexConstant(rewriter, loc, 2));
|
|
|
|
// Final word:
|
|
// bits 0-11: dst sel, ignored by these intrinsics
|
|
// bits 12-14: data format (ignored, must be nonzero, 7=float)
|
|
// bits 15-18: data format (ignored, must be nonzero, 4=32bit)
|
|
// bit 19: In nested heap (0 here)
|
|
// bit 20: Behavior on unmap (0 means "return 0 / ignore")
|
|
// bits 21-22: Index stride for swizzles (N/A)
|
|
// bit 23: Add thread ID (0)
|
|
// bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
|
|
// bits 25-26: Reserved (0)
|
|
// bit 27: Buffer is non-volatile (CDNA only)
|
|
// bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
|
|
// none, 3 = either swizzles or testing against offset field) RDNA only
|
|
// bits 30-31: Type (must be 0)
|
|
uint32_t word3 = (7 << 12) | (4 << 15);
|
|
if (chipset.majorVersion == 10) {
|
|
word3 |= (1 << 24);
|
|
uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
|
|
word3 |= (oob << 28);
|
|
}
|
|
Value word3Const = createI32Constant(rewriter, loc, word3);
|
|
resource = rewriter.create<LLVM::InsertElementOp>(
|
|
loc, llvm4xI32, resource, word3Const,
|
|
this->createIndexConstant(rewriter, loc, 3));
|
|
args.push_back(resource);
|
|
|
|
// Indexing (voffset)
|
|
Value voffset = createI32Constant(rewriter, loc, 0);
|
|
for (auto pair : llvm::enumerate(adaptor.getIndices())) {
|
|
size_t i = pair.index();
|
|
Value index = pair.value();
|
|
Value strideOp;
|
|
if (ShapedType::isDynamic(strides[i])) {
|
|
strideOp = rewriter.create<LLVM::MulOp>(
|
|
loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
|
|
} else {
|
|
strideOp =
|
|
createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
|
|
}
|
|
index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
|
|
voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
|
|
}
|
|
if (adaptor.getIndexOffset()) {
|
|
int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
|
|
Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
|
|
voffset =
|
|
voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
|
|
: extraOffsetConst;
|
|
}
|
|
args.push_back(voffset);
|
|
|
|
Value sgprOffset = adaptor.getSgprOffset();
|
|
if (!sgprOffset)
|
|
sgprOffset = createI32Constant(rewriter, loc, 0);
|
|
if (ShapedType::isDynamic(offset))
|
|
sgprOffset = rewriter.create<LLVM::AddOp>(
|
|
loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
|
|
else if (offset > 0)
|
|
sgprOffset = rewriter.create<LLVM::AddOp>(
|
|
loc, sgprOffset, createI32Constant(rewriter, loc, offset));
|
|
args.push_back(sgprOffset);
|
|
|
|
// bit 0: GLC = 0 (atomics drop value, less coherency)
|
|
// bits 1-2: SLC, DLC = 0 (similarly)
|
|
// bit 3: swizzled (0 for raw)
|
|
args.push_back(createI32Constant(rewriter, loc, 0));
|
|
|
|
llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
|
|
llvmBufferValType);
|
|
Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
|
|
ArrayRef<NamedAttribute>());
|
|
if (lowered->getNumResults() == 1) {
|
|
Value replacement = lowered->getResult(0);
|
|
if (llvmBufferValType != llvmWantedDataType) {
|
|
replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
|
|
replacement);
|
|
}
|
|
rewriter.replaceOp(gpuOp, replacement);
|
|
} else {
|
|
rewriter.eraseOp(gpuOp);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
|
|
using ConvertOpToLLVMPattern<LDSBarrierOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
|
LLVM::AsmDialect::AD_ATT);
|
|
const char *asmStr = "s_waitcnt lgkmcnt(0)\ns_barrier";
|
|
const char *constraints = "";
|
|
rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
|
|
op,
|
|
/*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
|
|
/*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
|
|
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
|
|
/*operand_attrs=*/ArrayAttr());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
/// If `input` is a vector of bytes, concatentate those bytes in little-endian
|
|
/// order to form a single integer of size 8 * [vector length]. This works
|
|
/// around a wart in the AMDGPU intrinsics where operations that logically take
|
|
/// vectors of bytes instead integers. Since we do not want to expose this
|
|
/// implementation detail to MLIR, we correct for it here.
|
|
static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value input) {
|
|
Type inputType = input.getType();
|
|
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
|
|
if (!vectorType.getElementType().isInteger(8))
|
|
return input;
|
|
int64_t numBytes = vectorType.getNumElements();
|
|
Type destType = rewriter.getIntegerType(numBytes * 8);
|
|
Value result = rewriter.createOrFold<LLVM::ConstantOp>(
|
|
loc, destType, rewriter.getIntegerAttr(destType, 0));
|
|
for (int64_t i = 0; i < numBytes; ++i) {
|
|
Value idxConst = createI32Constant(rewriter, loc, i);
|
|
Value element =
|
|
rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
|
|
Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
|
|
Value shiftConst = rewriter.createOrFold<LLVM::ConstantOp>(
|
|
loc, destType, rewriter.getIntegerAttr(destType, i * 8));
|
|
Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
|
|
result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
|
|
}
|
|
return result;
|
|
}
|
|
return input;
|
|
}
|
|
|
|
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
|
|
/// if one exists. This includes checking to ensure the intrinsic is supported
|
|
/// on the architecture you are compiling for.
|
|
static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
|
|
Chipset chipset) {
|
|
uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
|
|
b = mfma.getBlocks();
|
|
Type sourceElem = mfma.getSourceA().getType();
|
|
if (auto sourceType = dyn_cast<VectorType>(sourceElem))
|
|
sourceElem = sourceType.getElementType();
|
|
Type destElem = mfma.getDestC().getType();
|
|
if (auto destType = dyn_cast<VectorType>(destElem))
|
|
destElem = destType.getElementType();
|
|
|
|
if (sourceElem.isF32() && destElem.isF32()) {
|
|
if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) {
|
|
if (m == 32 && n == 32 && k == 4 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
|
|
if (m == 16 && n == 16 && k == 8 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
|
|
}
|
|
if (m == 32 && n == 32 && k == 1 && b == 2)
|
|
return ROCDL::mfma_f32_32x32x1f32::getOperationName();
|
|
if (m == 16 && n == 16 && k == 1 && b == 4)
|
|
return ROCDL::mfma_f32_16x16x1f32::getOperationName();
|
|
if (m == 4 && n == 4 && k == 1 && b == 16)
|
|
return ROCDL::mfma_f32_4x4x1f32::getOperationName();
|
|
if (m == 32 && n == 32 && k == 2 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x2f32::getOperationName();
|
|
if (m == 16 && n == 16 && k == 4 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x4f32::getOperationName();
|
|
}
|
|
|
|
if (sourceElem.isF16() && destElem.isF32()) {
|
|
if (m == 32 && n == 32 && k == 4 && b == 2)
|
|
return ROCDL::mfma_f32_32x32x4f16::getOperationName();
|
|
if (m == 16 && n == 16 && k == 4 && b == 4)
|
|
return ROCDL::mfma_f32_16x16x4f16::getOperationName();
|
|
if (m == 4 && n == 4 && k == 4 && b == 16)
|
|
return ROCDL::mfma_f32_4x4x4f16::getOperationName();
|
|
if (m == 32 && n == 32 && k == 8 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x8f16::getOperationName();
|
|
if (m == 16 && n == 16 && k == 16 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x16f16::getOperationName();
|
|
}
|
|
|
|
if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) {
|
|
if (m == 32 && n == 32 && k == 4 && b == 2)
|
|
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
|
|
if (m == 16 && n == 16 && k == 4 && b == 4)
|
|
return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
|
|
if (m == 4 && n == 4 && k == 4 && b == 16)
|
|
return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
|
|
if (m == 32 && n == 32 && k == 8 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
|
|
if (m == 16 && n == 16 && k == 16 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
|
|
}
|
|
|
|
if (sourceElem.isBF16() && destElem.isF32()) {
|
|
if (m == 32 && n == 32 && k == 2 && b == 2)
|
|
return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
|
|
if (m == 16 && n == 16 && k == 2 && b == 4)
|
|
return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
|
|
if (m == 4 && n == 4 && k == 2 && b == 16)
|
|
return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
|
|
if (m == 32 && n == 32 && k == 4 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
|
|
if (m == 16 && n == 16 && k == 8 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
|
|
}
|
|
|
|
if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
|
|
if (m == 32 && n == 32 && k == 4 && b == 2)
|
|
return ROCDL::mfma_i32_32x32x4i8::getOperationName();
|
|
if (m == 16 && n == 16 && k == 4 && b == 4)
|
|
return ROCDL::mfma_i32_16x16x4i8::getOperationName();
|
|
if (m == 4 && n == 4 && k == 4 && b == 16)
|
|
return ROCDL::mfma_i32_4x4x4i8::getOperationName();
|
|
if (m == 32 && n == 32 && k == 8 && b == 1)
|
|
return ROCDL::mfma_i32_32x32x8i8::getOperationName();
|
|
if (m == 16 && n == 16 && k == 16 && b == 1)
|
|
return ROCDL::mfma_i32_16x16x16i8::getOperationName();
|
|
if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40)
|
|
return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
|
|
if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40)
|
|
return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
|
|
}
|
|
|
|
if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) {
|
|
if (m == 16 && n == 16 && k == 4 && b == 1)
|
|
return ROCDL::mfma_f64_16x16x4f64::getOperationName();
|
|
if (m == 4 && n == 4 && k == 4 && b == 4)
|
|
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
|
|
}
|
|
|
|
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
|
|
chipset.minorVersion >= 0x40) {
|
|
// Known to be correct because there are no scalar f8 instructions and
|
|
// because a length mismatch will have been caught by the verifier.
|
|
Type sourceBElem =
|
|
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
|
|
if (m == 16 && n == 16 && k == 32 && b == 1) {
|
|
if (sourceBElem.isFloat8E5M2FNUZ())
|
|
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
|
|
if (sourceBElem.isFloat8E4M3FNUZ())
|
|
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
|
|
}
|
|
if (m == 32 && n == 32 && k == 16 && b == 1) {
|
|
if (sourceBElem.isFloat8E5M2FNUZ())
|
|
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
|
|
if (sourceBElem.isFloat8E4M3FNUZ())
|
|
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
|
|
}
|
|
}
|
|
|
|
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
|
|
chipset.minorVersion >= 0x40) {
|
|
Type sourceBElem =
|
|
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
|
|
if (m == 16 && n == 16 && k == 32 && b == 1) {
|
|
if (sourceBElem.isFloat8E5M2FNUZ())
|
|
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
|
|
if (sourceBElem.isFloat8E4M3FNUZ())
|
|
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
|
|
}
|
|
if (m == 32 && n == 32 && k == 16 && b == 1) {
|
|
if (sourceBElem.isFloat8E5M2FNUZ())
|
|
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
|
|
if (sourceBElem.isFloat8E4M3FNUZ())
|
|
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
|
|
}
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
namespace {
|
|
struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
|
|
MFMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
Type outType = typeConverter->convertType(op.getDestD().getType());
|
|
|
|
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
|
|
return op->emitOpError("MFMA only supported on gfx908+");
|
|
uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
|
|
if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
|
|
if (chipset.minorVersion < 0x40)
|
|
return op.emitOpError("negation unsupported on older than gfx840");
|
|
getBlgpField |=
|
|
op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
|
|
}
|
|
std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
|
|
if (!maybeIntrinsic.has_value())
|
|
return op.emitOpError("no intrinsic matching MFMA size on given chipset");
|
|
OperationState loweredOp(loc, *maybeIntrinsic);
|
|
loweredOp.addTypes(outType);
|
|
loweredOp.addOperands(
|
|
{mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
|
|
mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
|
|
adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
|
|
createI32Constant(rewriter, loc, op.getAbid()),
|
|
createI32Constant(rewriter, loc, getBlgpField)});
|
|
Operation *lowered = rewriter.create(loweredOp);
|
|
rewriter.replaceOp(op, lowered->getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertAMDGPUToROCDLPass
|
|
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
|
|
ConvertAMDGPUToROCDLPass() = default;
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *ctx = &getContext();
|
|
FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
|
|
if (failed(maybeChipset)) {
|
|
emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
|
|
return signalPassFailure();
|
|
}
|
|
|
|
RewritePatternSet patterns(ctx);
|
|
LLVMTypeConverter converter(ctx);
|
|
populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
|
|
LLVMConversionTarget target(getContext());
|
|
target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
|
|
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
|
|
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
|
RewritePatternSet &patterns,
|
|
Chipset chipset) {
|
|
// ROCDL supports fp8 types in some contexts, but there is no LLVM-level f8
|
|
// type. Therefore, for this target, declare f8 to be equal to i8.
|
|
converter.addConversion([](FloatType type) -> std::optional<Type> {
|
|
if (type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ())
|
|
return IntegerType::get(type.getContext(), 8);
|
|
return std::nullopt;
|
|
});
|
|
|
|
patterns.add<LDSBarrierOpLowering>(converter);
|
|
patterns.add<
|
|
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
|
|
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawBufferStoreOp>,
|
|
RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>,
|
|
RawBufferOpLowering<RawBufferAtomicFmaxOp, ROCDL::RawBufferAtomicFMaxOp>,
|
|
RawBufferOpLowering<RawBufferAtomicSmaxOp, ROCDL::RawBufferAtomicSMaxOp>,
|
|
RawBufferOpLowering<RawBufferAtomicUminOp, ROCDL::RawBufferAtomicUMinOp>,
|
|
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
|
|
ROCDL::RawBufferAtomicCmpSwap>,
|
|
MFMAOpLowering>(converter, chipset);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
|
|
return std::make_unique<ConvertAMDGPUToROCDLPass>();
|
|
}
|