//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/Dialect/Transform/Utils/Utils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::linalg; using namespace mlir::transform; #define DEBUG_TYPE "linalg-transforms" /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` /// function that returns the "main" result or failure. Returns failure if the /// pattern failed to apply. Extra arguments are forwarded to the pattern /// constructor. template static FailureOr tryApply(Operation *operation, Args &&...args) { // Check if the given operation has the type expected by the pattern. using OpTy = typename llvm::function_traits< decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; auto op = dyn_cast(operation); if (!op) return failure(); // Apply the pattern directly to the op. PatternTy pattern(operation->getContext(), std::forward(args)...); TrivialPatternRewriter rewriter(operation->getContext()); rewriter.setInsertionPoint(operation); auto result = pattern.returningMatchAndRewrite(op, rewriter); if (failed(result)) return failure(); return cast(result->getOperation()); } //===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::DecomposeOp::applyToOne(linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { #define DOWNSCALE(trans) \ { \ FailureOr res = tryApply(target); \ if (succeeded(res)) { \ results.push_back(*res); \ return DiagnosedSilenceableFailure::success(); \ } \ } #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b)) DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp) DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp) DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp) DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp) DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp) DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp) DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp) DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp) DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp) DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp) #undef DOWNSCALE_NORMAL #undef DOWNSCALE_CALL #undef DOWNSCALE results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // FuseOp //===----------------------------------------------------------------------===// /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. static LogicalResult applyTilingToAll( Operation *transformOp, ArrayRef payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref(TilingInterface)> applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].reserve(payloadOps.size()); for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); TrivialPatternRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr tiledResults = applyFn(tilingInterfaceOp); if (failed(tiledResults)) return failure(); // Perform the replacement of tiled and fused values. SmallVector opsToReplace{target}; llvm::append_range(opsToReplace, tiledResults->fusedProducers); for (Operation *toReplace : opsToReplace) { SmallVector replacements; replacements.reserve(toReplace->getNumResults()); for (OpResult res : toReplace->getResults()) { auto it = tiledResults->replacements.find(res); if (it == tiledResults->replacements.end()) replacements.push_back(res); else replacements.push_back(it->getSecond()); } rewriter.replaceOp(toReplace, replacements); } // Report back the relevant handles to the transform op. tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front()); assert(tiledResults->loops.size() == numLoops && "Mismatched number of loops, tile and fuse transform should have " "failed"); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].push_back(tiledResults->loops[i]); } transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); for (unsigned int i = 0; i < numLoops; ++i) transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); return success(); } /// Parse a tiling-like operation that returns the tiled op as well as the /// created tile loops. The function counts the non-zero tile sizes to compute /// the number of results. static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, StringRef sizesAttrName) { OpAsmParser::UnresolvedOperand targetOperand; SMLoc opLoc = parser.getCurrentLocation(); if (parser.parseOperand(targetOperand) || parser.parseOptionalAttrDict(result.attributes)) return failure(); Attribute sizesAttr = result.attributes.get(sizesAttrName); if (!sizesAttr) return parser.emitError(opLoc) << "expected '" << sizesAttrName << "' attribute"; auto sizesArrayAttr = sizesAttr.dyn_cast(); if (!sizesArrayAttr) return parser.emitError(opLoc) << "'" << sizesAttrName << "' attribute must be an array"; Type pdlOpType = parser.getBuilder().getType(); size_t numExpectedLoops = sizesArrayAttr.size() - llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOpType)); if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) return failure(); return success(); } DiagnosedSilenceableFailure transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); SmallVector tileInterchange = extractFromI64ArrayAttr(getTileInterchange()); scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; tilingOptions = tilingOptions.setTileSizes(tileSizes); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; LogicalResult result = applyTilingToAll( getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr { TrivialPatternRewriter rewriter(getContext()); return tileConsumerAndFuseProducerGreedilyUsingSCFForOp( rewriter, tilingInterfaceOp, tileAndFuseOptions); }); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } ParseResult transform::FuseOp::parse(OpAsmParser &parser, OperationState &result) { return parseTileLikeOp( parser, result, transform::FuseOp::getTileSizesAttrName(result.name).getValue()); } void transform::FuseOp::print(OpAsmPrinter &p) { p << ' '; p << getTarget(); p.printOptionalAttrDict((*this)->getAttrs()); } LogicalResult transform::FuseOp::verify() { SmallVector permutation = extractFromI64ArrayAttr(getTileInterchange()); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() << "expects interchange to be a permutation, found " << getTileInterchange(); } return success(); } //===----------------------------------------------------------------------===// // FuseIntoContainingOp //===----------------------------------------------------------------------===// void transform::FuseIntoContainingOp::build(OpBuilder &builder, OperationState &result, Value producerOp, Value containingOp) { result.addOperands({producerOp, containingOp}); result.addTypes(pdl::OperationType::get(builder.getContext())); } /// Find the first "extract" user of `producerOp` and tile it right before its /// use. The tiled op is fused under the `containingOp`. /// Return this fused op on success or nullptr if anything fails. static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n"); auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) << "producer is not a TileableInterface: " << *producerOp; return nullptr; } // Search the producer slices accessed within the containing operation. // TODO: Generalize to more extract/insert/parallel_insert triples, maybe // evolve into an interface. auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) { auto sliceOp = dyn_cast(user); return sliceOp && containingOp->isProperAncestor(sliceOp); }); // Find a fusion opportunity. if (it == tileableProducer->getUsers().end()) { diag.attachNote(tileableProducer->getLoc()) << "could not find fusion opportunity for: " << *tileableProducer; return nullptr; } auto sliceOpToTile = cast(*it); // Try to fuse the producer in-place. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sliceOpToTile); // Tile the producer. int64_t resultNumber = sliceOpToTile.getSource().cast().getResultNumber(); LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); FailureOr tiledProducer = tileableProducer.generateResultTileValue( rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), sliceOpToTile.getMixedSizes()); if (failed(tiledProducer)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; return nullptr; } LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. Operation *fusedOp = tiledProducer->getDefiningOp(); auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), sliceOpToTile->getResult(0) .getType() .cast() .getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); return fusedOp; } /// First, find the first "scf::ForeachThreadOp" user of `producerOp` and ensure /// it is exactly the `containingOp`, otherwise bail. /// Then, find the first "extract" user of the tied block argument and tile it /// right before its "extract" use. The tiled op is fused under the /// `containingOp`. /// Return this fused op on success or nullptr if anything fails. static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { LLVM_DEBUG( llvm::dbgs() << "Try to fuse an extract use through block argument\n"); auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) << "producer is not a TileableInterface: " << *producerOp; return nullptr; } // Search the first use by a "scf::ForeachThreadOp" user. scf::ForeachThreadOp foreachThreadOp; auto itProducerUses = llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) { foreachThreadOp = dyn_cast(use.getOwner()); return foreachThreadOp; }); // If it's not from the containing op, return. if (!foreachThreadOp || foreachThreadOp != containingOp) { diag.attachNote(tileableProducer->getLoc()) << "could not find a use by the containing op: " << *tileableProducer; return nullptr; } // Search the producer slices accessed within the containing // operation. // TODO: Generalize to more extract/insert/parallel_insert triples. // Maybe evolve into an interface. OpOperand *pUse = &(*itProducerUses); BlockArgument bbArg = foreachThreadOp.getTiedBlockArgument(pUse); // Search the producer slices accessed within the containing operation. // TODO: Generalize to more extract/insert/parallel_insert triples, maybe // evolve into an interface. auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) { auto sliceOp = dyn_cast(user); return sliceOp && containingOp->isProperAncestor(sliceOp); }); // Find a fusion opportunity. if (itBBArgUsers == bbArg.getUsers().end()) { diag.attachNote(containingOp->getLoc()) << "could not find fusion opportunity for bbArg: " << bbArg; return nullptr; } auto sliceOpToTile = cast(*itBBArgUsers); // Try to fuse the producer in-place. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sliceOpToTile); // Replace the use in the tileableProducer before tiling: clone, replace and // then tile. int64_t resultNumber = pUse->get().cast().getResultNumber(); LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); // Gather destination tensors. SmallVector destinationTensors; if (failed(tensor::getOrCreateDestinations( rewriter, tileableProducer->getLoc(), tileableProducer, destinationTensors))) { diag.attachNote(tileableProducer->getLoc()) << "failed to get destination tensors for: " << *tileableProducer; return nullptr; } BlockAndValueMapping bvm; bvm.map(destinationTensors[resultNumber], bbArg); auto tileableProducerClone = cast(rewriter.clone(*tileableProducer, bvm)); auto scopeGuard = llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); // Tile the producer. FailureOr tiledProducer = tileableProducerClone.generateResultTileValue( rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), sliceOpToTile.getMixedSizes()); if (failed(tiledProducer)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; return nullptr; } LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. Operation *fusedOp = tiledProducer->getDefiningOp(); auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), sliceOpToTile->getResult(0) .getType() .cast() .getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); // Replace the use in containingOp. rewriter.updateRootInPlace(containingOp, [&]() { containingOp->setOperand(pUse->getOperandNumber(), destinationTensors.front()); }); return fusedOp; } static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n"); // Gather all uses inside the containing op. SmallVector uses; for (OpResult result : producerOp->getOpResults()) { for (OpOperand &use : result.getUses()) { if (containingOp->isProperAncestor(use.getOwner())) { uses.push_back(&use); continue; } // Cannot clone and fuse if the use is by the containing op itself: fail // immediately. if (containingOp == use.getOwner()) { diag.attachNote(producerOp->getLoc()) << "producer op use by containing op cannot be fused by cloning"; return nullptr; } } } // Check for a non-empty list of fusion opportunities. if (uses.empty()) { diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning"; return nullptr; } // Clone and fuse inside the containing op. Operation *fusedOp = nullptr; OpOperand *use = uses.front(); // Parallel insert slice is not a valid clone destination. // TODO: Generalize to other type of ops. assert(!isa(use->getOwner()) && "Parallel insert slice is not a valid clone destination"); unsigned resultNumber = use->get().cast().getResultNumber(); LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(use->getOwner()); fusedOp = rewriter.clone(*producerOp); rewriter.updateRootInPlace( use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); }); return fusedOp; } DiagnosedSilenceableFailure transform::FuseIntoContainingOp::apply(transform::TransformResults &results, transform::TransformState &state) { SmallVector fusedOps; ArrayRef producerOps = state.getPayloadOps(getProducerOp()); // If nothing to fuse, propagate success. if (producerOps.empty()) { results.set(getFusedOp().cast(), SmallVector{}); return DiagnosedSilenceableFailure::success(); } ArrayRef containingOps = state.getPayloadOps(getContainingOp()); if (containingOps.size() != 1) { return emitDefiniteFailure() << "requires exactly one containing_op handle (got " << containingOps.size() << ")"; } Operation *containingOp = containingOps.front(); // Helper function to find the next producer that should be fused. Take any // producer that has a use inside the containing op. SmallVector remainingProducers(producerOps.begin(), producerOps.end()); auto getNextProducer = [&]() -> FailureOr { for (const auto &it : enumerate(remainingProducers)) { Operation *producerOp = it.value(); // The containing op may be a user of producerOp: use isAncestor. int64_t numUsesInContainingOp = llvm::count_if(producerOp->getUsers(), [&](Operation *op) { return containingOp->isAncestor(op); }); // TODO: When resolving the TODO below (no duplicate ops), take an op // that has no use among the remaining producers. This is a topological // sorting. if (numUsesInContainingOp > 0) { if (numUsesInContainingOp == 1) remainingProducers.erase(remainingProducers.begin() + it.index()); return producerOp; } } return failure(); }; IRRewriter rewriter(getContext()); while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { results.set(getFusedOp().cast(), ArrayRef()); Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark); diag << "could not find next producer to fuse into container"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } Operation *producerOp = *nextProducer; // Default diagnostic, to be complemented with more failure information. Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); diag << "could not fuse " << *producerOp << " into " << *containingOp; // TODO: If there are multiple uses of the producer in the containing op, // we currently tile/clone the op multiple times (once per use). In some // cases, we can tile/clone once and reuse the value for each use. // Futhermore, producers should then be traversed according to a // topological sorting. Operation *tiled = tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); if (tiled) { LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n" << *containingOp); fusedOps.push_back(tiled); continue; } Operation *tiledContainingOpOperand = tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( rewriter, diag, producerOp, containingOp); if (tiledContainingOpOperand) { LLVM_DEBUG(llvm::dbgs() << "\nFused an extract use through block argument\n" << *containingOp); fusedOps.push_back(tiledContainingOpOperand); continue; } Operation *cloned = cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp); if (cloned) { LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n" << *containingOp); fusedOps.push_back(cloned); continue; } results.set(getFusedOp().cast(), ArrayRef()); return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } results.set(getFusedOp().cast(), fusedOps); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // GeneralizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Exit early if no transformation is needed. if (isa(target)) { results.push_back(target); return DiagnosedSilenceableFailure::success(); } FailureOr generic = tryApply(target); if (succeeded(generic)) { results.push_back(generic->getOperation()); return DiagnosedSilenceableFailure::success(); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // InterchangeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::InterchangeOp::applyToOne(linalg::GenericOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ArrayRef interchangeVector = getIteratorInterchange(); // Exit early if no transformation is needed. if (interchangeVector.empty()) { results.push_back(target); return DiagnosedSilenceableFailure::success(); } TrivialPatternRewriter rewriter(target->getContext()); FailureOr res = interchangeGenericOp(rewriter, target, SmallVector(interchangeVector.begin(), interchangeVector.end())); if (failed(res)) return DiagnosedSilenceableFailure::definiteFailure(); results.push_back(res->getOperation()); return DiagnosedSilenceableFailure::success(); } LogicalResult transform::InterchangeOp::verify() { ArrayRef permutation = getIteratorInterchange(); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() << "expects iterator_interchange to be a permutation, found " << getIteratorInterchange(); } return success(); } //===---------------------------------------------------------------------===// // MatchOp //===---------------------------------------------------------------------===// void transform::MatchOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef opNames) { result.addOperands(target); result.addAttribute(MatchOp::getOpsAttrName(result.name), builder.getStrArrayAttr(opNames)); result.addTypes(pdl::OperationType::get(builder.getContext())); } DiagnosedSilenceableFailure transform::MatchOp::apply(transform::TransformResults &results, transform::TransformState &state) { llvm::StringSet<> strs; if (getOps().has_value()) strs.insert(getOps()->getAsValueRange().begin(), getOps()->getAsValueRange().end()); ArrayRef payloadOps = state.getPayloadOps(getTarget()); if (payloadOps.size() != 1) { results.set(getResult().cast(), {}); return emitDefiniteFailure("requires exactly one target handle"); } SmallVector res; auto matchFun = [&](Operation *op) { if (getOps().has_value() && !strs.contains(op->getName().getStringRef())) return; // Interfaces cannot be matched by name, just by ID. // So we specifically encode the interfaces we care about for this op. if (getInterface().has_value()) { auto iface = getInterface().value(); if (iface == transform::MatchInterfaceEnum::LinalgOp && !isa(op)) return; if (iface == transform::MatchInterfaceEnum::TilingInterface && isa(op)) return; } // Check if all specified attributes match. if (getOpAttrs().has_value()) { DictionaryAttr opAttrs = getOpAttrs().value(); for (NamedAttribute attr : opAttrs) { if (attr.getName() == getInterfaceAttrName() || attr.getName() == getOpsAttrName()) continue; if (!op->hasAttr(attr.getName())) return; if (op->getAttr(attr.getName()) != attr.getValue()) return; } } if (getFilterResultType().has_value()) { Type t = getFilterResultType().value(); if (op->getNumResults() != 1 || op->getResultTypes().front() != t) return; } // All constraints are satisfied. res.push_back(op); return; }; payloadOps.front()->walk(matchFun); results.set(getResult().cast(), res); return DiagnosedSilenceableFailure::success(); } //===---------------------------------------------------------------------===// // MultiTileSizesOp //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( LinalgOp target, transform::ApplyToEachResultList &results, TransformState &state) { OpBuilder builder(target.getContext()); builder.setInsertionPoint(target); OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); OpFoldResult divisor = builder.getIndexAttr(getDivisor()); FailureOr spec = computeMultiTileSizes( builder, target, getDimension(), targetSize, divisor); if (failed(spec)) { return emitSilenceableError() << "could not generate tile size computation"; } AffineExpr s0 = builder.getAffineSymbolExpr(0); AffineExpr s1 = builder.getAffineSymbolExpr(1); Operation *splitPoint = makeComposedAffineApply(builder, target.getLoc(), s0 * s1, {spec->lowTileSize, spec->lowTripCount}); Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); Operation *highTileSize = spec->highTileSize.getDefiningOp(); assert(lowTileSize && highTileSize && splitPoint && "tile sizes are not produced by operations"); results.reserve(results.size() + 3); results.push_back(lowTileSize); results.push_back(highTileSize); results.push_back(splitPoint); return DiagnosedSilenceableFailure::success(); } void transform::MultiTileSizesOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getTarget(), effects); producesHandle(getResults(), effects); modifiesPayload(effects); } //===---------------------------------------------------------------------===// // PadOp //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::PadOp::applyToOne(linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Convert the integer packing flags to booleans. SmallVector packPaddings; for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) packPaddings.push_back(static_cast(packPadding)); // Convert the padding values to attributes. SmallVector paddingValues; for (auto const &it : llvm::zip(getPaddingValues(), target->getOperandTypes())) { auto attr = std::get<0>(it).dyn_cast(); if (!attr) { emitOpError("expects padding values to be typed attributes"); return DiagnosedSilenceableFailure::definiteFailure(); } Type elementType = getElementTypeOrSelf(std::get<1>(it)); // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = attr.dyn_cast()) { paddingValues.push_back( parseAttribute(attr.cast(), elementType)); if (!paddingValues.back()) { auto diag = this->emitOpError("expects a padding that parses to ") << elementType << ", got " << std::get<0>(it); diag.attachNote(target.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } continue; } // Otherwise, add the attribute directly. if (attr.getType() != elementType) { auto diag = this->emitOpError("expects a padding value of type ") << elementType << ", got " << attr; diag.attachNote(target.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } paddingValues.push_back(attr); } // Extract the transpose vectors. SmallVector> transposePaddings; for (Attribute transposeVector : getTransposePaddings().cast()) transposePaddings.push_back( extractFromI64ArrayAttr(transposeVector.cast())); LinalgPaddingOptions paddingOptions; paddingOptions.setPaddingValues(paddingValues); paddingOptions.setPaddingDimensions( extractFromI64ArrayAttr(getPaddingDimensions())); paddingOptions.setPackPaddings(packPaddings); paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); paddingOptions.setTransposePaddings(transposePaddings); FailureOr result = tryApply(target, paddingOptions); if (succeeded(result)) { results.push_back(result->getOperation()); return DiagnosedSilenceableFailure::success(); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } LogicalResult transform::PadOp::verify() { SmallVector packPaddings = extractFromI64ArrayAttr(getPackPaddings()); if (any_of(packPaddings, [](int64_t packPadding) { return packPadding != 0 && packPadding != 1; })) { return emitOpError() << "expects pack_paddings to contain booleans (0/1), found " << getPackPaddings(); } SmallVector paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions()); if (any_of(paddingDimensions, [](int64_t paddingDimension) { return paddingDimension < 0; })) { return emitOpError() << "expects padding_dimensions to contain positive " "integers, found " << getPaddingDimensions(); } SmallVector hoistPaddings = extractFromI64ArrayAttr(getHoistPaddings()); if (any_of(hoistPaddings, [](int64_t hoistPadding) { return hoistPadding < 0; })) { return emitOpError() << "expects hoist_paddings to contain positive integers, found " << getHoistPaddings(); } ArrayAttr transposes = getTransposePaddings(); for (Attribute attr : transposes) { SmallVector transpose = extractFromI64ArrayAttr(attr); auto sequence = llvm::to_vector(llvm::seq(0, transpose.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), transpose.end())) { return emitOpError() << "expects transpose_paddings to be a permutation, found " << attr; } } return success(); } //===----------------------------------------------------------------------===// // PromoteOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::PromoteOp::applyToOne(linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; if (!getOperandsToPromote().empty()) promotionOptions = promotionOptions.setOperandsToPromote( extractFromI64ArrayAttr(getOperandsToPromote())); if (getUseFullTilesByDefault()) promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( getUseFullTilesByDefault()); if (getUseAlloca()) promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); if (!getUseFullTileBuffers().empty()) promotionOptions = promotionOptions.setUseFullTileBuffers( llvm::to_vector(getUseFullTileBuffers().getAsValueRange())); if (getAlignment().has_value()) promotionOptions = promotionOptions.setAlignment(*getAlignment()); if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return emitDefaultDefiniteFailure(target); TrivialPatternRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) return emitDefaultDefiniteFailure(target); results.push_back(target); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // ReplaceOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ReplaceOp::apply(TransformResults &transformResults, TransformState &state) { ArrayRef payload = state.getPayloadOps(getTarget()); // Check for invalid targets. for (Operation *target : payload) { if (target->getNumOperands() > 0) return emitDefiniteFailure() << "expected target without operands"; if (!target->hasTrait() && target->getNumRegions() > 0) return emitDefiniteFailure() << "expected target that is isloated from above"; } // Clone and replace. IRRewriter rewriter(getContext()); Operation *pattern = &getBodyRegion().front().front(); SmallVector replacements; for (Operation *target : payload) { if (getOperation()->isAncestor(target)) continue; rewriter.setInsertionPoint(target); Operation *replacement = rewriter.clone(*pattern); rewriter.replaceOp(target, replacement->getResults()); replacements.push_back(replacement); } transformResults.set(getReplacement().cast(), replacements); return DiagnosedSilenceableFailure::success(); } void transform::ReplaceOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); producesHandle(getReplacement(), effects); modifiesPayload(effects); } LogicalResult transform::ReplaceOp::verify() { if (!getBodyRegion().hasOneBlock()) return emitOpError() << "expected one block"; if (std::distance(getBodyRegion().front().begin(), getBodyRegion().front().end()) != 1) return emitOpError() << "expected one operation in block"; Operation *replacement = &getBodyRegion().front().front(); if (replacement->getNumOperands() > 0) return replacement->emitOpError() << "expected replacement without operands"; if (!replacement->hasTrait() && replacement->getNumRegions() > 0) return replacement->emitOpError() << "expect op that is isolated from above"; return success(); } //===----------------------------------------------------------------------===// // ScalarizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) { SmallVector tileSizes; Location loc = target.getLoc(); SmallVector allShapeSizes = target.createFlatListOfOperandDims(b, loc); AffineMap map = target.getShapesToLoopsMap(); if (!map) return tileSizes; IRRewriter rewriter(b); SmallVector shapeSizes = makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, allShapeSizes); // If the shape size is dynamic, tile by 1. // Otherwise, do not tile (i.e. tile size 0). for (OpFoldResult shapeSize : shapeSizes) { tileSizes.push_back(getConstantIntValue(shapeSize) ? b.create(loc, 0) : b.create(loc, 1)); } return tileSizes; }); SmallVector emptyTileSizes; TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(target.getOperation()), tilingOptions); if (failed(maybeTilingResult)) return emitDefaultDefiniteFailure(target); if (target->getNumResults()) rewriter.replaceOp(target, maybeTilingResult->replacements); else rewriter.eraseOp(target); results.reserve(maybeTilingResult->tiledOps.size()); for (Operation *tiled : maybeTilingResult->tiledOps) results.push_back(tiled); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. ArrayRef payload = state.getPayloadOps(getTarget()); TrivialPatternRewriter rewriter(getContext()); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { auto diag = DiagnosedSilenceableFailure::success(); splitPoints = llvm::to_vector(llvm::map_range( state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { diag = emitSilenceableError() << "expected dynamic split point handle to point to a " "single-result index-typed op"; diag.attachNote(op->getLoc()) << "dynamic split point"; } return OpFoldResult(op->getResult(0)); })); if (diag.isSilenceableFailure()) { results.set(getFirst().cast(), {}); results.set(getSecond().cast(), {}); return diag; } if (splitPoints.size() != payload.size()) { return emitDefiniteFailure() << "expected the dynamic split point handle to point to as " "many operations (" << splitPoints.size() << ") as the target handle (" << payload.size() << ")"; } } else { splitPoints.resize(payload.size(), rewriter.getIndexAttr(getStaticSplitPoint())); } // Split each target operation. SmallVector first, second; Operation *noSecondPart = nullptr; for (const auto &pair : llvm::zip(payload, splitPoints)) { Operation *target = std::get<0>(pair); auto linalgOp = dyn_cast(target); if (!linalgOp) { auto diag = emitSilenceableError() << "only applies to structured ops"; diag.attachNote(target->getLoc()) << "target op"; results.set(getFirst().cast(), {}); results.set(getSecond().cast(), {}); return diag; } if (getDimension() >= linalgOp.getNumLoops()) { auto diag = emitSilenceableError() << "dimension " << getDimension() << " does not exist in target op"; diag.attachNote(target->getLoc()) << "target op"; results.set(getFirst().cast(), {}); results.set(getSecond().cast(), {}); return diag; } rewriter.setInsertionPoint(linalgOp); std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp( rewriter, cast(linalgOp.getOperation()), getDimension(), std::get<1>(pair)); // Propagate errors. if (!first.back() && !second.back()) { auto diag = emitDefiniteFailure() << "internal failure in splitting"; diag.attachNote(target->getLoc()) << "target op"; return diag; } // Do not add null second parts. if (!second.back()) { noSecondPart = target; second.pop_back(); } } if (second.size() != first.size() && !second.empty()) { results.set(getFirst().cast(), {}); results.set(getSecond().cast(), {}); auto diag = emitSilenceableError() << "splitting does not produce the second part for a subset of targets"; diag.attachNote() << "expected splitting to produce the second part of all " "or none of the targets"; diag.attachNote(noSecondPart->getLoc()) << "first target with no second part"; return diag; } results.set(getFirst().cast(), first); results.set(getSecond().cast(), second); return DiagnosedSilenceableFailure::success(); } void SplitOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); if (getDynamicSplitPoint()) onlyReadsHandle(getDynamicSplitPoint(), effects); producesHandle(getResults(), effects); modifiesPayload(effects); } ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; IntegerAttr staticSplitPoint; auto pdlOperationType = pdl::OperationType::get(parser.getBuilder().getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parser.parseKeyword("after")) return failure(); OptionalParseResult dynamicPointParseResult = parser.parseOptionalOperand(dynamicSplitPoint); if (!dynamicPointParseResult.has_value()) { int64_t staticSplitPointValue; if (failed(parser.parseInteger(staticSplitPointValue))) return failure(); staticSplitPoint = parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); } else { if (failed(*dynamicPointParseResult) || parser.resolveOperand(dynamicSplitPoint, pdlOperationType, result.operands)) { return failure(); } staticSplitPoint = parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic); } result.addAttribute( SplitOp::getStaticSplitPointAttrName(result.name).getValue(), staticSplitPoint); if (failed(parser.parseOptionalAttrDict(result.attributes))) return failure(); result.addTypes({pdlOperationType, pdlOperationType}); return success(); } void SplitOp::print(OpAsmPrinter &printer) { printer << " " << getTarget() << " after "; int64_t staticSplitSize = static_cast(getStaticSplitPoint()); if (staticSplitSize != ShapedType::kDynamic) printer << staticSplitSize; else printer << getDynamicSplitPoint(); printer << " "; printer.printOptionalAttrDict(getOperation()->getAttrs(), {getStaticSplitPointAttrName()}); } LogicalResult SplitOp::verify() { if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamic) ^ (getDynamicSplitPoint() == nullptr)) { return emitOpError() << "expects either a dynamic or a static split " "point to be provided"; } return success(); } //===----------------------------------------------------------------------===// // SplitReductionOp //===----------------------------------------------------------------------===// void transform::SplitReductionOp::build( OpBuilder &builder, OperationState &result, Value target, int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel, bool useScalingAlgorithm, bool useAlloc) { MLIRContext *ctx = builder.getContext(); result.addOperands(target); result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name), builder.getI64IntegerAttr(splitFactor)); result.addAttribute( SplitReductionOp::getInsertSplitDimensionAttrName(result.name), builder.getI64IntegerAttr(insertSplitDimension)); if (innerParallel) { result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name), builder.getUnitAttr()); } if (useScalingAlgorithm) { result.addAttribute( SplitReductionOp::getUseScalingAlgorithmAttrName(result.name), builder.getUnitAttr()); } if (useAlloc) { result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name), builder.getUnitAttr()); } auto resultType = pdl::OperationType::get(ctx); result.addTypes({resultType, resultType, resultType, resultType}); } DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne( linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return linalg::SplitReductionOptions{int64_t(getSplitFactor()), unsigned(getInsertSplitDimension()), bool(getInnerParallel())}; }; TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) : splitReduction(rewriter, target, splitFn, getUseAlloc()); if (failed(splitResult)) return emitDefaultDefiniteFailure(target); results.push_back(splitResult->initOrAlloc); results.push_back(splitResult->fillOp); results.push_back(splitResult->splitLinalgOp); results.push_back(splitResult->resultCombiningLinalgOp); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TileReductionUsingScfOp //===----------------------------------------------------------------------===// void transform::TileReductionUsingScfOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef staticTileSizes) { // Call the default builder. // This is future-proof re mixed static-dynamic and setting up the proper // operands segment sizes attributes for multiple variadic operands. // In the absence of this, horrible bugs ensue. // TODO: support mixed static-dynamic (see TileToForeachThreadOp). MLIRContext *ctx = builder.getContext(); auto opTy = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); build(builder, result, /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, /*target=*/target, /*tile_sizes=*/staticTileSizesAttr); } DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr result = scf::tileReductionUsingScf( rewriter, cast(target.getOperation()), getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()))); if (failed(result)) return emitDefaultSilenceableFailure(target); results.push_back(result->loops.front()); results.push_back(result->initialOp); results.push_back(result->parallelTiledOp); results.push_back(result->mergeOp); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TileReductionUsingForeachThreadOp //===----------------------------------------------------------------------===// void transform::TileReductionUsingForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef staticNumThreads, ArrayRef staticTileSizes, ArrayAttr mapping) { // Call the default builder. // This is future-proof re mixed static-dynamic and setting up the proper // operands segment sizes attributes for multiple variadic operands. // In the absence of this, horrible bugs ensue. // TODO: support mixed static-dynamic (see TileToForeachThreadOp). MLIRContext *ctx = builder.getContext(); auto opTy = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); build(builder, result, /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, /*target=*/target, /*num_threads=*/staticNumThreadsAttr, /*tile_sizes=*/staticTileSizesAttr, /*mapping=*/mapping); } DiagnosedSilenceableFailure transform::TileReductionUsingForeachThreadOp::applyToOne( linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); SmallVector numThreads = getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); SmallVector tileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())); FailureOr result = linalg::tileReductionUsingForeachThread( rewriter, cast(target.getOperation()), numThreads, tileSizes, getMapping()); if (failed(result)) { results.assign(4, nullptr); auto diag = emitSilenceableError() << "could not tile reduction"; diag.attachNote(target.getLoc()) << "target operation"; return diag; } results.push_back(result->loops); results.push_back(result->initialOp); results.push_back(result->parallelTiledOp); results.push_back(result->mergeOp); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// void transform::TileOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef staticTileSizes, ArrayRef interchange) { return build(builder, result, /*target=*/target, /*mixedTileSizes=*/ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), interchange); } void transform::TileOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, ArrayRef interchange) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); // Call the default builder which sets up the proper operands segment sizes // attributes for multiple variadic operands. In the absence of this, horrible // bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); build(builder, result, /*resultTypes=*/TypeRange{operationType, operationType}, /*target=*/target, /*dynamic_sizes=*/dynamicTileSizes, /*static_sizes=*/staticTileSizesAttr, /*interchange=*/builder.getDenseI64ArrayAttr(interchange)); } DiagnosedSilenceableFailure transform::TileOp::apply(TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); ArrayRef targets = state.getPayloadOps(getTarget()); SmallVector> dynamicSizeProducers; dynamicSizeProducers.reserve(getDynamicSizes().size()); for (Value dynamicSizeProducerHandle : getDynamicSizes()) { dynamicSizeProducers.push_back( state.getPayloadOps(dynamicSizeProducerHandle)); if (dynamicSizeProducers.back().size() != targets.size()) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected as many dynamic size-producing operations (" << dynamicSizeProducers.back().size() << ") as target ops (" << targets.size() << ")"; diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; return diag; } for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && op->getResult(0).getType().isa()) continue; DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " "with a single index-type result"; diag.attachNote(op->getLoc()) << "size producer op"; diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; return diag; } } SmallVector tiled; SmallVector, 4> loops; loops.resize(getLoops().size()); for (auto &en : llvm::enumerate(targets)) { auto linalgOp = dyn_cast(en.value()); if (!linalgOp) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "only linalg ops are supported"; diag.attachNote(en.value()->getLoc()) << "target op"; return diag; } scf::SCFTilingOptions tilingOptions; unsigned index = en.index(); if (!tileSizes.empty()) { tilingOptions.setTileSizeComputationFunction( [&, index](OpBuilder &b, Operation *) { SmallVector sizes; sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; for (OpFoldResult ofr : getMixedSizes()) { if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( getLoc(), attr.cast().getInt())); } else { sizes.push_back( dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); } } return sizes; }); } tilingOptions.setInterchange(getInterchange()); TrivialPatternRewriter rewriter(linalgOp.getContext()); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(linalgOp.getOperation()), tilingOptions); if (failed(maybeTilingResult)) return DiagnosedSilenceableFailure::definiteFailure(); if (linalgOp.hasBufferSemantics()) rewriter.eraseOp(linalgOp); else rewriter.replaceOp(linalgOp, maybeTilingResult->loops.front()->getResults()); tiled.append(maybeTilingResult->tiledOps); for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops)) loops[en2.index()].push_back(en2.value()); } transformResults.set(getTiledLinalgOp().cast(), tiled); for (const auto &en : llvm::enumerate(loops)) transformResults.set(getLoops()[en.index()].cast(), en.value()); return DiagnosedSilenceableFailure::success(); } SmallVector transform::TileOp::getMixedSizes() { ValueRange dynamic = getDynamicSizes(); ArrayRef tileSizes = getStaticSizes(); SmallVector results; results.reserve(tileSizes.size()); unsigned dynamicPos = 0; Builder builder(getContext()); for (int64_t size : tileSizes) { if (size == ShapedType::kDynamic) { results.push_back(dynamic[dynamicPos++]); } else { results.push_back(builder.getIndexAttr(size)); } } return results; } // We want to parse `DenseI64ArrayAttr` using the short form without the // `array` prefix to be consistent in the IR with `parseDynamicIndexList`. ParseResult parseOptionalInterchange(OpAsmParser &parser, OperationState &result) { if (succeeded(parser.parseOptionalLBrace())) { if (failed(parser.parseKeyword("interchange"))) return parser.emitError(parser.getNameLoc()) << "expect `interchange`"; if (failed(parser.parseEqual())) return parser.emitError(parser.getNameLoc()) << "expect `=`"; result.addAttribute("interchange", DenseI64ArrayAttr::parse(parser, Type{})); if (failed(parser.parseRBrace())) return parser.emitError(parser.getNameLoc()) << "expect `}`"; } return success(); } void printOptionalInterchange(OpAsmPrinter &p, ArrayRef interchangeVals) { if (!interchangeVals.empty()) { p << " {interchange = ["; llvm::interleaveComma(interchangeVals, p, [&](int64_t integer) { p << integer; }); p << "]}"; } } ParseResult transform::TileOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; DenseI64ArrayAttr staticSizes; auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parseDynamicIndexList(parser, dynamicSizes, staticSizes) || parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands)) return ParseResult::failure(); // Parse optional interchange. if (failed(parseOptionalInterchange(parser, result))) return ParseResult::failure(); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); size_t numExpectedLoops = staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); return success(); } void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); printOptionalInterchange(p, getInterchange()); } void transform::TileOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); onlyReadsHandle(getDynamicSizes(), effects); producesHandle(getTiledLinalgOp(), effects); producesHandle(getLoops(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // TileToForeachThreadOp //===----------------------------------------------------------------------===// void transform::TileToForeachThreadOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef staticTileSizes, transform::TileSizesSpec, ArrayAttr mapping) { return build(builder, result, /*target=*/target, /*mixedTileSizes=*/ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), /*_=*/TileSizesSpec(), /*mapping=*/mapping); } void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, transform::TileSizesSpec, ArrayAttr mapping) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); // Call the default builder which sets up the proper operands segment sizes // attributes for multiple variadic operands. In the absence of this, horrible // bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); build(builder, result, /*resultTypes=*/TypeRange{operationType, operationType}, /*target=*/target, /*num_threads=*/ValueRange{}, /*tile_sizes=*/dynamicTileSizes, /*packed_num_threads=*/Value(), /*packed_tile_sizes=*/Value(), /*static_num_threads=*/builder.getDenseI64ArrayAttr({}), /*static_tile_sizes=*/staticTileSizesAttr, /*mapping=*/mapping); } void transform::TileToForeachThreadOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef staticNumThreads, transform::NumThreadsSpec, ArrayAttr mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), NumThreadsSpec(), mapping); } void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedNumThreads, transform::NumThreadsSpec, ArrayAttr mapping) { SmallVector staticNumThreads; SmallVector dynamicNumThreads; dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, staticNumThreads); // Call the default builder which sets up the proper operands segment sizes // attributes for multiple variadic operands. In the absence of this, horrible // bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); build(builder, result, /*resultTypes=*/TypeRange{operationType, operationType}, /*target=*/target, /*num_threads=*/dynamicNumThreads, /*tile_sizes=*/ValueRange{}, /*packed_num_threads=*/Value(), /*packed_tile_sizes=*/Value(), /*static_num_threads=*/staticNumThreadsAttr, /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}), /*mapping=*/mapping); } /// Assuming that `ofr` is an index attr or a transform dialect handle mapped /// to exactly one op with one index result, return that value. static DiagnosedSilenceableFailure unpackPDLOperations( transform::TransformState &state, TransformOpInterface transformOp, SmallVector &result, ArrayRef ofrs) { for (OpFoldResult ofr : ofrs) { if (ofr.is()) { if (!ofr.get().isa()) return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; result.push_back(ofr); continue; } ArrayRef payloadOps = state.getPayloadOps(ofr.get()); if (payloadOps.size() != 1) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "handle must be mapped to exactly one payload op"; diag.attachNote(ofr.get().getLoc()) << "mapped to " << payloadOps.size() << " payload ops"; return diag; } Operation *op = payloadOps[0]; if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "payload op must have exactly 1 index result"; diag.attachNote(op->getLoc()) << "has " << op->getNumResults() << " results"; return diag; } result.push_back(op->getResult(0)); } return DiagnosedSilenceableFailure::success(); } // Given a list of OpFoldResults that are either index attrs or op // handles, return a list of OpFoldResults where all op handles are // replaced with the first (and only) OpResult of that payload op. (There // must be exactly one mapped payload op and it must have exactly one // index result.) static DiagnosedSilenceableFailure unpackPDLOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector &result, Value packedHandle) { ArrayRef payloadOps = state.getPayloadOps(packedHandle); for (Operation *op : payloadOps) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "payload op must have exactly 1 index result"; diag.attachNote(op->getLoc()) << "has " << op->getNumResults() << " results"; return diag; } result.push_back(op->getResult(0)); } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef targets, ArrayRef mixedNumThreads, ArrayRef mixedTileSizes, std::optional mapping, SmallVector &tileOps, SmallVector &tiledOps) { if (targets.empty()) return DiagnosedSilenceableFailure::success(); // Transform all targets one by one. for (Operation *target : targets) { auto tilableOp = dyn_cast(target); if (!tilableOp) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "only TilingInterface ops are supported"; diag.attachNote(target->getLoc()) << "target op"; return diag; } rewriter.setInsertionPoint(tilableOp); FailureOr tilingResult = failure(); if (!mixedNumThreads.empty()) { tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, mixedNumThreads, mapping); } else { tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( rewriter, tilableOp, mixedTileSizes, mapping); } if (failed(tilingResult)) return transformOp.emitDefaultSilenceableFailure(tilableOp); rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults()); tileOps.push_back(tilingResult->tileOp); tiledOps.push_back(tilingResult->tiledOp); } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply( transform::TransformResults &transformResults, transform::TransformState &state) { IRRewriter rewriter(getContext()); auto transformOp = cast(getOperation()); ArrayRef targets = state.getPayloadOps(getTarget()); // Result payload ops. SmallVector tileOps; SmallVector tiledOps; // Unpack handles. SmallVector mixedNumThreads; DiagnosedSilenceableFailure status = getPackedNumThreads() ? unpackPDLOperations(state, transformOp, mixedNumThreads, getPackedNumThreads()) : unpackPDLOperations(state, transformOp, mixedNumThreads, getMixedNumThreads()); if (!status.succeeded()) return status; SmallVector mixedTileSizes; status = getPackedTileSizes() ? unpackPDLOperations(state, transformOp, mixedTileSizes, getPackedTileSizes()) : unpackPDLOperations(state, transformOp, mixedTileSizes, getMixedTileSizes()); if (!status.succeeded()) return status; DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl( rewriter, state, transformOp, targets, mixedNumThreads, mixedTileSizes, getMapping(), tileOps, tiledOps); if (!diag.succeeded()) { transformResults.set(getForeachThreadOp().cast(), {}); transformResults.set(getTiledOp().cast(), {}); return diag; } transformResults.set(getForeachThreadOp().cast(), tileOps); transformResults.set(getTiledOp().cast(), tiledOps); return DiagnosedSilenceableFailure::success(); } void transform::TileToForeachThreadOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); onlyReadsHandle(getTileSizes(), effects); onlyReadsHandle(getNumThreads(), effects); producesHandle(getResults(), effects); } SmallVector TileToForeachThreadOp::getMixedNumThreads() { Builder b(getContext()); return getMixedValues(getStaticNumThreads(), getNumThreads(), b); } SmallVector TileToForeachThreadOp::getMixedTileSizes() { Builder b(getContext()); return getMixedValues(getStaticTileSizes(), getTileSizes(), b); } LogicalResult TileToForeachThreadOp::verify() { int numThreadsSpec = static_cast(!getMixedNumThreads().empty()) + static_cast(getPackedNumThreads() != Value()); if (numThreadsSpec > 1) return emitOpError( "num_threads and packed_num_threads are mutually exclusive"); int tileSizesSpec = static_cast(!getMixedTileSizes().empty()) + static_cast(getPackedTileSizes() != Value()); if (tileSizesSpec > 1) return emitOpError( "tile_sizes and packed_tile_sizes are mutually exclusive"); if (numThreadsSpec == 0 && tileSizesSpec == 0) return emitOpError( "either (packed_)num_threads or (packed_)tile_sizes must be specified"); return success(); } //===----------------------------------------------------------------------===// // TileToScfForOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::TileToScfForOp::apply(TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); ArrayRef targets = state.getPayloadOps(getTarget()); SmallVector> dynamicSizeProducers; dynamicSizeProducers.reserve(getDynamicSizes().size()); for (Value dynamicSizeProducerHandle : getDynamicSizes()) { dynamicSizeProducers.push_back( state.getPayloadOps(dynamicSizeProducerHandle)); if (dynamicSizeProducers.back().size() != targets.size()) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected as many dynamic size-producing operations (" << dynamicSizeProducers.back().size() << ") as target ops (" << targets.size() << ")"; diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; return diag; } for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && op->getResult(0).getType().isa()) continue; DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " "with a single index-type result"; diag.attachNote(op->getLoc()) << "size producer op"; diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; return diag; } } SmallVector tiled; SmallVector, 4> loops; loops.resize(getLoops().size()); for (auto &en : llvm::enumerate(targets)) { auto tilingInterfaceOp = dyn_cast(en.value()); if (!tilingInterfaceOp) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "only TilingInterface ops are supported"; diag.attachNote(en.value()->getLoc()) << "target op"; return diag; } scf::SCFTilingOptions tilingOptions; unsigned index = en.index(); if (!tileSizes.empty()) { tilingOptions.setTileSizeComputationFunction( [&, index](OpBuilder &b, Operation *) { SmallVector sizes; sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; for (OpFoldResult ofr : getMixedSizes()) { if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( getLoc(), attr.cast().getInt())); } else { sizes.push_back( dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); } } return sizes; }); } tilingOptions.setInterchange(getInterchange()); TrivialPatternRewriter rewriter(tilingInterfaceOp.getContext()); FailureOr tilingResult = tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); if (failed(tilingResult)) return DiagnosedSilenceableFailure::definiteFailure(); rewriter.replaceOp(tilingInterfaceOp, tilingResult->replacements); tiled.append(tilingResult->tiledOps); for (const auto &en2 : llvm::enumerate(tilingResult->loops)) loops[en2.index()].push_back(en2.value()); } transformResults.set(getTiledLinalgOp().cast(), tiled); for (const auto &en : llvm::enumerate(loops)) transformResults.set(getLoops()[en.index()].cast(), en.value()); return DiagnosedSilenceableFailure::success(); } SmallVector transform::TileToScfForOp::getMixedSizes() { ValueRange dynamic = getDynamicSizes(); ArrayRef tileSizes = getStaticSizes(); SmallVector results; results.reserve(tileSizes.size()); unsigned dynamicPos = 0; Builder builder(getContext()); for (int64_t size : tileSizes) { if (size == ShapedType::kDynamic) { results.push_back(dynamic[dynamicPos++]); } else { results.push_back(builder.getIndexAttr(size)); } } return results; } ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; DenseI64ArrayAttr staticSizes; auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parseDynamicIndexList(parser, dynamicSizes, staticSizes) || parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands)) return ParseResult::failure(); // Parse optional interchange. if (failed(parseOptionalInterchange(parser, result))) return ParseResult::failure(); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); size_t numExpectedLoops = staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); return success(); } void TileToScfForOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); printOptionalInterchange(p, getInterchange()); } void transform::TileToScfForOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); onlyReadsHandle(getDynamicSizes(), effects); producesHandle(getTiledLinalgOp(), effects); producesHandle(getLoops(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // VectorizeOp //===----------------------------------------------------------------------===// void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result, Value target, bool vectorizePadding, bool vectorizeExtract) { result.addOperands(target); if (vectorizePadding) { result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name), builder.getUnitAttr()); } if (vectorizeExtract) { result.addAttribute(VectorizeOp::getVectorizeNdExtractAttrName(result.name), builder.getUnitAttr()); } result.addTypes(pdl::OperationType::get(builder.getContext())); } namespace { /// This is an helper only to call vectorize via a pattern inside of /// VectorizeOp::applyToOne. struct VectorizationPattern : public RewritePattern { explicit VectorizationPattern(MLIRContext *context, bool vectorizeExtract = false) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), vectorizeNDExtract(vectorizeExtract) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return rewriter.notifyMatchFailure(op, "expected Linalg Op"); return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{}, vectorizeNDExtract); } private: /// Controls whether to vectorize `tensor.extract` when the input tensor is /// rank >= 2. bool vectorizeNDExtract = false; }; } // namespace DiagnosedSilenceableFailure transform::VectorizeOp::applyToOne(Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->hasTrait()) { auto diag = this->emitOpError("requires isolated-from-above targets"); diag.attachNote(target->getLoc()) << "non-isolated target"; return DiagnosedSilenceableFailure::definiteFailure(); } MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx, getVectorizeNdExtract()); if (!getDisableTransferPermutationMapLoweringPatterns()) vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); if (!getDisableMultiReductionToContractPatterns()) vector::populateVectorReductionToContractPatterns(patterns); patterns.add(ctx, /*benefit=*/2); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); patterns.add(ctx); if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) return emitDefaultDefiniteFailure(target); results.push_back(target); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MaskedVectorizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { IRRewriter rewriter(getContext()); ArrayRef targets = state.getPayloadOps(getTarget()); if (targets.empty()) return DiagnosedSilenceableFailure::success(); SmallVector vectorSizes; for (OpFoldResult sz : getMixedVectorSizes()) { if (sz.is()) { auto attr = sz.get(); vectorSizes.push_back(attr.cast().getInt()); continue; } ArrayRef szPayloads = state.getPayloadOps(sz.get()); if (szPayloads.size() != 1) { auto diag = this->emitOpError( "requires vector size handle that is mapped to 1 payload op"); diag.attachNote(sz.get().getLoc()) << "mapped to " << szPayloads.size() << " payload ops"; return DiagnosedSilenceableFailure::definiteFailure(); } Operation *szPayloadOp = szPayloads[0]; if (szPayloadOp->getNumResults() != 1 || !szPayloadOp->getResult(0).getType().isIndex()) { auto diag = this->emitOpError( "requires vector size payload op with 1 index result"); diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op"; return DiagnosedSilenceableFailure::definiteFailure(); } IntegerAttr attr; if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) { auto diag = this->emitOpError("requires constant vector size"); diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op"; return DiagnosedSilenceableFailure::definiteFailure(); } vectorSizes.push_back(attr.getInt()); } // TODO: Check that the correct number of vectorSizes was provided. for (Operation *target : targets) { auto linalgOp = dyn_cast(target); if (!linalgOp) { Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); diag << "cannot vectorize non-Linalg op"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes))) { Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); diag << "failed to vectorize op"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } } return DiagnosedSilenceableFailure::success(); } void transform::MaskedVectorizeOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); onlyReadsHandle(getVectorSizes(), effects); } SmallVector MaskedVectorizeOp::getMixedVectorSizes() { OpBuilder b(getContext()); return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b); } //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// namespace { /// Registers new ops and declares PDL as dependent dialect since the /// additional ops are using PDL types for operands and results. class LinalgTransformDialectExtension : public transform::TransformDialectExtension< LinalgTransformDialectExtension> { public: using Base::Base; void init() { declareDependentDialect(); declareDependentDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" >(); } }; } // namespace #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" void mlir::linalg::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); }