//===- TransformDialect.cpp - Transform dialect operations ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Rewrite/PatternApplicator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "transform-dialect" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") using namespace mlir; #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" //===----------------------------------------------------------------------===// // PatternApplicatorExtension //===----------------------------------------------------------------------===// namespace { /// A TransformState extension that keeps track of compiled PDL pattern sets. /// This is intended to be used along the WithPDLPatterns op. The extension /// can be constructed given an operation that has a SymbolTable trait and /// contains pdl::PatternOp instances. The patterns are compiled lazily and one /// by one when requested; this behavior is subject to change. class PatternApplicatorExtension : public transform::TransformState::Extension { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) /// Creates the extension for patterns contained in `patternContainer`. explicit PatternApplicatorExtension(transform::TransformState &state, Operation *patternContainer) : Extension(state), patterns(patternContainer) {} /// Appends to `results` the operations contained in `root` that matched the /// PDL pattern with the given name. Note that `root` may or may not be the /// operation that contains PDL patterns. Reports an error if the pattern /// cannot be found. Note that when no operations are matched, this still /// succeeds as long as the pattern exists. LogicalResult findAllMatches(StringRef patternName, Operation *root, SmallVectorImpl &results); private: /// Map from the pattern name to a singleton set of rewrite patterns that only /// contains the pattern with this name. Populated when the pattern is first /// requested. // TODO: reconsider the efficiency of this storage when more usage data is // available. Storing individual patterns in a set and triggering compilation // for each of them has overhead. So does compiling a large set of patterns // only to apply a handlful of them. llvm::StringMap compiledPatterns; /// A symbol table operation containing the relevant PDL patterns. SymbolTable patterns; }; LogicalResult PatternApplicatorExtension::findAllMatches( StringRef patternName, Operation *root, SmallVectorImpl &results) { auto it = compiledPatterns.find(patternName); if (it == compiledPatterns.end()) { auto patternOp = patterns.lookup(patternName); if (!patternOp) return failure(); OwningOpRef pdlModuleOp = ModuleOp::create(patternOp.getLoc()); patternOp->moveBefore(pdlModuleOp->getBody(), pdlModuleOp->getBody()->end()); PDLPatternModule patternModule(std::move(pdlModuleOp)); // Merge in the hooks owned by the dialect. Make a copy as they may be // also used by the following operations. auto *dialect = root->getContext()->getLoadedDialect(); for (const auto &[name, constraintFn] : dialect->getPDLConstraintHooks()) patternModule.registerConstraintFunction(name, constraintFn); // Register a noop rewriter because PDL requires patterns to end with some // rewrite call. patternModule.registerRewriteFunction( "transform.dialect", [](PatternRewriter &, Operation *) {}); it = compiledPatterns .try_emplace(patternOp.getName(), std::move(patternModule)) .first; } PatternApplicator applicator(it->second); transform::TrivialPatternRewriter rewriter(root->getContext()); applicator.applyDefaultCostModel(); root->walk([&](Operation *op) { if (succeeded(applicator.matchAndRewrite(op, rewriter))) results.push_back(op); }); return success(); } } // namespace //===----------------------------------------------------------------------===// // AlternativesOp //===----------------------------------------------------------------------===// OperandRange transform::AlternativesOp::getSuccessorEntryOperands( std::optional index) { if (index && getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); } void transform::AlternativesOp::getSuccessorRegions( std::optional index, ArrayRef operands, SmallVectorImpl ®ions) { for (Region &alternative : llvm::drop_begin( getAlternatives(), index.has_value() ? *index + 1 : 0)) { regions.emplace_back(&alternative, !getOperands().empty() ? alternative.getArguments() : Block::BlockArgListType()); } if (index.has_value()) regions.emplace_back(getOperation()->getResults()); } void transform::AlternativesOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &bounds) { (void)operands; // The region corresponding to the first alternative is always executed, the // remaining may or may not be executed. bounds.reserve(getNumRegions()); bounds.emplace_back(1, 1); bounds.resize(getNumRegions(), InvocationBounds(0, 1)); } static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results) { for (const auto &res : block->getParentOp()->getOpResults()) results.set(res, {}); } static void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results) { for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(), block->getParentOp()->getOpResults())) { Value terminatorOperand = std::get<0>(pair); OpResult result = std::get<1>(pair); results.set(result, state.getPayloadOps(terminatorOperand)); } } DiagnosedSilenceableFailure transform::AlternativesOp::apply(transform::TransformResults &results, transform::TransformState &state) { SmallVector originals; if (Value scopeHandle = getScope()) llvm::append_range(originals, state.getPayloadOps(scopeHandle)); else originals.push_back(state.getTopLevel()); for (Operation *original : originals) { if (original->isAncestor(getOperation())) { auto diag = emitDefiniteFailure() << "scope must not contain the transforms being applied"; diag.attachNote(original->getLoc()) << "scope"; return diag; } if (!original->hasTrait()) { auto diag = emitDefiniteFailure() << "only isolated-from-above ops can be alternative scopes"; diag.attachNote(original->getLoc()) << "scope"; return diag; } } for (Region ® : getAlternatives()) { // Clone the scope operations and make the transforms in this alternative // region apply to them by virtue of mapping the block argument (the only // visible handle) to the cloned scope operations. This effectively prevents // the transformation from accessing any IR outside the scope. auto scope = state.make_region_scope(reg); auto clones = llvm::to_vector( llvm::map_range(originals, [](Operation *op) { return op->clone(); })); auto deleteClones = llvm::make_scope_exit([&] { for (Operation *clone : clones) clone->erase(); }); if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) return DiagnosedSilenceableFailure::definiteFailure(); bool failed = false; for (Operation &transform : reg.front().without_terminator()) { DiagnosedSilenceableFailure result = state.applyTransform(cast(transform)); if (result.isSilenceableFailure()) { LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() << "\n"); failed = true; break; } if (::mlir::failed(result.silence())) return DiagnosedSilenceableFailure::definiteFailure(); } // If all operations in the given alternative succeeded, no need to consider // the rest. Replace the original scoping operation with the clone on which // the transformations were performed. if (!failed) { // We will be using the clones, so cancel their scheduled deletion. deleteClones.release(); IRRewriter rewriter(getContext()); for (const auto &kvp : llvm::zip(originals, clones)) { Operation *original = std::get<0>(kvp); Operation *clone = std::get<1>(kvp); original->getBlock()->getOperations().insert(original->getIterator(), clone); rewriter.replaceOp(original, clone->getResults()); } forwardTerminatorOperands(®.front(), state, results); return DiagnosedSilenceableFailure::success(); } } return emitSilenceableError() << "all alternatives failed"; } LogicalResult transform::AlternativesOp::verify() { for (Region &alternative : getAlternatives()) { Block &block = alternative.front(); Operation *terminator = block.getTerminator(); if (terminator->getOperands().getTypes() != getResults().getTypes()) { InFlightDiagnostic diag = emitOpError() << "expects terminator operands to have the " "same type as results of the operation"; diag.attachNote(terminator->getLoc()) << "terminator"; return diag; } } return success(); } //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::CastOp::applyToOne(Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { results.push_back(target); return DiagnosedSilenceableFailure::success(); } void transform::CastOp::getEffects( SmallVectorImpl &effects) { onlyReadsPayload(effects); consumesHandle(getInput(), effects); producesHandle(getOutput(), effects); } bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { assert(inputs.size() == 1 && "expected one input"); assert(outputs.size() == 1 && "expected one output"); return llvm::all_of( std::initializer_list{inputs.front(), outputs.front()}, [](Type ty) { return ty .isa(); }); } //===----------------------------------------------------------------------===// // ForeachOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ForeachOp::apply(transform::TransformResults &results, transform::TransformState &state) { ArrayRef payloadOps = state.getPayloadOps(getTarget()); SmallVector> resultOps(getNumResults(), {}); for (Operation *op : payloadOps) { auto scope = state.make_region_scope(getBody()); if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) return DiagnosedSilenceableFailure::definiteFailure(); // Execute loop body. for (Operation &transform : getBody().front().without_terminator()) { DiagnosedSilenceableFailure result = state.applyTransform( cast(transform)); if (!result.succeeded()) return result; } // Append yielded payload ops to result list (if any). for (unsigned i = 0; i < getNumResults(); ++i) { ArrayRef yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i)); resultOps[i].append(yieldedOps.begin(), yieldedOps.end()); } } for (unsigned i = 0; i < getNumResults(); ++i) results.set(getResult(i).cast(), resultOps[i]); return DiagnosedSilenceableFailure::success(); } void transform::ForeachOp::getEffects( SmallVectorImpl &effects) { BlockArgument iterVar = getIterationVariable(); if (any_of(getBody().front().without_terminator(), [&](Operation &op) { return isHandleConsumed(iterVar, cast(&op)); })) { consumesHandle(getTarget(), effects); } else { onlyReadsHandle(getTarget(), effects); } for (Value result : getResults()) producesHandle(result, effects); } void transform::ForeachOp::getSuccessorRegions( std::optional index, ArrayRef operands, SmallVectorImpl ®ions) { Region *bodyRegion = &getBody(); if (!index) { regions.emplace_back(bodyRegion, bodyRegion->getArguments()); return; } // Branch back to the region or the parent. assert(*index == 0 && "unexpected region index"); regions.emplace_back(bodyRegion, bodyRegion->getArguments()); regions.emplace_back(); } OperandRange transform::ForeachOp::getSuccessorEntryOperands(std::optional index) { // The iteration variable op handle is mapped to a subset (one op to be // precise) of the payload ops of the ForeachOp operand. assert(index && *index == 0 && "unexpected region index"); return getOperation()->getOperands(); } transform::YieldOp transform::ForeachOp::getYieldOp() { return cast(getBody().front().getTerminator()); } LogicalResult transform::ForeachOp::verify() { auto yieldOp = getYieldOp(); if (getNumResults() != yieldOp.getNumOperands()) return emitOpError() << "expects the same number of results as the " "terminator has operands"; for (Value v : yieldOp.getOperands()) if (!v.getType().isa()) return yieldOp->emitOpError("expects operands to have types implementing " "TransformHandleTypeInterface"); return success(); } //===----------------------------------------------------------------------===// // GetClosestIsolatedParentOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { Operation *parent = target->getParentWithTrait(); if (!parent) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not find an isolated-from-above parent op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } parents.insert(parent); } results.set(getResult().cast(), parents.getArrayRef()); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // GetProducerOfOperand //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetProducerOfOperand::apply(transform::TransformResults &results, transform::TransformState &state) { int64_t operandNumber = getOperandNumber(); SmallVector producers; for (Operation *target : state.getPayloadOps(getTarget())) { Operation *producer = target->getNumOperands() <= operandNumber ? nullptr : target->getOperand(operandNumber).getDefiningOp(); if (!producer) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not find a producer for operand number: " << operandNumber << " of " << *target; diag.attachNote(target->getLoc()) << "target op"; results.set(getResult().cast(), SmallVector{}); return diag; } producers.push_back(producer); } results.set(getResult().cast(), producers); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MergeHandlesOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MergeHandlesOp::apply(transform::TransformResults &results, transform::TransformState &state) { SmallVector operations; for (Value operand : getHandles()) llvm::append_range(operations, state.getPayloadOps(operand)); if (!getDeduplicate()) { results.set(getResult().cast(), operations); return DiagnosedSilenceableFailure::success(); } SetVector uniqued(operations.begin(), operations.end()); results.set(getResult().cast(), uniqued.getArrayRef()); return DiagnosedSilenceableFailure::success(); } bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() { // Handles may be the same if deduplicating is enabled. return getDeduplicate(); } void transform::MergeHandlesOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getHandles(), effects); producesHandle(getResult(), effects); // There are no effects on the Payload IR as this is only a handle // manipulation. } OpFoldResult transform::MergeHandlesOp::fold(ArrayRef operands) { if (getDeduplicate() || getHandles().size() != 1) return {}; // If deduplication is not required and there is only one operand, it can be // used directly instead of merging. return getHandles().front(); } //===----------------------------------------------------------------------===// // SplitHandlesOp //===----------------------------------------------------------------------===// void transform::SplitHandlesOp::build(OpBuilder &builder, OperationState &result, Value target, int64_t numResultHandles) { result.addOperands(target); result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name), builder.getI64IntegerAttr(numResultHandles)); auto pdlOpType = pdl::OperationType::get(builder.getContext()); result.addTypes(SmallVector(numResultHandles, pdlOpType)); } DiagnosedSilenceableFailure transform::SplitHandlesOp::apply(transform::TransformResults &results, transform::TransformState &state) { int64_t numResultHandles = getHandle() ? state.getPayloadOps(getHandle()).size() : 0; int64_t expectedNumResultHandles = getNumResultHandles(); if (numResultHandles != expectedNumResultHandles) { // Failing case needs to propagate gracefully for both suppress and // propagate modes. for (int64_t idx = 0; idx < expectedNumResultHandles; ++idx) results.set(getResults()[idx].cast(), {}); // Empty input handle corner case: always propagates empty handles in both // suppress and propagate modes. if (numResultHandles == 0) return DiagnosedSilenceableFailure::success(); // If the input handle was not empty and the number of result handles does // not match, this is a legit silenceable error. return emitSilenceableError() << getHandle() << " expected to contain " << expectedNumResultHandles << " operation handles but it only contains " << numResultHandles << " handles"; } // Normal successful case. for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle()))) results.set(getResults()[en.index()].cast(), en.value()); return DiagnosedSilenceableFailure::success(); } void transform::SplitHandlesOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getHandle(), effects); producesHandle(getResults(), effects); // There are no effects on the Payload IR as this is only a handle // manipulation. } //===----------------------------------------------------------------------===// // PDLMatchOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::PDLMatchOp::apply(transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); assert(extension && "expected PatternApplicatorExtension to be attached by the parent op"); SmallVector targets; for (Operation *root : state.getPayloadOps(getRoot())) { if (failed(extension->findAllMatches( getPatternName().getLeafReference().getValue(), root, targets))) { emitDefiniteFailure() << "could not find pattern '" << getPatternName() << "'"; } } results.set(getResult().cast(), targets); return DiagnosedSilenceableFailure::success(); } void transform::PDLMatchOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getRoot(), effects); producesHandle(getMatched(), effects); onlyReadsPayload(effects); } //===----------------------------------------------------------------------===// // ReplicateOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ReplicateOp::apply(transform::TransformResults &results, transform::TransformState &state) { unsigned numRepetitions = state.getPayloadOps(getPattern()).size(); for (const auto &en : llvm::enumerate(getHandles())) { Value handle = en.value(); ArrayRef current = state.getPayloadOps(handle); SmallVector payload; payload.reserve(numRepetitions * current.size()); for (unsigned i = 0; i < numRepetitions; ++i) llvm::append_range(payload, current); results.set(getReplicated()[en.index()].cast(), payload); } return DiagnosedSilenceableFailure::success(); } void transform::ReplicateOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getPattern(), effects); consumesHandle(getHandles(), effects); producesHandle(getReplicated(), effects); } //===----------------------------------------------------------------------===// // SequenceOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::SequenceOp::apply(transform::TransformResults &results, transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); if (failed(mapBlockArguments(state))) return DiagnosedSilenceableFailure::definiteFailure(); // Apply the sequenced ops one by one. for (Operation &transform : getBodyBlock()->without_terminator()) { DiagnosedSilenceableFailure result = state.applyTransform(cast(transform)); if (result.isDefiniteFailure()) return result; if (result.isSilenceableFailure()) { if (getFailurePropagationMode() == FailurePropagationMode::Propagate) { // Propagate empty results in case of early exit. forwardEmptyOperands(getBodyBlock(), state, results); return result; } (void)result.silence(); } } // Forward the operation mapping for values yielded from the sequence to the // values produced by the sequence op. forwardTerminatorOperands(getBodyBlock(), state, results); return DiagnosedSilenceableFailure::success(); } /// Returns `true` if the given op operand may be consuming the handle value in /// the Transform IR. That is, if it may have a Free effect on it. static bool isValueUsePotentialConsumer(OpOperand &use) { // Conservatively assume the effect being present in absence of the interface. auto iface = dyn_cast(use.getOwner()); if (!iface) return true; return isHandleConsumed(use.get(), iface); } LogicalResult checkDoubleConsume(Value value, function_ref reportError) { OpOperand *potentialConsumer = nullptr; for (OpOperand &use : value.getUses()) { if (!isValueUsePotentialConsumer(use)) continue; if (!potentialConsumer) { potentialConsumer = &use; continue; } InFlightDiagnostic diag = reportError() << " has more than one potential consumer"; diag.attachNote(potentialConsumer->getOwner()->getLoc()) << "used here as operand #" << potentialConsumer->getOperandNumber(); diag.attachNote(use.getOwner()->getLoc()) << "used here as operand #" << use.getOperandNumber(); return diag; } return success(); } LogicalResult transform::SequenceOp::verify() { assert(getBodyBlock()->getNumArguments() == 1 && "the number of arguments must have been verified to be 1 by " "PossibleTopLevelTransformOpTrait"); BlockArgument arg = getBodyBlock()->getArgument(0); if (getRoot()) { if (arg.getType() != getRoot().getType()) { return emitOpError() << "expects the type of the block argument to match " "the type of the operand"; } } // Check if the block argument has more than one consuming use. if (failed(checkDoubleConsume( arg, [this]() { return (emitOpError() << "block argument #0"); }))) { return failure(); } // Check properties of the nested operations they cannot check themselves. for (Operation &child : *getBodyBlock()) { if (!isa(child) && &child != &getBodyBlock()->back()) { InFlightDiagnostic diag = emitOpError() << "expected children ops to implement TransformOpInterface"; diag.attachNote(child.getLoc()) << "op without interface"; return diag; } for (OpResult result : child.getResults()) { auto report = [&]() { return (child.emitError() << "result #" << result.getResultNumber()); }; if (failed(checkDoubleConsume(result, report))) return failure(); } } if (getBodyBlock()->getTerminator()->getOperandTypes() != getOperation()->getResultTypes()) { InFlightDiagnostic diag = emitOpError() << "expects the types of the terminator operands " "to match the types of the result"; diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; return diag; } return success(); } void transform::SequenceOp::getEffects( SmallVectorImpl &effects) { auto *mappingResource = TransformMappingResource::get(); effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource); for (Value result : getResults()) { effects.emplace_back(MemoryEffects::Allocate::get(), result, mappingResource); effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); } if (!getRoot()) { for (Operation &op : *getBodyBlock()) { auto iface = dyn_cast(&op); if (!iface) { // TODO: fill all possible effects; or require ops to actually implement // the memory effect interface always assert(false); } SmallVector nestedEffects; iface.getEffects(effects); } return; } // Carry over all effects on the argument of the entry block as those on the // operand, this is the same value just remapped. for (Operation &op : *getBodyBlock()) { auto iface = dyn_cast(&op); if (!iface) { // TODO: fill all possible effects; or require ops to actually implement // the memory effect interface always assert(false); } SmallVector nestedEffects; iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); for (const auto &effect : nestedEffects) effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource()); } } OperandRange transform::SequenceOp::getSuccessorEntryOperands( std::optional index) { assert(index && *index == 0 && "unexpected region index"); if (getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); } void transform::SequenceOp::getSuccessorRegions( std::optional index, ArrayRef operands, SmallVectorImpl ®ions) { if (!index) { Region *bodyRegion = &getBody(); regions.emplace_back(bodyRegion, !operands.empty() ? bodyRegion->getArguments() : Block::BlockArgListType()); return; } assert(*index == 0 && "unexpected region index"); regions.emplace_back(getOperation()->getResults()); } void transform::SequenceOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &bounds) { (void)operands; bounds.emplace_back(1, 1); } void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, TypeRange resultTypes, FailurePropagationMode failurePropagationMode, Value root, SequenceBodyBuilderFn bodyBuilder) { build(builder, state, resultTypes, failurePropagationMode, root); Region *region = state.regions.back().get(); Type bbArgType = root.getType(); OpBuilder::InsertionGuard guard(builder); Block *bodyBlock = builder.createBlock( region, region->begin(), TypeRange{bbArgType}, {state.location}); // Populate body. builder.setInsertionPointToStart(bodyBlock); bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); } void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, TypeRange resultTypes, FailurePropagationMode failurePropagationMode, Type bbArgType, SequenceBodyBuilderFn bodyBuilder) { build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value()); Region *region = state.regions.back().get(); OpBuilder::InsertionGuard guard(builder); Block *bodyBlock = builder.createBlock( region, region->begin(), TypeRange{bbArgType}, {state.location}); // Populate body. builder.setInsertionPointToStart(bodyBlock); bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); } //===----------------------------------------------------------------------===// // WithPDLPatternsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::WithPDLPatternsOp::apply(transform::TransformResults &results, transform::TransformState &state) { OwningOpRef pdlModuleOp = ModuleOp::create(getOperation()->getLoc()); TransformOpInterface transformOp = nullptr; for (Operation &nested : getBody().front()) { if (!isa(nested)) { transformOp = cast(nested); break; } } state.addExtension(getOperation()); auto guard = llvm::make_scope_exit( [&]() { state.removeExtension(); }); auto scope = state.make_region_scope(getBody()); if (failed(mapBlockArguments(state))) return DiagnosedSilenceableFailure::definiteFailure(); return state.applyTransform(transformOp); } LogicalResult transform::WithPDLPatternsOp::verify() { Block *body = getBodyBlock(); Operation *topLevelOp = nullptr; for (Operation &op : body->getOperations()) { if (isa(op)) continue; if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { if (topLevelOp) { InFlightDiagnostic diag = emitOpError() << "expects only one non-pattern op in its body"; diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; diag.attachNote(op.getLoc()) << "second non-pattern op"; return diag; } topLevelOp = &op; continue; } InFlightDiagnostic diag = emitOpError() << "expects only pattern and top-level transform ops in its body"; diag.attachNote(op.getLoc()) << "offending op"; return diag; } if (auto parent = getOperation()->getParentOfType()) { InFlightDiagnostic diag = emitOpError() << "cannot be nested"; diag.attachNote(parent.getLoc()) << "parent operation"; return diag; } return success(); } //===----------------------------------------------------------------------===// // PrintOp //===----------------------------------------------------------------------===// void transform::PrintOp::build(OpBuilder &builder, OperationState &result, StringRef name) { if (!name.empty()) { result.addAttribute(PrintOp::getNameAttrName(result.name), builder.getStrArrayAttr(name)); } } void transform::PrintOp::build(OpBuilder &builder, OperationState &result, Value target, StringRef name) { result.addOperands({target}); build(builder, result, name); } DiagnosedSilenceableFailure transform::PrintOp::apply(transform::TransformResults &results, transform::TransformState &state) { llvm::outs() << "[[[ IR printer: "; if (getName().has_value()) llvm::outs() << *getName() << " "; if (!getTarget()) { llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n"; return DiagnosedSilenceableFailure::success(); } llvm::outs() << "]]]\n"; ArrayRef targets = state.getPayloadOps(getTarget()); for (Operation *target : targets) llvm::outs() << *target << "\n"; return DiagnosedSilenceableFailure::success(); } void transform::PrintOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getTarget(), effects); onlyReadsPayload(effects); // There is no resource for stderr file descriptor, so just declare print // writes into the default resource. effects.emplace_back(MemoryEffects::Write::get()); }