Currently, `phaseParity` argument of `nvgpu.mbarrier.try_wait.parity` is index. This can cause a problem if it's passed any value different than 0 or 1. Because the PTX instruction only accepts even or odd phase. This PR makes phaseParity argument i1 to avoid misuse. Here is the information from PTX doc: ``` The .parity variant of the instructions test for the completion of the phase indicated by the operand phaseParity, which is the integer parity of either the current phase or the immediately preceding phase of the mbarrier object. An even phase has integer parity 0 and an odd phase has integer parity of 1. So the valid values of phaseParity operand are 0 and 1. ``` See for more information: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
1697 lines
72 KiB
C++
1697 lines
72 KiB
C++
//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
|
|
|
|
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
|
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <optional>
|
|
|
|
#define DEBUG_TYPE "nvgpu-to-nvvm"
|
|
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
|
|
#define DBGSE() (llvm::dbgs())
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
/// Number of bits that needs to be excluded when building matrix descriptor for
|
|
/// wgmma operations.
|
|
constexpr int exclude4LSB = 4;
|
|
|
|
/// GPU has 32 bit registers, this function truncates values when larger width
|
|
/// is not needed.
|
|
static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
|
|
Type type = value.getType();
|
|
assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
|
|
if (type.getIntOrFloatBitWidth() <= 32)
|
|
return value;
|
|
return b.create<LLVM::TruncOp>(b.getI32Type(), value);
|
|
}
|
|
|
|
/// Returns the type for the intrinsic given the vectorResultType of the
|
|
/// `gpu.mma.sync` operation.
|
|
static Type inferIntrinsicResultType(Type vectorResultType) {
|
|
MLIRContext *ctx = vectorResultType.getContext();
|
|
auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
|
|
auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
|
|
auto i32Ty = IntegerType::get(ctx, 32);
|
|
auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
|
|
Type f64Ty = Float64Type::get(ctx);
|
|
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
|
|
Type f32Ty = Float32Type::get(ctx);
|
|
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
|
|
if (a.getElementType() == f16x2Ty) {
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
|
|
}
|
|
if (a.getElementType() == i32x2Ty) {
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
ctx,
|
|
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
|
|
}
|
|
if (a.getElementType() == f64x2Ty) {
|
|
return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
|
|
}
|
|
if (a.getElementType() == f32x2Ty) {
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
ctx,
|
|
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
|
|
}
|
|
if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
|
|
}
|
|
return vectorResultType;
|
|
}
|
|
|
|
/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
|
|
/// always an LLVM struct) into a fragment that is compatible with the vector
|
|
/// type of this operation. This involves extracting elements from the struct
|
|
/// and inserting them into an LLVM array. These extra data-movement
|
|
/// operations should be canonicalized away by the LLVM backend.
|
|
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
|
|
Type resultType, Value intrinsicResult,
|
|
RewriterBase &rewriter) {
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
|
|
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
|
|
Type i32Ty = rewriter.getI32Type();
|
|
Type f32Ty = rewriter.getF32Type();
|
|
Type f64Ty = rewriter.getF64Type();
|
|
Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
|
|
Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
|
|
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
|
|
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
|
|
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
|
|
|
|
auto makeConst = [&](int32_t index) -> Value {
|
|
return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
|
|
rewriter.getI32IntegerAttr(index));
|
|
};
|
|
|
|
if (arrayType) {
|
|
SmallVector<Value, 4> elements;
|
|
|
|
// The intrinsic returns 32-bit wide elements in a form which can be
|
|
// directly bitcasted and inserted into the result vector.
|
|
if (arrayType.getElementType() == f16x2Ty ||
|
|
arrayType.getElementType() == f32x1Ty) {
|
|
for (unsigned i = 0; i < structType.getBody().size(); i++) {
|
|
Value el =
|
|
rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
|
|
el = rewriter.createOrFold<LLVM::BitcastOp>(
|
|
loc, arrayType.getElementType(), el);
|
|
elements.push_back(el);
|
|
}
|
|
}
|
|
|
|
// The intrinsic returns i32, f64, and f32 values as individual scalars,
|
|
// even when the result is notionally a 64-bit wide element (e.g. f32x2). We
|
|
// need to extract them from the struct and pack them into the 64-bit wide
|
|
// rows of the vector result.
|
|
if (arrayType.getElementType() == i32x2Ty ||
|
|
arrayType.getElementType() == f64x2Ty ||
|
|
arrayType.getElementType() == f32x2Ty) {
|
|
|
|
for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
|
|
Value vec =
|
|
rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
|
|
Value x1 =
|
|
rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
|
|
Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
|
|
i * 2 + 1);
|
|
vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
|
|
x1, makeConst(0));
|
|
vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
|
|
x2, makeConst(1));
|
|
elements.push_back(vec);
|
|
}
|
|
}
|
|
|
|
// Create the final vectorized result.
|
|
Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
|
|
for (const auto &el : llvm::enumerate(elements)) {
|
|
result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
|
|
el.index());
|
|
}
|
|
return result;
|
|
}
|
|
|
|
return intrinsicResult;
|
|
}
|
|
|
|
/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
|
|
/// given as 2D `vectors` where the rows are 32b or 64b wide. The
|
|
/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
|
|
/// scalars of certain types. This function helps unpack the `vector` arguments
|
|
/// and cast them to the types expected by `nvvm.mma.sync`.
|
|
static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
|
|
Value operand,
|
|
NVVM::MMATypes operandPtxType) {
|
|
SmallVector<Value> result;
|
|
Type i32Ty = b.getI32Type();
|
|
Type f64Ty = b.getF64Type();
|
|
Type f32Ty = b.getF32Type();
|
|
Type i64Ty = b.getI64Type();
|
|
Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
|
|
Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
|
|
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
|
|
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
|
|
|
|
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
|
|
Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
|
|
|
|
// For 4xi8 vectors, the intrinsic expects these to be provided as i32
|
|
// scalar types.
|
|
if (arrayTy.getElementType() == i8x4Ty ||
|
|
arrayTy.getElementType() == i4x8Ty ||
|
|
(arrayTy.getElementType() == f32x1Ty &&
|
|
operandPtxType == NVVM::MMATypes::tf32)) {
|
|
result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
|
|
continue;
|
|
}
|
|
|
|
// For some element types (i32, f32, f64), we need to unpack the inner
|
|
// vector/array type as well because the intrinsic expects individual
|
|
// scalars to be provided.
|
|
VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
|
|
if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
|
|
innerArrayTy.getElementType() == f64Ty ||
|
|
innerArrayTy.getElementType() == f32Ty)) {
|
|
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
|
|
idx < innerSize; idx++) {
|
|
result.push_back(b.create<LLVM::ExtractElementOp>(
|
|
toUse,
|
|
b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
|
|
}
|
|
continue;
|
|
}
|
|
result.push_back(toUse);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Returns whether mbarrier object has shared memory address space.
|
|
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
|
|
return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
|
|
barrierType.getMemorySpace()));
|
|
}
|
|
|
|
/// Returns the memory space attribute of the mbarrier object.
|
|
Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context,
|
|
nvgpu::MBarrierGroupType barrierType) {
|
|
Attribute memorySpace = {};
|
|
if (isMbarrierShared(barrierType)) {
|
|
memorySpace =
|
|
IntegerAttr::get(IntegerType::get(context, 64),
|
|
nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
|
|
}
|
|
return memorySpace;
|
|
}
|
|
|
|
/// Returns memref type of the mbarrier object. The type is defined in the
|
|
/// MBarrierGroupType.
|
|
MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
|
|
nvgpu::MBarrierGroupType barrierType) {
|
|
Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
|
|
MemRefLayoutAttrInterface layout;
|
|
return MemRefType::get({barrierType.getNumBarriers()},
|
|
IntegerType::get(context, 64), layout, memorySpace);
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
MLIRContext *ctx = getContext();
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
// The result type of ldmatrix will always be a struct of 32bit integer
|
|
// registers if more than one 32bit value is returned. Otherwise, the result
|
|
// is a single i32. The result type of the GPU operation is always a vector
|
|
// of shape (NumRegisters, VectorRegister) where VectorRegister is the
|
|
// vector type of the result and always 32 bits long. We bitcast the result
|
|
// of the NVVM::LdMatrix to this vector type.
|
|
auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
|
|
if (!vectorResultType) {
|
|
return failure();
|
|
}
|
|
Type innerVectorType = LLVM::getFixedVectorType(
|
|
vectorResultType.getElementType(), vectorResultType.getDimSize(1));
|
|
|
|
int64_t num32BitRegs = vectorResultType.getDimSize(0);
|
|
|
|
Type ldMatrixResultType;
|
|
if (num32BitRegs > 1) {
|
|
ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
|
|
ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
|
|
} else {
|
|
ldMatrixResultType = rewriter.getI32Type();
|
|
}
|
|
|
|
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
|
|
Value srcPtr =
|
|
getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
|
|
adaptor.getIndices(), rewriter);
|
|
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
|
|
ldMatrixResultType, srcPtr,
|
|
/*num=*/op.getNumTiles(),
|
|
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
|
|
: NVVM::MMALayout::row);
|
|
|
|
// The ldmatrix operation returns either a single i32 value or a struct of
|
|
// i32 values. Here we unpack those values and cast them back to their
|
|
// actual vector type (still of width 32b) and repack them into a result
|
|
// struct.
|
|
Type finalResultType = typeConverter->convertType(vectorResultType);
|
|
Value result = b.create<LLVM::UndefOp>(finalResultType);
|
|
for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
|
|
Value i32Register =
|
|
num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
|
|
: ldMatrixResult;
|
|
Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
|
|
result = b.create<LLVM::InsertValueOp>(result, casted, i);
|
|
}
|
|
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
|
|
/// enum).
|
|
static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
|
|
Type elType = getElementTypeOrSelf(t);
|
|
if (elType.isInteger(8))
|
|
return NVVM::MMATypes::s8;
|
|
if (elType.isInteger(4))
|
|
return NVVM::MMATypes::s4;
|
|
if (elType.isF16())
|
|
return NVVM::MMATypes::f16;
|
|
if (elType.isF64())
|
|
return NVVM::MMATypes::f64;
|
|
if (elType.isF32())
|
|
return NVVM::MMATypes::tf32;
|
|
return failure();
|
|
}
|
|
|
|
struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
// Get the shapes of the MMAMatrix type being used. The shapes will
|
|
// choose which intrinsic this op will be lowered to.
|
|
VectorType aType = op.getMatrixA().getType();
|
|
VectorType bType = op.getMatrixA().getType();
|
|
VectorType cType = op.getMatrixC().getType();
|
|
|
|
std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
|
|
|
|
// Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
|
|
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
|
|
if (aType.getElementType().isF32() && !tf32Enabled)
|
|
return failure();
|
|
|
|
FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
|
|
if (failed(ptxTypeA))
|
|
return op->emitOpError("failed to deduce operand PTX types");
|
|
FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
|
|
if (failed(ptxTypeB))
|
|
return op->emitOpError("failed to deduce operand PTX types");
|
|
std::optional<NVVM::MMATypes> ptxTypeC =
|
|
NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
|
|
/*isAccumulator=*/true);
|
|
if (!ptxTypeC)
|
|
return op->emitError(
|
|
"could not infer the PTX type for the accumulator/result");
|
|
|
|
// TODO: add an attribute to the op to customize this behavior.
|
|
std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
|
|
if (isa<IntegerType>(aType.getElementType()))
|
|
overflow = NVVM::MMAIntOverflow::satfinite;
|
|
|
|
SmallVector<Value> matA =
|
|
unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
|
|
SmallVector<Value> matB =
|
|
unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
|
|
SmallVector<Value> matC =
|
|
unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
|
|
|
|
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
|
|
Type intrinsicResTy = inferIntrinsicResultType(
|
|
typeConverter->convertType(op->getResultTypes()[0]));
|
|
Value intrinsicResult = b.create<NVVM::MmaOp>(
|
|
intrinsicResTy, matA, matB, matC,
|
|
/*shape=*/gemmShape,
|
|
/*b1Op=*/std::nullopt,
|
|
/*intOverflow=*/overflow,
|
|
/*multiplicandPtxTypes=*/
|
|
std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
|
|
/*multiplicandLayouts=*/
|
|
std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
|
|
NVVM::MMALayout::col});
|
|
rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
|
|
desiredRetTy, intrinsicResult,
|
|
rewriter));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertNVGPUToNVVMPass
|
|
: public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
|
|
using Base::Base;
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
|
|
arith::ArithDialect>();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
LowerToLLVMOptions options(&getContext());
|
|
RewritePatternSet patterns(&getContext());
|
|
LLVMTypeConverter converter(&getContext(), options);
|
|
IRRewriter rewriter(&getContext());
|
|
populateGpuMemorySpaceAttributeConversions(
|
|
converter, [](gpu::AddressSpace space) -> unsigned {
|
|
switch (space) {
|
|
case gpu::AddressSpace::Global:
|
|
return static_cast<unsigned>(
|
|
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
|
|
case gpu::AddressSpace::Workgroup:
|
|
return static_cast<unsigned>(
|
|
NVVM::NVVMMemorySpace::kSharedMemorySpace);
|
|
case gpu::AddressSpace::Private:
|
|
return 0;
|
|
}
|
|
llvm_unreachable("unknown address space enum value");
|
|
return 0;
|
|
});
|
|
/// device-side async tokens cannot be materialized in nvvm. We just
|
|
/// convert them to a dummy i32 type in order to easily drop them during
|
|
/// conversion.
|
|
converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
|
|
return converter.convertType(IntegerType::get(type.getContext(), 32));
|
|
});
|
|
converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
|
|
Type elemType = type.getFragmented().getElementType();
|
|
int64_t sizeM = type.getFragmented().getDimSize(0);
|
|
int64_t sizeN = type.getFragmented().getDimSize(1);
|
|
|
|
unsigned numMembers;
|
|
if (elemType.isF32() || elemType.isInteger(32))
|
|
numMembers = sizeN / 2;
|
|
else if (elemType.isF16())
|
|
numMembers = sizeN / 4;
|
|
else
|
|
llvm_unreachable("unsupported type for warpgroup accumulator");
|
|
|
|
SmallVector<Type> innerStructBody;
|
|
for (unsigned i = 0; i < numMembers; i++)
|
|
innerStructBody.push_back(elemType);
|
|
auto innerStructType =
|
|
LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
|
|
|
|
SmallVector<Type> structBody;
|
|
for (int i = 0; i < sizeM; i += kWgmmaSizeM)
|
|
structBody.push_back(innerStructType);
|
|
|
|
auto convertedType =
|
|
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
|
|
return converter.convertType(convertedType);
|
|
});
|
|
converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
|
|
return converter.convertType(IntegerType::get(type.getContext(), 64));
|
|
});
|
|
converter.addConversion(
|
|
[&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
|
|
return converter.convertType(IntegerType::get(type.getContext(), 64));
|
|
});
|
|
converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
|
|
return converter.convertType(
|
|
nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
|
|
});
|
|
converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
|
|
return LLVM::LLVMPointerType::get(type.getContext());
|
|
});
|
|
populateNVGPUToNVVMConversionPatterns(converter, patterns);
|
|
LLVMConversionTarget target(getContext());
|
|
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
|
|
target.addLegalDialect<::mlir::arith::ArithDialect>();
|
|
target.addLegalDialect<::mlir::memref::MemRefDialect>();
|
|
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
|
|
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
|
|
converter, patterns, target);
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
/// Returns the constraints for the sparse MMA inline assembly instruction.
|
|
static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
|
|
unsigned matBSize,
|
|
unsigned matCSize) {
|
|
std::string str;
|
|
llvm::raw_string_ostream ss(str);
|
|
for (unsigned i = 0; i < matCSize; i++)
|
|
ss << "=r,";
|
|
for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
|
|
ss << "r,";
|
|
// The final operand is for the sparsity metadata.
|
|
// The sparsity selector appears as direct literal.
|
|
ss << "r";
|
|
ss.flush();
|
|
return str;
|
|
}
|
|
|
|
/// Returns the string for the `mma.sp.sync` instruction that corresponds to
|
|
/// the given parameters. Note that this function doesn't do any validation,
|
|
/// it's expected that the provided parameters correspond to a valid
|
|
/// instruction.
|
|
static std::string buildMmaSparseAsmString(
|
|
const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
|
|
unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
|
|
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
|
|
std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
|
|
auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
|
|
return NVVM::stringifyMMATypes(ptxType);
|
|
};
|
|
|
|
std::string asmStr;
|
|
llvm::raw_string_ostream ss(asmStr);
|
|
ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
|
|
<< shape[2] << ".row.col.";
|
|
|
|
if (overflow)
|
|
ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
|
|
|
|
ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
|
|
<< ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
|
|
unsigned asmArgIdx = 0;
|
|
|
|
// The operand string is structured into sections `{matC elements...},
|
|
// {matA elements...}, {matB elements...}, {matC elements}`.
|
|
for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
|
|
ss << "{";
|
|
for (unsigned i = 0; i < arrSize; i++)
|
|
ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
|
|
ss << "},";
|
|
}
|
|
ss << "$" << asmArgIdx++ << ",";
|
|
assert(metaDataSelector <= 1);
|
|
ss << "0x" << metaDataSelector << ";";
|
|
ss.flush();
|
|
return asmStr;
|
|
}
|
|
|
|
/// Builds an inline assembly operation corresponding to the specified MMA
|
|
/// sparse sync operation.
|
|
static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
|
|
ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
|
|
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
|
|
std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
|
|
ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
|
|
int64_t metadataSelector, const std::array<int64_t, 3> &shape,
|
|
Type intrinsicResultType) {
|
|
auto asmDialectAttr =
|
|
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
|
|
|
|
const unsigned matASize = unpackedAData.size();
|
|
const unsigned matBSize = unpackedB.size();
|
|
const unsigned matCSize = unpackedC.size();
|
|
|
|
std::string asmStr = buildMmaSparseAsmString(
|
|
shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
|
|
ptxTypeD, overflow, metadataSelector);
|
|
std::string constraintStr =
|
|
buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
|
|
|
|
SmallVector<Value> asmVals;
|
|
asmVals.reserve(matASize + matBSize + matCSize + 1);
|
|
for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
|
|
llvm::append_range(asmVals, args);
|
|
asmVals.push_back(indexData);
|
|
|
|
return b.create<LLVM::InlineAsmOp>(
|
|
/*resultTypes=*/intrinsicResultType,
|
|
/*operands=*/asmVals,
|
|
/*asm_string=*/asmStr,
|
|
/*constraints=*/constraintStr,
|
|
/*has_side_effects=*/true,
|
|
/*is_align_stack=*/false,
|
|
/*asm_dialect=*/asmDialectAttr,
|
|
/*operand_attrs=*/ArrayAttr());
|
|
}
|
|
|
|
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
|
|
struct NVGPUMmaSparseSyncLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
// Get the shapes of the MMAMatrix type being used. The shapes will
|
|
// choose which intrinsic this op will be lowered to.
|
|
VectorType aType = op.getMatrixA().getType();
|
|
VectorType bType = op.getMatrixB().getType();
|
|
VectorType cType = op.getMatrixC().getType();
|
|
|
|
FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
|
|
if (failed(ptxTypeA))
|
|
return op->emitOpError("failed to deduce operand PTX types");
|
|
FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
|
|
if (failed(ptxTypeB))
|
|
return op->emitOpError("failed to deduce operand PTX types");
|
|
std::optional<NVVM::MMATypes> ptxTypeC =
|
|
NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
|
|
/*isAccumulator=*/true);
|
|
if (!ptxTypeC)
|
|
return op->emitError(
|
|
"could not infer the PTX type for the accumulator/result");
|
|
|
|
// Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
|
|
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
|
|
if (aType.getElementType().isF32() && !tf32Enabled)
|
|
return failure();
|
|
|
|
// TODO: add an attribute to the op to customize this behavior.
|
|
std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
|
|
if (isa<IntegerType>(aType.getElementType()))
|
|
overflow = NVVM::MMAIntOverflow::satfinite;
|
|
|
|
SmallVector<Value> matA =
|
|
unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
|
|
SmallVector<Value> matB =
|
|
unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
|
|
SmallVector<Value> matC =
|
|
unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
|
|
|
|
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
|
|
Type intrinsicResTy = inferIntrinsicResultType(
|
|
typeConverter->convertType(op->getResultTypes()[0]));
|
|
|
|
// Bitcast the sparse metadata from vector<2xf16> to an i32.
|
|
Value sparseMetadata = adaptor.getSparseMetadata();
|
|
if (sparseMetadata.getType() !=
|
|
LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
|
|
return op->emitOpError() << "Expected metadata type to be LLVM "
|
|
"VectorType of 2 i16 elements";
|
|
sparseMetadata =
|
|
b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
|
|
|
|
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
|
|
b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
|
|
matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
|
|
intrinsicResTy);
|
|
if (failed(intrinsicResult))
|
|
return failure();
|
|
|
|
assert((*intrinsicResult).getNumResults() == 1 &&
|
|
"expected inline asm op returns a single LLVM struct type");
|
|
rewriter.replaceOp(
|
|
op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
|
|
(*intrinsicResult)->getResult(0), rewriter));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUAsyncCopyLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
Location loc = op.getLoc();
|
|
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
|
|
Value dstPtr =
|
|
getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
|
|
adaptor.getDstIndices(), rewriter);
|
|
FailureOr<unsigned> dstAddressSpace =
|
|
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
|
|
if (failed(dstAddressSpace))
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "destination memref address space not convertible to integer");
|
|
|
|
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
|
|
FailureOr<unsigned> srcAddressSpace =
|
|
getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
|
|
if (failed(srcAddressSpace))
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "source memref address space not convertible to integer");
|
|
|
|
Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
|
|
adaptor.getSrcIndices(), rewriter);
|
|
// Intrinsics takes a global pointer so we need an address space cast.
|
|
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
|
|
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
|
|
scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
|
|
int64_t dstElements = adaptor.getDstElements().getZExtValue();
|
|
int64_t sizeInBytes =
|
|
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
|
|
// When the optional SrcElements argument is *not* present, the regular
|
|
// CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
|
|
// memory) to fill DstElements number of elements in the destination
|
|
// (shared memory).
|
|
Value srcBytes = adaptor.getSrcElements();
|
|
if (srcBytes) {
|
|
// When the optional SrcElements argument is present, the source (global
|
|
// memory) of CpAsyncOp is read only for SrcElements number of elements.
|
|
// The rest of the DstElements in the destination (shared memory) are
|
|
// filled with zeros.
|
|
Value c3I32 =
|
|
b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
|
|
Value bitwidth = b.create<LLVM::ConstantOp>(
|
|
b.getI32Type(),
|
|
b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
|
|
Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
|
|
srcBytes = b.create<LLVM::LShrOp>(
|
|
b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
|
|
}
|
|
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
|
|
// 16 dst bytes.
|
|
NVVM::LoadCacheModifierKind cacheModifier =
|
|
(op.getBypassL1().value_or(false) && sizeInBytes == 16)
|
|
? NVVM::LoadCacheModifierKind::CG
|
|
: NVVM::LoadCacheModifierKind::CA;
|
|
|
|
b.create<NVVM::CpAsyncOp>(
|
|
dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
|
|
NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
|
|
srcBytes);
|
|
|
|
// Drop the result token.
|
|
Value zero = b.create<LLVM::ConstantOp>(
|
|
IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
|
|
rewriter.replaceOp(op, zero);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUAsyncCreateGroupLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
|
|
// Drop the result token.
|
|
Value zero = rewriter.create<LLVM::ConstantOp>(
|
|
op->getLoc(), IntegerType::get(op.getContext(), 32),
|
|
rewriter.getI32IntegerAttr(0));
|
|
rewriter.replaceOp(op, zero);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUAsyncWaitLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// If numGroup is not present pick 0 as a conservative correct value.
|
|
int32_t numGroups = adaptor.getNumGroups().value_or(0);
|
|
rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Creates mbarrier object in shared memory
|
|
struct NVGPUMBarrierCreateLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
|
|
|
|
template <typename moduleT>
|
|
memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
|
|
Operation *funcOp, moduleT moduleOp,
|
|
MemRefType barrierType) const {
|
|
SymbolTable symbolTable(moduleOp);
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(&moduleOp.front());
|
|
auto global = rewriter.create<memref::GlobalOp>(
|
|
funcOp->getLoc(), "__mbarrier",
|
|
/*sym_visibility=*/rewriter.getStringAttr("private"),
|
|
/*type=*/barrierType,
|
|
/*initial_value=*/ElementsAttr(),
|
|
/*constant=*/false,
|
|
/*alignment=*/rewriter.getI64IntegerAttr(8));
|
|
symbolTable.insert(global);
|
|
return global;
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Operation *funcOp = op->getParentOp();
|
|
MemRefType barrierType = nvgpu::getMBarrierMemrefType(
|
|
rewriter.getContext(), op.getBarriers().getType());
|
|
|
|
memref::GlobalOp global;
|
|
if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
|
|
global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
|
|
else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
|
|
global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
|
|
|
|
rewriter.setInsertionPoint(op);
|
|
rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
|
|
global.getName());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Base class for lowering mbarrier operations to nvvm intrinsics.
|
|
template <typename SourceOp>
|
|
struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
|
|
public:
|
|
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
|
/// Returns the base pointer of the mbarrier object.
|
|
Value getMbarrierPtr(ImplicitLocOpBuilder &b,
|
|
nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
|
|
Value mbarId,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
MemRefType mbarrierMemrefType =
|
|
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
|
|
return ConvertToLLVMPattern::getStridedElementPtr(
|
|
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
|
|
}
|
|
};
|
|
|
|
/// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
|
|
struct NVGPUMBarrierInitLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
|
|
using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
|
|
rewriter.setInsertionPoint(op);
|
|
Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Value count = truncToI32(b, adaptor.getCount());
|
|
if (isMbarrierShared(mbarrierType)) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
|
|
op, barrier, count, adaptor.getPredicate());
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
|
|
adaptor.getPredicate());
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
|
|
struct NVGPUMBarrierArriveLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
|
|
using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Type tokenType = getTypeConverter()->convertType(
|
|
nvgpu::MBarrierTokenType::get(op->getContext()));
|
|
if (isMbarrierShared(op.getBarriers().getType())) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
|
|
barrier);
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
|
|
barrier);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
|
|
/// `nvvm.mbarrier.arrive.nocomplete`
|
|
struct NVGPUMBarrierArriveNoCompleteLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
|
|
using MBarrierBasePattern<
|
|
nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Type tokenType = getTypeConverter()->convertType(
|
|
nvgpu::MBarrierTokenType::get(op->getContext()));
|
|
Value count = truncToI32(b, adaptor.getCount());
|
|
if (isMbarrierShared(op.getBarriers().getType())) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
|
|
op, tokenType, barrier, count);
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
|
|
op, tokenType, barrier, count);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
|
|
struct NVGPUMBarrierTestWaitLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
|
|
using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Type retType = rewriter.getI1Type();
|
|
if (isMbarrierShared(op.getBarriers().getType())) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
|
|
op, retType, barrier, adaptor.getToken());
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
|
|
op, retType, barrier, adaptor.getToken());
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUMBarrierArriveExpectTxLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
|
|
using MBarrierBasePattern<
|
|
nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Value txcount = truncToI32(b, adaptor.getTxcount());
|
|
|
|
if (isMbarrierShared(op.getBarriers().getType())) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
|
|
op, barrier, txcount, adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
|
|
op, barrier, txcount, adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUMBarrierTryWaitParityLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
|
|
using MBarrierBasePattern<
|
|
nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Value ticks = truncToI32(b, adaptor.getTicks());
|
|
Value phase =
|
|
b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
|
|
|
|
if (isMbarrierShared(op.getBarriers().getType())) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
|
|
op, barrier, phase, ticks);
|
|
return success();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
|
|
phase, ticks);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUTmaAsyncLoadOpLowering
|
|
: public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
|
|
using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
|
|
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
|
|
adaptor.getDst(), {}, rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
|
|
SmallVector<Value> coords = adaptor.getCoordinates();
|
|
for (auto [index, value] : llvm::enumerate(coords)) {
|
|
coords[index] = truncToI32(b, value);
|
|
}
|
|
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
|
|
op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
|
|
ValueRange{}, adaptor.getMulticastMask(), Value{},
|
|
adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUTmaAsyncStoreOpLowering
|
|
: public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
|
|
using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
|
|
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
|
|
adaptor.getSrc(), {}, rewriter);
|
|
SmallVector<Value> coords = adaptor.getCoordinates();
|
|
for (auto [index, value] : llvm::enumerate(coords)) {
|
|
coords[index] = truncToI32(b, value);
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
|
|
op, adaptor.getTensorMapDescriptor(), dest, coords,
|
|
adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUGenerateWarpgroupDescriptorLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
|
|
nvgpu::TensorMapSwizzleKind swizzleKind =
|
|
op.getTensorMap().getType().getSwizzle();
|
|
|
|
unsigned layout =
|
|
(swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
|
|
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
|
|
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
|
|
: 1;
|
|
unsigned swizzle =
|
|
(swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
|
|
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
|
|
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
|
|
: 0;
|
|
|
|
auto ti64 = b.getIntegerType(64);
|
|
auto makeConst = [&](uint64_t index) -> Value {
|
|
return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
|
|
};
|
|
auto shiftLeft = [&](Value value, unsigned shift) -> Value {
|
|
return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
|
|
};
|
|
auto shiftRight = [&](Value value, unsigned shift) -> Value {
|
|
return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
|
|
};
|
|
auto insertBit = [&](Value desc, Value val, int startBit) {
|
|
return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
|
|
};
|
|
|
|
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
|
|
uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
|
|
uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
|
|
uint64_t offsetVal = 0;
|
|
|
|
Value strideDim = makeConst(strideDimVal);
|
|
Value leadDim = makeConst(leadDimVal);
|
|
|
|
Value baseAddr = getStridedElementPtr(
|
|
op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
|
|
adaptor.getTensor(), {}, rewriter);
|
|
Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
|
|
// Just use 14 bits for base address
|
|
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
|
|
|
|
int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
|
|
startLeadBit = 16, startBaseAddrBit = 0;
|
|
Value dsc = makeConst(0);
|
|
// // [62,64) swizzle type
|
|
dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
|
|
// // [49,52) base_offset
|
|
dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
|
|
// // [32,46) stride
|
|
dsc = insertBit(dsc, strideDim, startStrideBit);
|
|
// // [16,30) leading dimension
|
|
dsc = insertBit(dsc, leadDim, startLeadBit);
|
|
// // [0,14) start_address
|
|
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
|
|
|
|
LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
|
|
<< "leading_off:" << leadDimVal << "\t"
|
|
<< "stride_off :" << strideDimVal << "\t"
|
|
<< "base_offset:" << offsetVal << "\t"
|
|
<< "layout_type:" << swizzle << " ("
|
|
<< nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
|
|
<< ")\n start_addr : " << baseAddr << "\n");
|
|
|
|
rewriter.replaceOp(op, dsc);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
|
|
return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
|
|
b.getI32IntegerAttr(index));
|
|
}
|
|
|
|
/// Returns a Value that holds data type enum that is expected by CUDA driver.
|
|
static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
|
|
// Enum is from CUDA driver API
|
|
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
|
|
enum CUtensorMapDataTypeEnum {
|
|
CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
|
|
CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
|
CU_TENSOR_MAP_DATA_TYPE_UINT32,
|
|
CU_TENSOR_MAP_DATA_TYPE_INT32,
|
|
CU_TENSOR_MAP_DATA_TYPE_UINT64,
|
|
CU_TENSOR_MAP_DATA_TYPE_INT64,
|
|
CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
|
|
CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
|
|
CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
|
|
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
|
CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
|
|
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
|
|
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
|
|
};
|
|
|
|
if (type.isUnsignedInteger(8))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
|
|
if (type.isUnsignedInteger(16))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
|
|
if (type.isUnsignedInteger(32))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
|
|
if (type.isUnsignedInteger(64))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
|
|
if (type.isSignlessInteger(32))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
|
|
if (type.isSignlessInteger(64))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
|
|
if (type.isF16())
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
|
|
if (type.isF32())
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
|
|
if (type.isF64())
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
|
|
if (type.isBF16())
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
|
|
|
|
llvm_unreachable("Not supported data type");
|
|
}
|
|
|
|
struct NVGPUTmaCreateDescriptorOpLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
|
|
Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
|
|
|
|
Value tensorElementType =
|
|
elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
|
|
auto promotedOperands = getTypeConverter()->promoteOperands(
|
|
b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
|
|
|
|
Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
|
|
makeI64Const(b, 5));
|
|
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
|
|
Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
|
|
boxArrayPtr, makeI64Const(b, index));
|
|
b.create<LLVM::StoreOp>(value, gep);
|
|
}
|
|
|
|
nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
|
|
// Set Arguments for the function call
|
|
SmallVector<Value> arguments;
|
|
arguments.push_back(promotedOperands[0]); // rank
|
|
arguments.push_back(promotedOperands[1]); // descriptor
|
|
arguments.push_back(tensorElementType); // data type
|
|
arguments.push_back(
|
|
makeI64Const(b, (int)desc.getInterleave())); // interleave
|
|
arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
|
|
arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
|
|
arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
|
|
arguments.push_back(boxArrayPtr); // box dimensions
|
|
|
|
// Set data types of the arguments
|
|
SmallVector<Type> argTypes = {
|
|
llvmInt64Type, /* int64_t tensorRank */
|
|
llvmPointerType, /* ptr */
|
|
llvmInt64Type, /* int64_t */
|
|
llvmInt64Type, /* int64_t */
|
|
llvmInt64Type, /* int64_t */
|
|
llvmInt64Type, /* int64_t */
|
|
llvmInt64Type, /* int64_t */
|
|
llvmPointerType /* ptr */
|
|
};
|
|
FunctionCallBuilder hostRegisterCallBuilder = {
|
|
"mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
|
|
Value tensorMap =
|
|
hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
|
|
|
|
rewriter.replaceOp(op, tensorMap);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUWarpgroupMmaOpLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
|
|
|
|
/// This is a helper class to generate required NVVM Ops for warp-group level
|
|
/// matrix multiplication.
|
|
/// When the given GEMM shape is larger than the shape of
|
|
/// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
|
|
/// Op(s), group and execute them asynchronously. The class also handles
|
|
/// waiting for completion and iterates through WarpgroupMatrixDescriptor to
|
|
/// create descriptors for each instruction.
|
|
///
|
|
/// For example this is the case when the shape of GEMM is 128x128x128
|
|
///
|
|
/// nvvm.wgmma.fence.aligned
|
|
///
|
|
/// nvvm.wgmma.mma.async descA, descB
|
|
/// iterate(descA, descB)
|
|
/// nvvm.wgmma.mma.async descA, descB
|
|
/// [6x times more]
|
|
///
|
|
/// nvvm.wgmma.group.sync.aligned
|
|
/// nvvm.wgmma.wait.group.sync [groupId]
|
|
///
|
|
class WarpgroupGemm {
|
|
nvgpu::WarpgroupMmaOp op;
|
|
ImplicitLocOpBuilder b;
|
|
OpAdaptor adaptor;
|
|
|
|
// Entire shape of the given Op
|
|
int64_t totalM, totalN, totalK;
|
|
|
|
// Shape of one wgmma instruction
|
|
int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
|
|
|
|
// Iteration counts for GEMM
|
|
int iterationM = 0, iterationN = 0, iterationK = 0;
|
|
|
|
/// The function returns the shape of wgmma instruction that is defined in
|
|
/// PTX programming guide.
|
|
/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
|
|
void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
|
|
wgmmaM = 64;
|
|
wgmmaN = sizeN;
|
|
if (inputElemType.isTF32()) {
|
|
wgmmaK = 8;
|
|
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
|
|
wgmmaK = 16;
|
|
} else if (inputElemType.isFloat8E4M3FN() ||
|
|
inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
|
|
wgmmaK = 32;
|
|
} else if (inputElemType.isInteger(1)) {
|
|
wgmmaK = 256;
|
|
} else {
|
|
llvm_unreachable("msg: not supported K shape");
|
|
}
|
|
LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
|
|
<< ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
|
|
}
|
|
|
|
/// Generates WGMMATypesAttr from MLIR Type
|
|
NVVM::WGMMATypesAttr generateWgmmaType(Type type,
|
|
bool useF32 = false) const {
|
|
auto getWgmmaType = [=](Type elemType) {
|
|
if (elemType.isF32() || elemType.isTF32())
|
|
return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
|
|
if (elemType.isF16())
|
|
return NVVM::WGMMATypes::f16;
|
|
if (elemType.isBF16())
|
|
return NVVM::WGMMATypes::bf16;
|
|
if (elemType.isFloat8E4M3FN())
|
|
return NVVM::WGMMATypes::e4m3;
|
|
if (elemType.isFloat8E5M2())
|
|
return NVVM::WGMMATypes::e5m2;
|
|
if (elemType.isInteger(1))
|
|
return NVVM::WGMMATypes::b1;
|
|
if (elemType.isInteger(8))
|
|
return NVVM::WGMMATypes::s8;
|
|
if (elemType.isUnsignedInteger(8))
|
|
return NVVM::WGMMATypes::u8;
|
|
if (elemType.isInteger(32))
|
|
return NVVM::WGMMATypes::s32;
|
|
llvm_unreachable("unsupported type");
|
|
};
|
|
return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
|
|
}
|
|
|
|
/// Generates layout attribute for the input matrix for wgmma instruction
|
|
NVVM::MMALayoutAttr
|
|
generateWgmmaLayout(std::optional<bool> transpose) const {
|
|
if (transpose.value_or(false))
|
|
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
|
|
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
|
|
}
|
|
|
|
/// Generates shape attribute for wgmma instruction
|
|
NVVM::MMAShapeAttr generateWgmmaShape() const {
|
|
return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
|
|
}
|
|
|
|
/// Generates scale attributes of output matrix for wgmma instruction
|
|
NVVM::WGMMAScaleOutAttr generateScaleOut() const {
|
|
return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
|
|
NVVM::WGMMAScaleOut::one);
|
|
}
|
|
/// Generates scale attributes of input matrix for wgmma instruction
|
|
NVVM::WGMMAScaleInAttr generateScaleIn() const {
|
|
return NVVM::WGMMAScaleInAttr::get(op->getContext(),
|
|
NVVM::WGMMAScaleIn::one);
|
|
}
|
|
|
|
/// Basic function to generate Add
|
|
Value makeAdd(Value lhs, Value rhs) {
|
|
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
|
|
};
|
|
|
|
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
|
|
/// Currently, it only handles row-major.
|
|
///
|
|
/// It moves the pointer like below for [128][64] size:
|
|
/// +2 +4 +6
|
|
/// ↓ ↓ ↓
|
|
/// descA ---> +--+--+--+--+
|
|
/// |->|->|->|->|
|
|
/// | | | | |
|
|
/// | | | | |
|
|
/// | | | | |
|
|
/// descA+512---> +-----------+
|
|
/// | | | | |
|
|
/// | | | | |
|
|
/// | | | | |
|
|
/// | | | | |
|
|
/// +-----------+
|
|
///
|
|
Value iterateDescriptorA(Value desc, int i, int j, int k) {
|
|
MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
|
|
Type elemA = matrixTypeA.getElementType();
|
|
int byte = elemA.getIntOrFloatBitWidth() / 8;
|
|
int tileShapeA = matrixTypeA.getDimSize(1);
|
|
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
|
|
incrementVal = incrementVal >> exclude4LSB;
|
|
LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
|
|
<< "] [wgmma descriptors] Descriptor A + "
|
|
<< incrementVal << " | \t ");
|
|
if (!incrementVal)
|
|
return desc;
|
|
return makeAdd(desc, makeI64Const(b, incrementVal));
|
|
}
|
|
|
|
/// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
|
|
/// Currently, it only handles column-major.
|
|
///
|
|
/// It moves the pointer like below for [128][64] size:
|
|
/// descB ---> +--+--+--+--+--+--+--+--+
|
|
/// |↓ | | | | | | | |
|
|
/// |↓ | | | | | | | |
|
|
/// |↓ | | | | | | | |
|
|
/// |↓ | | | | | | | |
|
|
/// +--+--+--+--+--+--+--+--+
|
|
///
|
|
Value iterateDescriptorB(Value desc, int i, int j, int k) {
|
|
MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
|
|
Type elemB = matrixTypeB.getElementType();
|
|
int byte = elemB.getIntOrFloatBitWidth() / 8;
|
|
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
|
|
incrementVal = incrementVal >> exclude4LSB;
|
|
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
|
|
if (!incrementVal)
|
|
return desc;
|
|
return makeAdd(desc, makeI64Const(b, incrementVal));
|
|
}
|
|
|
|
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
|
|
/// descriptors and arranges them based on induction variables: i, j, and k.
|
|
Value generateWgmma(int i, int j, int k, Value matrixC) {
|
|
LLVM_DEBUG(DBGS() << "\t wgmma."
|
|
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
|
|
<< "(A[" << (iterationM * wgmmaM) << ":"
|
|
<< (iterationM * wgmmaM) + wgmmaM << "]["
|
|
<< (iterationK * wgmmaK) << ":"
|
|
<< (iterationK * wgmmaK + wgmmaK) << "] * "
|
|
<< " B[" << (iterationK * wgmmaK) << ":"
|
|
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
|
|
<< wgmmaN << "])\n");
|
|
|
|
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
|
|
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
|
|
|
|
Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
|
|
NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
|
|
|
|
Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
|
|
NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
|
|
|
|
Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
|
|
NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
|
|
|
|
NVVM::MMAShapeAttr shape = generateWgmmaShape();
|
|
NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
|
|
NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
|
|
NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
|
|
NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
|
|
|
|
auto overflow = NVVM::MMAIntOverflowAttr::get(
|
|
op->getContext(), NVVM::MMAIntOverflow::wrapped);
|
|
|
|
return b.create<NVVM::WgmmaMmaAsyncOp>(
|
|
matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
|
|
itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
|
|
overflow);
|
|
}
|
|
|
|
/// Generates multiple wgmma instructions to complete the given GEMM shape
|
|
Value generateWgmmaGroup() {
|
|
Value wgmmaResult =
|
|
b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
|
|
|
|
// Perform GEMM
|
|
SmallVector<Value> wgmmaResults;
|
|
for (int i = 0; i < iterationM; ++i) {
|
|
Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
|
|
for (int j = 0; j < iterationN; ++j)
|
|
for (int k = 0; k < iterationK; ++k)
|
|
matrixC = generateWgmma(i, j, k, matrixC);
|
|
wgmmaResults.push_back(matrixC);
|
|
}
|
|
for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
|
|
wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
|
|
wgmmaResult, matrix, idx);
|
|
}
|
|
return wgmmaResult;
|
|
}
|
|
|
|
public:
|
|
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
|
|
OpAdaptor adaptor)
|
|
: op(op), b(b), adaptor(adaptor) {
|
|
// Find the entire GEMM Shape
|
|
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
|
|
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
|
|
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
|
|
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
|
|
<< "] += A[" << totalM << "][" << totalK << "] * B["
|
|
<< totalK << "][" << totalN << "] ---===\n");
|
|
|
|
// Find the shape for one wgmma instruction
|
|
findWgmmaShape(
|
|
totalM, totalN,
|
|
op.getDescriptorA().getType().getTensor().getElementType());
|
|
|
|
// Iterations counts to complete the given shape with wgmma shape
|
|
iterationM = totalM / wgmmaM;
|
|
iterationN = totalN / wgmmaN;
|
|
iterationK = totalK / wgmmaK;
|
|
}
|
|
|
|
/// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
|
|
/// includes generating a fence Op (WgmmaFenceAlignedOp) before the
|
|
/// instructions and group synchronization, as well as waiting
|
|
/// (WgmmaGroupSyncAlignedOp) for group synchronization
|
|
/// (WgmmaWaitGroupSyncOp) after the instructions.
|
|
Value generateWarpgroupMma() {
|
|
b.create<NVVM::WgmmaFenceAlignedOp>();
|
|
Value wgmmaResult = generateWgmmaGroup();
|
|
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
|
|
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
|
|
return wgmmaResult;
|
|
}
|
|
};
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
|
|
// Step 1. Build a helper class
|
|
WarpgroupGemm warpgroupGemm(op, b, adaptor);
|
|
|
|
// Step 2. Get the entire GEMM Shape
|
|
Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
|
|
|
|
// Step 3. Replace fragmented result struct with the op results
|
|
rewriter.replaceOp(op, wgmmaResult);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUWarpgroupMmaStoreOpLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
|
|
|
|
/// This function stores a fragmented register matrix owned by a warp group
|
|
/// (128 threads) into a memref. Each thread has 64 registers, each the size
|
|
/// of a struct.
|
|
/// Here is what each threads (T) holds, each `d` is struct value with a
|
|
/// number.
|
|
///
|
|
/// Threads in warp-group (128 threads) and what they owns in the matrixD:
|
|
/// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
|
|
/// 32-63 Warp-1 -> MatrixD[16:31][0:N]
|
|
/// 64-95 Warp-2 -> MatrixD[32:47][0:N]
|
|
/// 96-127 Warp-3 -> MatrixD[48:64][0:N]
|
|
///
|
|
/// Matrix-D:
|
|
/// +______________________________________________________________________+
|
|
/// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
|
|
/// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
|
|
/// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
|
|
/// ..| .........|.........|.........|.........|........|...........|........|
|
|
/// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
|
|
/// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
|
|
/// ..| .........|.........|.........|.........|........|...........|........|
|
|
/// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
|
|
/// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
|
|
/// ..| .........|.........|.........|.........|........|...........|........|
|
|
/// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
|
|
/// ..| .........|.........|.........|.........|........|...........|........|
|
|
/// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
|
|
/// ..| .........|.........|.........|.........|........|...........|........|
|
|
/// +______________________________________________________________________+
|
|
///
|
|
/// \param rewriter: The pattern rewriter.
|
|
/// \param matrixD: Result of the warp-group MMA operation (fragmented
|
|
/// matrix). It is holded by a thread and a struct with 64 elements.
|
|
/// \param dstMemref: The memref where the registers will be stored.
|
|
/// \param offset: the offset within the memref where the registers will be
|
|
/// stored.
|
|
void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
|
|
TypedValue<MemRefType> dstMemref,
|
|
int offset) const {
|
|
Type i32 = b.getI32Type();
|
|
|
|
auto makeConst = [&](int32_t index) -> Value {
|
|
return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
|
|
};
|
|
Value c1 = makeConst(1);
|
|
Value c2 = makeConst(2);
|
|
Value c4 = makeConst(4);
|
|
Value c8 = makeConst(8);
|
|
Value c16 = makeConst(16);
|
|
Value warpSize = makeConst(kWarpSize);
|
|
|
|
auto makeMul = [&](Value lhs, Value rhs) -> Value {
|
|
return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
|
|
};
|
|
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
|
|
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
|
|
};
|
|
|
|
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
|
|
TypedValue<::mlir::MemRefType> memref) {
|
|
Type it = b.getIndexType();
|
|
Value idx = b.create<arith::IndexCastOp>(it, x);
|
|
Value idy0 = b.create<arith::IndexCastOp>(it, y);
|
|
Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
|
|
Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
|
|
Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
|
|
b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
|
|
b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
|
|
};
|
|
|
|
Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
|
|
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
|
|
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
|
|
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
|
|
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
|
|
|
|
Value tj = makeMul(lane4modId, c2);
|
|
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
|
|
if (offset)
|
|
ti = makeAdd(ti, makeConst(offset));
|
|
|
|
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
|
|
|
|
// Number of 32-bit registers owns per thread
|
|
constexpr unsigned numAdjacentRegisters = 2;
|
|
// Number of 8x8 matrices one below another per warp
|
|
constexpr unsigned numStackedMatrices = 2;
|
|
|
|
size_t storeCount = (structType.getBody().size() /
|
|
(numStackedMatrices * numAdjacentRegisters));
|
|
|
|
for (size_t i = 0; i < numStackedMatrices; ++i) {
|
|
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
|
|
for (size_t j = 0; j < storeCount; ++j) {
|
|
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
|
|
size_t structIndex = (i * numAdjacentRegisters) +
|
|
(j * (numStackedMatrices * numAdjacentRegisters));
|
|
makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
|
|
}
|
|
}
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
int offset = 0;
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value matriDValue = adaptor.getMatrixD();
|
|
auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
|
|
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
|
|
auto structType = matrixD.cast<LLVM::LLVMStructType>();
|
|
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
|
|
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
|
|
offset += structType.getBody().size();
|
|
}
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
LLVM::LLVMStructType packStructType =
|
|
getTypeConverter()
|
|
->convertType(op.getMatrixC().getType())
|
|
.cast<LLVM::LLVMStructType>();
|
|
Type elemType = packStructType.getBody()
|
|
.front()
|
|
.cast<LLVM::LLVMStructType>()
|
|
.getBody()
|
|
.front();
|
|
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
|
|
Value packStruct = b.create<LLVM::UndefOp>(packStructType);
|
|
SmallVector<Value> innerStructs;
|
|
// Unpack the structs and set all values to zero
|
|
for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
|
|
auto structType = s.cast<LLVM::LLVMStructType>();
|
|
Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
|
|
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
|
|
structValue = b.create<LLVM::InsertValueOp>(
|
|
structType, structValue, zero, ArrayRef<int64_t>({i}));
|
|
}
|
|
innerStructs.push_back(structValue);
|
|
}
|
|
// Pack the inner structs into a single struct
|
|
for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
|
|
packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
|
|
packStruct, matrix, idx);
|
|
}
|
|
rewriter.replaceOp(op, packStruct);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUTmaPrefetchOpLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
|
|
op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<
|
|
NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
|
|
NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
|
|
NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
|
|
NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
|
|
NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
|
|
NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
|
|
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
|
|
NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
|
|
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
|
|
NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
|
|
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
|
|
NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
|
|
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
|
|
NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
|
|
NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
|
|
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
|
|
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
|
|
NVGPUMmaSparseSyncLowering>(converter);
|
|
}
|