Files
clang-p2996/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Tres Popp 5550c82189 [mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.

Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.

Caveats include:
- This clang-tidy script probably has more problems.
- This only touches C++ code, so nothing that is being generated.

Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants
  for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443

Implementation:
This first patch was created with the following steps. The intention is
to only do automated changes at first, so I waste less time if it's
reverted, and so the first mass change is more clear as an example to
other teams that will need to follow similar steps.

Steps are described per line, as comments are removed by git:
0. Retrieve the change from the following to build clang-tidy with an
   additional check:
   https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check
1. Build clang-tidy
2. Run clang-tidy over your entire codebase while disabling all checks
   and enabling the one relevant one. Run on all header files also.
3. Delete .inc files that were also modified, so the next build rebuilds
   them to a pure state.
4. Some changes have been deleted for the following reasons:
   - Some files had a variable also named cast
   - Some files had not included a header file that defines the cast
     functions
   - Some files are definitions of the classes that have the casting
     methods, so the code still refers to the method instead of the
     function without adding a prefix or removing the method declaration
     at the same time.

```
ninja -C $BUILD_DIR clang-tidy

run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
               -header-filter=mlir/ mlir/* -fix

rm -rf $BUILD_DIR/tools/mlir/**/*.inc

git restore mlir/lib/IR mlir/lib/Dialect/DLTI/DLTI.cpp\
            mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp\
            mlir/lib/**/IR/\
            mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp\
            mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp\
            mlir/test/lib/Dialect/Test/TestTypes.cpp\
            mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp\
            mlir/test/lib/Dialect/Test/TestAttributes.cpp\
            mlir/unittests/TableGen/EnumsGenTest.cpp\
            mlir/test/python/lib/PythonTestCAPI.cpp\
            mlir/include/mlir/IR/
```

Differential Revision: https://reviews.llvm.org/D150123
2023-05-12 11:21:25 +02:00

699 lines
30 KiB
C++

//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
/// Returns the type for the intrinsic given the vectorResultType of the
/// `gpu.mma.sync` operation.
static Type inferIntrinsicResultType(Type vectorResultType) {
MLIRContext *ctx = vectorResultType.getContext();
auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
auto i32Ty = IntegerType::get(ctx, 32);
auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
Type f64Ty = Float64Type::get(ctx);
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
Type f32Ty = Float32Type::get(ctx);
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
if (a.getElementType() == f16x2Ty) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
}
if (a.getElementType() == i32x2Ty) {
return LLVM::LLVMStructType::getLiteral(
ctx,
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
}
if (a.getElementType() == f64x2Ty) {
return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
}
if (a.getElementType() == f32x2Ty) {
return LLVM::LLVMStructType::getLiteral(
ctx,
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
}
if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
}
return vectorResultType;
}
/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
/// always an LLVM struct) into a fragment that is compatible with the vector
/// type of this operation. This involves extracting elements from the struct
/// and inserting them into an LLVM array. These extra data-movement
/// operations should be canonicalized away by the LLVM backend.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
Type resultType, Value intrinsicResult,
RewriterBase &rewriter) {
MLIRContext *ctx = rewriter.getContext();
auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
Type i32Ty = rewriter.getI32Type();
Type f32Ty = rewriter.getF32Type();
Type f64Ty = rewriter.getF64Type();
Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
auto makeConst = [&](int32_t index) -> Value {
return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
rewriter.getI32IntegerAttr(index));
};
if (arrayType) {
SmallVector<Value, 4> elements;
// The intrinsic returns 32-bit wide elements in a form which can be
// directly bitcasted and inserted into the result vector.
if (arrayType.getElementType() == f16x2Ty ||
arrayType.getElementType() == f32x1Ty) {
for (unsigned i = 0; i < structType.getBody().size(); i++) {
Value el =
rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
el = rewriter.createOrFold<LLVM::BitcastOp>(
loc, arrayType.getElementType(), el);
elements.push_back(el);
}
}
// The intrinsic returns i32, f64, and f32 values as individual scalars,
// even when the result is notionally a 64-bit wide element (e.g. f32x2). We
// need to extract them from the struct and pack them into the 64-bit wide
// rows of the vector result.
if (arrayType.getElementType() == i32x2Ty ||
arrayType.getElementType() == f64x2Ty ||
arrayType.getElementType() == f32x2Ty) {
for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
Value vec =
rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
Value x1 =
rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
i * 2 + 1);
vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
x1, makeConst(0));
vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
x2, makeConst(1));
elements.push_back(vec);
}
}
// Create the final vectorized result.
Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
for (const auto &el : llvm::enumerate(elements)) {
result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
el.index());
}
return result;
}
return intrinsicResult;
}
/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
/// given as 2D `vectors` where the rows are 32b or 64b wide. The
/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
/// scalars of certain types. This function helps unpack the `vector` arguments
/// and cast them to the types expected by `nvvm.mma.sync`.
static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
Location loc, Value operand,
NVVM::MMATypes operandPtxType) {
SmallVector<Value> result;
Type i32Ty = rewriter.getI32Type();
Type f64Ty = rewriter.getF64Type();
Type f32Ty = rewriter.getF32Type();
Type i8Ty = rewriter.getI8Type();
Type i4Ty = rewriter.getIntegerType(4);
Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
// For 4xi8 vectors, the intrinsic expects these to be provided as i32
// scalar types.
if (arrayTy.getElementType() == i8x4Ty ||
arrayTy.getElementType() == i4x8Ty ||
(arrayTy.getElementType() == f32x1Ty &&
operandPtxType == NVVM::MMATypes::tf32)) {
result.push_back(
rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
continue;
}
// For some element types (i32, f32, f64), we need to unpack the inner
// vector/array type as well because the intrinsic expects individual
// scalars to be provided.
VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
innerArrayTy.getElementType() == f64Ty ||
innerArrayTy.getElementType() == f32Ty)) {
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
idx < innerSize; idx++) {
result.push_back(rewriter.create<LLVM::ExtractElementOp>(
loc, toUse,
rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
}
continue;
}
result.push_back(toUse);
}
return result;
}
namespace {
struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = getContext();
Location loc = op->getLoc();
// The result type of ldmatrix will always be a struct of 32bit integer
// registers if more than one 32bit value is returned. Otherwise, the result
// is a single i32. The result type of the GPU operation is always a vector
// of shape (NumRegisters, VectorRegister) where VectorRegister is the
// vector type of the result and always 32 bits long. We bitcast the result
// of the NVVM::LdMatrix to this vector type.
auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
if (!vectorResultType) {
return failure();
}
Type innerVectorType = LLVM::getFixedVectorType(
vectorResultType.getElementType(), vectorResultType.getDimSize(1));
int64_t num32BitRegs = vectorResultType.getDimSize(0);
Type ldMatrixResultType;
if (num32BitRegs > 1) {
ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
} else {
ldMatrixResultType = rewriter.getI32Type();
}
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
Value srcPtr =
getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
adaptor.getIndices(), rewriter);
Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
loc, ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
: NVVM::MMALayout::row);
// The ldmatrix operation returns either a single i32 value or a struct of
// i32 values. Here we unpack those values and cast them back to their
// actual vector type (still of width 32b) and repack them into a result
// struct.
Type finalResultType = typeConverter->convertType(vectorResultType);
Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
Value i32Register =
num32BitRegs > 1
? rewriter.create<LLVM::ExtractValueOp>(loc, ldMatrixResult, i)
: ldMatrixResult;
Value casted =
rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
result = rewriter.create<LLVM::InsertValueOp>(loc, result, casted, i);
}
rewriter.replaceOp(op, result);
return success();
}
};
/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
/// enum).
static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
Type elType = getElementTypeOrSelf(t);
if (elType.isInteger(8))
return NVVM::MMATypes::s8;
if (elType.isInteger(4))
return NVVM::MMATypes::s4;
if (elType.isF16())
return NVVM::MMATypes::f16;
if (elType.isF64())
return NVVM::MMATypes::f64;
if (elType.isF32())
return NVVM::MMATypes::tf32;
return failure();
}
struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
// Get the shapes of the MMAMatrix type being used. The shapes will
// choose which intrinsic this op will be lowered to.
VectorType aType = op.getMatrixA().getType();
VectorType bType = op.getMatrixA().getType();
VectorType cType = op.getMatrixC().getType();
std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
// Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
if (aType.getElementType().isF32() && !tf32Enabled)
return failure();
FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
if (failed(ptxTypeA))
return op->emitOpError("failed to deduce operand PTX types");
FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
if (failed(ptxTypeB))
return op->emitOpError("failed to deduce operand PTX types");
std::optional<NVVM::MMATypes> ptxTypeC =
NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
/*isAccumulator=*/true);
if (!ptxTypeC)
return op->emitError(
"could not infer the PTX type for the accumulator/result");
// TODO: add an attribute to the op to customize this behavior.
std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
if (isa<IntegerType>(aType.getElementType()))
overflow = NVVM::MMAIntOverflow::satfinite;
SmallVector<Value> matA =
unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
SmallVector<Value> matB =
unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
SmallVector<Value> matC =
unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
typeConverter->convertType(op->getResultTypes()[0]));
Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
op.getLoc(), intrinsicResTy, matA, matB, matC,
/*shape=*/gemmShape,
/*b1Op=*/std::nullopt,
/*intOverflow=*/overflow,
/*multiplicandPtxTypes=*/
std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
/*multiplicandLayouts=*/
std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
NVVM::MMALayout::col});
rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
desiredRetTy, intrinsicResult,
rewriter));
return success();
}
};
struct ConvertNVGPUToNVVMPass
: public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
using Base::Base;
void runOnOperation() override {
LowerToLLVMOptions options(&getContext());
options.useOpaquePointers = useOpaquePointers;
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext(), options);
/// device-side async tokens cannot be materialized in nvvm. We just convert
/// them to a dummy i32 type in order to easily drop them during conversion.
converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
return converter.convertType(IntegerType::get(type.getContext(), 32));
});
populateNVGPUToNVVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
Value dstBytes, Value srcElements,
mlir::MemRefType elementType,
ConversionPatternRewriter &rewriter) {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
const char *cpAsyncCgStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";
const char *cpAsyncCaStr = "cp.async.ca.shared.global [$0], [$1], $2, $3;\n";
const char *asmConstraints = "r,l,n,r";
Value c3I32 = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
Value bitwidth = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(elementType.getElementTypeBitWidth()));
Value srcElementsI32 =
rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcElements);
Value srcBytes = rewriter.create<LLVM::LShrOp>(
loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32), c3I32);
SmallVector<Value> asmVals{dstPtr, srcPtr, dstBytes, srcBytes};
// Pick the right asm string based on the dstBytes which is a compile-time
// constant.
auto dstByteConstOp =
dyn_cast<mlir::LLVM::ConstantOp>(dstBytes.getDefiningOp());
auto dstByteAttr = dyn_cast<mlir::IntegerAttr>(dstByteConstOp.getValue());
int64_t dstByteVal = dstByteAttr.getValue().getSExtValue();
assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) &&
"cp.async byte copy size must be 4, 8 or 16");
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
// 16 dst bytes.
const char *asmStr = (dstByteVal == 16) ? cpAsyncCgStr : cpAsyncCaStr;
rewriter.create<LLVM::InlineAsmOp>(
loc, LLVM::LLVMVoidType::get(rewriter.getContext()),
/*operands=*/asmVals,
/*asm_string=*/asmStr,
/*constraints=*/asmConstraints, /*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
}
/// Returns the constraints for the sparse MMA inline assembly instruction.
static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
unsigned matBSize,
unsigned matCSize) {
std::string str;
llvm::raw_string_ostream ss(str);
for (unsigned i = 0; i < matCSize; i++)
ss << "=r,";
for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
ss << "r,";
// The final operand is for the sparsity metadata.
// The sparsity selector appears as direct literal.
ss << "r";
ss.flush();
return str;
}
/// Returns the string for the `mma.sp.sync` instruction that corresponds to
/// the given parameters. Note that this function doesn't do any validation,
/// it's expected that the provided parameters correspond to a valid
/// instruction.
static std::string buildMmaSparseAsmString(
const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
return NVVM::stringifyMMATypes(ptxType);
};
std::string asmStr;
llvm::raw_string_ostream ss(asmStr);
ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
<< shape[2] << ".row.col.";
if (overflow)
ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
<< ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
unsigned asmArgIdx = 0;
// The operand string is structured into sections `{matC elements...},
// {matA elements...}, {matB elements...}, {matC elements}`.
for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
ss << "{";
for (unsigned i = 0; i < arrSize; i++)
ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
ss << "},";
}
ss << "$" << asmArgIdx++ << ",";
assert(metaDataSelector <= 1);
ss << "0x" << metaDataSelector << ";";
ss.flush();
return asmStr;
}
/// Builds an inline assembly operation corresponding to the specified MMA
/// sparse sync operation.
static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
int64_t metadataSelector, const std::array<int64_t, 3> &shape,
Type intrinsicResultType, ConversionPatternRewriter &rewriter) {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
const unsigned matASize = unpackedAData.size();
const unsigned matBSize = unpackedB.size();
const unsigned matCSize = unpackedC.size();
std::string asmStr = buildMmaSparseAsmString(
shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
ptxTypeD, overflow, metadataSelector);
std::string constraintStr =
buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
SmallVector<Value> asmVals;
asmVals.reserve(matASize + matBSize + matCSize + 1);
for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
llvm::append_range(asmVals, args);
asmVals.push_back(indexData);
return rewriter.create<LLVM::InlineAsmOp>(loc,
/*resultTypes=*/intrinsicResultType,
/*operands=*/asmVals,
/*asm_string=*/asmStr,
/*constraints=*/constraintStr,
/*has_side_effects=*/true,
/*is_align_stack=*/false,
/*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
}
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
struct NVGPUMmaSparseSyncLowering
: public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
// Get the shapes of the MMAMatrix type being used. The shapes will
// choose which intrinsic this op will be lowered to.
VectorType aType = op.getMatrixA().getType();
VectorType bType = op.getMatrixB().getType();
VectorType cType = op.getMatrixC().getType();
FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
if (failed(ptxTypeA))
return op->emitOpError("failed to deduce operand PTX types");
FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
if (failed(ptxTypeB))
return op->emitOpError("failed to deduce operand PTX types");
std::optional<NVVM::MMATypes> ptxTypeC =
NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
/*isAccumulator=*/true);
if (!ptxTypeC)
return op->emitError(
"could not infer the PTX type for the accumulator/result");
// Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
if (aType.getElementType().isF32() && !tf32Enabled)
return failure();
// TODO: add an attribute to the op to customize this behavior.
std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
if (isa<IntegerType>(aType.getElementType()))
overflow = NVVM::MMAIntOverflow::satfinite;
SmallVector<Value> matA =
unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
SmallVector<Value> matB =
unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
SmallVector<Value> matC =
unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
typeConverter->convertType(op->getResultTypes()[0]));
// Bitcast the sparse metadata from vector<2xf16> to an i32.
Value sparseMetadata = adaptor.getSparseMetadata();
if (sparseMetadata.getType() !=
LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
sparseMetadata = rewriter.create<LLVM::BitcastOp>(
loc, rewriter.getI32Type(), sparseMetadata);
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
intrinsicResTy, rewriter);
if (failed(intrinsicResult))
return failure();
assert((*intrinsicResult).getNumResults() == 1 &&
"expected inline asm op returns a single LLVM struct type");
rewriter.replaceOp(
op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
(*intrinsicResult)->getResult(0), rewriter));
return success();
}
};
struct NVGPUAsyncCopyLowering
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
using ConvertOpToLLVMPattern<
nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
adaptor.getDstIndices(), rewriter);
auto i8Ty = IntegerType::get(op.getContext(), 8);
FailureOr<unsigned> dstAddressSpace =
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
if (failed(dstAddressSpace))
return rewriter.notifyMatchFailure(
loc, "destination memref address space not convertible to integer");
auto dstPointerType =
getTypeConverter()->getPointerType(i8Ty, *dstAddressSpace);
if (!getTypeConverter()->useOpaquePointers())
dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
FailureOr<unsigned> srcAddressSpace =
getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
if (failed(srcAddressSpace))
return rewriter.notifyMatchFailure(
loc, "source memref address space not convertible to integer");
Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
adaptor.getSrcIndices(), rewriter);
auto srcPointerType =
getTypeConverter()->getPointerType(i8Ty, *srcAddressSpace);
if (!getTypeConverter()->useOpaquePointers())
scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = getTypeConverter()->getPointerType(
i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
scrPtr);
int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
// bypass L1 is only supported for byte sizes of 16, we drop the hint
// otherwise.
UnitAttr bypassL1 =
sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr();
// When the optional SrcElements argument is present, the source (global
// memory) of CpAsyncOp is read only for SrcElements number of elements. The
// rest of the DstElements in the destination (shared memory) are filled
// with zeros.
if (op.getSrcElements())
emitCpAsyncOpZfillAsm(loc, dstPtr, scrPtr,
rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(sizeInBytes)),
adaptor.getSrcElements(), srcMemrefType, rewriter);
// When the optional SrcElements argument is *not* present, the regular
// CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
// memory) to fill DstElements number of elements in the destination (shared
// memory).
else
rewriter.create<NVVM::CpAsyncOp>(loc, dstPtr, scrPtr,
rewriter.getI32IntegerAttr(sizeInBytes),
bypassL1);
// Drop the result token.
Value zero = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), IntegerType::get(op.getContext(), 32),
rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
};
struct NVGPUAsyncCreateGroupLowering
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
using ConvertOpToLLVMPattern<
nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
// Drop the result token.
Value zero = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), IntegerType::get(op.getContext(), 32),
rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
};
struct NVGPUAsyncWaitLowering
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
using ConvertOpToLLVMPattern<
nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// If numGroup is not present pick 0 as a conservative correct value.
int32_t numGroups = adaptor.getNumGroups().value_or(0);
rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
rewriter.eraseOp(op);
return success();
}
};
} // namespace
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
}