//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements lowering of vector operations to GPU dialect ops. // //===----------------------------------------------------------------------===// #include #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "../PassDetail.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" using namespace mlir; // Return true if the contract op can be convert to MMA matmul. static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { if (llvm::size(contract.masks()) != 0) return false; using MapList = ArrayRef>; auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; bindDims(contract.getContext(), m, n, k); auto iteratorTypes = contract.iterator_types().getValue(); if (!(isParallelIterator(iteratorTypes[0]) && isParallelIterator(iteratorTypes[1]) && isReductionIterator(iteratorTypes[2]))) return false; // The contract needs to represent a matmul to be able to convert to // MMAMatrix matmul. if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) return false; return true; } // Return the stide for the dimension 0 of |type| if it is a memref and has a // constant stride. static llvm::Optional getMemrefConstantHorizontalStride(ShapedType type) { auto memrefType = type.dyn_cast(); if (!memrefType) return false; // If the memref is 0 or 1D the horizontal stride is 0. if(memrefType.getRank() < 2) return 0; int64_t offset = 0; SmallVector strides; if (failed(getStridesAndOffset(memrefType, strides, offset))) return llvm::None; int64_t stride = strides[strides.size() - 2]; if (stride == ShapedType::kDynamicStrideOrOffset) return llvm::None; return stride; } // Return true if the transfer op can be converted to a MMA matrix load. static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { if (readOp.mask() || readOp.hasOutOfBoundsDim() || readOp.getVectorType().getRank() != 2) return false; if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) return false; AffineMap map = readOp.permutation_map(); OpBuilder b(readOp.getContext()); AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); AffineExpr zero = b.getAffineConstantExpr(0); auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, readOp.getContext()); // TODO: Support transpose once it is added to GPU dialect ops. // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). return !(!map.isMinorIdentity() && map != broadcastInnerDim); } // Return true if the transfer op can be converted to a MMA matrix store. static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { // TODO: support 0-d corner case. if (writeOp.getTransferRank() == 0) return false; if (writeOp.mask() || writeOp.hasOutOfBoundsDim() || writeOp.getVectorType().getRank() != 2) return false; if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) return false; // TODO: Support transpose once it is added to GPU dialect ops. if (!writeOp.permutation_map().isMinorIdentity()) return false; return true; } /// Return true if the constant is a splat to a 2D vector so that it can be /// converted to a MMA constant matrix op. static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { auto vecType = constantOp.getType().dyn_cast(); if (!vecType || vecType.getRank() != 2) return false; return constantOp.getValue().isa(); } /// Return true if this is a broadcast from scalar to a 2D vector. static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { return broadcastOp.getVectorType().getRank() == 2 && broadcastOp.source().getType().isa(); } /// Return the MMA elementwise enum associated with `op` if it is supported. /// Return `llvm::None` otherwise. static llvm::Optional convertElementwiseOpToMMA(Operation *op) { if (isa(op)) return gpu::MMAElementwiseOp::ADDF; if (isa(op)) return gpu::MMAElementwiseOp::MULF; if (isa(op)) return gpu::MMAElementwiseOp::MAXF; if (isa(op)) return gpu::MMAElementwiseOp::MINF; if (isa(op)) return gpu::MMAElementwiseOp::DIVF; return llvm::None; } /// Return true if the op is supported as elementwise op on MMAMatrix type. static bool elementwiseSupportsMMAMatrixType(Operation *op) { return convertElementwiseOpToMMA(op).hasValue(); } static bool supportsMMaMatrixType(Operation *op) { if (isa(op)) return true; if (auto transferRead = dyn_cast(op)) return transferReadSupportsMMAMatrixType(transferRead); if (auto transferWrite = dyn_cast(op)) return transferWriteSupportsMMAMatrixType(transferWrite); if (auto contract = dyn_cast(op)) return contractSupportsMMAMatrixType(contract); if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) return broadcastSupportsMMAMatrixType(broadcast); return elementwiseSupportsMMAMatrixType(op); } /// Return an unsorted slice handling scf.for region differently than /// `getSlice`. In scf.for we only want to include as part of the slice elements /// that are part of the use/def chain. static SetVector getSliceContract(Operation *op, TransitiveFilter backwardFilter, TransitiveFilter forwardFilter) { SetVector slice; slice.insert(op); unsigned currentIndex = 0; SetVector backwardSlice; SetVector forwardSlice; while (currentIndex != slice.size()) { auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); getBackwardSlice(currentOp, &backwardSlice, backwardFilter); slice.insert(backwardSlice.begin(), backwardSlice.end()); // Compute and insert the forwardSlice starting from currentOp. forwardSlice.clear(); // Special case for ForOp, we don't want to include the whole region but // only the value using the region arguments. // TODO: We should refine this to only care about the region arguments being // converted to matrix type. if (auto forOp = dyn_cast(currentOp)) { for (Value forOpResult : forOp.getResults()) getForwardSlice(forOpResult, &forwardSlice, forwardFilter); for (BlockArgument &arg : forOp.getRegionIterArgs()) getForwardSlice(arg, &forwardSlice, forwardFilter); } else { getForwardSlice(currentOp, &forwardSlice, forwardFilter); } slice.insert(forwardSlice.begin(), forwardSlice.end()); ++currentIndex; } return slice; } // Analyze slice of operations based on convert op to figure out if the whole // slice can be converted to MMA operations. static SetVector getOpToConvert(mlir::Operation *op) { auto hasVectorDest = [](Operation *op) { return llvm::any_of(op->getResultTypes(), [](Type t) { return t.isa(); }); }; auto hasVectorSrc = [](Operation *op) { return llvm::any_of(op->getOperandTypes(), [](Type t) { return t.isa(); }); }; SetVector opToConvert; op->walk([&](vector::ContractionOp contract) { if (opToConvert.contains(contract.getOperation())) return; SetVector dependentOps = getSliceContract(contract, hasVectorDest, hasVectorSrc); // If any instruction cannot use MMA matrix type drop the whole // chain. MMA matrix are stored in an opaque type so they cannot be used // by all operations. if (llvm::any_of(dependentOps, [](Operation *op) { return !supportsMMaMatrixType(op); })) return; opToConvert.insert(dependentOps.begin(), dependentOps.end()); }); // Sort the operations so that we can convert them in topological order. return topologicalSort(opToConvert); } namespace { // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted // to MMA matmul. struct PrepareContractToGPUMMA : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); // Set up the parallel/reduction structure in right form. using MapList = ArrayRef>; auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; bindDims(rewriter.getContext(), m, n, k); static constexpr std::array perm = {1, 0}; auto iteratorTypes = op.iterator_types().getValue(); SmallVector maps = op.getIndexingMaps(); if (!(isParallelIterator(iteratorTypes[0]) && isParallelIterator(iteratorTypes[1]) && isReductionIterator(iteratorTypes[2]))) return failure(); // // Two outer parallel, one inner reduction (matmat flavor). // if (maps == infer({{m, k}, {k, n}, {m, n}})) { // This is the classical row-major matmul, nothing to do. return failure(); } if (maps == infer({{m, k}, {n, k}, {m, n}})) { rhs = rewriter.create(loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { lhs = rewriter.create(loc, lhs, perm); } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { rhs = rewriter.create(loc, rhs, perm); lhs = rewriter.create(loc, lhs, perm); } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { std::swap(rhs, lhs); rhs = rewriter.create(loc, rhs, perm); lhs = rewriter.create(loc, lhs, perm); } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { std::swap(rhs, lhs); rhs = rewriter.create(loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { std::swap(lhs, rhs); lhs = rewriter.create(loc, lhs, perm); } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { std::swap(lhs, rhs); } else { return failure(); } rewriter.replaceOpWithNewOp( op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), op.iterator_types()); return success(); } }; // Merge transpose op into the transfer read op. Transpose are not supported on // MMA types but MMA load can transpose the matrix when loading. struct CombineTransferReadOpTranspose final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { auto transferReadOp = op.vector().getDefiningOp(); if (!transferReadOp) return failure(); // TODO: support 0-d corner case. if (transferReadOp.getTransferRank() == 0) return failure(); if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim()) return failure(); SmallVector perm; op.getTransp(perm); SmallVector permU; for (int64_t o : perm) permU.push_back(unsigned(o)); AffineMap permutationMap = AffineMap::getPermutationMap(permU, op.getContext()); AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map()); rewriter.replaceOpWithNewOp( op, op.getType(), transferReadOp.source(), transferReadOp.indices(), AffineMapAttr::get(newMap), transferReadOp.padding(), transferReadOp.mask(), transferReadOp.in_boundsAttr()); return success(); } }; } // namespace // MMA types have different layout based on how they are used in matmul ops. // Figure the right layout to use by looking at op uses. // TODO: Change the GPU dialect to abstract the layout at the this level and // only care about it during lowering to NVVM. template static const char *inferFragType(OpTy op) { for (Operation *users : op->getUsers()) { auto contract = dyn_cast(users); if (!contract) continue; if (contract.lhs() == op.getResult()) return "AOp"; if (contract.rhs() == op.getResult()) return "BOp"; } return "COp"; } static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap &valueMapping) { assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); assert(transferReadSupportsMMAMatrixType(op)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); AffineMap map = op.permutation_map(); // Handle broadcast by setting the stride to 0. if (map.getResult(0).isa()) { assert(map.getResult(0).cast().getValue() == 0); stride = 0; } assert(stride); const char *fragType = inferFragType(op); gpu::MMAMatrixType type = gpu::MMAMatrixType::get(op.getVectorType().getShape(), op.getVectorType().getElementType(), fragType); OpBuilder b(op); Value load = b.create( op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride)); valueMapping[op.getResult()] = load; } static void convertTransferWriteOp(vector::TransferWriteOp op, llvm::DenseMap &valueMapping) { assert(transferWriteSupportsMMAMatrixType(op)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); assert(stride); OpBuilder b(op); Value matrix = valueMapping.find(op.vector())->second; b.create( op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride)); op.erase(); } static void convertContractOp(vector::ContractionOp op, llvm::DenseMap &valueMapping) { OpBuilder b(op); Value opA = valueMapping.find(op.lhs())->second; Value opB = valueMapping.find(op.rhs())->second; Value opC = valueMapping.find(op.acc())->second; Value matmul = b.create(op.getLoc(), opC.getType(), opA, opB, opC); valueMapping[op.getResult()] = matmul; } /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap &valueMapping) { assert(constantSupportsMMAMatrixType(op)); OpBuilder b(op); Attribute splat = op.getValue().cast().getSplatValue(); auto scalarConstant = b.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); auto vecType = op.getType().cast(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); auto matrix = b.create(op.getLoc(), type, scalarConstant); valueMapping[op.getResult()] = matrix; } /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. static void convertBroadcastOp(vector::BroadcastOp op, llvm::DenseMap &valueMapping) { assert(broadcastSupportsMMAMatrixType(op)); OpBuilder b(op); const char *fragType = inferFragType(op); auto vecType = op.getVectorType(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); auto matrix = b.create(op.getLoc(), type, op.source()); valueMapping[op.getResult()] = matrix; } // Replace ForOp with a new ForOp with extra operands. The YieldOp is not // updated and needs to be updated separatly for the loop to be correct. static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, ValueRange newIterOperands) { // Create a new loop before the existing one, with the extra operands. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(loop); auto operands = llvm::to_vector<4>(loop.getIterOperands()); operands.append(newIterOperands.begin(), newIterOperands.end()); scf::ForOp newLoop = b.create(loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), operands); newLoop.getBody()->erase(); newLoop.getLoopBody().getBlocks().splice( newLoop.getLoopBody().getBlocks().begin(), loop.getLoopBody().getBlocks()); for (Value operand : newIterOperands) newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( loop.getNumResults()))) std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); loop.erase(); return newLoop; } static void convertForOp(scf::ForOp op, llvm::DenseMap &valueMapping) { SmallVector newOperands; SmallVector> argMapping; for (const auto &operand : llvm::enumerate(op.getIterOperands())) { auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) continue; argMapping.push_back(std::make_pair( operand.index(), op.getNumIterOperands() + newOperands.size())); newOperands.push_back(it->second); } OpBuilder b(op); scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands); Block &loopBody = *newForOp.getBody(); for (auto mapping : argMapping) { valueMapping[newForOp.getResult(mapping.first)] = newForOp.getResult(mapping.second); valueMapping[loopBody.getArgument(mapping.first + newForOp.getNumInductionVars())] = loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); } } static void convertYieldOp(scf::YieldOp op, llvm::DenseMap &valueMapping) { OpBuilder b(op); auto loop = cast(op->getParentOp()); auto yieldOperands = llvm::to_vector<4>(op.getOperands()); for (const auto &operand : llvm::enumerate(op.getOperands())) { auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) continue; // Replace the yield of old value with the for op argument to make it easier // to remove the dead code. yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; yieldOperands.push_back(it->second); } b.create(op.getLoc(), yieldOperands); op.erase(); } /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap &valueMapping) { OpBuilder b(op); SmallVector matrixOperands; for (Value operand : op->getOperands()) matrixOperands.push_back(valueMapping.find(operand)->second); Value newOp = b.create( op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType); valueMapping[op->getResult(0)] = newOp; } namespace mlir { void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { patterns.add( patterns.getContext()); } void convertVectorToMMAOps(FuncOp funcOp) { SetVector ops = getOpToConvert(funcOp); llvm::DenseMap valueMapping; for (Operation *op : ops) { if (auto transferRead = dyn_cast(op)) { convertTransferReadOp(transferRead, valueMapping); } else if (auto transferWrite = dyn_cast(op)) { convertTransferWriteOp(transferWrite, valueMapping); } else if (auto contractOp = dyn_cast(op)) { convertContractOp(contractOp, valueMapping); } else if (auto constantOp = dyn_cast(op)) { convertConstantOp(constantOp, valueMapping); } else if (auto broadcastOp = dyn_cast(op)) { convertBroadcastOp(broadcastOp, valueMapping); } else if (auto forOp = dyn_cast(op)) { convertForOp(forOp, valueMapping); } else if (auto yiledOp = dyn_cast(op)) { convertYieldOp(yiledOp, valueMapping); } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { convertElementwiseOp(op, *elementwiseType, valueMapping); } } } } // namespace mlir namespace { struct ConvertVectorToGPUPass : public ConvertVectorToGPUBase { void runOnOperation() override { RewritePatternSet patterns(getOperation().getContext()); populatePrepareVectorToMMAPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); convertVectorToMMAOps(getOperation()); } }; } // namespace std::unique_ptr mlir::createConvertVectorToGPUPass() { return std::make_unique(); }