//===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains definitions of patterns to lower GPU Subgroup MMA ops to // NVVM Dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" using namespace mlir; namespace { /// Checks if all the operands of the op being lowered are of LLVM Types. The /// types are expected to be converted by the `LLVMTypeConverter` before the op /// is actually lowered. If the type of an operands is not already converted it /// hints a missing typeConversion and failure is returned in that case. static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter) { if (!llvm::all_of(operands, [](Value value) { return LLVM::isCompatibleType(value.getType()); })) { return rewriter.notifyMatchFailure( op, "cannot convert if operands aren't of LLVM type."); } return success(); } /// Error string to emit when unimplemented WMMA variant is encountered. static constexpr StringRef kInvalidCaseStr = "Unimplemented WMMA variant, Only M16N16K16 version implemented."; /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. static LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) { StringRef operandStr = type.getOperand(); assert(type.getElementType().isa()); Type baseType = type.getElementType().isF16() ? VectorType::get(2, type.getElementType()) : type.getElementType(); auto getLLVMType = [&](int64_t numElements) { return LLVM::LLVMStructType::getLiteral( type.getContext(), SmallVector(numElements, baseType)); }; if (operandStr.equals("AOp") || operandStr.equals("BOp")) return getLLVMType(8); if (type.getElementType().isF16()) return getLLVMType(4); return getLLVMType(8); } /// This class implements the conversion of GPU MMA loadOp to wmma.load op /// in the NVVM dialect. The conversion not only emits the NVVM op but also /// emits code that is necessary to store the data in the destination memref /// after it has been loaded. struct WmmaLoadOpToNVVMLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Operation *op = subgroupMmaLoadMatrixOp.getOperation(); if (failed(areAllLLVMTypes(op, operands, rewriter))) return failure(); unsigned indexTypeBitwidth = this->getTypeConverter()->getIndexTypeBitwidth(); // The corresponding intrinsics expects leadDimension to be a 32-bit // integer, so all the calculations of linearizing the load address // must also follow this restriction. if (indexTypeBitwidth != 32) return rewriter.notifyMatchFailure( op, "Expected indices to the memref to be 32-bit wide."); Location loc = op->getLoc(); auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr(); gpu::SubgroupMmaLoadMatrixOpAdaptor adaptor(operands); // MemRefDescriptor to extract alignedPtr and offset. MemRefDescriptor promotedSrcOp(adaptor.srcMemref()); // Emit ops which compute the load offset using `srcOffsetI`, // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are // assumed to be normalized and hence the simple conversion works. SmallVector indices(adaptor.indices()); Value srcOffsetIVal = indices[0]; Value srcOffsetJVal = indices[1]; Type i32Ty = rewriter.getI32Type(); Value leadingDim32 = rewriter.create(loc, i32Ty, leadDimension); Value numElemsLeadDim = rewriter.create(loc, i32Ty, leadingDim32, srcOffsetIVal); Value loadOffset = rewriter.create(loc, i32Ty, numElemsLeadDim, srcOffsetJVal); Value promotedSrcOpToUse; promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc); Value actualOffset = rewriter.create(loc, i32Ty, loadOffset, promotedSrcOpToUse); Value loadAddress = rewriter.create( loc, promotedSrcOp.getElementPtrType(), promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef{actualOffset}); // Bitcast the base address pointer of the destination memref, So that // values can be stored in chunks of 32-bits and semantics match with the // intrinsic exposed by NVPTX backend. Value loadAddressCasted = rewriter.create( loc, LLVM::LLVMPointerType::get( i32Ty, promotedSrcOp.getElementPtrType().getAddressSpace()), loadAddress); // Get the shape of the MMAMatrix type being returned. The shape will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType retType = subgroupMmaLoadMatrixOp.res().getType().cast(); ArrayRef retTypeShape = retType.getShape(); Type resType = convertMMAToLLVMType(retType); StringRef operandStr = retType.getOperand(); // Create nvvm.mma_load op according to the operand types. SmallVector loadOpOperands({loadAddressCasted, leadingDim32}); if (operandStr.equals("AOp")) { if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { rewriter.replaceOpWithNewOp(op, resType, loadOpOperands); } else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } } else if (operandStr.equals("BOp")) { if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { rewriter.replaceOpWithNewOp(op, resType, loadOpOperands); } else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } } else { if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { if (retType.getElementType().isF16()) { rewriter.replaceOpWithNewOp( op, resType, loadOpOperands); } else if (retType.getElementType().isF32()) { rewriter.replaceOpWithNewOp( op, resType, loadOpOperands); } } else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } } return success(); } }; /// This class implements the conversion of GPU MMA storeOp to wmma.store op /// in the NVVM dialect. The conversion not only emits the NVVM op but also /// emits code that is necessary to unpack the data in the source and /// convert the data in the format that is needed by the NVVM op. struct WmmaStoreOpToNVVMLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Operation *op = subgroupMmaStoreMatrixOp.getOperation(); if (failed(areAllLLVMTypes(op, operands, rewriter))) return failure(); unsigned indexTypeBitwidth = this->getTypeConverter()->getIndexTypeBitwidth(); // The corresponding intrinsics expects leadDimension to be a 32-bit // integer, so all the calculations of linearizing the store address // must also follow this restriction. if (indexTypeBitwidth != 32) return rewriter.notifyMatchFailure( op, "expected indices to the memref to be 32-bit wide."); Location loc = op->getLoc(); gpu::SubgroupMmaStoreMatrixOpAdaptor adaptor(operands); // MemRefDescriptor to extract alignedPtr and offset. MemRefDescriptor promotedDstOp(adaptor.dstMemref()); auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr(); // Emit ops which compute the store offset using `dstOffsetI`, // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + // ((leadDimension * dstOffsetI) + dstOffsetJ)). SmallVector indices(adaptor.indices()); Value dstOffsetIVal = indices[0]; Value dstOffsetJVal = indices[1]; Type i32Ty = rewriter.getI32Type(); Value leadingDim32 = rewriter.create(loc, i32Ty, leadDimension); Value numElemsLeadDim = rewriter.create(loc, i32Ty, leadingDim32, dstOffsetIVal); Value loadOffset = rewriter.create(loc, i32Ty, numElemsLeadDim, dstOffsetJVal); Value promotedDstOpToUse; promotedDstOpToUse = promotedDstOp.offset(rewriter, loc); Value actualOffset = rewriter.create(loc, i32Ty, loadOffset, promotedDstOpToUse); Value storeAddress = rewriter.create( loc, promotedDstOp.getElementPtrType(), promotedDstOp.alignedPtr(rewriter, loc), ArrayRef{actualOffset}); // Bitcast the base address pointer of the destination memref, So that // values can be stored in chunks of 32-bits and semantics match with the // intrinsic exposed by NVPTX backend. Value storeAddressCasted = rewriter.create( loc, LLVM::LLVMPointerType::get( i32Ty, promotedDstOp.getElementPtrType().getAddressSpace()), storeAddress); SmallVector storeOpOperands; storeOpOperands.push_back(storeAddressCasted); // Get the shape of the MMAMatrix type being stored. The shape will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType srcType = subgroupMmaStoreMatrixOp.src().getType().cast(); ArrayRef srcTypeShape = srcType.getShape(); auto matrixType = adaptor.src().getType().cast(); for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { Value toUse = rewriter.create( loc, matrixType.getBody()[i], adaptor.src(), rewriter.getI32ArrayAttr(i)); storeOpOperands.push_back(toUse); } storeOpOperands.push_back(leadingDim32); // Unpack the results from the source. if (srcType.getElementType().isF16()) { // Create nvvm.mma_store op. if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) { rewriter.create(loc, storeOpOperands); } else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } rewriter.eraseOp(op); return success(); } if (srcType.getElementType().isF32()) { // Create nvvm.mma_store op. if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) rewriter.create(loc, storeOpOperands); else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } rewriter.eraseOp(op); return success(); } return failure(); } }; /// This class implements the conversion of GPU MMA computeOp to wmma.mma op /// in the NVVM dialect. struct WmmaMmaOpToNVVMLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Operation *op = subgroupMmaComputeOp.getOperation(); if (failed(areAllLLVMTypes(op, operands, rewriter))) return failure(); Location loc = op->getLoc(); // The wmma.mma intrinsic in llvm requires the operands as individual // values. So individual elements from the memrefs need to be extracted and // then passed on to the intrinsic call. Emit llvm ops to extract individual // values form lowered memrefs. SmallVector unpackedOps; auto unpackOp = [&](Value operand) { auto structType = operand.getType().cast(); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { Value toUse = rewriter.create( loc, structType.getBody()[i], operand, rewriter.getI32ArrayAttr(i)); unpackedOps.push_back(toUse); } }; // Get the shapes of the MMAMatrix type being used. The shapes will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType aType = subgroupMmaComputeOp.opA().getType().cast(); ArrayRef aTypeShape = aType.getShape(); gpu::MMAMatrixType bType = subgroupMmaComputeOp.opB().getType().cast(); ArrayRef bTypeShape = bType.getShape(); gpu::MMAMatrixType cType = subgroupMmaComputeOp.opC().getType().cast(); ArrayRef cTypeShape = cType.getShape(); gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands); unpackOp(transformedOperands.opA()); unpackOp(transformedOperands.opB()); unpackOp(transformedOperands.opC()); if (cType.getElementType().isF16()) { if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { // Create nvvm.wmma.mma op. rewriter.replaceOpWithNewOp( op, transformedOperands.opC().getType(), unpackedOps); return success(); } return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } if (cType.getElementType().isF32()) { if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { // Create nvvm.wmma.mma op. rewriter.replaceOpWithNewOp( op, transformedOperands.opC().getType(), unpackedOps); return success(); } return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } return failure(); } }; /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp. struct WmmaConstantOpToNVVMLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), operands, rewriter))) return failure(); Location loc = subgroupMmaConstantOp.getLoc(); Value cst = operands[0]; LLVM::LLVMStructType type = convertMMAToLLVMType( subgroupMmaConstantOp.getType().cast()); // If the element type is a vector create a vector from the operand. if (auto vecType = type.getBody()[0].dyn_cast()) { Value vecCst = rewriter.create(loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { Value idx = rewriter.create( loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getI32IntegerAttr(vecEl)); vecCst = rewriter.create(loc, vecType, vecCst, cst, idx); } cst = vecCst; } Value matrixStruct = rewriter.create(loc, type); for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { matrixStruct = rewriter.create( loc, matrixStruct, cst, rewriter.getI32ArrayAttr(i)); } rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct); return success(); } }; } // anonymous namespace namespace mlir { void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.insert( converter); } } // namespace mlir