//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h" #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include using namespace mlir; using namespace mlir::linalg; using namespace mlir::transform; #define DEBUG_TYPE "linalg-transforms" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define DBGSNL() (llvm::dbgs() << "\n") #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` /// function that returns the "main" result or failure. Returns failure if the /// pattern failed to apply. Extra arguments are forwarded to the pattern /// constructor. template static FailureOr tryApply(Operation *operation, Args &&...args) { // Check if the given operation has the type expected by the pattern. using OpTy = typename llvm::function_traits< decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; auto op = dyn_cast(operation); if (!op) return failure(); // Apply the pattern directly to the op. PatternTy pattern(operation->getContext(), std::forward(args)...); // We want to discourage direct use of PatternRewriter in APIs but In this // very specific case, an IRRewriter is not enough. struct TrivialPatternRewriter : public PatternRewriter { public: explicit TrivialPatternRewriter(MLIRContext *context) : PatternRewriter(context) {} }; TrivialPatternRewriter rewriter(operation->getContext()); rewriter.setInsertionPoint(operation); auto result = pattern.returningMatchAndRewrite(op, rewriter); if (failed(result)) return failure(); return cast(result->getOperation()); } /// Assuming that `ofr` is an index attr or a param of index type /// or a transform dialect handle mapped to exactly one op /// with one index result, return that value. static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations( transform::TransformState &state, TransformOpInterface transformOp, SmallVector &result, ArrayRef ofrs) { for (OpFoldResult ofr : ofrs) { if (ofr.is()) { if (!isa(ofr.get())) return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; result.push_back(ofr); continue; } Value transformValue = ofr.get(); if (isa(transformValue.getType())) { ArrayRef params = state.getParams(transformValue); if (params.size() != 1) return transformOp.emitDefiniteFailure() << "requires exactly one parameter associated"; result.push_back(params[0]); continue; } auto payloadOps = state.getPayloadOps(transformValue); if (!llvm::hasSingleElement(payloadOps)) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "handle must be mapped to exactly one payload op"; diag.attachNote(transformValue.getLoc()) << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; return diag; } Operation *op = *payloadOps.begin(); if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "payload op must have exactly 1 index result"; diag.attachNote(op->getLoc()) << "has " << op->getNumResults() << " results"; return diag; } result.push_back(op->getResult(0)); } return DiagnosedSilenceableFailure::success(); } // Given a list of params that are index attrs or a list of OpFoldResults // that are either index attrs or op handles, return a list of OpFoldResults // of index attrs or a list of OpFoldResults where all op handles are // replaced with the first (and only) OpResult of that payload op. // (There must be exactly one parameter associated with the AnyParamType or // one mapped payload op which must have exactly one index result.) static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations( transform::TransformState &state, TransformOpInterface transformOp, SmallVector &result, Value packedHandle) { if (isa(packedHandle.getType())) { ArrayRef params = state.getParams(packedHandle); for (auto param : params) { if (!isa(param)) return transformOp.emitDefiniteFailure() << "expected the parameter to be associated with an integer " "attribute"; result.push_back(param); } return DiagnosedSilenceableFailure::success(); } for (Operation *op : state.getPayloadOps(packedHandle)) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "payload op must have exactly 1 index result"; diag.attachNote(op->getLoc()) << "has " << op->getNumResults() << " results"; return diag; } result.push_back(op->getResult(0)); } return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::populateEraseUnnecessaryInputsPatterns(patterns); } void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::ControlDropUnitDims options; linalg::populateFoldUnitExtentDimsPatterns(patterns, options); } void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::ControlDropUnitDims options; options.rankReductionStrategy = linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice; linalg::populateFoldUnitExtentDimsPatterns(patterns, options); } void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::populateLinalgTilingCanonicalizationPatterns(patterns); } //===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// void transform::BufferizeToAllocationOp::build(OpBuilder &b, OperationState &result, Value target, Attribute memorySpace) { SmallVector resultTypes; resultTypes.push_back(b.getType()); resultTypes.push_back(b.getType()); return build(b, result, /*resultTypes=*/resultTypes, /*target=*/target, /*memorySpace=*/memorySpace); } void transform::BufferizeToAllocationOp::build(OpBuilder &b, OperationState &result, Value target, int64_t memorySpace) { SmallVector resultTypes; resultTypes.push_back(b.getType()); resultTypes.push_back(b.getType()); return build(b, result, /*resultTypes=*/resultTypes, /*target=*/target, /*memorySpace=*/b.getI64IntegerAttr(memorySpace)); } namespace { class NewOpsListener : public RewriterBase::ForwardingListener { public: using RewriterBase::ForwardingListener::ForwardingListener; SmallVector getNewOps() const { return SmallVector(newOps.begin(), newOps.end()); } private: void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override { ForwardingListener::notifyOperationInserted(op, previous); // We only care about newly created ops. if (previous.isSet()) return; auto inserted = newOps.insert(op); (void)inserted; assert(inserted.second && "expected newly created op"); } void notifyOperationErased(Operation *op) override { ForwardingListener::notifyOperationErased(op); op->walk([&](Operation *op) { newOps.erase(op); }); } DenseSet newOps; }; } // namespace DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { // Attach listener to keep track of newly created ops. OpBuilder::Listener *previousListener = rewriter.getListener(); auto resetListener = llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); }); NewOpsListener newOpsListener(previousListener); rewriter.setListener(&newOpsListener); linalg::BufferizeToAllocationOptions options; if (getMemcpyOp() == "bufferization.materialize_in_destination") { options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp:: MaterializeInDestination; } else if (getMemcpyOp() == "memref.copy") { options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy; } else if (getMemcpyOp() == "linalg.copy") { options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy; } else { llvm_unreachable("invalid memcpy op"); } if (getAllocOp() == "memref.alloc") { options.allocOp = linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc; } else if (getAllocOp() == "memref.alloca") { options.allocOp = linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca; } else { llvm_unreachable("invalid alloc op"); } options.bufferizeDestinationOnly = getBufferizeDestinationOnly(); options.emitDealloc = getEmitDealloc(); // Bufferize ops. Attribute memorySpace = getMemorySpace().has_value() ? getMemorySpace().value() : Attribute(); SmallVector allocatedBuffers; for (Operation *op : state.getPayloadOps(getTarget())) { Value buffer = linalg::bufferizeToAllocation(rewriter, options, op, memorySpace); if (!buffer) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "failed to bufferize operation"; diag.attachNote(op->getLoc()) << "target payload op"; return diag; } allocatedBuffers.push_back(buffer); } // Set results. results.setValues(cast(getAllocatedBuffer()), allocatedBuffers); results.set(cast(getNewOps()), newOpsListener.getNewOps()); return DiagnosedSilenceableFailure::success(); } void transform::BufferizeToAllocationOp::getEffects( SmallVectorImpl &effects) { if (getBufferizeDestinationOnly()) { // The destination is replaced with a newly allocated buffer, but the op // itself remains in place. onlyReadsHandle(getTarget(), effects); } else { consumesHandle(getTarget(), effects); } producesHandle(getAllocatedBuffer(), effects); producesHandle(getNewOps(), effects); modifiesPayload(effects); } LogicalResult transform::BufferizeToAllocationOp::verify() { if (getMemcpyOp() != "bufferization.materialize_in_destination" && getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy") return emitOpError() << "unsupported memcpy op"; if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca") return emitOpError() << "unsupported alloc op"; return success(); } //===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter, LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { #define DOWNSCALE(trans) \ { \ FailureOr res = tryApply(target); \ if (succeeded(res)) { \ results.push_back(*res); \ return DiagnosedSilenceableFailure::success(); \ } \ } #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b)) DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp) DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp) DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp) DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp) DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp) DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp) DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp) DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp) DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp) DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp) DOWNSCALE(DownscaleConv2DOp) #undef DOWNSCALE_NORMAL #undef DOWNSCALE_CALL #undef DOWNSCALE return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // DecomposeInterfaceOp //===----------------------------------------------------------------------===// // Decompose the target operation if it implements the AggregatedOpInterface. // Push the decomposed operations (the ones that replaces the values produced by // \p target) in the `results`. DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { auto decomposableOp = dyn_cast(target); if (!decomposableOp) { failed(rewriter.notifyMatchFailure(target, "payload is not a decomposable op")); return emitDefaultSilenceableFailure(target); } FailureOr> maybeNewResults = decomposableOp.decomposeOperation(rewriter); if (failed(maybeNewResults)) return emitDefaultSilenceableFailure(target); rewriter.replaceOp(decomposableOp, *maybeNewResults); for (Value val : *maybeNewResults) { Operation *definition = val.getDefiningOp(); if (definition) results.push_back(definition); } return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // EliminateLinalgOpAnchoredEmptyTensorsOp //===----------------------------------------------------------------------===// void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getTarget(), effects); modifiesPayload(effects); } DiagnosedSilenceableFailure transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply( transform::TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { bufferization::OneShotBufferizationOptions options; options.allowReturnAllocsFromLoops = true; for (Operation *target : state.getPayloadOps(getTarget())) { bufferization::OneShotAnalysisState state(target, options); if (failed(analyzeOp(target, state))) return mlir::emitSilenceableFailure(target->getLoc()) << "failed to analyze op"; if (failed(linalg::linalgOpAnchoredEmptyTensorEliminationStep( rewriter, target, state))) return mlir::emitSilenceableFailure(target->getLoc()) << "failed to eliminate LinalgOp anchored tensor.empty ops"; } return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // FuseOp //===----------------------------------------------------------------------===// /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template static LogicalResult applyTilingToAll( RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref(TilingInterface)> applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); rewriter.setInsertionPoint(target); FailureOr tiledResults = applyFn(tilingInterfaceOp); if (failed(tiledResults)) return failure(); // Perform the replacement of tiled and fused values. SmallVector opsToReplace{target}; llvm::append_range(opsToReplace, tiledResults->fusedProducers); for (Operation *toReplace : opsToReplace) { for (OpResult res : toReplace->getResults()) if (auto replacement = tiledResults->replacements.lookup(res)) rewriter.replaceAllUsesWith(res, replacement); if (toReplace->use_empty()) { rewriter.eraseOp(toReplace); } } // Report back the relevant handles to the transform op. tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front()); assert(tiledResults->loops.size() == numLoops && "Mismatched number of loops, tile and fuse transform should have " "failed"); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].push_back(tiledResults->loops[i]); } transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); for (unsigned int i = 0; i < numLoops; ++i) transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); return success(); } DiagnosedSilenceableFailure transform::FuseOp::apply(transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { SmallVector tileSizes = extractFromIntegerArrayAttr(getTileSizes()); SmallVector tileInterchange = extractFromIntegerArrayAttr(getTileInterchange()); scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, tileAndFuseOptions); }); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } LogicalResult transform::FuseOp::verify() { SmallVector permutation = extractFromIntegerArrayAttr(getTileInterchange()); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() << "expects interchange to be a permutation, found " << getTileInterchange(); } SmallVector sizes = extractFromIntegerArrayAttr(getTileSizes()); size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0); if (numExpectedLoops != getNumResults() - 1) return emitOpError() << "expects " << numExpectedLoops << " loop results"; return success(); } //===----------------------------------------------------------------------===// // FuseIntoContainingOp //===----------------------------------------------------------------------===// void transform::FuseIntoContainingOp::build(OpBuilder &builder, OperationState &result, Value producerOp, Value containingOp) { result.addOperands({producerOp, containingOp}); auto resultType = transform::AnyOpType::get(builder.getContext()); result.addTypes({resultType, resultType}); } /// Add new operands to the forall op for users of the producerOp /// that are dominated by the containing scf.forall op. static Operation *replaceForAllWithNewSignature( RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector &offsets, SmallVector &sizes) { // Count number of users not including the containing op SetVector dominatedUsers; DominanceInfo domInfo(containingOp); for (Operation *user : producerOp->getResult(resultNumber).getUsers()) { if (!containingOp->isAncestor(user) && (domInfo.dominates(containingOp, user))) { dominatedUsers.insert(user); } } if (dominatedUsers.empty()) return nullptr; // Create new scf.forall op auto forallOp = cast(containingOp); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(forallOp); // Get new output Location loc = forallOp.getLoc(); auto genericOp = dyn_cast(producerOp); if (!genericOp) return nullptr; SmallVector outputs = genericOp.getOutputs(); SmallVector newOuts(forallOp.getOutputs()); newOuts.push_back(outputs[resultNumber]); // Create new scf.forall op auto newforallOp = rewriter.create( loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts, forallOp.getMapping()); rewriter.eraseBlock(newforallOp.getBody()); newforallOp.getRegion().takeBody(forallOp.getRegion()); // Add additional block argument for new value being returned // and replaces all uses of the new output with corresponding bbArg // inside the scf.forall to enable fusion into this new scf.forall. newforallOp.getBody()->addArgument(newOuts.back().getType(), newOuts.back().getLoc()); auto bbArgs = newforallOp.getBody()->getArguments(); rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(), [&](OpOperand &use) { Operation *op = use.getOwner(); return newforallOp->isProperAncestor(op); }); // Fix terminator scf::InParallelOp terminatorOp = newforallOp.getTerminator(); SmallVector yieldingOps = llvm::to_vector<4>(llvm::map_range( terminatorOp.getYieldingOps(), [](Operation &op) { return &op; })); Operation *firstYieldOp = yieldingOps.front(); rewriter.setInsertionPoint(firstYieldOp); Value src = tileAndFuseResult.tiledValues[0]; Value dst = newforallOp.getRegionIterArgs().back(); SmallVector strides(offsets.size(), rewriter.getIndexAttr(1)); rewriter.create(firstYieldOp->getLoc(), src, dst, offsets, sizes, strides); for (auto result : llvm::enumerate(forallOp.getResults())) { rewriter.replaceAllUsesWith(result.value(), newforallOp->getResult(result.index())); } rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber), newforallOp->getResults().back(), [&](OpOperand &use) { Operation *user = use.getOwner(); return dominatedUsers.contains(user); }); return newforallOp; } /// Find the first "extract" user of `producerOp` and tile it right before its /// use. The tiled op is fused under the `containingOp`. /// Return this fused op on success or nullptr if anything fails. /// If tiled op has uses that are dominated by `containingOp`, return /// a new `containingOp` with results of the fused op appended to /// results of the `containingOp` or nullptr if there are no dominated uses. static std::tuple, Operation *> tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n"); auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) << "producer is not a TileableInterface: " << *producerOp; return {}; } // Search the producer slices accessed within the containing operation. // TODO: Generalize to more extract/insert/parallel_insert triples, maybe // evolve into an interface. auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) { auto sliceOp = dyn_cast(user); return sliceOp && containingOp->isProperAncestor(sliceOp); }); // Find a fusion opportunity. if (it == tileableProducer->getUsers().end()) { diag.attachNote(tileableProducer->getLoc()) << "could not find fusion opportunity for: " << *tileableProducer; return {}; } auto sliceOpToTile = cast(*it); // Try to fuse the producer in-place. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sliceOpToTile); // Tile the producer. int64_t resultNumber = cast(sliceOpToTile.getSource()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); SmallVector offsets = sliceOpToTile.getMixedOffsets(); SmallVector sizes = sliceOpToTile.getMixedSizes(); FailureOr tileAndFuseResult = tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets, sizes); if (failed(tileAndFuseResult)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; return {}; } #ifndef NDEBUG for (auto *tiledOp : tileAndFuseResult->tiledOps) { LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n"); } #endif // Replace the extract op. auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], cast(sliceOpToTile->getResult(0).getType()).getShape()); if (failed(maybeRankReduced)) { diag.attachNote(producerOp->getLoc()) << "shape types don't match (missing canonicalization?):\nTiledOp: " << tileAndFuseResult->tiledValues[0] << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n'; return {}; } rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); // Add new outputs to containing op, if required Operation *newContainingOp = replaceForAllWithNewSignature( rewriter, diag, producerOp, containingOp, *tileAndFuseResult, resultNumber, offsets, sizes); return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp); } /// First, find the first "scf::ForallOp" user of `producerOp` and ensure /// it is exactly the `containingOp`, otherwise bail. /// Then, find the first "extract" user of the tied block argument and tile it /// right before its "extract" use. The tiled op is fused under the /// `containingOp`. /// Return this fused op on success or nullptr if anything fails. static SmallVector tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n"); auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) << "producer is not a TileableInterface: " << *producerOp; return {}; } // Search the first use by a "scf::ForallOp" user. scf::ForallOp forallOp; auto itProducerUses = llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) { forallOp = dyn_cast(use.getOwner()); return forallOp; }); // If it's not from the containing op, return. if (!forallOp || forallOp != containingOp) { diag.attachNote(tileableProducer->getLoc()) << "could not find a use by the containing op: " << *tileableProducer; return {}; } // Search the producer slices accessed within the containing // operation. // TODO: Generalize to more extract/insert/parallel_insert triples. // Maybe evolve into an interface. OpOperand *pUse = &(*itProducerUses); BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse); // Search the producer slices accessed within the containing operation. // TODO: Generalize to more extract/insert/parallel_insert triples, maybe // evolve into an interface. auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) { auto sliceOp = dyn_cast(user); return sliceOp && containingOp->isProperAncestor(sliceOp); }); // Find a fusion opportunity. if (itBBArgUsers == bbArg.getUsers().end()) { diag.attachNote(containingOp->getLoc()) << "could not find fusion opportunity for bbArg: " << bbArg; return {}; } auto sliceOpToTile = cast(*itBBArgUsers); // Try to fuse the producer in-place. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sliceOpToTile); // Replace the use in the tileableProducer before tiling: clone, replace and // then tile. int64_t resultNumber = cast(pUse->get()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); // Gather destination tensors. SmallVector destinationTensors; if (failed(tensor::getOrCreateDestinations( rewriter, tileableProducer->getLoc(), tileableProducer, destinationTensors))) { diag.attachNote(tileableProducer->getLoc()) << "failed to get destination tensors for: " << *tileableProducer; return {}; } IRMapping bvm; bvm.map(destinationTensors[resultNumber], bbArg); auto tileableProducerClone = cast(rewriter.clone(*tileableProducer, bvm)); auto scopeGuard = llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); // Tile the producer. FailureOr tileAndFuseResult = tileableProducerClone.generateResultTileValue( rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), sliceOpToTile.getMixedSizes()); if (failed(tileAndFuseResult)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; return {}; } // Replace the extract op. auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], cast(sliceOpToTile->getResult(0).getType()).getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); // Replace the use in containingOp. rewriter.modifyOpInPlace(containingOp, [&]() { containingOp->setOperand(pUse->getOperandNumber(), destinationTensors.front()); }); return tileAndFuseResult->tiledOps; } static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n"); // Gather all uses inside the containing op. SmallVector uses; for (OpResult result : producerOp->getOpResults()) { for (OpOperand &use : result.getUses()) { if (containingOp->isProperAncestor(use.getOwner())) { uses.push_back(&use); continue; } // Cannot clone and fuse if the use is by the containing op itself: fail // immediately. if (containingOp == use.getOwner()) { diag.attachNote(producerOp->getLoc()) << "producer op use by containing op cannot be fused by cloning"; return nullptr; } } } // Check for a non-empty list of fusion opportunities. if (uses.empty()) { diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning"; return nullptr; } // Clone and fuse inside the containing op. Operation *fusedOp = nullptr; OpOperand *use = uses.front(); // Parallel insert slice is not a valid clone destination. // TODO: Generalize to other type of ops. assert(!isa(use->getOwner()) && "Parallel insert slice is not a valid clone destination"); unsigned resultNumber = cast(use->get()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(use->getOwner()); fusedOp = rewriter.clone(*producerOp); rewriter.modifyOpInPlace( use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); }); return fusedOp; } bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() { // Allow repeated handles since we are fusing everything anyway. return true; } DiagnosedSilenceableFailure transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector fusedOps; auto producerOps = state.getPayloadOps(getProducerOp()); auto containingOps = state.getPayloadOps(getContainingOp()); if (!llvm::hasSingleElement(containingOps)) { return emitDefiniteFailure() << "requires exactly one containing_op handle (got " << llvm::range_size(containingOps) << ")"; } Operation *containingOp = *containingOps.begin(); // If nothing to fuse, propagate success. if (std::empty(producerOps)) { results.set(cast(getFusedOp()), SmallVector{}); results.set(cast(getNewContainingOp()), {containingOp}); return DiagnosedSilenceableFailure::success(); } // Helper function to find the next producer that should be fused. Take any // producer that has a use inside the containing op. SetVector remainingProducers(producerOps.begin(), producerOps.end()); auto getNextProducer = [&]() -> FailureOr { for (const auto &it : enumerate(remainingProducers)) { Operation *producerOp = it.value(); // The containing op may be a user of producerOp: use isAncestor. int64_t numUsesInContainingOp = llvm::count_if(producerOp->getUsers(), [&](Operation *op) { return containingOp->isAncestor(op); }); // TODO: When resolving the TODO below (no duplicate ops), take an op // that has no use among the remaining producers. This is a topological // sorting. if (numUsesInContainingOp > 0) { if (numUsesInContainingOp == 1) remainingProducers.erase(remainingProducers.begin() + it.index()); return producerOp; } } return failure(); }; while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { auto diag = mlir::emitSilenceableFailure(getLoc()) << "could not find next producer to fuse into container"; diag.attachNote(containingOp->getLoc()) << "containing op"; return diag; } Operation *producerOp = *nextProducer; // Default diagnostic, to be complemented with more failure information. Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); diag << "could not fuse " << *producerOp << " into " << *containingOp; // TODO: If there are multiple uses of the producer in the containing op, // we currently tile/clone the op multiple times (once per use). In some // cases, we can tile/clone once and reuse the value for each use. // Futhermore, producers should then be traversed according to a // topological sorting. auto [tiledOps, newContainingOp] = tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); if (!tiledOps.empty()) { LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp); fusedOps.append(tiledOps); if (newContainingOp) { // Update handles associated with the containing op so we don't need to // invalidate them. This is a hack to support better composability // between tiling and fusion while a proper mechanism is being // investigated. // // DO NOT replicate this elsewhere unless you understand what you are // doing. LogicalResult replacementStatus = rewriter.notifyPayloadOperationReplaced(containingOp, newContainingOp); (void)replacementStatus; assert(succeeded(replacementStatus) && "unable to update transform state mapping"); rewriter.eraseOp(containingOp); containingOp = newContainingOp; } continue; } SmallVector tiledContainingOpOperand = tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( rewriter, diag, producerOp, containingOp); if (!tiledContainingOpOperand.empty()) { LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n" << *containingOp); fusedOps.append(tiledContainingOpOperand); continue; } Operation *cloned = cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp); if (cloned) { LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp); fusedOps.push_back(cloned); continue; } return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } results.set(cast(getFusedOp()), fusedOps); results.set(cast(getNewContainingOp()), {containingOp}); return DiagnosedSilenceableFailure::success(); } void transform::FuseIntoContainingOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getProducerOp(), effects); onlyReadsHandle(getContainingOp(), effects); producesHandle(getResults(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // GeneralizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter, LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Exit early if no transformation is needed. if (isa(target)) { results.push_back(target); return DiagnosedSilenceableFailure::success(); } rewriter.setInsertionPoint(target); FailureOr generic = generalizeNamedOp(rewriter, target); if (succeeded(generic)) { results.push_back(generic->getOperation()); return DiagnosedSilenceableFailure::success(); } return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // SpecializeOp //===----------------------------------------------------------------------===/ DiagnosedSilenceableFailure transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter, LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Exit early if the operation is not a generic. if (!isa(target)) { results.push_back(target); return DiagnosedSilenceableFailure::success(); } rewriter.setInsertionPoint(target); FailureOr named = specializeGenericOp(rewriter, cast(target)); if (succeeded(named)) { results.push_back(named->getOperation()); return DiagnosedSilenceableFailure::success(); } return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // InterchangeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter, GenericOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ArrayRef interchangeVector = getIteratorInterchange(); // Exit early if no transformation is needed. if (interchangeVector.empty()) { results.push_back(target); return DiagnosedSilenceableFailure::success(); } unsigned numLoops = cast(target.getOperation()).getNumLoops(); if (interchangeVector.size() != numLoops) { return emitSilenceableError() << getIteratorInterchangeAttrName() << " has length (" << interchangeVector.size() << ") different from the number of loops in the target operation (" << numLoops << ")"; } FailureOr res = interchangeGenericOp(rewriter, target, SmallVector(interchangeVector.begin(), interchangeVector.end())); if (failed(res)) return emitDefiniteFailure() << "failed to apply"; results.push_back(res->getOperation()); return DiagnosedSilenceableFailure::success(); } LogicalResult transform::InterchangeOp::verify() { ArrayRef permutation = getIteratorInterchange(); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() << "expects iterator_interchange to be a permutation, found " << getIteratorInterchange(); } return success(); } //===----------------------------------------------------------------------===// // LowerPackOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( transform::TransformRewriter &rewriter, tensor::PackOp target, transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { rewriter.setInsertionPoint(target); FailureOr res = lowerPack(rewriter, target); if (failed(res)) { return mlir::emitSilenceableFailure(target->getLoc()) << "cannot lower to pad + expand + transpose"; } transformResults.push_back(res->padOp); transformResults.push_back(res->expandShapeOp); transformResults.push_back(res->transposeOp); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // LowerUnPackOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( transform::TransformRewriter &rewriter, tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { rewriter.setInsertionPoint(target); FailureOr res = lowerUnPack(rewriter, target); if (failed(res)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "cannot lower to transpose + collapse + extract"; diag.attachNote(target->getLoc()) << "target payload op"; return diag; } transformResults.push_back(res->emptyOp); transformResults.push_back(res->transposeOp); transformResults.push_back(res->collapseShapeOp); transformResults.push_back(res->extractSliceOp); return DiagnosedSilenceableFailure::success(); } //===---------------------------------------------------------------------===// // MatchOp //===---------------------------------------------------------------------===// void transform::MatchOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef opNames) { result.addOperands(target); result.addAttribute(MatchOp::getOpsAttrName(result.name), builder.getStrArrayAttr(opNames)); result.addTypes(transform::AnyOpType::get(builder.getContext())); } void transform::MatchOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value target, ArrayRef opNames) { result.addOperands(target); result.addAttribute(MatchOp::getOpsAttrName(result.name), builder.getStrArrayAttr(opNames)); result.addTypes(resultTypes); } DiagnosedSilenceableFailure transform::MatchOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { llvm::StringSet<> strs; if (getOps().has_value()) strs.insert(getOps()->getAsValueRange().begin(), getOps()->getAsValueRange().end()); auto payloadOps = state.getPayloadOps(getTarget()); if (!llvm::hasSingleElement(payloadOps)) { return emitDefiniteFailure("requires exactly one target handle"); } SmallVector res; bool incorrectNumOperandTypes = false; auto matchFun = [&](Operation *op) { if (getOps().has_value() && !strs.contains(op->getName().getStringRef())) return; // Interfaces cannot be matched by name, just by ID. // So we specifically encode the interfaces we care about for this op. if (getInterface().has_value()) { auto iface = getInterface().value(); if (iface == transform::MatchInterfaceEnum::LinalgOp && !isa(op)) return; if (iface == transform::MatchInterfaceEnum::TilingInterface && !isa(op)) return; if (iface == transform::MatchInterfaceEnum::LoopLikeInterface && !isa(op)) return; } // Check if all specified attributes match. if (getOpAttrs().has_value()) { DictionaryAttr opAttrs = getOpAttrs().value(); for (NamedAttribute attr : opAttrs) { if (attr.getName() == getInterfaceAttrName() || attr.getName() == getOpsAttrName()) continue; if (!op->hasAttr(attr.getName())) return; if (op->getAttr(attr.getName()) != attr.getValue()) return; } } if (getFilterResultType().has_value()) { Type t = getFilterResultType().value(); if (op->getNumResults() != 1 || op->getResultTypes().front() != t) return; } if (getFilterOperandTypes().has_value()) { mlir::ArrayAttr types = getFilterOperandTypes().value(); auto operandTypes = op->getOperandTypes(); if (types.size() == 1) { // All the operands must must be equal to the specified type auto typeattr = dyn_cast(getFilterOperandTypes().value()[0]); Type t = typeattr.getValue().cast<::mlir::Type>(); if (!llvm::all_of(op->getOperandTypes(), [&](Type operandType) { return operandType == t; })) return; } else { // The operand types must match all the types in the list (in the same // order in with they are specified) if (types.size() != operandTypes.size()) { incorrectNumOperandTypes = true; return; } for (auto [attr, operandType] : llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) { auto typeattr = cast(attr); Type type = typeattr.getValue().cast<::mlir::Type>(); if (type != operandType) return; } } } // All constraints are satisfied. res.push_back(op); return; }; (*payloadOps.begin())->walk(matchFun); if (incorrectNumOperandTypes) return emitDefiniteFailure("If filter_operand_types contains more than a " "type, then it must contain as much types as " "the number of operands in the target ops"); results.set(cast(getResult()), res); return DiagnosedSilenceableFailure::success(); } //===---------------------------------------------------------------------===// // MultiTileSizesOp //===---------------------------------------------------------------------===// static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type) { printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType}); } static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType) { FunctionType funcType; llvm::SMLoc typeLoc = parser.getCurrentLocation(); if (failed(parser.parseType(funcType))) return failure(); if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) { parser.emitError(typeLoc) << "expects a trailing functional type with one " "argument and one result"; } targetType = funcType.getInput(0); lowSizeType = highSizeType = splitPointType = funcType.getResult(0); return success(); } DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( transform::TransformRewriter &rewriter, LinalgOp target, transform::ApplyToEachResultList &results, TransformState &state) { if (isa(getLowSize().getType())) { if (target.hasDynamicShape()) { auto diag = emitSilenceableError() << "cannot compute parametric tile sizes for dynamically " "shaped payload op"; diag.attachNote(target->getLoc()) << "payload op"; return diag; } FailureOr spec = computeStaticMultiTileSizes( target, getDimension(), getTargetSize(), getDivisor()); if (failed(spec)) { return emitSilenceableError() << "failed to compute multi-size tiling sizes"; } Builder builder(target.getContext()); results.assign(llvm::map_range( ArrayRef({spec->lowTileSize, spec->highTileSize, spec->lowTileSize * spec->lowTripCount}), [&builder, this](int64_t value) { return builder.getIntegerAttr( cast(getLowSize().getType()).getType(), value); })); return DiagnosedSilenceableFailure::success(); } OpBuilder builder(target.getContext()); builder.setInsertionPoint(target); OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); OpFoldResult divisor = builder.getIndexAttr(getDivisor()); FailureOr spec = computeMultiTileSizes( builder, target, getDimension(), targetSize, divisor); if (failed(spec)) { return emitSilenceableError() << "could not generate tile size computation"; } AffineExpr s0 = builder.getAffineSymbolExpr(0); AffineExpr s1 = builder.getAffineSymbolExpr(1); Operation *splitPoint = affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1, {spec->lowTileSize, spec->lowTripCount}); Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); Operation *highTileSize = spec->highTileSize.getDefiningOp(); assert(lowTileSize && highTileSize && splitPoint && "tile sizes are not produced by operations"); results.reserve(results.size() + 3); results.push_back(lowTileSize); results.push_back(highTileSize); results.push_back(splitPoint); return DiagnosedSilenceableFailure::success(); } void transform::MultiTileSizesOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getTarget(), effects); producesHandle(getResults(), effects); if (isa(getLowSize().getType())) onlyReadsPayload(effects); else modifiesPayload(effects); } LogicalResult transform::MultiTileSizesOp::verify() { if (getLowSize().getType() != getHighSize().getType() || getLowSize().getType() != getSplitPoint().getType()) { return emitOpError() << "expects all results type to be the same"; } return success(); } //===---------------------------------------------------------------------===// // PackOp //===---------------------------------------------------------------------===// void transform::PackOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedPackedSizes) { SmallVector staticPackedSizes; SmallVector dynamicPackedSizes; dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes, staticPackedSizes); // Call the default builder which sets up the proper operands segment sizes // attributes for multiple variadic operands. In the absence of this, horrible // bugs ensue. Type linalgOpHType = transform::OperationType::get( builder.getContext(), GenericOp::getOperationName()); build(builder, result, /*resultType=*/linalgOpHType, /*target=*/target, /*dynamic_sizes=*/dynamicPackedSizes, /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes)); } SmallVector transform::PackOp::getMixedPackedSizes() { Builder b(getContext()); return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b); } DiagnosedSilenceableFailure transform::PackOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); // If nothing to pack, propagate success. if (std::empty(targetOps)) { transformResults.set(cast(getPackedOp()), ArrayRef({})); return DiagnosedSilenceableFailure::success(); } // Fail on multi-op handles. auto linalgOp = dyn_cast(*targetOps.begin()); if (!llvm::hasSingleElement(targetOps) || !linalgOp) { return emitSilenceableError() << "requires target to map to exactly 1 LinalgOp (got " << llvm::range_size(targetOps) << ")"; } // Fail on mismatched number of pack sizes. if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) { return emitSilenceableError() << "requires number of packed sizes match the number of loops (" << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops() << ")"; } // Unpack handles to constants or actual SSA index values. SmallVector packedSizes; DiagnosedSilenceableFailure status = unpackSingleIndexResultPayloadOperations( state, *this, packedSizes, getMixedPackedSizes()); rewriter.setInsertionPoint(linalgOp); FailureOr maybeResult = pack(rewriter, linalgOp, packedSizes); if (failed(maybeResult)) return emitDefiniteFailure("data tiling failed"); transformResults.set(cast(getPackedOp()), {maybeResult->packedLinalgOp.getOperation()}); return DiagnosedSilenceableFailure::success(); } void transform::PackOp::getEffects( SmallVectorImpl &effects) { transform::consumesHandle(getTarget(), effects); transform::onlyReadsHandle(getPackedSizes(), effects); transform::producesHandle(getPackedOp(), effects); transform::modifiesPayload(effects); } //===---------------------------------------------------------------------===// // PackGreedilyOp. //===---------------------------------------------------------------------===// LogicalResult transform::PackGreedilyOp::verify() { if (!isPermutationVector(getMatmulInnerDimsOrder())) { return emitOpError() << getMatmulInnerDimsOrderAttrName() << " is not a valid permutation"; } // TODO: relax to allow empty once we have another strategy than just matmul. if (!getMatmulPaddedSizesNextMultipleOf().empty()) { for (auto [s, nmo] : llvm::zip_equal(getMixedMatmulPackedSizes(), getMatmulPaddedSizesNextMultipleOf())) { std::optional maybeStaticPackedSize = getConstantIntValue(s); if (nmo != 0 && (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) { return emitOpError() << "at most one of the packed_size and the " "padded_sizes_next_multiple_of can be nonzero " "for the matmul strategy"; } } } return success(); } DiagnosedSilenceableFailure PackGreedilyOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; for (Operation *op : state.getPayloadOps(getTarget())) { auto linalgOp = dyn_cast(op); if (!linalgOp) continue; // linalgOp will be replaced and the insertion point may be invalidated if // we set it before -> set it after. rewriter.setInsertionPointAfter(linalgOp); // Failing to pack greedily is perfectly fine. // In the future we will want to order packings according to some metric. FailureOr packResult = packMatmulGreedily( /*rewriter=*/rewriter, /*linalgOp=*/linalgOp, /*mnkPackedSizes=*/getMixedMatmulPackedSizes(), /*mnkPaddedSizesNextMultipleOf=*/ getMatmulPaddedSizesNextMultipleOf(), /*mnkOrder=*/getMatmulInnerDimsOrder()); if (succeeded(packResult)) { results.push_back(packResult->packedLinalgOp); continue; } results.push_back(linalgOp); } transformResults.set(cast(getPackedOp()), results); return DiagnosedSilenceableFailure::success(); } SmallVector PackGreedilyOp::getMixedMatmulPackedSizes() { Builder b(getContext()); return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(), b); } void transform::PackGreedilyOp::getEffects( SmallVectorImpl &effects) { transform::consumesHandle(getTarget(), effects); transform::onlyReadsHandle(getMatmulPackedSizes(), effects); transform::producesHandle(getPackedOp(), effects); transform::modifiesPayload(effects); } //===---------------------------------------------------------------------===// // PackTransposeOp //===---------------------------------------------------------------------===// LogicalResult transform::PackTransposeOp::verify() { if (!isPermutationVector(getInnerPerm())) { return emitOpError() << getInnerPermAttrName() << " is not a valid permutation"; } if (!isPermutationVector(getOuterPerm())) { return emitOpError() << getOuterPermAttrName() << " is not a valid permutation"; } if (getInnerPerm().empty() && getOuterPerm().empty()) { return emitOpError() << " at least one of " << getInnerPermAttrName() << " or " << getOuterPermAttrName() << " must be specified"; } return success(); } namespace { enum class OuterOrInnerPerm { Outer = 0, Inner = 1 }; } // namespace /// Return true if `permutation` is a valid permutation of the /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos` /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op. /// This is the case when the `permutation` rank matches the rank expected by /// `op` and `permutation` is itself a permutation vector. /// Return true if either `op` or `permutation` are empty to allow a simpler /// polymorphic implementation. template bool isValidPackingPermutation( RelayoutOpTy op, ArrayRef permutation, OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) { static_assert( llvm::is_one_of::value, "applies to only pack or unpack operations"); if (!op || permutation.empty()) return true; size_t innerRank = op.getInnerDimsPos().size(); if (outerOrInnerPerm == OuterOrInnerPerm::Inner) return permutation.size() == innerRank && isPermutationVector(permutation); // op.getOuterDimsPerm() may be empty, in which case it is identity. // Don't rely on it. if (std::is_same::value) { return permutation.size() == op.getSourceRank() && isPermutationVector(permutation); } return permutation.size() == op.getDestRank() && isPermutationVector(permutation); } DiagnosedSilenceableFailure transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp()); auto linalgOps = state.getPayloadOps(getTargetLinalgOp()); // Step 1. If nothing to pack, propagate success. if (std::empty(packOrUnpackOps)) { transformResults.set(cast(getPackedOp()), {}); transformResults.set(cast(getPackOp()), {}); transformResults.set(cast(getUnPackOp()), {}); return DiagnosedSilenceableFailure::success(); } // Step 2. Bunch of runtime sanity check and error messages. // Step 2.1. Fail on multi-op handles. if (!llvm::hasSingleElement(packOrUnpackOps) || !llvm::hasSingleElement(linalgOps)) { return emitSilenceableError() << "requires target to map to exactly 1 " "packing op and 1 packed op (" << "got " << llvm::range_size(packOrUnpackOps) << " and " << llvm::range_size(linalgOps) << ")"; } // Step 2.2. Fail on wrong type. auto packOp = dyn_cast(*packOrUnpackOps.begin()); auto unPackOp = dyn_cast(*packOrUnpackOps.begin()); if ((!packOp && !unPackOp)) { return emitSilenceableError() << "requires target to map to a " "tensor.pack or tensor.unpack"; } LinalgOp linalgOpTarget = dyn_cast(*linalgOps.begin()); if (!linalgOpTarget) return emitSilenceableError() << "requires a LinalgOp target"; // Step 2.3. Fail if we can't get the producer / consumer Linalg op. LinalgOp linalgOp; if (packOp && packOp.getResult().hasOneUse()) linalgOp = dyn_cast(*(packOp.getResult().getUsers().begin())); else if (unPackOp) linalgOp = unPackOp.getSource().getDefiningOp(); if (linalgOp != linalgOpTarget) { auto errorMsg = packOp ? StringLiteral{"not a single use by the LinalgOp target"} : StringLiteral{"not produced by the LinalgOp target"}; return emitSilenceableError() << errorMsg; } // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical // PackOp. if (unPackOp) { assert(!packOp && "packOp must be null on entry when unPackOp is not null"); OpOperand *packUse = linalgOp.getDpsInitOperand( cast(unPackOp.getSource()).getResultNumber()); packOp = dyn_cast_or_null(packUse->get().getDefiningOp()); if (!packOp || !packOp.getResult().hasOneUse()) return emitSilenceableError() << "could not find matching pack op"; } // Step 2.5. Fail if any permutation does not validate. for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) { ArrayRef perm = (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm(); auto errorMsg = (permType == OuterOrInnerPerm::Outer) ? StringLiteral{"invalid outer_perm"} : StringLiteral{"invalid inner_perm"}; if (!isValidPackingPermutation(packOp, perm, permType) || !isValidPackingPermutation(unPackOp, perm, permType)) { Operation *packOrUnpackOp = unPackOp ? unPackOp.getOperation() : packOp.getOperation(); return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp; } } // From here on, packOp and linalgOp are always present, unPackOp may or may // not be present. assert(packOp && linalgOp && "unexpected null op"); // Step 3. Actually transpose the ops. FailureOr res = packTranspose( rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm()); // Preconditions have been checked, it is an error to fail here. assert(succeeded(res) && "unexpected packTranspose failure"); // Step 4. Return results. transformResults.set(cast(getPackOp()), {res->transposedPackOp}); transformResults.set(cast(getPackedOp()), {res->transposedLinalgOp}); if (unPackOp) { transformResults.set(cast(getUnPackOp()), {res->transposedUnPackOp}); } else { transformResults.set(cast(getUnPackOp()), {}); } return DiagnosedSilenceableFailure::success(); } //===---------------------------------------------------------------------===// // PadOp //===---------------------------------------------------------------------===// void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target, ArrayRef paddingDimensions, ArrayRef padToMultipleOf, ArrayRef packPaddings, ArrayRef transposePaddings, StringRef copyBackOp) { auto resultType = transform::AnyOpType::get(b.getContext()); return build(/*builder=*/b, /*result=*/result, /*types=*/TypeRange{resultType, resultType}, /*target=*/target, /*paddingValues=*/ArrayAttr(), // let inference handle this /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions), /*padToMultipleOf=*/ (padToMultipleOf.empty() ? ArrayAttr() : b.getI64ArrayAttr(padToMultipleOf)), /*packPaddings=*/b.getI64ArrayAttr(packPaddings), /*transposePaddings=*/b.getArrayAttr(transposePaddings), /*copyBackOp=*/b.getStringAttr(copyBackOp)); } DiagnosedSilenceableFailure transform::PadOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector paddedOps, padOps, copyBackOps; for (Operation *target : state.getPayloadOps(getTarget())) { auto linalgTarget = dyn_cast(target); if (!linalgTarget) { auto diag = emitSilenceableError() << "expected LinalgOp target"; diag.attachNote(target->getLoc()) << "target op"; return diag; } // Convert the integer packing flags to booleans. SmallVector packPaddings; for (int64_t packPadding : extractFromIntegerArrayAttr(getPackPaddings())) packPaddings.push_back(static_cast(packPadding)); // Convert the padding values to attributes. SmallVector paddingValues; for (auto const &it : llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) { auto attr = dyn_cast(std::get<0>(it)); if (!attr) { emitOpError("expects padding values to be typed attributes"); return DiagnosedSilenceableFailure::definiteFailure(); } Type elementType = getElementTypeOrSelf(std::get<1>(it)); // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = dyn_cast(attr)) { auto parsedAttr = dyn_cast_if_present(parseAttribute( stringAttr, getContext(), elementType, /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); if (!parsedAttr || parsedAttr.getType() != elementType) { auto diag = this->emitOpError("expects a padding that parses to ") << elementType << ", got " << std::get<0>(it); diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } paddingValues.push_back(parsedAttr); continue; } // Otherwise, add the attribute directly. if (attr.getType() != elementType) { auto diag = this->emitOpError("expects a padding value of type ") << elementType << ", got " << attr; diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } paddingValues.push_back(attr); } // Extract the transpose vectors. SmallVector> transposePaddings; for (Attribute transposeVector : cast(getTransposePaddings())) transposePaddings.push_back(extractFromIntegerArrayAttr( cast(transposeVector))); LinalgOp paddedOp; LinalgPaddingOptions options; options.paddingDimensions = extractFromIntegerArrayAttr(getPaddingDimensions()); SmallVector padToMultipleOf(options.paddingDimensions.size(), 1); if (getPadToMultipleOf().has_value()) padToMultipleOf = extractFromIntegerArrayAttr(*getPadToMultipleOf()); options.padToMultipleOf = padToMultipleOf; options.paddingValues = paddingValues; options.packPaddings = packPaddings; if (getCopyBackOp() == bufferization::MaterializeInDestinationOp::getOperationName()) { options.copyBackOp = LinalgPaddingOptions::CopyBackOp:: BufferizationMaterializeInDestination; } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) { options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy; } else if (getCopyBackOp() == kCopyOpNone) { options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None; } else { llvm_unreachable("unsupported copy_back op"); } SmallVector replacements; SmallVector newPadOps; if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp, replacements, newPadOps))) { auto diag = emitSilenceableError() << "failed to pad op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } // We need to perform our own replacement here because this API is still // used in patterns that "pad and hoist", for which the replacement values // need to be different. // TODO: clean this up and stop "pad and hoist" behavior more globally now // that we have more composable abstractions. rewriter.replaceOp(linalgTarget, replacements); paddedOps.push_back(paddedOp); padOps.append(newPadOps.begin(), newPadOps.end()); if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) { for (Value v : replacements) { Operation *copyBackOp = v.getDefiningOp(); if (!llvm::is_contained(copyBackOps, copyBackOp)) copyBackOps.push_back(copyBackOp); } } } results.set(cast(getPadded()), paddedOps); results.set(cast(getPad()), padOps); results.set(cast(getCopy()), copyBackOps); return DiagnosedSilenceableFailure::success(); } LogicalResult transform::PadOp::verify() { SmallVector packPaddings = extractFromIntegerArrayAttr(getPackPaddings()); if (any_of(packPaddings, [](int64_t packPadding) { return packPadding != 0 && packPadding != 1; })) { return emitOpError() << "expects pack_paddings to contain booleans (0/1), found " << getPackPaddings(); } SmallVector paddingDimensions = extractFromIntegerArrayAttr(getPaddingDimensions()); if (any_of(paddingDimensions, [](int64_t paddingDimension) { return paddingDimension < 0; })) { return emitOpError() << "expects padding_dimensions to contain positive " "integers, found " << getPaddingDimensions(); } if (getPadToMultipleOf().has_value()) { if (getPadToMultipleOf()->size() != paddingDimensions.size()) { return emitOpError() << "expects as many multiples as padding_dimensions"; } } ArrayAttr transposes = getTransposePaddings(); for (Attribute attr : transposes) { SmallVector transpose = extractFromIntegerArrayAttr(attr); auto sequence = llvm::to_vector(llvm::seq(0, transpose.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), transpose.end())) { return emitOpError() << "expects transpose_paddings to be a permutation, found " << attr; } } if (getCopyBackOp() != bufferization::MaterializeInDestinationOp::getOperationName() && getCopyBackOp() != linalg::CopyOp::getOperationName() && getCopyBackOp() != kCopyOpNone) return emitOpError() << "invalid copy_back_op"; return success(); } //===---------------------------------------------------------------------===// // HoistPadOp //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); auto loopOps = state.getPayloadOps(getLoop()); if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) { return emitDefiniteFailure() << "requires exactly one target and one loop handle (got " << llvm::range_size(targetOps) << " and " << llvm::range_size(loopOps) << ")"; } auto padOp = dyn_cast_or_null(*targetOps.begin()); auto loopOp = dyn_cast_or_null(*loopOps.begin()); if (!padOp || !loopOp) return emitDefiniteFailure() << "requires exactly 2 non-null handles"; FailureOr result = linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp, getTranspose()); if (failed(result)) return emitDefiniteFailure() << "could not build packing loop nest"; if (result->clonedLoopIvs.empty()) { transformResults.set(cast(getPackingLoop()), {result->hoistedPadOp.getOperation()}); return DiagnosedSilenceableFailure::success(); } auto outerPackedLoop = scf::getForInductionVarOwner(result->clonedLoopIvs.front()); transformResults.set(cast(getPackingLoop()), {outerPackedLoop.getOperation()}); return DiagnosedSilenceableFailure::success(); } LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() { ArrayRef transpose = getTranspose(); auto sequence = llvm::to_vector(llvm::seq(0, transpose.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), transpose.end())) { return emitOpError() << "expects transpose to be a permutation, found " << getTranspose(); } return success(); } void transform::HoistPadBuildPackingLoopNestOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTarget(), effects); transform::onlyReadsHandle(getLoop(), effects); transform::producesHandle(getPackingLoop(), effects); transform::modifiesPayload(effects); } DiagnosedSilenceableFailure transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter, tensor::PadOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { tensor::PadOp hoistedPadOp; SmallVector transposeOps; FailureOr result = hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(), hoistedPadOp, transposeOps); if (succeeded(result)) { // We need to perform our own replacement here because this API is still // used in patterns that "pad and hoist", for which the replacement values // need to be different. // TODO: clean this up and stop "pad and hoist" behavior more globally now // that we have more composable abstractions. rewriter.replaceOp(target, *result); results.push_back(hoistedPadOp); return DiagnosedSilenceableFailure::success(); } return emitDefaultSilenceableFailure(target); } LogicalResult transform::HoistPadOp::verify() { ArrayRef transpose = getTranspose(); auto sequence = llvm::to_vector(llvm::seq(0, transpose.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), transpose.end())) { return emitOpError() << "expects transpose to be a permutation, found " << getTranspose(); } return success(); } //===----------------------------------------------------------------------===// // PromoteOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter, LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; if (!getOperandsToPromote().empty()) promotionOptions = promotionOptions.setOperandsToPromote( extractFromIntegerArrayAttr(getOperandsToPromote())); if (getUseFullTilesByDefault()) promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( getUseFullTilesByDefault()); if (getUseAlloca()) promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); if (!getUseFullTileBuffers().empty()) promotionOptions = promotionOptions.setUseFullTileBuffers( llvm::to_vector(getUseFullTileBuffers().getAsValueRange())); if (getAlignment().has_value()) promotionOptions = promotionOptions.setAlignment(*getAlignment()); if (getMemorySpace().has_value()) promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace()); if (getMapping().has_value()) { // The mapping should only contain an element auto mapping = *getMapping(); if (mapping.size() > 1) return emitDefaultDefiniteFailure(target); auto addressSpace = cast(mapping[0]); if (addressSpace.getAddressSpace() == mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) { promotionOptions = promotionOptions .setAllocationDeallocationFns(allocateWorkgroupMemory, deallocateWorkgroupMemory) .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory) .setUseFullTileBuffers({false, false}); } else if (addressSpace.getAddressSpace() == mlir::gpu::GPUDialect::getPrivateAddressSpace()) { promotionOptions = promotionOptions .setAllocationDeallocationFns(allocateGPUPrivateMemory, deallocateGPUPrivateMemory) .setCopyInOutFns(copyToGPUPrivateMemory, copyToGPUPrivateMemory) .setUseFullTileBuffers({false, false}); } else { return emitDefaultDefiniteFailure(target); } } if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return emitDefaultDefiniteFailure(target); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) return emitDefaultDefiniteFailure(target); results.push_back(target); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // ReplaceOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ReplaceOp::apply(transform::TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { auto payload = state.getPayloadOps(getTarget()); // Check for invalid targets. for (Operation *target : payload) { if (target->getNumOperands() > 0) return emitDefiniteFailure() << "expected target without operands"; if (!target->hasTrait() && target->getNumRegions() > 0) return emitDefiniteFailure() << "expected target that is isolated from above"; } // Clone and replace. Operation *pattern = &getBodyRegion().front().front(); SmallVector replacements; for (Operation *target : payload) { if (getOperation()->isAncestor(target)) continue; rewriter.setInsertionPoint(target); Operation *replacement = rewriter.clone(*pattern); rewriter.replaceOp(target, replacement->getResults()); replacements.push_back(replacement); } transformResults.set(cast(getReplacement()), replacements); return DiagnosedSilenceableFailure::success(); } void transform::ReplaceOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); producesHandle(getReplacement(), effects); modifiesPayload(effects); } LogicalResult transform::ReplaceOp::verify() { if (!getBodyRegion().hasOneBlock()) return emitOpError() << "expected one block"; if (std::distance(getBodyRegion().front().begin(), getBodyRegion().front().end()) != 1) return emitOpError() << "expected one operation in block"; Operation *replacement = &getBodyRegion().front().front(); if (replacement->getNumOperands() > 0) return replacement->emitOpError() << "expected replacement without operands"; if (!replacement->hasTrait() && replacement->getNumRegions() > 0) return replacement->emitOpError() << "expect op that is isolated from above"; return success(); } //===----------------------------------------------------------------------===// // ScalarizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) { SmallVector tileSizes; Location loc = target.getLoc(); SmallVector allShapeSizes = target.createFlatListOfOperandDims(b, loc); AffineMap map = target.getShapesToLoopsMap(); if (!map) return tileSizes; SmallVector shapeSizes = affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, allShapeSizes); // If the shape size is dynamic, tile by 1. // Otherwise, do not tile (i.e. tile size 0). for (OpFoldResult shapeSize : shapeSizes) { tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0) : b.getIndexAttr(1)); } return tileSizes; }); SmallVector emptyTileSizes; rewriter.setInsertionPoint(target); FailureOr maybeTilingResult = tileUsingSCF( rewriter, cast(target.getOperation()), tilingOptions); if (failed(maybeTilingResult)) return emitDefaultDefiniteFailure(target); if (target->getNumResults()) rewriter.replaceOp(target, maybeTilingResult->replacements); else rewriter.eraseOp(target); results.reserve(maybeTilingResult->tiledOps.size()); for (Operation *tiled : maybeTilingResult->tiledOps) results.push_back(tiled); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // ConvertToLoopsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector loops; for (Operation *target : state.getPayloadOps(getTarget())) { auto tilingOp = dyn_cast(*target); if (!target) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected the payload to implement TilingInterface"; diag.attachNote(target->getLoc()) << "payload op"; return diag; } rewriter.setInsertionPoint(target); FailureOr> generatedLoops = scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp); if (failed(generatedLoops)) return emitDefaultDefiniteFailure(target); for (scf::ForOp &loop : *generatedLoops) { loops.push_back(loop.getOperation()); } rewriter.eraseOp(target); } results.set(cast(getResult()), loops); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // RewriteInDestinationPassingStyleOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::RewriteInDestinationPassingStyleOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { SmallVector res; rewriter.setInsertionPoint(target); FailureOr maybeResult = TypeSwitch>(target) .Case( [&rewriter](auto op) { return rewriteInDestinationPassingStyle(rewriter, op); }); if (failed(maybeResult)) return emitDefaultSilenceableFailure(target); results.push_back(*maybeResult); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure SplitOp::apply(transform::TransformRewriter &rewriter, TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. SmallVector payload = llvm::to_vector(state.getPayloadOps(getTarget())); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { auto diag = DiagnosedSilenceableFailure::success(); if (isa(getDynamicSplitPoint().getType())) { splitPoints = llvm::to_vector(llvm::map_range( state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { diag = emitSilenceableError() << "expected dynamic split point handle to point to a " "single-result index-typed op"; diag.attachNote(op->getLoc()) << "dynamic split point"; } return OpFoldResult(op->getResult(0)); })); } else { splitPoints = llvm::to_vector( llvm::map_range(state.getParams(getDynamicSplitPoint()), [](Attribute attr) { return OpFoldResult(attr); })); } if (diag.isSilenceableFailure()) return diag; if (splitPoints.size() != payload.size()) { return emitDefiniteFailure() << "expected the dynamic split point handle to point to as " "many operations (" << splitPoints.size() << ") as the target handle (" << payload.size() << ")"; } } else { splitPoints.resize(payload.size(), rewriter.getIndexAttr(getStaticSplitPoint())); } // Split each target operation. SmallVector first, second; Operation *noSecondPart = nullptr; for (const auto &pair : llvm::zip(payload, splitPoints)) { Operation *target = std::get<0>(pair); auto linalgOp = dyn_cast(target); if (!linalgOp) { auto diag = emitSilenceableError() << "only applies to structured ops"; diag.attachNote(target->getLoc()) << "target op"; return diag; } if (getDimension() >= linalgOp.getNumLoops()) { auto diag = emitSilenceableError() << "dimension " << getDimension() << " does not exist in target op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } rewriter.setInsertionPoint(linalgOp); std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp( rewriter, cast(linalgOp.getOperation()), getDimension(), std::get<1>(pair)); // Propagate errors. if (!first.back() && !second.back()) { auto diag = emitDefiniteFailure() << "internal failure in splitting"; diag.attachNote(target->getLoc()) << "target op"; return diag; } // Do not add null second parts. if (!second.back()) { noSecondPart = target; second.pop_back(); } } if (second.size() != first.size() && !second.empty()) { auto diag = emitSilenceableError() << "splitting does not produce the second part for a subset " "of targets"; diag.attachNote() << "expected splitting to produce the second part of all " "or none of the targets"; diag.attachNote(noSecondPart->getLoc()) << "first target with no second part"; return diag; } results.set(cast(getFirst()), first); results.set(cast(getSecond()), second); return DiagnosedSilenceableFailure::success(); } void SplitOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); if (getDynamicSplitPoint()) onlyReadsHandle(getDynamicSplitPoint(), effects); producesHandle(getResults(), effects); modifiesPayload(effects); } ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; IntegerAttr staticSplitPoint; if (parser.parseOperand(target) || parser.parseKeyword("after")) return failure(); OptionalParseResult dynamicPointParseResult = parser.parseOptionalOperand(dynamicSplitPoint); if (!dynamicPointParseResult.has_value()) { int64_t staticSplitPointValue; if (failed(parser.parseInteger(staticSplitPointValue))) return failure(); staticSplitPoint = parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); } Type targetType; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(targetType) || parser.resolveOperand(target, targetType, result.operands)) { return failure(); } if (dynamicPointParseResult.has_value()) { Type splitPointType; if (failed(*dynamicPointParseResult) || parser.parseComma() || parser.parseType(splitPointType) || parser.resolveOperand(dynamicSplitPoint, splitPointType, result.operands)) { return failure(); } staticSplitPoint = parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic); } result.addAttribute( SplitOp::getStaticSplitPointAttrName(result.name).getValue(), staticSplitPoint); result.addTypes({targetType, targetType}); return success(); } void SplitOp::print(OpAsmPrinter &printer) { printer << " " << getTarget() << " after "; int64_t staticSplitSize = static_cast(getStaticSplitPoint()); if (staticSplitSize != ShapedType::kDynamic) printer << staticSplitSize; else printer << getDynamicSplitPoint(); printer << " "; printer.printOptionalAttrDict(getOperation()->getAttrs(), {getStaticSplitPointAttrName()}); printer << " : " << getTarget().getType(); if (staticSplitSize == ShapedType::kDynamic) printer << ", " << getDynamicSplitPoint().getType(); } LogicalResult SplitOp::verify() { if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamic) ^ (getDynamicSplitPoint() == nullptr)) { return emitOpError() << "expects either a dynamic or a static split " "point to be provided"; } return success(); } //===----------------------------------------------------------------------===// // SplitReductionOp //===----------------------------------------------------------------------===// void transform::SplitReductionOp::build( OpBuilder &builder, OperationState &result, Value target, int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel, bool useScalingAlgorithm, bool useAlloc) { MLIRContext *ctx = builder.getContext(); result.addOperands(target); result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name), builder.getI64IntegerAttr(splitFactor)); result.addAttribute( SplitReductionOp::getInsertSplitDimensionAttrName(result.name), builder.getI64IntegerAttr(insertSplitDimension)); if (innerParallel) { result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name), builder.getUnitAttr()); } if (useScalingAlgorithm) { result.addAttribute( SplitReductionOp::getUseScalingAlgorithmAttrName(result.name), builder.getUnitAttr()); } if (useAlloc) { result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name), builder.getUnitAttr()); } auto resultType = transform::AnyOpType::get(ctx); result.addTypes({resultType, resultType, resultType, resultType}); } DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne( transform::TransformRewriter &rewriter, LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return linalg::SplitReductionOptions{int64_t(getSplitFactor()), unsigned(getInsertSplitDimension()), bool(getInnerParallel())}; }; rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) : splitReduction(rewriter, target, splitFn, getUseAlloc()); if (failed(splitResult)) return emitDefaultDefiniteFailure(target); results.push_back(splitResult->initOrAlloc); results.push_back(splitResult->fillOp); results.push_back(splitResult->splitLinalgOp); results.push_back(splitResult->resultCombiningLinalgOp); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TileReductionUsingForOp //===----------------------------------------------------------------------===// void transform::TileReductionUsingForOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef staticTileSizes) { // Call the default builder. // This is future-proof re mixed static-dynamic and setting up the proper // operands segment sizes attributes for multiple variadic operands. // In the absence of this, horrible bugs ensue. // TODO: support mixed static-dynamic (see TileUsingForallOp). MLIRContext *ctx = builder.getContext(); auto opTy = transform::AnyOpType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); build(builder, result, /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, /*target=*/target, /*tile_sizes=*/staticTileSizesAttr); } DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne( transform::TransformRewriter &rewriter, LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); FailureOr result = scf::tileReductionUsingScf( rewriter, cast(target.getOperation()), getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()))); if (failed(result)) return emitDefaultSilenceableFailure(target); results.push_back(result->initialOp); results.push_back(result->parallelTiledOp); results.push_back(result->mergeOp); results.push_back(result->loops.front()); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TileReductionUsingForallOp //===----------------------------------------------------------------------===// void transform::TileReductionUsingForallOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef staticNumThreads, ArrayRef staticTileSizes, ArrayAttr mapping) { // Call the default builder. // This is future-proof re mixed static-dynamic and setting up the proper // operands segment sizes attributes for multiple variadic operands. // In the absence of this, horrible bugs ensue. // TODO: support mixed static-dynamic (see TileUsingForallOp). MLIRContext *ctx = builder.getContext(); auto opTy = transform::AnyOpType::get(ctx); auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); build(builder, result, /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, /*target=*/target, /*num_threads=*/staticNumThreadsAttr, /*tile_sizes=*/staticTileSizesAttr, /*mapping=*/mapping); } DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( transform::TransformRewriter &rewriter, LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); SmallVector numThreads = getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); SmallVector tileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())); FailureOr result = linalg::tileReductionUsingForall( rewriter, cast(target.getOperation()), numThreads, tileSizes, getMapping()); if (failed(result)) { auto diag = emitSilenceableError() << "could not tile reduction"; diag.attachNote(target.getLoc()) << "target operation"; return diag; } results.push_back(result->initialOp); results.push_back(result->parallelTiledOp); results.push_back(result->mergeOp); results.push_back(result->loops); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TileUsingForOp //===----------------------------------------------------------------------===// void transform::TileUsingForOp::build( OpBuilder &builder, OperationState &result, TypeRange loopTypes, Value target, ArrayRef staticTileSizes, ArrayRef interchange, std::optional> scalableSizes) { return build(builder, result, loopTypes, /*target=*/target, /*mixedTileSizes=*/ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), interchange, scalableSizes); } void transform::TileUsingForOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef staticTileSizes, ArrayRef interchange, std::optional> scalableSizes) { build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), interchange, scalableSizes); } void transform::TileUsingForOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, ArrayRef interchange, std::optional> scalableSizes) { // Loop types are automaticaly splat by the callee, setting up one is // enough. SmallVector loopTypes(1, builder.getType()); build(builder, result, loopTypes, target, mixedTileSizes, interchange, scalableSizes); } void transform::TileUsingForOp::build( OpBuilder &builder, OperationState &result, TypeRange loopTypes, Value target, ArrayRef mixedTileSizes, ArrayRef interchange, std::optional> scalableSizes) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); // Call the default builder which sets up the proper operands segment sizes // attributes for multiple variadic operands. In the absence of this, // horrible bugs ensue. auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); unsigned numExpectedLoops = staticTileSizes.size() - llvm::count(staticTileSizes, 0); SmallVector resultTypes; resultTypes.reserve(numExpectedLoops); assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) && "expected one loop type or as many as loops"); if (loopTypes.size() == 1) resultTypes.append(numExpectedLoops, loopTypes[0]); else llvm::append_range(resultTypes, loopTypes); SmallVector expandedScalableSizes(mixedTileSizes.size(), false); if (scalableSizes.has_value()) expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end()); build(builder, result, /*tiled_linalg_op=*/target.getType(), /*loops=*/resultTypes, /*target=*/target, /*dynamic_sizes=*/dynamicTileSizes, /*static_sizes=*/staticTileSizesAttr, /*interchange=*/builder.getDenseI64ArrayAttr(interchange), /*scalable_sizes=*/expandedScalableSizes); } LogicalResult transform::TileUsingForOp::verify() { if (getMixedSizes().size() != getScalableSizes().size()) return emitOpError("expected same number of sizes (") << getMixedSizes().size() << ") and scalable sizes ()" << getScalableSizes().size() << ")"; return success(); } DiagnosedSilenceableFailure transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); SmallVector targets = llvm::to_vector(state.getPayloadOps(getTarget())); SmallVector> dynamicSizeProducers; SmallVector> paramSizes; dynamicSizeProducers.reserve(getDynamicSizes().size()); paramSizes.reserve(getDynamicSizes().size()); for (Value transformValue : getDynamicSizes()) { if (isa(transformValue.getType())) { dynamicSizeProducers.push_back({}); ArrayRef params = state.getParams(transformValue); paramSizes.push_back( llvm::to_vector(llvm::map_range(params, [](Attribute attr) { return cast(attr).getValue().getSExtValue(); }))); if (paramSizes.back().size() != targets.size()) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected as many parameter values (" << dynamicSizeProducers.back().size() << ") as target ops (" << targets.size() << ")"; diag.attachNote(transformValue.getLoc()) << "for this parameter"; return diag; } continue; } paramSizes.push_back({}); dynamicSizeProducers.push_back( llvm::to_vector(state.getPayloadOps(transformValue))); if (dynamicSizeProducers.back().size() != targets.size()) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected as many dynamic size-producing operations (" << dynamicSizeProducers.back().size() << ") as target ops (" << targets.size() << ")"; diag.attachNote(transformValue.getLoc()) << "for this handle"; return diag; } for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && isa(op->getResult(0).getType())) { continue; } DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " "with a single index-type result"; diag.attachNote(op->getLoc()) << "size producer op"; diag.attachNote(transformValue.getLoc()) << "for this handle"; return diag; } } SmallVector tiled; SmallVector, 4> loops; loops.resize(getLoops().size()); auto scalableSizes = getScalableSizes(); for (auto [i, op] : llvm::enumerate(targets)) { auto tilingInterface = dyn_cast(op); if (!tilingInterface) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "only ops implementing TilingInterface are supported"; diag.attachNote(op->getLoc()) << "target op"; return diag; } if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "too many tiles provided, expected at most " << tilingInterface.getLoopIteratorTypes().size() << " found " << tileSizes.size(); diag.attachNote(op->getLoc()) << "target op"; return diag; } scf::SCFTilingOptions tilingOptions; if (tileSizes.empty()) { tilingOptions.setTileSizeComputationFunction( [](OpBuilder &, Operation *) -> SmallVector { return {}; }); } else { tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b, Operation *) { SmallVector sizes; sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) { if (auto attr = llvm::dyn_cast_if_present(ofr)) { if (scalableSizes[ofrIdx]) { auto val = b.create( getLoc(), attr.cast().getInt()); Value vscale = b.create(getLoc(), b.getIndexType()); sizes.push_back( b.create(getLoc(), val, vscale).getResult()); } else { sizes.push_back(attr); } continue; } ArrayRef dynamicSizes = dynamicSizeProducers[dynamicIdx]; ArrayRef params = paramSizes[dynamicIdx]; ++dynamicIdx; assert((dynamicSizes.empty() ^ params.empty()) && "expected either dynamic sizes or parameters"); if (!params.empty()) { sizes.push_back(b.getIndexAttr(params[index])); } else { sizes.push_back(dynamicSizes[index]->getResult(0)); } } return sizes; }); } tilingOptions.setInterchange(getInterchange()); FailureOr maybeTilingResult = tileUsingSCF(rewriter, tilingInterface, tilingOptions); if (failed(maybeTilingResult)) return DiagnosedSilenceableFailure::definiteFailure(); rewriter.replaceOp(op, maybeTilingResult->replacements); tiled.append(maybeTilingResult->tiledOps); for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops)) loops[en2.index()].push_back(en2.value()); } transformResults.set(cast(getTiledLinalgOp()), tiled); for (const auto &en : llvm::enumerate(loops)) transformResults.set(cast(getLoops()[en.index()]), en.value()); return DiagnosedSilenceableFailure::success(); } SmallVector transform::TileUsingForOp::getMixedSizes() { ValueRange dynamic = getDynamicSizes(); ArrayRef tileSizes = getStaticSizes(); SmallVector results; results.reserve(tileSizes.size()); unsigned dynamicPos = 0; Builder builder(getContext()); for (int64_t size : tileSizes) { if (size == ShapedType::kDynamic) { results.push_back(dynamic[dynamicPos++]); } else { results.push_back(builder.getIndexAttr(size)); } } return results; } // We want to parse `DenseI64ArrayAttr` using the short form without the // `array` prefix to be consistent in the IR with `parseDynamicIndexList`. ParseResult parseOptionalInterchange(OpAsmParser &parser, OperationState &result) { if (failed(parser.parseOptionalKeyword("interchange"))) return success(); if (failed(parser.parseEqual())) return failure(); result.addAttribute( transform::TileUsingForOp::getInterchangeAttrName(result.name), DenseI64ArrayAttr::parse(parser, Type{})); return success(); } void printOptionalInterchange(OpAsmPrinter &p, ArrayRef interchangeVals) { if (!interchangeVals.empty()) { p << " interchange = ["; llvm::interleaveComma(interchangeVals, p, [&](int64_t integer) { p << integer; }); p << "]"; } } ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; DenseI64ArrayAttr staticSizes; FunctionType functionalType; llvm::SMLoc operandLoc; DenseBoolArrayAttr scalableVals; if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) || parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) || parseOptionalInterchange(parser, result) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(functionalType)) return ParseResult::failure(); size_t numExpectedLoops = staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0); if (functionalType.getNumResults() != numExpectedLoops + 1) { return parser.emitError(parser.getNameLoc()) << "expected " << (numExpectedLoops + 1) << " result type(s)"; } if (functionalType.getNumInputs() != dynamicSizes.size() + 1) { return parser.emitError(operandLoc) << "expected " << dynamicSizes.size() + 1 << " operand type(s)"; } if (parser.resolveOperand(target, functionalType.getInputs().front(), result.operands) || parser.resolveOperands(dynamicSizes, functionalType.getInputs().drop_front(), operandLoc, result.operands)) { return failure(); } result.addAttribute(getScalableSizesAttrName(result.name), scalableVals); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); result.addTypes(functionalType.getResults()); return success(); } void TileUsingForOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), /*valueTypes=*/{}, getScalableSizesAttr(), OpAsmParser::Delimiter::Square); printOptionalInterchange(p, getInterchange()); p.printOptionalAttrDict( (*this)->getAttrs(), /*elidedAttrs=*/{getInterchangeAttrName(getOperation()->getName()), getScalableSizesAttrName(getOperation()->getName()), getStaticSizesAttrName(getOperation()->getName())}); p << " : "; p.printFunctionalType(getOperands().getTypes(), getResults().getTypes()); } void transform::TileUsingForOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); onlyReadsHandle(getDynamicSizes(), effects); producesHandle(getTiledLinalgOp(), effects); producesHandle(getLoops(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // TileUsingForallOp //===----------------------------------------------------------------------===// void transform::TileUsingForallOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef staticTileSizes, transform::TileSizesSpec, ArrayAttr mapping) { return build(builder, result, /*target=*/target, /*mixedTileSizes=*/ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), /*_=*/TileSizesSpec(), /*mapping=*/mapping); } void transform::TileUsingForallOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, transform::TileSizesSpec, ArrayAttr mapping) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); // Call the default builder which sets up the proper operands segment sizes // attributes for multiple variadic operands. In the absence of this, // horrible bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = transform::AnyOpType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); build(builder, result, /*resultTypes=*/TypeRange{operationType, operationType}, /*target=*/target, /*num_threads=*/ValueRange{}, /*tile_sizes=*/dynamicTileSizes, /*packed_num_threads=*/Value(), /*packed_tile_sizes=*/Value(), /*static_num_threads=*/builder.getDenseI64ArrayAttr({}), /*static_tile_sizes=*/staticTileSizesAttr, /*mapping=*/mapping); } void transform::TileUsingForallOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef staticNumThreads, transform::NumThreadsSpec, ArrayAttr mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), NumThreadsSpec(), mapping); } void transform::TileUsingForallOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedNumThreads, transform::NumThreadsSpec, ArrayAttr mapping) { SmallVector staticNumThreads; SmallVector dynamicNumThreads; dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, staticNumThreads); // Call the default builder which sets up the proper operands segment sizes // attributes for multiple variadic operands. In the absence of this, // horrible bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = transform::AnyOpType::get(ctx); auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); build(builder, result, /*resultTypes=*/TypeRange{operationType, operationType}, /*target=*/target, /*num_threads=*/dynamicNumThreads, /*tile_sizes=*/ValueRange{}, /*packed_num_threads=*/Value(), /*packed_tile_sizes=*/Value(), /*static_num_threads=*/staticNumThreadsAttr, /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}), /*mapping=*/mapping); } DiagnosedSilenceableFailure transform::tileToForallOpImpl( RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef mixedNumThreads, ArrayRef mixedTileSizes, std::optional mapping, linalg::ForallTilingResult &tilingResult) { // Transform all targets one by one. auto tileableOp = dyn_cast(target); if (!tileableOp) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "only TilingInterface ops are supported"; diag.attachNote(target->getLoc()) << "target op"; return diag; } rewriter.setInsertionPoint(tileableOp); FailureOr maybeTilingResult = failure(); if (!mixedNumThreads.empty()) { maybeTilingResult = linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping); } else { maybeTilingResult = linalg::tileToForallOpUsingTileSizes( rewriter, tileableOp, mixedTileSizes, mapping); } if (failed(maybeTilingResult)) return transformOp.emitDefaultSilenceableFailure(tileableOp); rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults()); tilingResult = *maybeTilingResult; return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::TileUsingForallOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { auto transformOp = cast(getOperation()); // Result payload ops. SmallVector tileOps; SmallVector tiledOps; // Unpack handles. SmallVector mixedNumThreads; DiagnosedSilenceableFailure status = getPackedNumThreads() ? unpackSingleIndexResultPayloadOperations( state, transformOp, mixedNumThreads, getPackedNumThreads()) : unpackSingleIndexResultPayloadOperations( state, transformOp, mixedNumThreads, getMixedNumThreads()); if (!status.succeeded()) return status; SmallVector mixedTileSizes; status = getPackedTileSizes() ? unpackSingleIndexResultPayloadOperations( state, transformOp, mixedTileSizes, getPackedTileSizes()) : unpackSingleIndexResultPayloadOperations( state, transformOp, mixedTileSizes, getMixedTileSizes()); if (!status.succeeded()) return status; for (Operation *target : state.getPayloadOps(getTarget())) { linalg::ForallTilingResult tilingResult; DiagnosedSilenceableFailure diag = tileToForallOpImpl( rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, getMapping(), tilingResult); if (!diag.succeeded()) return diag; tileOps.push_back(tilingResult.tileOp); tiledOps.push_back(tilingResult.tiledOp); } transformResults.set(cast(getForallOp()), tileOps); transformResults.set(cast(getTiledOp()), tiledOps); return DiagnosedSilenceableFailure::success(); } void transform::TileUsingForallOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); onlyReadsHandle(getTileSizes(), effects); onlyReadsHandle(getNumThreads(), effects); onlyReadsHandle(getPackedNumThreads(), effects); onlyReadsHandle(getPackedTileSizes(), effects); producesHandle(getResults(), effects); modifiesPayload(effects); } SmallVector TileUsingForallOp::getMixedNumThreads() { Builder b(getContext()); return getMixedValues(getStaticNumThreads(), getNumThreads(), b); } SmallVector TileUsingForallOp::getMixedTileSizes() { Builder b(getContext()); return getMixedValues(getStaticTileSizes(), getTileSizes(), b); } LogicalResult TileUsingForallOp::verify() { int numThreadsSpec = static_cast(!getMixedNumThreads().empty()) + static_cast(getPackedNumThreads() != Value()); if (numThreadsSpec > 1) return emitOpError( "num_threads and packed_num_threads are mutually exclusive"); int tileSizesSpec = static_cast(!getMixedTileSizes().empty()) + static_cast(getPackedTileSizes() != Value()); if (tileSizesSpec > 1) return emitOpError( "tile_sizes and packed_tile_sizes are mutually exclusive"); if (numThreadsSpec == 0 && tileSizesSpec == 0) return emitOpError("either (packed_)num_threads or (packed_)tile_sizes " "must be specified"); return success(); } //===----------------------------------------------------------------------===// // VectorizeChildrenAndApplyPatternsOp //===----------------------------------------------------------------------===// void transform::VectorizeChildrenAndApplyPatternsOp::build( OpBuilder &builder, OperationState &result, Value target, bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) { result.addOperands(target); if (vectorizePadding) { result.addAttribute( VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName( result.name), builder.getUnitAttr()); } if (vectorizeExtract) { result.addAttribute( VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName( result.name), builder.getUnitAttr()); } if (flatten1DDepthwiseConv) { result.addAttribute( VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName( result.name), builder.getUnitAttr()); } result.addTypes(transform::AnyOpType::get(builder.getContext())); } namespace { /// This is an helper only to call vectorize via a pattern inside of /// VectorizeChildrenAndApplyPatternsOp::applyToOne. struct VectorizationPattern : public RewritePattern { explicit VectorizationPattern(MLIRContext *context, bool vectorizeExtract = false, bool flattenConv = false) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), vectorizeNDExtract(vectorizeExtract), flatten1DDepthwiseConv(flattenConv) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return rewriter.notifyMatchFailure(op, "expected Linalg Op"); return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{}, /*scalableVecDims=*/{}, vectorizeNDExtract, flatten1DDepthwiseConv); } private: /// Controls whether to vectorize `tensor.extract` when the input tensor is /// rank >= 2. bool vectorizeNDExtract = false; /// Controls whether to "flatten" the channel dimension when vectorising 1D /// depthwise convolutions. This should lead to bette vectorization for /// tensors with a low number of channel dimensions. bool flatten1DDepthwiseConv = false; }; } // namespace DiagnosedSilenceableFailure transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->hasTrait()) { auto diag = this->emitOpError("requires isolated-from-above targets"); diag.attachNote(target->getLoc()) << "non-isolated target"; return DiagnosedSilenceableFailure::definiteFailure(); } MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx, getVectorizeNdExtract(), getFlatten_1dDepthwiseConv()); if (!getDisableTransferPermutationMapLoweringPatterns()) vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); if (!getDisableMultiReductionToContractPatterns()) vector::populateVectorReductionToContractPatterns(patterns); vector::populateSinkVectorBroadcastPatterns(patterns); patterns.add(ctx, /*benefit=*/2); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); patterns.add(ctx); if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); TrackingListener listener(state, *this); GreedyRewriteConfig config; config.listener = &listener; if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config))) return emitDefaultDefiniteFailure(target); results.push_back(target); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // VectorizeOp //===----------------------------------------------------------------------===// static const StringLiteral kVectorSizesKeyword = "vector_sizes"; ParseResult transform::VectorizeOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; DenseI64ArrayAttr staticSizes; SmallVector operandTypes; llvm::SMLoc operandLoc; DenseBoolArrayAttr scalableVals; if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc)) return ParseResult::failure(); if (succeeded(parser.parseOptionalKeyword(kVectorSizesKeyword))) { if (failed(parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals))) return ParseResult::failure(); } if (succeeded(parser.parseOptionalKeyword( getVectorizeNdExtractAttrName(result.name)))) result.addAttribute(getVectorizeNdExtractAttrName(result.name), parser.getBuilder().getUnitAttr()); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(operandTypes)) return ParseResult::failure(); if (operandTypes.size() != dynamicSizes.size() + 1) { return parser.emitError(operandLoc) << "expected " << dynamicSizes.size() + 1 << " operand type(s)"; } if (parser.resolveOperand(target, operandTypes.front(), result.operands) || parser.resolveOperands(dynamicSizes, ArrayRef(operandTypes).drop_front(), operandLoc, result.operands)) { return failure(); } if (scalableVals) result.addAttribute(getScalableSizesAttrName(result.name), scalableVals); if (staticSizes) result.addAttribute(getStaticVectorSizesAttrName(result.name), staticSizes); return success(); } void transform::VectorizeOp::print(OpAsmPrinter &p) { p << ' ' << getTarget() << ' '; if (!getMixedVectorSizes().empty()) { p << kVectorSizesKeyword << ' '; printDynamicIndexList(p, getOperation(), getVectorSizes(), getStaticVectorSizesAttr(), /*valueTypes=*/{}, getScalableSizesAttr(), OpAsmParser::Delimiter::Square); } if (getVectorizeNdExtract()) p << getVectorizeNdExtractAttrName() << ' '; p.printOptionalAttrDict( (*this)->getAttrs(), /*elidedAttrs=*/{ getScalableSizesAttrName(getOperation()->getName()), getStaticVectorSizesAttrName(getOperation()->getName())}); p << " : "; p << getTarget().getType(); if (!getVectorSizes().empty()) { p << ", "; llvm::interleaveComma(getVectorSizes(), p, [&](Value operand) { p << operand.getType(); }); } } DiagnosedSilenceableFailure transform::VectorizeOp::apply( transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { auto targets = state.getPayloadOps(getTarget()); if (std::empty(targets)) return DiagnosedSilenceableFailure::success(); SmallVector vectorSizes; for (OpFoldResult sz : getMixedVectorSizes()) { if (sz.is()) { auto attr = sz.get(); vectorSizes.push_back(cast(attr).getInt()); continue; } else if (sz.is() && isa(sz.get().getType())) { ArrayRef params = state.getParams(sz.get()); if (params.size() != 1) return emitSilenceableFailure(getLoc()) << "expected a single param"; vectorSizes.push_back( cast(params.front()).getValue().getSExtValue()); continue; } auto szPayloads = state.getPayloadOps(sz.get()); if (!llvm::hasSingleElement(szPayloads)) { auto diag = this->emitOpError( "requires vector size handle that is mapped to 1 payload op"); diag.attachNote(sz.get().getLoc()) << "mapped to " << llvm::range_size(szPayloads) << " payload ops"; return DiagnosedSilenceableFailure::definiteFailure(); } Operation *szPayloadOp = *szPayloads.begin(); if (szPayloadOp->getNumResults() != 1 || !szPayloadOp->getResult(0).getType().isIndex()) { auto diag = this->emitOpError( "requires vector size payload op with 1 index result"); diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op"; return DiagnosedSilenceableFailure::definiteFailure(); } IntegerAttr attr; if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) { auto diag = this->emitOpError("requires constant vector size"); diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op"; return DiagnosedSilenceableFailure::definiteFailure(); } vectorSizes.push_back(attr.getInt()); } // TODO: Check that the correct number of vectorSizes was provided. for (Operation *target : targets) { if (!isa( target)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Unsupported Op, cannot vectorize"; } if (failed(linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(), getVectorizeNdExtract().has_value() ? getVectorizeNdExtract().value() : false))) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; } } return DiagnosedSilenceableFailure::success(); } void transform::VectorizeOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); onlyReadsHandle(getVectorSizes(), effects); modifiesPayload(effects); } SmallVector VectorizeOp::getMixedVectorSizes() { OpBuilder b(getContext()); return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b); } LogicalResult transform::VectorizeOp::verify() { if (getStaticVectorSizes().size() != getScalableSizes().size()) return emitOpError("expected same number of vector sizes (") << getStaticVectorSizes().size() << ") and scalable sizes (" << getScalableSizes().size() << ")"; return success(); } //===----------------------------------------------------------------------===// // HoistRedundantVectorTransfersOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::HoistRedundantVectorTransfersOp::applyToOne( transform::TransformRewriter &rewriter, func::FuncOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // WARNING: This hoisting does not model parallelism and is generally // incorrect when used on distributed loops with memref semantics! // TODO: obsolete and should be retired. linalg::hoistRedundantVectorTransfers(target); results.push_back(target); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // ConvertConv2DToImg2ColOp. //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( transform::TransformRewriter &rewriter, linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); auto maybeTransformed = TypeSwitch>>( target) .Case([&](linalg::Conv2DNhwcHwcfOp op) { return rewriteInIm2Col(rewriter, op); }) .Case([&](linalg::Conv2DNhwcFhwcOp op) { return rewriteInIm2Col(rewriter, op); }) .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) { return rewriteInIm2Col(rewriter, op); }) .Case([&](linalg::Conv2DNchwFchwOp op) { return rewriteInIm2Col(rewriter, op); }) .Default([&](Operation *op) { return rewriter.notifyMatchFailure(op, "not supported"); }); if (failed(maybeTransformed)) return emitDefaultSilenceableFailure(target); // Handle to the operation producing the img2col tensor. results.push_back(maybeTransformed->first); // Handle to the operation that replaces the original convolution. results.push_back(maybeTransformed->second); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // FlattenElementwiseLinalgOp. //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( transform::TransformRewriter &rewriter, linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); if (!isElementwise(target)) return mlir::emitSilenceableFailure(target->getLoc()) << "only elementwise flattening is supported"; // If rank <= 1, do nothing if (target.getNumLoops() <= 1) { results.push_back(target); return DiagnosedSilenceableFailure::success(); } // Attempt to flatten all dims to one. ReassociationIndices reassociation(target.getNumLoops()); std::iota(reassociation.begin(), reassociation.end(), 0); auto maybeFlattened = collapseOpIterationDims(target, reassociation, rewriter); if (failed(maybeFlattened)) return mlir::emitSilenceableFailure(target->getLoc()) << "attempted to flatten, but failed"; results.push_back(maybeFlattened->collapsedOp); rewriter.replaceOp(target, maybeFlattened->results); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TransposeConv2DOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne( transform::TransformRewriter &rewriter, linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); auto maybeTransformed = TypeSwitch>(target) .Case([&](linalg::Conv2DNhwcFhwcOp op) { return transposeConv2D(rewriter, op); }) .Case([&](linalg::Conv2DNhwcFhwcQOp op) { return transposeConv2D(rewriter, op); }) .Default([&](Operation *op) { return rewriter.notifyMatchFailure(op, "not supported"); }); if (failed(maybeTransformed)) return emitDefaultSilenceableFailure(target); // Handle to the new Conv2D operation with transposed filters results.push_back(*maybeTransformed); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // InsertSliceToCopyOp //===----------------------------------------------------------------------===// template DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state) { static_assert(llvm::is_one_of() && "wrong op type"); if (auto copySource = target.getSource().template getDefiningOp()) { results.push_back(copySource); return DiagnosedSilenceableFailure::success(); } // If we are inside an InParallel region, temporarily set the insertion point // outside: only tensor.parallel_insert_slice ops are allowed in there. if constexpr (std::is_same_v) { rewriter.setInsertionPoint( target->template getParentOfType()); } Value extracted = rewriter.create( target.getLoc(), target.getDest(), target.getMixedOffsets(), target.getMixedSizes(), target.getMixedStrides()); Value copied = rewriter .create(target.getLoc(), target.getSource(), extracted) .getResult(0); // Reset the insertion point. rewriter.setInsertionPoint(target); rewriter.replaceOpWithNewOp( target, copied, target.getDest(), target.getMixedOffsets(), target.getMixedSizes(), target.getMixedStrides()); results.push_back(copied.getDefiningOp()); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne( transform::TransformRewriter &rewriter, Operation *targetOp, transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(targetOp); if (auto target = dyn_cast(targetOp)) return doit(rewriter, target, results, state); if (auto target = dyn_cast(targetOp)) return doit(rewriter, target, results, state); DiagnosedSilenceableFailure diag = emitSilenceableError() << "only InsertSliceOp and ParallelInsertSliceOp ops are supported"; diag.attachNote(targetOp->getLoc()) << "target op"; return diag; } //===----------------------------------------------------------------------===// // MapCopyToThreadsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Check if the op is supported. if (!isa(target)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "only linalg.copy and tensor.pad target ops are supported"; diag.attachNote(target->getLoc()) << "target op"; return diag; } assert(target->getNumResults() == 1 && "expected single result"); auto resultShapedType = cast(target->getResult(0).getType()); if (!resultShapedType.hasStaticShape()) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "only statically sized ops of rank <= 3 are supported"; diag.attachNote(target->getLoc()) << "target op"; return diag; } // Conservatively set the minimum viable desired bitwidth alignment. int64_t desiredBitAlignment = getDesiredBitAlignment(); int64_t eltBitwidth = resultShapedType.getElementType().getIntOrFloatBitWidth(); if (desiredBitAlignment % eltBitwidth != 0) { desiredBitAlignment = eltBitwidth; } gpu::CopyMappingInfo mapping( /*ctx=*/getContext(), /*totalNumThreads=*/getTotalNumThreads(), /*alignment=*/desiredBitAlignment, /*sizes=*/resultShapedType.getShape(), /*favorPredication=*/false, /*elementalBitwidth=*/ resultShapedType.getElementType().getIntOrFloatBitWidth()); if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "too few threads to map copy op to threads on the most minor " "dimension, given alignment and vector size constraints, try " "smaller tile size of mapping to more threads"; diag.attachNote(target->getLoc()) << "target op"; return diag; } // OpBuilder only used to compute attributes. OpBuilder b(getContext()); linalg::ForallTilingResult tilingResult; DiagnosedSilenceableFailure diag = tileToForallOpImpl( /*rewriter=*/rewriter, /*state=*/state, /*transformOp=*/*this, /*target=*/target, /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b), /*mixedTileSizes=*/ArrayRef{}, /*mapping=*/b.getArrayAttr(mapping.threadMapping), /*tilingResult=*/tilingResult); if (!diag.succeeded()) return diag; results.push_back(tilingResult.tileOp); results.push_back(tilingResult.tiledOp); return DiagnosedSilenceableFailure::success(); } #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"