//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVM #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; /// Returns the type for the intrinsic given the vectorResultType of the /// `gpu.mma.sync` operation. static Type inferIntrinsicResultType(Type vectorResultType) { MLIRContext *ctx = vectorResultType.getContext(); auto a = vectorResultType.cast(); auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); auto i32Ty = IntegerType::get(ctx, 32); auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); Type f64Ty = Float64Type::get(ctx); Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); Type f32Ty = Float32Type::get(ctx); Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); if (a.getElementType() == f16x2Ty) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(a.getNumElements(), f16x2Ty)); } if (a.getElementType() == i32x2Ty) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(static_cast(a.getNumElements()) * 2, i32Ty)); } if (a.getElementType() == f64x2Ty) { return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); } if (a.getElementType() == f32x2Ty) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(static_cast(a.getNumElements()) * 2, f32Ty)); } if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(static_cast(a.getNumElements()), f32Ty)); } return vectorResultType; } /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is /// always an LLVM struct) into a fragment that is compatible with the vector /// type of this operation. This involves extracting elements from the struct /// and inserting them into an LLVM array. These extra data-movement /// operations should be canonicalized away by the LLVM backend. static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter) { MLIRContext *ctx = rewriter.getContext(); auto structType = intrinsicResultType.dyn_cast(); auto arrayType = resultType.dyn_cast(); Type i32Ty = rewriter.getI32Type(); Type f32Ty = rewriter.getF32Type(); Type f64Ty = rewriter.getF64Type(); Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); auto makeConst = [&](int32_t index) -> Value { return rewriter.create(loc, IntegerType::get(ctx, 32), rewriter.getI32IntegerAttr(index)); }; if (arrayType) { SmallVector elements; // The intrinsic returns 32-bit wide elements in a form which can be // directly bitcasted and inserted into the result vector. if (arrayType.getElementType() == f16x2Ty || arrayType.getElementType() == f32x1Ty) { for (unsigned i = 0; i < structType.getBody().size(); i++) { Value el = rewriter.create(loc, intrinsicResult, i); el = rewriter.createOrFold( loc, arrayType.getElementType(), el); elements.push_back(el); } } // The intrinsic returns i32, f64, and f32 values as individual scalars, // even when the result is notionally a 64-bit wide element (e.g. f32x2). We // need to extract them from the struct and pack them into the 64-bit wide // rows of the vector result. if (arrayType.getElementType() == i32x2Ty || arrayType.getElementType() == f64x2Ty || arrayType.getElementType() == f32x2Ty) { for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { Value vec = rewriter.create(loc, arrayType.getElementType()); Value x1 = rewriter.create(loc, intrinsicResult, i * 2); Value x2 = rewriter.create(loc, intrinsicResult, i * 2 + 1); vec = rewriter.create(loc, vec.getType(), vec, x1, makeConst(0)); vec = rewriter.create(loc, vec.getType(), vec, x2, makeConst(1)); elements.push_back(vec); } } // Create the final vectorized result. Value result = rewriter.create(loc, arrayType); for (const auto &el : llvm::enumerate(elements)) { result = rewriter.create(loc, result, el.value(), el.index()); } return result; } return intrinsicResult; } /// The `gpu.mma.sync` converter below expects matrix fragment operands to be /// given as 2D `vectors` where the rows are 32b or 64b wide. The /// `nvvm.mma.sync` op expects these argments to be a given in a long list of /// scalars of certain types. This function helps unpack the `vector` arguments /// and cast them to the types expected by `nvvm.mma.sync`. static SmallVector unpackOperandVector(RewriterBase &rewriter, Location loc, Value operand, NVVM::MMATypes operandPtxType) { SmallVector result; Type i32Ty = rewriter.getI32Type(); Type f64Ty = rewriter.getF64Type(); Type f32Ty = rewriter.getF32Type(); Type i8Ty = rewriter.getI8Type(); Type i4Ty = rewriter.getIntegerType(4); Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8); Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); auto arrayTy = operand.getType().cast(); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { Value toUse = rewriter.create(loc, operand, i); // For 4xi8 vectors, the intrinsic expects these to be provided as i32 // scalar types. if (arrayTy.getElementType() == i8x4Ty || arrayTy.getElementType() == i4x8Ty || (arrayTy.getElementType() == f32x1Ty && operandPtxType == NVVM::MMATypes::tf32)) { result.push_back( rewriter.create(loc, rewriter.getI32Type(), toUse)); continue; } // For some element types (i32, f32, f64), we need to unpack the inner // vector/array type as well because the intrinsic expects individual // scalars to be provided. VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || innerArrayTy.getElementType() == f64Ty || innerArrayTy.getElementType() == f32Ty)) { for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); idx < innerSize; idx++) { result.push_back(rewriter.create( loc, toUse, rewriter.create( loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx)))); } continue; } result.push_back(toUse); } return result; } namespace { struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *ctx = getContext(); Location loc = op->getLoc(); // The result type of ldmatrix will always be a struct of 32bit integer // registers if more than one 32bit value is returned. Otherwise, the result // is a single i32. The result type of the GPU operation is always a vector // of shape (NumRegisters, VectorRegister) where VectorRegister is the // vector type of the result and always 32 bits long. We bitcast the result // of the NVVM::LdMatrix to this vector type. auto vectorResultType = op->getResultTypes()[0].dyn_cast(); if (!vectorResultType) { return failure(); } Type innerVectorType = LLVM::getFixedVectorType( vectorResultType.getElementType(), vectorResultType.getDimSize(1)); int64_t num32BitRegs = vectorResultType.getDimSize(0); Type ldMatrixResultType; if (num32BitRegs > 1) { ldMatrixResultType = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(num32BitRegs, rewriter.getI32Type())); } else { ldMatrixResultType = rewriter.getI32Type(); } auto srcMemrefType = op.getSrcMemref().getType().cast(); Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); Value ldMatrixResult = rewriter.create( loc, ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col : NVVM::MMALayout::row); // The ldmatrix operation returns either a single i32 value or a struct of // i32 values. Here we unpack those values and cast them back to their // actual vector type (still of width 32b) and repack them into a result // struct. Type finalResultType = typeConverter->convertType(vectorResultType); Value result = rewriter.create(loc, finalResultType); for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { Value i32Register = num32BitRegs > 1 ? rewriter.create(loc, ldMatrixResult, i) : ldMatrixResult; Value casted = rewriter.create(loc, innerVectorType, i32Register); result = rewriter.create(loc, result, casted, i); } rewriter.replaceOp(op, result); return success(); } }; /// Convert the given type into the corresponding PTX type (NVVM::MMATypes /// enum). static FailureOr getNvvmMmaType(Type t) { Type elType = getElementTypeOrSelf(t); if (elType.isInteger(8)) return NVVM::MMATypes::s8; if (elType.isInteger(4)) return NVVM::MMATypes::s4; if (elType.isF16()) return NVVM::MMATypes::f16; if (elType.isF64()) return NVVM::MMATypes::f64; if (elType.isF32()) return NVVM::MMATypes::tf32; return failure(); } struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get the shapes of the MMAMatrix type being used. The shapes will // choose which intrinsic this op will be lowered to. VectorType aType = op.getMatrixA().getType(); VectorType bType = op.getMatrixA().getType(); VectorType cType = op.getMatrixC().getType(); std::array gemmShape = op.getMmaShapeAsArray(); // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32). bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); if (aType.getElementType().isF32() && !tf32Enabled) return failure(); FailureOr ptxTypeA = getNvvmMmaType(aType); if (failed(ptxTypeA)) return op->emitOpError("failed to deduce operand PTX types"); FailureOr ptxTypeB = getNvvmMmaType(bType); if (failed(ptxTypeB)) return op->emitOpError("failed to deduce operand PTX types"); std::optional ptxTypeC = NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), /*isAccumulator=*/true); if (!ptxTypeC) return op->emitError( "could not infer the PTX type for the accumulator/result"); // TODO: add an attribute to the op to customize this behavior. std::optional overflow(std::nullopt); if (aType.getElementType().isa()) overflow = NVVM::MMAIntOverflow::satfinite; SmallVector matA = unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA); SmallVector matB = unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB); SmallVector matC = unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC); Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); Type intrinsicResTy = inferIntrinsicResultType( typeConverter->convertType(op->getResultTypes()[0])); Value intrinsicResult = rewriter.create( op.getLoc(), intrinsicResTy, matA, matB, matC, /*shape=*/gemmShape, /*b1Op=*/std::nullopt, /*intOverflow=*/overflow, /*multiplicandPtxTypes=*/ std::array{*ptxTypeA, *ptxTypeB}, /*multiplicandLayouts=*/ std::array{NVVM::MMALayout::row, NVVM::MMALayout::col}); rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, intrinsicResult, rewriter)); return success(); } }; struct ConvertNVGPUToNVVMPass : public impl::ConvertNVGPUToNVVMBase { ConvertNVGPUToNVVMPass() = default; void runOnOperation() override { RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); /// device-side async tokens cannot be materialized in nvvm. We just convert /// them to a dummy i32 type in order to easily drop them during conversion. converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { return converter.convertType(IntegerType::get(type.getContext(), 32)); }); populateNVGPUToNVVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr, Value dstBytes, Value srcElements, mlir::MemRefType elementType, ConversionPatternRewriter &rewriter) { auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); const char *asmStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n"; const char *asmConstraints = "r,l,n,r"; Value c3I32 = rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3)); Value bitwidth = rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(elementType.getElementTypeBitWidth())); Value srcElementsI32 = rewriter.create(loc, rewriter.getI32Type(), srcElements); Value srcBytes = rewriter.create( loc, rewriter.create(loc, bitwidth, srcElementsI32), c3I32); SmallVector asmVals{dstPtr, srcPtr, dstBytes, srcBytes}; rewriter.create( loc, LLVM::LLVMVoidType::get(rewriter.getContext()), /*operands=*/asmVals, /*asm_string=*/asmStr, /*constraints=*/asmConstraints, /*has_side_effects=*/true, /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, /*operand_attrs=*/ArrayAttr()); } /// Returns the constraints for the sparse MMA inline assembly instruction. static std::string buildMmaSparseAsmConstraintString(unsigned matASize, unsigned matBSize, unsigned matCSize) { std::string str; llvm::raw_string_ostream ss(str); for (unsigned i = 0; i < matCSize; i++) ss << "=r,"; for (unsigned i = 0; i < matASize + matBSize + matCSize; i++) ss << "r,"; // The final two operands are for the sparsity metadata and sparsity selector. ss << "r,r"; ss.flush(); return str; } /// Returns the string for the `mma.sp.sync` instruction that corresponds to /// the give parameters. Note that this function doesn't do any validation, /// it's expected that the provided parameters correspond to a valid /// instruction. static std::string buildMmaSparseAsmString(const std::array &shape, unsigned matASize, unsigned matBSize, unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, std::optional overflow) { auto ptxTypeStr = [](NVVM::MMATypes ptxType) { return NVVM::stringifyMMATypes(ptxType); }; std::string asmStr; llvm::raw_string_ostream ss(asmStr); ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k" << shape[2] << ".row.col."; if (overflow) ss << NVVM::stringifyMMAIntOverflow(*overflow) << "."; ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "." << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " "; unsigned asmArgIdx = 0; // The operand string is structured into sections `{matC elements...}, // {matA elements...}, {matB elements...}, {matC elements}`. for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) { ss << "{"; for (unsigned i = 0; i < arrSize; i++) ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : ""); ss << "},"; } ss << "$" << asmArgIdx++ << ","; ss << "$" << asmArgIdx++ << ";"; ss.flush(); return asmStr; } /// Builds an inline assembly operation corresponding to the specified MMA /// sparse sync operation. static FailureOr emitMmaSparseSyncOpAsm( Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, std::optional overflow, ArrayRef unpackedAData, ArrayRef unpackedB, ArrayRef unpackedC, Value indexData, int64_t metadataSelector, const std::array &shape, Type intrinsicResultType, ConversionPatternRewriter &rewriter) { auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); std::string asmStr = buildMmaSparseAsmString( shape, unpackedAData.size(), unpackedB.size(), unpackedC.size(), ptxTypeA, ptxTypeB, ptxTypeC, ptxTypeD, overflow); std::string constraintStr = buildMmaSparseAsmConstraintString( unpackedAData.size(), unpackedB.size(), unpackedC.size()); Value selectorVal = rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(metadataSelector)); SmallVector asmVals; asmVals.reserve(unpackedAData.size() + unpackedB.size() + unpackedC.size() + 2); for (ArrayRef args : {unpackedAData, unpackedB, unpackedC}) llvm::append_range(asmVals, args); asmVals.push_back(indexData); asmVals.push_back(selectorVal); return rewriter.create(loc, /*resultTypes=*/intrinsicResultType, /*operands=*/asmVals, /*asm_string=*/asmStr, /*constraints=*/constraintStr, /*has_side_effects=*/true, /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, /*operand_attrs=*/ArrayAttr()); } /// Lowers `nvgpu.mma.sp.sync` to inline assembly. struct NVGPUMmaSparseSyncLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get the shapes of the MMAMatrix type being used. The shapes will // choose which intrinsic this op will be lowered to. VectorType aType = op.getMatrixA().getType(); VectorType bType = op.getMatrixB().getType(); VectorType cType = op.getMatrixC().getType(); FailureOr ptxTypeA = getNvvmMmaType(aType); if (failed(ptxTypeA)) return op->emitOpError("failed to deduce operand PTX types"); FailureOr ptxTypeB = getNvvmMmaType(bType); if (failed(ptxTypeB)) return op->emitOpError("failed to deduce operand PTX types"); std::optional ptxTypeC = NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), /*isAccumulator=*/true); if (!ptxTypeC) return op->emitError( "could not infer the PTX type for the accumulator/result"); // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32). bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); if (aType.getElementType().isF32() && !tf32Enabled) return failure(); // TODO: add an attribute to the op to customize this behavior. std::optional overflow(std::nullopt); if (aType.getElementType().isa()) overflow = NVVM::MMAIntOverflow::satfinite; SmallVector matA = unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA); SmallVector matB = unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB); SmallVector matC = unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC); Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); Type intrinsicResTy = inferIntrinsicResultType( typeConverter->convertType(op->getResultTypes()[0])); // Bitcast the sparse metadata from vector<2xf16> to an i32. Value sparseMetadata = adaptor.getSparseMetadata(); if (sparseMetadata.getType() != LLVM::getFixedVectorType(rewriter.getI16Type(), 2)) return op->emitOpError() << "Expected metadata type to be LLVM " "VectorType of 2 i16 elements"; sparseMetadata = rewriter.create( loc, rewriter.getI32Type(), sparseMetadata); FailureOr intrinsicResult = emitMmaSparseSyncOpAsm( loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(), intrinsicResTy, rewriter); if (failed(intrinsicResult)) return failure(); assert((*intrinsicResult).getNumResults() == 1 && "expected inline asm op returns a single LLVM struct type"); rewriter.replaceOp( op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, (*intrinsicResult)->getResult(0), rewriter)); return success(); } }; struct NVGPUAsyncCopyLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto dstMemrefType = op.getDst().getType().cast(); Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(), adaptor.getDstIndices(), rewriter); auto i8Ty = IntegerType::get(op.getContext(), 8); auto dstPointerType = LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt()); dstPtr = rewriter.create(loc, dstPointerType, dstPtr); auto srcMemrefType = op.getSrc().getType().cast(); Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(), adaptor.getSrcIndices(), rewriter); auto srcPointerType = LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt()); scrPtr = rewriter.create(loc, srcPointerType, scrPtr); // Intrinsics takes a global pointer so we need an address space cast. auto srcPointerGlobalType = LLVM::LLVMPointerType::get( i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace); scrPtr = rewriter.create(loc, srcPointerGlobalType, scrPtr); int64_t dstElements = adaptor.getDstElements().getZExtValue(); int64_t sizeInBytes = (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; // bypass L1 is only supported for byte sizes of 16, we drop the hint // otherwise. UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr(); // When the optional SrcElements argument is present, the source (global // memory) of CpAsyncOp is read only for SrcElements number of elements. The // rest of the DstElements in the destination (shared memory) are filled // with zeros. if (op.getSrcElements()) emitCpAsyncOpZfillAsm(loc, dstPtr, scrPtr, rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(sizeInBytes)), adaptor.getSrcElements(), srcMemrefType, rewriter); // When the optional SrcElements argument is *not* present, the regular // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global // memory) to fill DstElements number of elements in the destination (shared // memory). else rewriter.create(loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1); // Drop the result token. Value zero = rewriter.create( op->getLoc(), IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } }; struct NVGPUAsyncCreateGroupLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.create(op.getLoc()); // Drop the result token. Value zero = rewriter.create( op->getLoc(), IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } }; struct NVGPUAsyncWaitLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // If numGroup is not present pick 0 as a conservative correct value. int32_t numGroups = adaptor.getNumGroups().value_or(0); rewriter.create(op.getLoc(), numGroups); rewriter.eraseOp(op); return success(); } }; } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(converter); } std::unique_ptr mlir::createConvertNVGPUToNVVMPass() { return std::make_unique(); }