//===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::linalg; using namespace mlir::nvgpu; using namespace mlir::transform; #define DEBUG_TYPE "nvgpu-transforms" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define DBGSNL() (llvm::dbgs() << "\n") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") //===----------------------------------------------------------------------===// // PipelineSharedMemoryCopiesOp //===----------------------------------------------------------------------===// /// Returns true if the given type has the default memory space. static bool hasDefaultMemorySpace(BaseMemRefType type) { return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0; } /// Returns true if the given type has the shared (workgroup) memory space. static bool hasSharedMemorySpace(BaseMemRefType type) { auto space = dyn_cast_if_present(type.getMemorySpace()); return space && space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); } /// Returns the value produced by a load from the default memory space. Returns /// null if the operation is not such a load. static Value getValueLoadedFromGlobal(Operation *op) { // TODO: consider an interface or leveraging the memory effects interface. auto load = dyn_cast(op); if (!load) return nullptr; auto loadType = dyn_cast(load.getSource().getType()); if (!loadType || !hasDefaultMemorySpace(loadType)) return nullptr; return load; } /// Returns true if the operation is storing the given value into shared memory. static bool isStoreToShared(Operation *op, Value v) { // TOD: consider an interface or leveraging the memory effects interface. auto store = dyn_cast(op); if (!store || store.getVector() != v) return false; auto storeType = dyn_cast(store.getSource().getType()); return storeType || hasSharedMemorySpace(storeType); } /// Returns true if the operation is a load from the default memory space the /// result of which is only stored into the shared memory space. static bool isLoadFromGlobalStoredToShared(Operation *op) { Value loaded = getValueLoadedFromGlobal(op); if (!loaded || !loaded.hasOneUse()) return false; return isStoreToShared(*loaded.getUsers().begin(), loaded); } /// Populate `ops` with the set of operations that belong to the stage 0 of the /// pipelined version of the given loop when pipelining copies to shared memory. /// Specifically, this collects: /// /// 1. all loads from global memory, both sync and async; /// 2. the barriers for async loads. /// /// In particular, barriers are omitted if they do not dominate at least one /// async load for which there is not yet a barrier. static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet &ops) { llvm::SmallPtrSet barriers; for (Operation &op : *forOp.getBody()) { // Bail on nested ops for now. if (op.getNumRegions() > 0) return failure(); if (isa(op)) { barriers.insert(&op); continue; } if (isa(op)) { ops.insert(&op); ops.insert(std::make_move_iterator(barriers.begin()), std::make_move_iterator(barriers.end())); assert(barriers.empty() && "expected to have moved the barriers into another set"); continue; } if (isLoadFromGlobalStoredToShared(&op)) { ops.insert(&op); continue; } } return success(); } /// Hook for the loop pipeliner that sets the "num groups in flight" attribute /// of async wait operations corresponding to pipelined shared memory copies. // TODO: this currently assumes that there are no groups that could be in flight // in the existing code. static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth) { // Based on the order of copies within the loop we need to set the number // of copies in flight, unless it is already set. auto waitOp = dyn_cast(op); if (!waitOp || waitOp.getNumGroups()) return; int numGroupInFlight = 0; if (part == scf::PipeliningOption::PipelinerPart::Kernel || part == scf::PipeliningOption::PipelinerPart::Prologue) { numGroupInFlight = depth - 1; } else { // By construction there should be no wait op in the prologue as all the // wait should be in the last stage. assert(part == scf::PipeliningOption::PipelinerPart::Epilogue); // Based on the schedule we pick we know how many groups are in flight for // each iteration of the epilogue. numGroupInFlight = depth - 1 - iteration; } waitOp.setNumGroups(numGroupInFlight); } /// Hook for the loop pipeliner that populates `ops` with the stage information /// as follows: /// /// - operations in `stage0Ops` (typically loads from global memory and /// related barriers) are at stage 0; /// - operations in the backward slice of any stage0Ops are all at stage 0; /// - other operations are at stage `depth`; /// - the internal order of the pipelined loop has ops at stage `depth` first, /// then those at stage 0, with relative order within each group preserved. /// static void getPipelineStages( scf::ForOp forOp, std::vector> &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl &stage0Ops) { SetVector dependencies; BackwardSliceOptions options([&](Operation *visited) { return visited->getBlock() == forOp.getBody(); }); options.inclusive = true; for (Operation &op : forOp.getBody()->getOperations()) { if (stage0Ops.contains(&op)) getBackwardSlice(&op, &dependencies, options); } for (Operation &op : forOp.getBody()->getOperations()) { if (!dependencies.contains(&op) && !isa(op)) opsWithPipelineStages.emplace_back(&op, depth); } for (Operation &op : forOp.getBody()->getOperations()) { if (dependencies.contains(&op)) opsWithPipelineStages.emplace_back(&op, 0); } } /// Hook for the loop pipeliner. Replaces op with a predicated version and /// returns the resulting operation. Returns the original op if the predication /// isn't necessary for the given op. Returns null if predication is needed but /// not supported. static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate) { // Some operations may be fine to execute "speculatively" more times than the // original number of iterations, in particular side-effect free operations // and barriers, even if they cannot be predicated. if (isMemoryEffectFree(op) || isa(op)) { return op; } // Otherwise, only async copies can currently be predicated. auto asyncCopyOp = dyn_cast(op); if (!asyncCopyOp) return nullptr; // Create srcElement Value based on `predicate`. The next lines generate // the following code: // // srcElement = (pred) ? prevSrcElements : 0; // Location loc = asyncCopyOp->getLoc(); Value dstElements = rewriter.create(loc, asyncCopyOp.getDstElementsAttr()); Value originalSrcElement = asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements; Value c0Index = rewriter.create(loc, 0); auto srcElements = rewriter.create( loc, predicate, originalSrcElement, c0Index); auto asyncCopyZeroFillOp = rewriter.create( loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, UnitAttr()); rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp); return asyncCopyZeroFillOp; } /// Applies loop pipelining with the given depth to the given loop so that /// copies into the shared memory are pipelined. Doesn't affect other loops. /// Returns a pair containing the error state and the pipelined op, the latter /// being null in case of any failure. The error state contains a definite error /// if the IR has been modified and a silenceable error otherwise. static std::tuple pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling) { llvm::SmallPtrSet stage0Ops; if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) { return std::make_tuple( emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"), scf::ForOp()); } if (stage0Ops.empty()) { return std::make_tuple( emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp()); } scf::PipeliningOption options; unsigned maxDepth = depth; auto setAnnotation = [&](Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration) { return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth); }; options.getScheduleFn = [&](scf::ForOp schedulingFor, std::vector> &ops) { if (schedulingFor != forOp) return; return getPipelineStages(forOp, ops, maxDepth, stage0Ops); }; options.annotateFn = setAnnotation; if (!epiloguePeeling) { options.peelEpilogue = false; options.predicateFn = replaceOpWithPredicatedOp; } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(forOp); bool modifiedIR; FailureOr maybePipelined = pipelineForLoop(rewriter, forOp, options, &modifiedIR); if (succeeded(maybePipelined)) { return std::make_tuple(DiagnosedSilenceableFailure::success(), *maybePipelined); } return std::make_tuple( modifiedIR ? DiagnosedSilenceableFailure::definiteFailure() : emitSilenceableFailure(forOp, "pipelining preconditions failed"), scf::ForOp()); } DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne( TransformRewriter &rewriter, scf::ForOp forOp, ApplyToEachResultList &results, TransformState &state) { auto [diag, pipelined] = pipelineForSharedCopies( rewriter, forOp, static_cast(getDepth()), getPeelEpilogue()); if (diag.succeeded()) { results.push_back(pipelined); return DiagnosedSilenceableFailure::success(); } if (diag.isDefiniteFailure()) { auto diag = emitDefiniteFailure("irreversible pipelining failure"); if (!getPeelEpilogue()) { diag.attachNote(forOp->getLoc()) << "couldn't predicate?"; diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName(); } return diag; } return std::move(diag); } //===----------------------------------------------------------------------===// // RewriteMatmulAsMmaSyncOp //===----------------------------------------------------------------------===// /// Helper struct to encode a pair of row/column indexings in the form of /// affine expressions. struct RowColIndexing : private std::pair { RowColIndexing(AffineExpr row, AffineExpr col) : std::pair(row, col) {} AffineExpr row() const { return first; }; AffineExpr col() const { return second; }; void print(llvm::raw_ostream &os) const { os << "- indexing: " << first << ", " << second; } }; /// Helper struct to provide a simple mapping from matmul operations to the /// corresponding mma.sync operation. This is constrained to the case where the /// matmul matches the mma.sync operation 1-1. struct MmaSyncBuilder { MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId) : b(b), loc(loc), laneId(laneId) {} using IndexCalculator = std::function(MLIRContext *)>; /// Create the mma.sync operation corresponding to `linalgOp` along with all /// the supporting load/store and vector operations. FailureOr buildMmaSync(LinalgOp linalgOp); private: struct MmaSyncInfo { std::tuple indexFns; std::tuple, SmallVector, SmallVector> vectorShapes; SmallVector mmaShape; bool tf32Enabled; }; /// Return the specific index calculator for the given `linalgOp` or failure /// if the op is not supported. This is the toplevel switch that should just /// be Tablegen'd in the future. FailureOr getIndexCalculators(ArrayRef opShape, TypeRange elementalTypes); //===--------------------------------------------------------------------===// // Instruction-specific row, column indexing expression builders. // These should all be declaratively specified via Tablegen in the future. // The Tablegen specification should be as straightforward as possible to // only model the existing size and type combinations. //===--------------------------------------------------------------------===// // // TODO: Tablegen all this. //===--------------------------------------------------------------------===// // m16n8k4 tf32 case. //===--------------------------------------------------------------------===// /// From the NVIDIA doc: /// groupID = %laneid >> 2 /// threadIDInGroup = %laneid % 4 /// row = groupID for a0 /// groupID + 8 for a1 /// col = threadIDInGroup static SmallVector m16n8k4tf32Lhs(MLIRContext *ctx) { auto dim = getAffineDimExpr(0, ctx); AffineExpr groupID = dim.floorDiv(4); AffineExpr threadIDInGroup = dim % 4; return {RowColIndexing{groupID, threadIDInGroup}, RowColIndexing{groupID + 8, threadIDInGroup}}; } /// From the NVIDIA doc: /// groupID = %laneid >> 2 /// threadIDInGroup = %laneid % 4 /// row = threadIDInGroup /// col = groupID static SmallVector m16n8k4tf32Rhs(MLIRContext *ctx) { auto dim = getAffineDimExpr(0, ctx); AffineExpr groupID = dim.floorDiv(4); AffineExpr threadIDInGroup = dim % 4; return {RowColIndexing{threadIDInGroup, groupID}}; } /// From the NVIDIA doc: /// groupID = %laneid >> 2 /// threadIDInGroup = %laneid % 4 /// row = groupID for c0 and c1 /// groupID + 8 for c2 and c3 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} static SmallVector m16n8k4tf32Res(MLIRContext *ctx) { auto dim = getAffineDimExpr(0, ctx); AffineExpr groupID = dim.floorDiv(4); AffineExpr threadIDInGroup = dim % 4; return {RowColIndexing{groupID, threadIDInGroup * 2 + 0}, RowColIndexing{groupID, threadIDInGroup * 2 + 1}, RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}}; } //===--------------------------------------------------------------------===// // m16n8k16 f16 case. //===--------------------------------------------------------------------===// /// From the NVIDIA doc: /// groupID = %laneid >> 2 /// threadIDInGroup = %laneid % 4 /// /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 /// groupID + 8 Otherwise /// /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4 static SmallVector m16n8k16f16Lhs(MLIRContext *ctx) { auto dim = getAffineDimExpr(0, ctx); AffineExpr groupID = dim.floorDiv(4); AffineExpr threadIDInGroup = dim % 4; // clang-format off return { RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7 }; // clang-format on } /// From the NVIDIA doc: /// groupID = %laneid >> 2 /// threadIDInGroup = %laneid % 4 /// /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2 /// /// col = groupID static SmallVector m16n8k16f16Rhs(MLIRContext *ctx) { auto dim = getAffineDimExpr(0, ctx); AffineExpr groupID = dim.floorDiv(4); AffineExpr threadIDInGroup = dim % 4; // clang-format off return { RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0 RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3 }; // clang-format on } /// From the NVIDIA doc: /// groupID = %laneid >> 2 /// threadIDInGroup = %laneid % 4 /// /// row = groupID for ci where i < 2 /// groupID + 8 for ci where i >= 2 /// /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} static SmallVector m16n8k16f16Res(MLIRContext *ctx) { auto dim = getAffineDimExpr(0, ctx); AffineExpr groupID = dim.floorDiv(4); AffineExpr threadIDInGroup = dim % 4; // clang-format off return { RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3 }; // clang-format on } //===--------------------------------------------------------------------===// /// Helper functions to create customizable load and stores operations. The /// specific shapes of each MMA instruction are passed via the /// IndexCalculator callback. //===--------------------------------------------------------------------===// /// Build a list of memref.load operations indexed at `(row, col)` indices /// that make sense for a particular MMA instruction and specified via the /// IndexCalculator callback. SmallVector buildMemrefLoads(OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn); /// Perform a distributed load of a vector operand of `vectorShape` for a /// particular MMA instruction whose `(row, col)` indices are specified via /// the IndexCalculator callback. Each `laneId` loads the subportion of the /// data that makes sense for the particular MMA operation. /// The `vectorShape` matches existing NVGPU dialect op specification but /// could also be flattened in the future if needed for simplification. Value buildMmaSyncMemrefLoadOperand(OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape); /// Build a list of memref.store operations indexed at `(row, col)` indices /// that make sense for a particular MMA instruction and specified via the /// IndexCalculator callback. SmallVector buildMemrefStores(OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn); /// Perform a distributed store of a vector operand of `vectorShape` for a /// particular MMA instruction whose `(row, col)` indices are specified via /// the IndexCalculator callback. Each `laneId` loads the subportion of the /// data that makes sense for the particular MMA operation. /// The `vectorShape` matches existing NVGPU dialect op specification but /// could also be flattened in the future if needed for simplification. SmallVector buildMmaSyncMemrefStoreOperand( OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape); OpBuilder &b; Location loc; OpFoldResult laneId; }; //===--------------------------------------------------------------------===// /// Helper functions to create customizable load and stores operations. The /// specific shapes of each MMA instruction are passed via the /// IndexCalculator callback. //===--------------------------------------------------------------------===// template static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn) { VectorType vectorType = vector.getType().cast(); auto vectorShape = vectorType.getShape(); auto strides = computeStrides(vectorShape); for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) { auto indices = delinearize(idx, strides); reduceFn(applyFn(vector, idx, indices), idx, indices); } } SmallVector MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn) { auto aff = [&](AffineExpr e) { return affine::makeComposedFoldedAffineApply(b, loc, e, laneId); }; SmallVector res; SmallVector indexings = indexFn(b.getContext()); for (auto indexing : indexings) { Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); auto load = b.create(loc, memref, ValueRange{row, col}); res.push_back(load); } return res; } Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand( OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape) { auto loads = buildMemrefLoads(b, loc, laneId, memref, indexFn); Type elementType = getElementTypeOrSelf(memref.getType()); auto vt = VectorType::get(vectorShape, elementType); Value res = b.create(loc, vt, loads[0]); foreachIndividualVectorElement( res, /*applyFn=*/ [&](Value v, int64_t linearIdx, ArrayRef indices) { return loads[linearIdx]; }, /*reduceFn=*/ [&](Value v, int64_t linearIdx, ArrayRef indices) { res = b.create(loc, v, res, indices); }); return res; } SmallVector MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn) { auto aff = [&](AffineExpr e) { return affine::makeComposedFoldedAffineApply(b, loc, e, laneId); }; SmallVector res; for (auto [indexing, val] : llvm::zip_equal(indexFn(b.getContext()), toStore)) { Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); Operation *store = b.create(loc, val, memref, ValueRange{row, col}); res.push_back(store); } return res; } SmallVector MmaSyncBuilder::buildMmaSyncMemrefStoreOperand( OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape) { SmallVector toStore; toStore.reserve(32); foreachIndividualVectorElement( vectorToStore, /*applyFn=*/ [&](Value v, int64_t linearIdx, ArrayRef indices) { return b.create(loc, vectorToStore, indices); }, /*reduceFn=*/ [&](Value v, int64_t linearIdx, ArrayRef indices) { toStore.push_back(v); }); return buildMemrefStores(b, loc, toStore, laneId, memref, indexFn); } static std::tuple, SmallVector, SmallVector> makeVectorShapes(ArrayRef lhs, ArrayRef rhs, ArrayRef res) { SmallVector vlhs{lhs.begin(), lhs.end()}; SmallVector vrhs{rhs.begin(), rhs.end()}; SmallVector vres{res.begin(), res.end()}; return std::make_tuple(vlhs, vrhs, vres); } FailureOr MmaSyncBuilder::getIndexCalculators(ArrayRef opShape, TypeRange elementalTypes) { // TODO: Tablegen all this. Type f16 = b.getF16Type(); Type f32 = b.getF32Type(); if (opShape == ArrayRef{16, 8, 4} && elementalTypes == TypeRange{f32, f32, f32}) { return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs, &MmaSyncBuilder::m16n8k4tf32Rhs, &MmaSyncBuilder::m16n8k4tf32Res), makeVectorShapes({2, 1}, {1, 1}, {2, 2}), SmallVector{opShape.begin(), opShape.end()}, /*tf32Enabled=*/true}; } // This is the version with f16 accumulation. // TODO: version with f32 accumulation. if (opShape == ArrayRef{16, 8, 16} && elementalTypes == TypeRange{f16, f16, f16}) { return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs, &MmaSyncBuilder::m16n8k16f16Rhs, &MmaSyncBuilder::m16n8k16f16Res), makeVectorShapes({4, 2}, {2, 2}, {2, 2}), SmallVector{opShape.begin(), opShape.end()}, /*tf32Enabled=*/false}; } return failure(); } FailureOr MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { Value lhsMemref = linalgOp.getDpsInputOperand(0)->get(); Value rhsMemref = linalgOp.getDpsInputOperand(1)->get(); Value resMemref = linalgOp.getDpsInitOperand(0)->get(); assert(lhsMemref.getType().cast().getRank() == 2 && "expected lhs to be a 2D memref"); assert(rhsMemref.getType().cast().getRank() == 2 && "expected rhs to be a 2D memref"); assert(resMemref.getType().cast().getRank() == 2 && "expected res to be a 2D memref"); int64_t m = cast(lhsMemref.getType()).getShape()[0]; int64_t n = cast(rhsMemref.getType()).getShape()[1]; int64_t k = cast(lhsMemref.getType()).getShape()[1]; Type lhsType = getElementTypeOrSelf(lhsMemref.getType()); Type rhsType = getElementTypeOrSelf(rhsMemref.getType()); Type resType = getElementTypeOrSelf(resMemref.getType()); FailureOr maybeInfo = getIndexCalculators({m, n, k}, {lhsType, rhsType, resType}); if (failed(maybeInfo)) return failure(); MmaSyncInfo info = *maybeInfo; auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; auto [lhsShape, rhsShape, resShape] = info.vectorShapes; Value lhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, lhsMemref, lhsIndexFn, lhsShape); Value rhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, rhsMemref, rhsIndexFn, rhsShape); Value res = buildMmaSyncMemrefLoadOperand(b, loc, laneId, resMemref, resIndexFn, resShape); res = b.create(loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled); buildMmaSyncMemrefStoreOperand(b, loc, res, laneId, resMemref, resIndexFn, resShape); return res.getDefiningOp(); } DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( transform::TransformRewriter &rewriter, LinalgOp linalgOp, transform::ApplyToEachResultList &results, transform::TransformState &state) { bool fail = true; // TODO: more robust detection of matmulOp, with transposes etc. if (isa_and_nonnull(linalgOp.getOperation())) { Location loc = linalgOp.getLoc(); // TODO: more robust computation of laneId, for now assume a single warp. Value laneId = rewriter.create( loc, rewriter.getIndexType(), gpu::Dimension::x); if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp))) fail = false; } if (fail) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "unsupported target op: " << linalgOp; diag.attachNote(linalgOp->getLoc()) << "target op"; return diag; } rewriter.eraseOp(linalgOp); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// namespace { class NVGPUTransformDialectExtension : public transform::TransformDialectExtension< NVGPUTransformDialectExtension> { public: NVGPUTransformDialectExtension() { declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" >(); } }; } // namespace #define GET_OP_CLASSES #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) { registry.addExtensions(); }