//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "transform-dialect" #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") using namespace mlir; //===----------------------------------------------------------------------===// // TransformState //===----------------------------------------------------------------------===// constexpr const Value transform::TransformState::kTopLevelValue; transform::TransformState::TransformState(Region *region, Operation *payloadRoot, const TransformOptions &options) : topLevel(payloadRoot), options(options) { auto result = mappings.try_emplace(region); assert(result.second && "the region scope is already present"); (void)result; #if LLVM_ENABLE_ABI_BREAKING_CHECKS regionStack.push_back(region); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } Operation *transform::TransformState::getTopLevel() const { return topLevel; } ArrayRef transform::TransformState::getPayloadOps(Value value) const { const TransformOpMapping &operationMapping = getMapping(value).direct; auto iter = operationMapping.find(value); assert(iter != operationMapping.end() && "cannot find mapping for payload handle (param handle provided?)"); return iter->getSecond(); } ArrayRef transform::TransformState::getParams(Value value) const { const ParamMapping &mapping = getMapping(value).params; auto iter = mapping.find(value); assert(iter != mapping.end() && "cannot find mapping for param handle (payload handle provided?)"); return iter->getSecond(); } LogicalResult transform::TransformState::getHandlesForPayloadOp( Operation *op, SmallVectorImpl &handles) const { bool found = false; for (const Mappings &mapping : llvm::make_second_range(mappings)) { auto iterator = mapping.reverse.find(op); if (iterator != mapping.reverse.end()) { llvm::append_range(handles, iterator->getSecond()); found = true; } } return success(found); } LogicalResult transform::TransformState::setPayloadOps(Value value, ArrayRef targets) { assert(value != kTopLevelValue && "attempting to reset the transformation root"); assert(!value.getType().isa() && "cannot associate payload ops with a value of parameter type"); for (Operation *target : targets) { if (target) continue; return emitError(value.getLoc()) << "attempting to assign a null payload op to this transform value"; } auto iface = value.getType().cast(); DiagnosedSilenceableFailure result = iface.checkPayload(value.getLoc(), targets); if (failed(result.checkAndReport())) return failure(); // Setting new payload for the value without cleaning it first is a misuse of // the API, assert here. SmallVector storedTargets(targets.begin(), targets.end()); Mappings &mappings = getMapping(value); bool inserted = mappings.direct.insert({value, std::move(storedTargets)}).second; assert(inserted && "value is already associated with another list"); (void)inserted; for (Operation *op : targets) mappings.reverse[op].push_back(value); return success(); } LogicalResult transform::TransformState::setParams(Value value, ArrayRef params) { assert(value != nullptr && "attempting to set params for a null value"); for (Attribute attr : params) { if (attr) continue; return emitError(value.getLoc()) << "attempting to assign a null parameter to this transform value"; } auto valueType = value.getType().dyn_cast(); assert(value && "cannot associate parameter with a value of non-parameter type"); DiagnosedSilenceableFailure result = valueType.checkPayload(value.getLoc(), params); if (failed(result.checkAndReport())) return failure(); Mappings &mappings = getMapping(value); bool inserted = mappings.params.insert({value, llvm::to_vector(params)}).second; assert(inserted && "value is already associated with another list of params"); (void)inserted; return success(); } void transform::TransformState::dropReverseMapping(Mappings &mappings, Operation *op, Value value) { auto it = mappings.reverse.find(op); if (it == mappings.reverse.end()) return; llvm::erase_value(it->getSecond(), value); if (it->getSecond().empty()) mappings.reverse.erase(it); } void transform::TransformState::removePayloadOps(Value value) { Mappings &mappings = getMapping(value); for (Operation *op : mappings.direct[value]) dropReverseMapping(mappings, op, value); mappings.direct.erase(value); } LogicalResult transform::TransformState::updatePayloadOps( Value value, function_ref callback) { Mappings &mappings = getMapping(value); auto it = mappings.direct.find(value); assert(it != mappings.direct.end() && "unknown handle"); SmallVector &association = it->getSecond(); SmallVector updated; updated.reserve(association.size()); for (Operation *op : association) { dropReverseMapping(mappings, op, value); if (Operation *updatedOp = callback(op)) { updated.push_back(updatedOp); mappings.reverse[updatedOp].push_back(value); } } auto iface = value.getType().cast(); DiagnosedSilenceableFailure result = iface.checkPayload(value.getLoc(), updated); if (failed(result.checkAndReport())) return failure(); it->second = updated; return success(); } void transform::TransformState::recordHandleInvalidationOne( OpOperand &handle, Operation *payloadOp, Value otherHandle) { ArrayRef potentialAncestors = getPayloadOps(handle.get()); // If the op is associated with invalidated handle, skip the check as it // may be reading invalid IR. if (invalidatedHandles.count(otherHandle)) return; for (Operation *ancestor : potentialAncestors) { if (!ancestor->isAncestor(payloadOp)) continue; // Make sure the error-reporting lambda doesn't capture anything // by-reference because it will go out of scope. Additionally, extract // location from Payload IR ops because the ops themselves may be // deleted before the lambda gets called. Location ancestorLoc = ancestor->getLoc(); Location opLoc = payloadOp->getLoc(); Operation *owner = handle.getOwner(); unsigned operandNo = handle.getOperandNumber(); invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, otherHandle](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) << "op uses a handle invalidated by a " "previously executed transform op"; diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops"; diag.attachNote(owner->getLoc()) << "invalidated by this transform op that consumes its operand #" << operandNo << " and invalidates handles to payload ops nested in payload " "ops associated with the consumed handle"; diag.attachNote(ancestorLoc) << "ancestor payload op"; diag.attachNote(opLoc) << "nested payload op"; }; } } void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { for (const Mappings &mapping : llvm::make_second_range(mappings)) for (const auto &[payloadOp, otherHandles] : mapping.reverse) for (Value otherHandle : otherHandles) recordHandleInvalidationOne(handle, payloadOp, otherHandle); } LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( TransformOpInterface transform) { auto memoryEffectsIface = cast(transform.getOperation()); SmallVector effects; memoryEffectsIface.getEffectsOnResource( transform::TransformMappingResource::get(), effects); for (OpOperand &target : transform->getOpOperands()) { // If the operand uses an invalidated handle, report it. auto it = invalidatedHandles.find(target.get()); if (!transform.allowsRepeatedHandleOperands() && it != invalidatedHandles.end()) return it->getSecond()(transform->getLoc()), failure(); // Invalidate handles pointing to the operations nested in the operation // associated with the handle consumed by this operation. auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) { return isa(effect.getEffect()) && effect.getValue() == target.get(); }; if (llvm::any_of(effects, consumesTarget)) recordHandleInvalidation(target); } return success(); } DiagnosedSilenceableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { LLVM_DEBUG(DBGS() << "applying: " << transform << "\n"); auto printOnFailureRAII = llvm::make_scope_exit([this] { (void)this; DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { DBGS() << "Top-level payload:\n"; getTopLevel()->print(llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()); }); }); if (options.getExpensiveChecksEnabled()) { if (failed(checkAndRecordHandleInvalidation(transform))) return DiagnosedSilenceableFailure::definiteFailure(); for (OpOperand &operand : transform->getOpOperands()) { if (!isHandleConsumed(operand.get(), transform)) continue; DenseSet seen; for (Operation *op : getPayloadOps(operand.get())) { if (!seen.insert(op).second) { DiagnosedSilenceableFailure diag = transform.emitSilenceableError() << "a handle passed as operand #" << operand.getOperandNumber() << " and consumed by this operation points to a payload " "operation more than once"; diag.attachNote(op->getLoc()) << "repeated target op"; return diag; } } } } transform::TransformResults results(transform->getNumResults()); // Compute the result but do not short-circuit the silenceable failure case as // we still want the handles to propagate properly so the "suppress" mode can // proceed on a best effort basis. DiagnosedSilenceableFailure result(transform.apply(results, *this)); if (result.isDefiniteFailure()) return result; // Remove the mapping for the operand if it is consumed by the operation. This // allows us to catch use-after-free with assertions later on. auto memEffectInterface = cast(transform.getOperation()); SmallVector effects; for (OpOperand &target : transform->getOpOperands()) { effects.clear(); memEffectInterface.getEffectsOnValue(target.get(), effects); if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { return isa( effect.getResource()) && isa(effect.getEffect()); })) { removePayloadOps(target.get()); } } for (OpResult result : transform->getResults()) { assert(result.getDefiningOp() == transform.getOperation() && "payload IR association for a value other than the result of the " "current transform op"); if (result.getType().isa()) { assert(results.isParam(result.getResultNumber()) && "expected parameters for the parameter-typed result"); if (failed( setParams(result, results.getParams(result.getResultNumber())))) { return DiagnosedSilenceableFailure::definiteFailure(); } } else { assert(!results.isParam(result.getResultNumber()) && "expected payload ops for the non-parameter typed result"); if (failed( setPayloadOps(result, results.get(result.getResultNumber())))) { return DiagnosedSilenceableFailure::definiteFailure(); } } } printOnFailureRAII.release(); DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { DBGS() << "Top-level payload:\n"; getTopLevel()->print(llvm::dbgs()); }); return result; } //===----------------------------------------------------------------------===// // TransformState::Extension //===----------------------------------------------------------------------===// transform::TransformState::Extension::~Extension() = default; LogicalResult transform::TransformState::Extension::replacePayloadOp(Operation *op, Operation *replacement) { SmallVector handles; if (failed(state.getHandlesForPayloadOp(op, handles))) return failure(); for (Value handle : handles) { LogicalResult result = state.updatePayloadOps(handle, [&](Operation *current) { return current == op ? replacement : current; }); if (failed(result)) return failure(); } return success(); } //===----------------------------------------------------------------------===// // TransformResults //===----------------------------------------------------------------------===// transform::TransformResults::TransformResults(unsigned numSegments) { segments.resize(numSegments, ArrayRef(nullptr, static_cast(0))); paramSegments.resize(numSegments, ArrayRef( nullptr, static_cast(0))); } void transform::TransformResults::set(OpResult value, ArrayRef ops) { int64_t position = value.getResultNumber(); assert(position < static_cast(segments.size()) && "setting results for a non-existent handle"); assert(segments[position].data() == nullptr && "results already set"); int64_t start = operations.size(); llvm::append_range(operations, ops); segments[position] = makeArrayRef(operations).drop_front(start); } void transform::TransformResults::setParams( OpResult value, ArrayRef params) { int64_t position = value.getResultNumber(); assert(position < static_cast(paramSegments.size()) && "setting params for a non-existent handle"); assert(paramSegments[position].data() == nullptr && "params already set"); size_t start = this->params.size(); llvm::append_range(this->params, params); paramSegments[position] = makeArrayRef(this->params).drop_front(start); } ArrayRef transform::TransformResults::get(unsigned resultNumber) const { assert(resultNumber < segments.size() && "querying results for a non-existent handle"); assert(segments[resultNumber].data() != nullptr && "querying unset results (param expected?)"); return segments[resultNumber]; } ArrayRef transform::TransformResults::getParams(unsigned resultNumber) const { assert(resultNumber < paramSegments.size() && "querying params for a non-existent handle"); assert(paramSegments[resultNumber].data() != nullptr && "querying unset params (payload ops expected?)"); return paramSegments[resultNumber]; } bool transform::TransformResults::isParam(unsigned resultNumber) const { assert(resultNumber < paramSegments.size() && "querying association for a non-existent handle"); return paramSegments[resultNumber].data() != nullptr; } //===----------------------------------------------------------------------===// // Utilities for TransformEachOpTrait. //===----------------------------------------------------------------------===// LogicalResult transform::detail::checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult) { Location transformOpLoc = transformOp->getLoc(); StringRef transformOpName = transformOp->getName().getStringRef(); unsigned expectedNumResults = transformOp->getNumResults(); // TODO: encode this implicit must always produce `expectedNumResults` // and nullptr is fine with a proper trait. if (partialResult.size() != expectedNumResults) { auto diag = mlir::emitError(transformOpLoc, "applications of ") << transformOpName << " expected to produce " << expectedNumResults << " results (actually produced " << partialResult.size() << ")."; diag.attachNote(transformOpLoc) << "If you need variadic results, consider a generic `apply` " << "instead of the specialized `applyToOne`."; diag.attachNote(transformOpLoc) << "Producing " << expectedNumResults << " null results is " << "allowed if the use case warrants it."; diag.attachNote(payloadOpLoc) << "when applied to this op"; return failure(); } // Check that all is null or none is null // TODO: relax this behavior and encode with a proper trait. if (llvm::any_of( partialResult, [](llvm::PointerUnion ptr) { return ptr; }) && llvm::any_of(partialResult, [](llvm::PointerUnion ptr) { return !ptr; })) { auto diag = mlir::emitError(transformOpLoc, "unexpected application of ") << transformOpName << " produces both null and non null results."; diag.attachNote(payloadOpLoc) << "when applied to this op"; return failure(); } // Check that the right kind of value was produced. for (const auto &[ptr, res] : llvm::zip(partialResult, transformOp->getResults())) { if (ptr.is() && !res.getType().template isa()) { mlir::emitError(transformOpLoc) << "applications of " << transformOpName << " expected to produce an Attribute for result #" << res.getResultNumber(); return failure(); } if (ptr.is() && !res.getType().template isa()) { mlir::emitError(transformOpLoc) << "applications of " << transformOpName << " expected to produce an Operation * for result #" << res.getResultNumber(); return failure(); } } return success(); } void transform::detail::setApplyToOneResults( Operation *transformOp, TransformResults &transformResults, ArrayRef results) { for (OpResult r : transformOp->getResults()) { if (r.getType().isa()) { auto params = llvm::to_vector( llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { return oneResult[r.getResultNumber()].get(); })); transformResults.setParams(r, params); } else { auto payloads = llvm::to_vector( llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { return oneResult[r.getResultNumber()].get(); })); transformResults.set(r, payloads); } } } //===----------------------------------------------------------------------===// // Utilities for PossibleTopLevelTransformOpTrait. //===----------------------------------------------------------------------===// LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( TransformState &state, Operation *op, Region ®ion) { SmallVector targets; if (op->getNumOperands() != 0) llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); else targets.push_back(state.getTopLevel()); return state.mapBlockArguments(region.front().getArgument(0), targets); } LogicalResult transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { // Attaching this trait without the interface is a misuse of the API, but it // cannot be caught via a static_assert because interface registration is // dynamic. assert(isa(op) && "should implement TransformOpInterface to have " "PossibleTopLevelTransformOpTrait"); if (op->getNumRegions() < 1) return op->emitOpError() << "expects at least one region"; Region *bodyRegion = &op->getRegion(0); if (!llvm::hasNItems(*bodyRegion, 1)) return op->emitOpError() << "expects a single-block region"; Block *body = &bodyRegion->front(); if (body->getNumArguments() != 1 || !body->getArgumentTypes()[0].isa()) { return op->emitOpError() << "expects the entry block to have one argument " "of type implementing TransformHandleTypeInterface"; } if (auto *parent = op->getParentWithTrait()) { if (op->getNumOperands() == 0) { InFlightDiagnostic diag = op->emitOpError() << "expects the root operation to be provided for a nested op"; diag.attachNote(parent->getLoc()) << "nested in another possible top-level op"; return diag; } } return success(); } //===----------------------------------------------------------------------===// // Utilities for ParamProducedTransformOpTrait. //===----------------------------------------------------------------------===// void transform::detail::getParamProducerTransformOpTraitEffects( Operation *op, SmallVectorImpl &effects) { producesHandle(op->getResults(), effects); bool hasPayloadOperands = false; for (Value operand : op->getOperands()) { onlyReadsHandle(operand, effects); if (operand.getType().isa()) hasPayloadOperands = true; } if (hasPayloadOperands) onlyReadsPayload(effects); } LogicalResult transform::detail::verifyParamProducerTransformOpTrait(Operation *op) { // Interfaces can be attached dynamically, so this cannot be a static // assert. if (!op->getName().getInterface()) { llvm::report_fatal_error( Twine("ParamProducerTransformOpTrait must be attached to an op that " "implements MemoryEffectsOpInterface, found on ") + op->getName().getStringRef()); } for (Value result : op->getResults()) { if (result.getType().isa()) continue; return op->emitOpError() << "ParamProducerTransformOpTrait attached to this op expects " "result types to implement TransformParamTypeInterface"; } return success(); } //===----------------------------------------------------------------------===// // Memory effects. //===----------------------------------------------------------------------===// void transform::consumesHandle( ValueRange handles, SmallVectorImpl &effects) { for (Value handle : handles) { effects.emplace_back(MemoryEffects::Read::get(), handle, TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Free::get(), handle, TransformMappingResource::get()); } } /// Returns `true` if the given list of effects instances contains an instance /// with the effect type specified as template parameter. template static bool hasEffect(ArrayRef effects) { return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { return isa(effect.getEffect()) && isa(effect.getResource()); }); } bool transform::isHandleConsumed(Value handle, transform::TransformOpInterface transform) { auto iface = cast(transform.getOperation()); SmallVector effects; iface.getEffectsOnValue(handle, effects); return ::hasEffect(effects) && ::hasEffect(effects); } void transform::producesHandle( ValueRange handles, SmallVectorImpl &effects) { for (Value handle : handles) { effects.emplace_back(MemoryEffects::Allocate::get(), handle, TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Write::get(), handle, TransformMappingResource::get()); } } void transform::onlyReadsHandle( ValueRange handles, SmallVectorImpl &effects) { for (Value handle : handles) { effects.emplace_back(MemoryEffects::Read::get(), handle, TransformMappingResource::get()); } } void transform::modifiesPayload( SmallVectorImpl &effects) { effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); } void transform::onlyReadsPayload( SmallVectorImpl &effects) { effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); } //===----------------------------------------------------------------------===// // Entry point. //===----------------------------------------------------------------------===// LogicalResult transform::applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const TransformOptions &options) { #ifndef NDEBUG if (!transform->hasTrait() || transform->getNumOperands() != 0) { transform->emitError() << "expected transform to start at the top-level transform op"; llvm::report_fatal_error("could not run transforms", /*gen_crash_diag=*/false); } #endif // NDEBUG TransformState state(transform->getParentRegion(), payloadRoot, options); return state.applyTransform(transform).checkAndReport(); } //===----------------------------------------------------------------------===// // Generated interface implementation. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"