Files
clang-p2996/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
MaheshRavishankar 205c5325b3 [mlir] Add a utility method to move operation dependencies. (#129975)
The added utility method moves all SSA values that an operation depends
upon before an insertion point. This is useful during transformations
where such movements might make transformations (like fusion) more
powerful.

To test the operation add a transform dialect op that calls the move
operation. To be able to capture the `notifyMatchFailure` messages from
the transformation and to report/check these in the test modify the
`ErrorCheckingTrackingListener` to capture the last match failure
notification.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
2025-03-10 20:23:08 -07:00

2046 lines
80 KiB
C++

//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "transform-dialect"
#define DEBUG_TYPE_FULL "transform-dialect-full"
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
#define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
using namespace mlir;
//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b) {
do {
if (a->isProperAncestor(b))
return false;
if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
return a->isBeforeInBlock(bAncestor);
}
} while ((a = a->getParentOp()));
return false;
}
//===----------------------------------------------------------------------===//
// TransformState
//===----------------------------------------------------------------------===//
constexpr const Value transform::TransformState::kTopLevelValue;
transform::TransformState::TransformState(
Region *region, Operation *payloadRoot,
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options)
: topLevel(payloadRoot), options(options) {
topLevelMappedValues.reserve(extraMappings.size());
for (ArrayRef<MappedValue> mapping : extraMappings)
topLevelMappedValues.push_back(mapping);
if (region) {
RegionScope *scope = new RegionScope(*this, *region);
topLevelRegionScope.reset(scope);
}
}
Operation *transform::TransformState::getTopLevel() const { return topLevel; }
ArrayRef<Operation *>
transform::TransformState::getPayloadOpsView(Value value) const {
const TransformOpMapping &operationMapping = getMapping(value).direct;
auto iter = operationMapping.find(value);
assert(iter != operationMapping.end() &&
"cannot find mapping for payload handle (param/value handle "
"provided?)");
return iter->getSecond();
}
ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
const ParamMapping &mapping = getMapping(value).params;
auto iter = mapping.find(value);
assert(iter != mapping.end() && "cannot find mapping for param handle "
"(operation/value handle provided?)");
return iter->getSecond();
}
ArrayRef<Value>
transform::TransformState::getPayloadValuesView(Value handleValue) const {
const ValueMapping &mapping = getMapping(handleValue).values;
auto iter = mapping.find(handleValue);
assert(iter != mapping.end() && "cannot find mapping for value handle "
"(param/operation handle provided?)");
return iter->getSecond();
}
LogicalResult transform::TransformState::getHandlesForPayloadOp(
Operation *op, SmallVectorImpl<Value> &handles,
bool includeOutOfScope) const {
bool found = false;
for (const auto &[region, mapping] : llvm::reverse(mappings)) {
auto iterator = mapping->reverse.find(op);
if (iterator != mapping->reverse.end()) {
llvm::append_range(handles, iterator->getSecond());
found = true;
}
// Stop looking when reaching a region that is isolated from above.
if (!includeOutOfScope &&
region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
break;
}
return success(found);
}
LogicalResult transform::TransformState::getHandlesForPayloadValue(
Value payloadValue, SmallVectorImpl<Value> &handles,
bool includeOutOfScope) const {
bool found = false;
for (const auto &[region, mapping] : llvm::reverse(mappings)) {
auto iterator = mapping->reverseValues.find(payloadValue);
if (iterator != mapping->reverseValues.end()) {
llvm::append_range(handles, iterator->getSecond());
found = true;
}
// Stop looking when reaching a region that is isolated from above.
if (!includeOutOfScope &&
region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
break;
}
return success(found);
}
/// Given a list of MappedValues, cast them to the value kind implied by the
/// interface of the handle type, and dispatch to one of the callbacks.
static DiagnosedSilenceableFailure dispatchMappedValues(
Value handle, ArrayRef<transform::MappedValue> values,
function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
function_ref<LogicalResult(ValueRange)> valuesFn) {
if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
SmallVector<Operation *> operations;
operations.reserve(values.size());
for (transform::MappedValue value : values) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
operations.push_back(op);
continue;
}
return emitSilenceableFailure(handle.getLoc())
<< "wrong kind of value provided for top-level operation handle";
}
if (failed(operationsFn(operations)))
return DiagnosedSilenceableFailure::definiteFailure();
return DiagnosedSilenceableFailure::success();
}
if (llvm::isa<transform::TransformValueHandleTypeInterface>(
handle.getType())) {
SmallVector<Value> payloadValues;
payloadValues.reserve(values.size());
for (transform::MappedValue value : values) {
if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
payloadValues.push_back(v);
continue;
}
return emitSilenceableFailure(handle.getLoc())
<< "wrong kind of value provided for the top-level value handle";
}
if (failed(valuesFn(payloadValues)))
return DiagnosedSilenceableFailure::definiteFailure();
return DiagnosedSilenceableFailure::success();
}
assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
"unsupported kind of block argument");
SmallVector<transform::Param> parameters;
parameters.reserve(values.size());
for (transform::MappedValue value : values) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
parameters.push_back(attr);
continue;
}
return emitSilenceableFailure(handle.getLoc())
<< "wrong kind of value provided for top-level parameter";
}
if (failed(paramsFn(parameters)))
return DiagnosedSilenceableFailure::definiteFailure();
return DiagnosedSilenceableFailure::success();
}
LogicalResult
transform::TransformState::mapBlockArgument(BlockArgument argument,
ArrayRef<MappedValue> values) {
return dispatchMappedValues(
argument, values,
[&](ArrayRef<Operation *> operations) {
return setPayloadOps(argument, operations);
},
[&](ArrayRef<Param> params) {
return setParams(argument, params);
},
[&](ValueRange payloadValues) {
return setPayloadValues(argument, payloadValues);
})
.checkAndReport();
}
LogicalResult transform::TransformState::mapBlockArguments(
Block::BlockArgListType arguments,
ArrayRef<SmallVector<MappedValue>> mapping) {
for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
if (failed(mapBlockArgument(argument, values)))
return failure();
return success();
}
LogicalResult
transform::TransformState::setPayloadOps(Value value,
ArrayRef<Operation *> targets) {
assert(value != kTopLevelValue &&
"attempting to reset the transformation root");
assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
"wrong handle type");
for (Operation *target : targets) {
if (target)
continue;
return emitError(value.getLoc())
<< "attempting to assign a null payload op to this transform value";
}
auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
DiagnosedSilenceableFailure result =
iface.checkPayload(value.getLoc(), targets);
if (failed(result.checkAndReport()))
return failure();
// Setting new payload for the value without cleaning it first is a misuse of
// the API, assert here.
SmallVector<Operation *> storedTargets(targets);
Mappings &mappings = getMapping(value);
bool inserted =
mappings.direct.insert({value, std::move(storedTargets)}).second;
assert(inserted && "value is already associated with another list");
(void)inserted;
for (Operation *op : targets)
mappings.reverse[op].push_back(value);
return success();
}
LogicalResult
transform::TransformState::setPayloadValues(Value handle,
ValueRange payloadValues) {
assert(handle != nullptr && "attempting to set params for a null value");
assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
"wrong handle type");
for (Value payload : payloadValues) {
if (payload)
continue;
return emitError(handle.getLoc()) << "attempting to assign a null payload "
"value to this transform handle";
}
auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
DiagnosedSilenceableFailure result =
iface.checkPayload(handle.getLoc(), payloadValueVector);
if (failed(result.checkAndReport()))
return failure();
Mappings &mappings = getMapping(handle);
bool inserted =
mappings.values.insert({handle, std::move(payloadValueVector)}).second;
assert(
inserted &&
"value handle is already associated with another list of payload values");
(void)inserted;
for (Value payload : payloadValues)
mappings.reverseValues[payload].push_back(handle);
return success();
}
LogicalResult transform::TransformState::setParams(Value value,
ArrayRef<Param> params) {
assert(value != nullptr && "attempting to set params for a null value");
for (Attribute attr : params) {
if (attr)
continue;
return emitError(value.getLoc())
<< "attempting to assign a null parameter to this transform value";
}
auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
assert(value &&
"cannot associate parameter with a value of non-parameter type");
DiagnosedSilenceableFailure result =
valueType.checkPayload(value.getLoc(), params);
if (failed(result.checkAndReport()))
return failure();
Mappings &mappings = getMapping(value);
bool inserted =
mappings.params.insert({value, llvm::to_vector(params)}).second;
assert(inserted && "value is already associated with another list of params");
(void)inserted;
return success();
}
template <typename Mapping, typename Key, typename Mapped>
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
auto it = mapping.find(key);
if (it == mapping.end())
return;
llvm::erase(it->getSecond(), mapped);
if (it->getSecond().empty())
mapping.erase(it);
}
void transform::TransformState::forgetMapping(Value opHandle,
ValueRange origOpFlatResults,
bool allowOutOfScope) {
Mappings &mappings = getMapping(opHandle, allowOutOfScope);
for (Operation *op : mappings.direct[opHandle])
dropMappingEntry(mappings.reverse, op, opHandle);
mappings.direct.erase(opHandle);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(opHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
for (Value opResult : origOpFlatResults) {
SmallVector<Value> resultHandles;
(void)getHandlesForPayloadValue(opResult, resultHandles);
for (Value resultHandle : resultHandles) {
Mappings &localMappings = getMapping(resultHandle);
dropMappingEntry(localMappings.values, resultHandle, opResult);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(resultHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
}
}
}
void transform::TransformState::forgetValueMapping(
Value valueHandle, ArrayRef<Operation *> payloadOperations) {
Mappings &mappings = getMapping(valueHandle);
for (Value payloadValue : mappings.reverseValues[valueHandle])
dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
mappings.values.erase(valueHandle);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(valueHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
for (Operation *payloadOp : payloadOperations) {
SmallVector<Value> opHandles;
(void)getHandlesForPayloadOp(payloadOp, opHandles);
for (Value opHandle : opHandles) {
Mappings &localMappings = getMapping(opHandle);
dropMappingEntry(localMappings.direct, opHandle, payloadOp);
dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
localMappings.incrementTimestamp(opHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
}
}
LogicalResult
transform::TransformState::replacePayloadOp(Operation *op,
Operation *replacement) {
// TODO: consider invalidating the handles to nested objects here.
#ifndef NDEBUG
for (Value opResult : op->getResults()) {
SmallVector<Value> valueHandles;
(void)getHandlesForPayloadValue(opResult, valueHandles,
/*includeOutOfScope=*/true);
assert(valueHandles.empty() && "expected no mapping to old results");
}
#endif // NDEBUG
// Drop the mapping between the op and all handles that point to it. Fail if
// there are no handles.
SmallVector<Value> opHandles;
if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
return failure();
for (Value handle : opHandles) {
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
dropMappingEntry(mappings.reverse, op, handle);
}
// Replace the pointed-to object of all handles with the replacement object.
// In case a payload op was erased (replacement object is nullptr), a nullptr
// is stored in the mapping. These nullptrs are removed after each transform.
// Furthermore, nullptrs are not enumerated by payload op iterators. The
// relative order of ops is preserved.
//
// Removing an op from the mapping would be problematic because removing an
// element from an array invalidates iterators; merely changing the value of
// elements does not.
for (Value handle : opHandles) {
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
auto it = mappings.direct.find(handle);
if (it == mappings.direct.end())
continue;
SmallVector<Operation *, 2> &association = it->getSecond();
// Note that an operation may be associated with the handle more than once.
for (Operation *&mapped : association) {
if (mapped == op)
mapped = replacement;
}
if (replacement) {
mappings.reverse[replacement].push_back(handle);
} else {
opHandlesToCompact.insert(handle);
}
}
return success();
}
LogicalResult
transform::TransformState::replacePayloadValue(Value value, Value replacement) {
SmallVector<Value> valueHandles;
if (failed(getHandlesForPayloadValue(value, valueHandles,
/*includeOutOfScope=*/true)))
return failure();
for (Value handle : valueHandles) {
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
dropMappingEntry(mappings.reverseValues, value, handle);
// If replacing with null, that is erasing the mapping, drop the mapping
// between the handles and the IR objects
if (!replacement) {
dropMappingEntry(mappings.values, handle, value);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(handle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
} else {
auto it = mappings.values.find(handle);
if (it == mappings.values.end())
continue;
SmallVector<Value> &association = it->getSecond();
for (Value &mapped : association) {
if (mapped == value)
mapped = replacement;
}
mappings.reverseValues[replacement].push_back(handle);
}
}
return success();
}
void transform::TransformState::recordOpHandleInvalidationOne(
OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
Operation *payloadOp, Value otherHandle, Value throughValue,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
// If the op is associated with invalidated handle, skip the check as it
// may be reading invalid IR. This also ensures we report the first
// invalidation and not the last one.
if (invalidatedHandles.count(otherHandle) ||
newlyInvalidated.count(otherHandle))
return;
FULL_LDBG("--recordOpHandleInvalidationOne\n");
DEBUG_WITH_TYPE(
DEBUG_TYPE_FULL,
llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
[](Operation *op) { llvm::dbgs() << *op; });
llvm::dbgs() << "\n");
Operation *owner = consumingHandle.getOwner();
unsigned operandNo = consumingHandle.getOperandNumber();
for (Operation *ancestor : potentialAncestors) {
// clang-format off
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
{ (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
{ (DBGS() << "----of payload with name: "
<< payloadOp->getName().getIdentifier() << "\n"); });
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
{ (DBGS() << "----of payload: " << *payloadOp << "\n"); });
// clang-format on
if (!ancestor->isAncestor(payloadOp))
continue;
// Make sure the error-reporting lambda doesn't capture anything
// by-reference because it will go out of scope. Additionally, extract
// location from Payload IR ops because the ops themselves may be
// deleted before the lambda gets called.
Location ancestorLoc = ancestor->getLoc();
Location opLoc = payloadOp->getLoc();
std::optional<Location> throughValueLoc =
throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
otherHandle,
throughValueLoc](Location currentLoc) {
InFlightDiagnostic diag = emitError(currentLoc)
<< "op uses a handle invalidated by a "
"previously executed transform op";
diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
diag.attachNote(owner->getLoc())
<< "invalidated by this transform op that consumes its operand #"
<< operandNo
<< " and invalidates all handles to payload IR entities associated "
"with this operand and entities nested in them";
diag.attachNote(ancestorLoc) << "ancestor payload op";
diag.attachNote(opLoc) << "nested payload op";
if (throughValueLoc) {
diag.attachNote(*throughValueLoc)
<< "consumed handle points to this payload value";
}
};
}
}
void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
Value payloadValue, Value valueHandle,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
// If the op is associated with invalidated handle, skip the check as it
// may be reading invalid IR. This also ensures we report the first
// invalidation and not the last one.
if (invalidatedHandles.count(valueHandle) ||
newlyInvalidated.count(valueHandle))
return;
for (Operation *ancestor : potentialAncestors) {
Operation *definingOp;
std::optional<unsigned> resultNo;
unsigned argumentNo = std::numeric_limits<unsigned>::max();
unsigned blockNo = std::numeric_limits<unsigned>::max();
unsigned regionNo = std::numeric_limits<unsigned>::max();
if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
definingOp = opResult.getOwner();
resultNo = opResult.getResultNumber();
} else {
auto arg = llvm::cast<BlockArgument>(payloadValue);
definingOp = arg.getParentBlock()->getParentOp();
argumentNo = arg.getArgNumber();
blockNo = std::distance(arg.getOwner()->getParent()->begin(),
arg.getOwner()->getIterator());
regionNo = arg.getOwner()->getParent()->getRegionNumber();
}
assert(definingOp && "expected the value to be defined by an op as result "
"or block argument");
if (!ancestor->isAncestor(definingOp))
continue;
Operation *owner = opHandle.getOwner();
unsigned operandNo = opHandle.getOperandNumber();
Location ancestorLoc = ancestor->getLoc();
Location opLoc = definingOp->getLoc();
Location valueLoc = payloadValue.getLoc();
newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
argumentNo, blockNo, regionNo, ancestorLoc,
opLoc, valueLoc](Location currentLoc) {
InFlightDiagnostic diag = emitError(currentLoc)
<< "op uses a handle invalidated by a "
"previously executed transform op";
diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
diag.attachNote(owner->getLoc())
<< "invalidated by this transform op that consumes its operand #"
<< operandNo
<< " and invalidates all handles to payload IR entities "
"associated with this operand and entities nested in them";
diag.attachNote(ancestorLoc)
<< "ancestor op associated with the consumed handle";
if (resultNo) {
diag.attachNote(opLoc)
<< "op defining the value as result #" << *resultNo;
} else {
diag.attachNote(opLoc)
<< "op defining the value as block argument #" << argumentNo
<< " of block #" << blockNo << " in region #" << regionNo;
}
diag.attachNote(valueLoc) << "payload value";
};
}
}
void transform::TransformState::recordOpHandleInvalidation(
OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
Value throughValue,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
if (potentialAncestors.empty()) {
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
(DBGS() << "----recording invalidation for empty handle: " << handle.get()
<< "\n");
});
Operation *owner = handle.getOwner();
unsigned operandNo = handle.getOperandNumber();
newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
InFlightDiagnostic diag = emitError(currentLoc)
<< "op uses a handle associated with empty "
"payload and invalidated by a "
"previously executed transform op";
diag.attachNote(owner->getLoc())
<< "invalidated by this transform op that consumes its operand #"
<< operandNo;
};
return;
}
// Iterate over the mapping and invalidate aliasing handles. This is quite
// expensive and only necessary for error reporting in case of transform
// dialect misuse with dangling handles. Iteration over the handles is based
// on the assumption that the number of handles is significantly less than the
// number of IR objects (operations and values). Alternatively, we could walk
// the IR nested in each payload op associated with the given handle and look
// for handles associated with each operation and value.
for (const auto &[region, mapping] : llvm::reverse(mappings)) {
// Go over all op handle mappings and mark as invalidated any handle
// pointing to any of the payload ops associated with the given handle or
// any op nested in them.
for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
for (Value otherHandle : otherHandles)
recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
otherHandle, throughValue,
newlyInvalidated);
}
// Go over all value handle mappings and mark as invalidated any handle
// pointing to any result of the payload op associated with the given handle
// or any op nested in them. Similarly invalidate handles to argument of
// blocks belonging to any region of any payload op associated with the
// given handle or any op nested in them.
for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
for (Value valueHandle : valueHandles)
recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
payloadValue, valueHandle,
newlyInvalidated);
}
// Stop lookup when reaching a region that is isolated from above.
if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
break;
}
}
void transform::TransformState::recordValueHandleInvalidation(
OpOperand &valueHandle,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
// Invalidate other handles to the same value.
for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
SmallVector<Value> otherValueHandles;
(void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
for (Value otherHandle : otherValueHandles) {
Operation *owner = valueHandle.getOwner();
unsigned operandNo = valueHandle.getOperandNumber();
Location valueLoc = payloadValue.getLoc();
newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
valueLoc](Location currentLoc) {
InFlightDiagnostic diag = emitError(currentLoc)
<< "op uses a handle invalidated by a "
"previously executed transform op";
diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
diag.attachNote(owner->getLoc())
<< "invalidated by this transform op that consumes its operand #"
<< operandNo
<< " and invalidates handles to the same values as associated with "
"it";
diag.attachNote(valueLoc) << "payload value";
};
}
if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
Operation *payloadOp = opResult.getOwner();
recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
newlyInvalidated);
} else {
auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
for (Operation &payloadOp : *arg.getOwner())
recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
newlyInvalidated);
}
}
}
/// Checks that the operation does not use invalidated handles as operands.
/// Reports errors and returns failure if it does. Otherwise, invalidates the
/// handles consumed by the operation as well as any handles pointing to payload
/// IR operations nested in the operations associated with the consumed handles.
LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
transform::TransformOpInterface transform,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
auto memoryEffectsIface =
cast<MemoryEffectOpInterface>(transform.getOperation());
SmallVector<MemoryEffects::EffectInstance> effects;
memoryEffectsIface.getEffectsOnResource(
transform::TransformMappingResource::get(), effects);
for (OpOperand &target : transform->getOpOperands()) {
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
(DBGS() << "----iterate on handle: " << target.get() << "\n");
});
// If the operand uses an invalidated handle, report it. If the operation
// allows handles to point to repeated payload operations, only report
// pre-existing invalidation errors. Otherwise, also report invalidations
// caused by the current transform operation affecting its other operands.
auto it = invalidatedHandles.find(target.get());
auto nit = newlyInvalidated.find(target.get());
if (it != invalidatedHandles.end()) {
FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
"invalidated -> FAILURE\n");
return it->getSecond()(transform->getLoc()), failure();
}
if (!transform.allowsRepeatedHandleOperands() &&
nit != newlyInvalidated.end()) {
FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
"invalidated (by this op) -> FAILURE\n");
return nit->getSecond()(transform->getLoc()), failure();
}
// Invalidate handles pointing to the operations nested in the operation
// associated with the handle consumed by this operation.
auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
return isa<MemoryEffects::Free>(effect.getEffect()) &&
effect.getValue() == target.get();
};
if (llvm::any_of(effects, consumesTarget)) {
FULL_LDBG("----found consume effect\n");
if (llvm::isa<transform::TransformHandleTypeInterface>(
target.get().getType())) {
FULL_LDBG("----recordOpHandleInvalidation\n");
SmallVector<Operation *> payloadOps =
llvm::to_vector(getPayloadOps(target.get()));
recordOpHandleInvalidation(target, payloadOps, nullptr,
newlyInvalidated);
} else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
target.get().getType())) {
FULL_LDBG("----recordValueHandleInvalidation\n");
recordValueHandleInvalidation(target, newlyInvalidated);
} else {
FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
}
} else {
FULL_LDBG("----no consume effect -> SKIP\n");
}
}
FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
return success();
}
LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
transform::TransformOpInterface transform) {
InvalidatedHandleMap newlyInvalidated;
LogicalResult checkResult =
checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
std::make_move_iterator(newlyInvalidated.end()));
return checkResult;
}
template <typename T>
DiagnosedSilenceableFailure
checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
transform::TransformOpInterface transform,
unsigned operandNumber) {
DenseSet<T> seen;
for (T p : payload) {
if (!seen.insert(p).second) {
DiagnosedSilenceableFailure diag =
transform.emitSilenceableError()
<< "a handle passed as operand #" << operandNumber
<< " and consumed by this operation points to a payload "
"entity more than once";
if constexpr (std::is_pointer_v<T>)
diag.attachNote(p->getLoc()) << "repeated target op";
else
diag.attachNote(p.getLoc()) << "repeated target value";
return diag;
}
}
return DiagnosedSilenceableFailure::success();
}
void transform::TransformState::compactOpHandles() {
for (Value handle : opHandlesToCompact) {
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
if (llvm::find(mappings.direct[handle], nullptr) !=
mappings.direct[handle].end())
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(handle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
llvm::erase(mappings.direct[handle], nullptr);
}
opHandlesToCompact.clear();
}
DiagnosedSilenceableFailure
transform::TransformState::applyTransform(TransformOpInterface transform) {
LLVM_DEBUG({
DBGS() << "applying: ";
transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
llvm::dbgs() << "\n";
});
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
DBGS() << "Top-level payload before application:\n"
<< *getTopLevel() << "\n");
auto printOnFailureRAII = llvm::make_scope_exit([this] {
(void)this;
LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
});
// Set current transform op.
regionStack.back()->currentTransform = transform;
// Expensive checks to detect invalid transform IR.
if (options.getExpensiveChecksEnabled()) {
FULL_LDBG("ExpensiveChecksEnabled\n");
if (failed(checkAndRecordHandleInvalidation(transform)))
return DiagnosedSilenceableFailure::definiteFailure();
for (OpOperand &operand : transform->getOpOperands()) {
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
(DBGS() << "iterate on handle: " << operand.get() << "\n");
});
if (!isHandleConsumed(operand.get(), transform)) {
FULL_LDBG("--handle not consumed -> SKIP\n");
continue;
}
if (transform.allowsRepeatedHandleOperands()) {
FULL_LDBG("--op allows repeated handles -> SKIP\n");
continue;
}
FULL_LDBG("--handle is consumed\n");
Type operandType = operand.get().getType();
if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
DiagnosedSilenceableFailure check =
checkRepeatedConsumptionInOperand<Operation *>(
getPayloadOpsView(operand.get()), transform,
operand.getOperandNumber());
if (!check.succeeded()) {
FULL_LDBG("----FAILED\n");
return check;
}
} else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
DiagnosedSilenceableFailure check =
checkRepeatedConsumptionInOperand<Value>(
getPayloadValuesView(operand.get()), transform,
operand.getOperandNumber());
if (!check.succeeded()) {
FULL_LDBG("----FAILED\n");
return check;
}
} else {
FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
}
}
}
// Find which operands are consumed.
SmallVector<OpOperand *> consumedOperands =
transform.getConsumedHandleOpOperands();
// Remember the results of the payload ops associated with the consumed
// op handles or the ops defining the value handles so we can drop the
// association with them later. This must happen here because the
// transformation may destroy or mutate them so we cannot traverse the payload
// IR after that.
SmallVector<Value> origOpFlatResults;
SmallVector<Operation *> origAssociatedOps;
for (OpOperand *opOperand : consumedOperands) {
Value operand = opOperand->get();
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
for (Operation *payloadOp : getPayloadOps(operand)) {
llvm::append_range(origOpFlatResults, payloadOp->getResults());
}
continue;
}
if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
for (Value payloadValue : getPayloadValuesView(operand)) {
if (llvm::isa<OpResult>(payloadValue)) {
origAssociatedOps.push_back(payloadValue.getDefiningOp());
continue;
}
llvm::append_range(
origAssociatedOps,
llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
[](Operation &op) { return &op; }));
}
continue;
}
DiagnosedDefiniteFailure diag =
emitDefiniteFailure(transform->getLoc())
<< "unexpectedly consumed a value that is not a handle as operand #"
<< opOperand->getOperandNumber();
diag.attachNote(operand.getLoc())
<< "value defined here with type " << operand.getType();
return diag;
}
// Prepare rewriter and listener.
TrackingListenerConfig config;
config.skipHandleFn = [&](Value handle) {
// Skip handle if it is dead.
auto scopeIt =
llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
return handle.getParentRegion() == scope->region;
});
assert(scopeIt != regionStack.rend() &&
"could not find region scope for handle");
RegionScope *scope = *scopeIt;
return llvm::all_of(handle.getUsers(), [&](Operation *user) {
return user == scope->currentTransform ||
happensBefore(user, scope->currentTransform);
});
};
transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
config);
transform::TransformRewriter rewriter(transform->getContext(),
&trackingListener);
// Compute the result but do not short-circuit the silenceable failure case as
// we still want the handles to propagate properly so the "suppress" mode can
// proceed on a best effort basis.
transform::TransformResults results(transform->getNumResults());
DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
compactOpHandles();
// Error handling: fail if transform or listener failed.
DiagnosedSilenceableFailure trackingFailure =
trackingListener.checkAndResetError();
if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
transform->hasAttr(FindPayloadReplacementOpInterface::
kSilenceTrackingFailuresAttrName)) {
// Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
// do not report failures if the above mentioned attribute is set.
if (trackingFailure.isSilenceableFailure())
(void)trackingFailure.silence();
trackingFailure = DiagnosedSilenceableFailure::success();
}
if (!trackingFailure.succeeded()) {
if (result.succeeded()) {
result = std::move(trackingFailure);
} else {
// Transform op errors have precedence, report those first.
if (result.isSilenceableFailure())
result.attachNote() << "tracking listener also failed: "
<< trackingFailure.getMessage();
(void)trackingFailure.silence();
}
}
if (result.isDefiniteFailure())
return result;
// If a silenceable failure was produced, some results may be unset, set them
// to empty lists.
if (result.isSilenceableFailure())
results.setRemainingToEmpty(transform);
// Remove the mapping for the operand if it is consumed by the operation. This
// allows us to catch use-after-free with assertions later on.
for (OpOperand *opOperand : consumedOperands) {
Value operand = opOperand->get();
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
forgetMapping(operand, origOpFlatResults);
} else if (llvm::isa<TransformValueHandleTypeInterface>(
operand.getType())) {
forgetValueMapping(operand, origAssociatedOps);
}
}
if (failed(updateStateFromResults(results, transform->getResults())))
return DiagnosedSilenceableFailure::definiteFailure();
printOnFailureRAII.release();
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
DBGS() << "Top-level payload:\n";
getTopLevel()->print(llvm::dbgs());
});
return result;
}
LogicalResult transform::TransformState::updateStateFromResults(
const TransformResults &results, ResultRange opResults) {
for (OpResult result : opResults) {
if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
assert(results.isParam(result.getResultNumber()) &&
"expected parameters for the parameter-typed result");
if (failed(
setParams(result, results.getParams(result.getResultNumber())))) {
return failure();
}
} else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
assert(results.isValue(result.getResultNumber()) &&
"expected values for value-type-result");
if (failed(setPayloadValues(
result, results.getValues(result.getResultNumber())))) {
return failure();
}
} else {
assert(!results.isParam(result.getResultNumber()) &&
"expected payload ops for the non-parameter typed result");
if (failed(
setPayloadOps(result, results.get(result.getResultNumber())))) {
return failure();
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// TransformState::Extension
//===----------------------------------------------------------------------===//
transform::TransformState::Extension::~Extension() = default;
LogicalResult
transform::TransformState::Extension::replacePayloadOp(Operation *op,
Operation *replacement) {
// TODO: we may need to invalidate handles to operations and values nested in
// the operation being replaced.
return state.replacePayloadOp(op, replacement);
}
LogicalResult
transform::TransformState::Extension::replacePayloadValue(Value value,
Value replacement) {
return state.replacePayloadValue(value, replacement);
}
//===----------------------------------------------------------------------===//
// TransformState::RegionScope
//===----------------------------------------------------------------------===//
transform::TransformState::RegionScope::~RegionScope() {
// Remove handle invalidation notices as handles are going out of scope.
// The same region may be re-entered leading to incorrect invalidation
// errors.
for (Block &block : *region) {
for (Value handle : block.getArguments()) {
state.invalidatedHandles.erase(handle);
}
for (Operation &op : block) {
for (Value handle : op.getResults()) {
state.invalidatedHandles.erase(handle);
}
}
}
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
// Remember pointers to payload ops referenced by the handles going out of
// scope.
SmallVector<Operation *> referencedOps =
llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
state.mappings.erase(region);
state.regionStack.pop_back();
}
//===----------------------------------------------------------------------===//
// TransformResults
//===----------------------------------------------------------------------===//
transform::TransformResults::TransformResults(unsigned numSegments) {
operations.appendEmptyRows(numSegments);
params.appendEmptyRows(numSegments);
values.appendEmptyRows(numSegments);
}
void transform::TransformResults::setParams(
OpResult value, ArrayRef<transform::TransformState::Param> params) {
int64_t position = value.getResultNumber();
assert(position < static_cast<int64_t>(this->params.size()) &&
"setting params for a non-existent handle");
assert(this->params[position].data() == nullptr && "params already set");
assert(operations[position].data() == nullptr &&
"another kind of results already set");
assert(values[position].data() == nullptr &&
"another kind of results already set");
this->params.replace(position, params);
}
void transform::TransformResults::setMappedValues(
OpResult handle, ArrayRef<MappedValue> values) {
DiagnosedSilenceableFailure diag = dispatchMappedValues(
handle, values,
[&](ArrayRef<Operation *> operations) {
return set(handle, operations), success();
},
[&](ArrayRef<Param> params) {
return setParams(handle, params), success();
},
[&](ValueRange payloadValues) {
return setValues(handle, payloadValues), success();
});
#ifndef NDEBUG
if (!diag.succeeded())
llvm::dbgs() << diag.getStatusString() << "\n";
assert(diag.succeeded() && "incorrect mapping");
#endif // NDEBUG
(void)diag.silence();
}
void transform::TransformResults::setRemainingToEmpty(
transform::TransformOpInterface transform) {
for (OpResult opResult : transform->getResults()) {
if (!isSet(opResult.getResultNumber()))
setMappedValues(opResult, {});
}
}
ArrayRef<Operation *>
transform::TransformResults::get(unsigned resultNumber) const {
assert(resultNumber < operations.size() &&
"querying results for a non-existent handle");
assert(operations[resultNumber].data() != nullptr &&
"querying unset results (values or params expected?)");
return operations[resultNumber];
}
ArrayRef<transform::TransformState::Param>
transform::TransformResults::getParams(unsigned resultNumber) const {
assert(resultNumber < params.size() &&
"querying params for a non-existent handle");
assert(params[resultNumber].data() != nullptr &&
"querying unset params (ops or values expected?)");
return params[resultNumber];
}
ArrayRef<Value>
transform::TransformResults::getValues(unsigned resultNumber) const {
assert(resultNumber < values.size() &&
"querying values for a non-existent handle");
assert(values[resultNumber].data() != nullptr &&
"querying unset values (ops or params expected?)");
return values[resultNumber];
}
bool transform::TransformResults::isParam(unsigned resultNumber) const {
assert(resultNumber < params.size() &&
"querying association for a non-existent handle");
return params[resultNumber].data() != nullptr;
}
bool transform::TransformResults::isValue(unsigned resultNumber) const {
assert(resultNumber < values.size() &&
"querying association for a non-existent handle");
return values[resultNumber].data() != nullptr;
}
bool transform::TransformResults::isSet(unsigned resultNumber) const {
assert(resultNumber < params.size() &&
"querying association for a non-existent handle");
return params[resultNumber].data() != nullptr ||
operations[resultNumber].data() != nullptr ||
values[resultNumber].data() != nullptr;
}
//===----------------------------------------------------------------------===//
// TrackingListener
//===----------------------------------------------------------------------===//
transform::TrackingListener::TrackingListener(TransformState &state,
TransformOpInterface op,
TrackingListenerConfig config)
: TransformState::Extension(state), transformOp(op), config(config) {
if (op) {
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
consumedHandles.insert(opOperand->get());
}
}
}
Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
Operation *defOp = nullptr;
for (Value v : values) {
// Skip empty values.
if (!v)
continue;
if (!defOp) {
defOp = v.getDefiningOp();
continue;
}
if (defOp != v.getDefiningOp())
return nullptr;
}
return defOp;
}
DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
Operation *&result, Operation *op, ValueRange newValues) const {
assert(op->getNumResults() == newValues.size() &&
"invalid number of replacement values");
SmallVector<Value> values(newValues.begin(), newValues.end());
DiagnosedSilenceableFailure diag = emitSilenceableFailure(
getTransformOp(), "tracking listener failed to find replacement op "
"during application of this transform op");
do {
// If the replacement values belong to different ops, drop the mapping.
Operation *defOp = getCommonDefiningOp(values);
if (!defOp) {
diag.attachNote() << "replacement values belong to different ops";
return diag;
}
// Skip through ops that implement CastOpInterface.
if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
values.clear();
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
diag.attachNote(defOp->getLoc())
<< "using output of 'CastOpInterface' op";
continue;
}
// If the defining op has the same name or we do not care about the name of
// op replacements at all, we take it as a replacement.
if (!config.requireMatchingReplacementOpName ||
op->getName() == defOp->getName()) {
result = defOp;
return DiagnosedSilenceableFailure::success();
}
// Replacing an op with a constant-like equivalent is a common
// canonicalization.
if (defOp->hasTrait<OpTrait::ConstantLike>()) {
result = defOp;
return DiagnosedSilenceableFailure::success();
}
values.clear();
// Skip through ops that implement FindPayloadReplacementOpInterface.
if (auto findReplacementOpInterface =
dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
values.assign(findReplacementOpInterface.getNextOperands());
diag.attachNote(defOp->getLoc()) << "using operands provided by "
"'FindPayloadReplacementOpInterface'";
continue;
}
} while (!values.empty());
diag.attachNote() << "ran out of suitable replacement values";
return diag;
}
void transform::TrackingListener::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
DBGS() << "Match Failure : " << diag.str() << "\n";
});
}
void transform::TrackingListener::notifyOperationErased(Operation *op) {
// Remove mappings for result values.
for (OpResult value : op->getResults())
(void)replacePayloadValue(value, nullptr);
// Remove mapping for op.
(void)replacePayloadOp(op, nullptr);
}
void transform::TrackingListener::notifyOperationReplaced(
Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
"invalid number of replacement values");
// Replace value handles.
for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
(void)replacePayloadValue(oldValue, newValue);
// Replace op handle.
SmallVector<Value> opHandles;
if (failed(getTransformState().getHandlesForPayloadOp(
op, opHandles, /*includeOutOfScope=*/true))) {
// Op is not tracked.
return;
}
// Helper function to check if the current transform op consumes any handle
// that is mapped to `op`.
//
// Note: If a handle was consumed, there shouldn't be any alive users, so it
// is not really necessary to check for consumed handles. However, in case
// there are indeed alive handles that were consumed (which is undefined
// behavior) and a replacement op could not be found, we want to fail with a
// nicer error message: "op uses a handle invalidated..." instead of "could
// not find replacement op". This nicer error is produced later.
auto handleWasConsumed = [&] {
return llvm::any_of(opHandles,
[&](Value h) { return consumedHandles.contains(h); });
};
// Check if there are any handles that must be updated.
Value aliveHandle;
if (config.skipHandleFn) {
auto it = llvm::find_if(opHandles,
[&](Value v) { return !config.skipHandleFn(v); });
if (it != opHandles.end())
aliveHandle = *it;
} else if (!opHandles.empty()) {
aliveHandle = opHandles.front();
}
if (!aliveHandle || handleWasConsumed()) {
// The op is tracked but the corresponding handles are dead or were
// consumed. Drop the op form the mapping.
(void)replacePayloadOp(op, nullptr);
return;
}
Operation *replacement;
DiagnosedSilenceableFailure diag =
findReplacementOp(replacement, op, newValues);
// If the op is tracked but no replacement op was found, send a
// notification.
if (!diag.succeeded()) {
diag.attachNote(aliveHandle.getLoc())
<< "replacement is required because this handle must be updated";
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
(void)replacePayloadOp(op, nullptr);
return;
}
(void)replacePayloadOp(op, replacement);
}
transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
// The state of the ErrorCheckingTrackingListener must be checked and reset
// if there was an error. This is to prevent errors from accidentally being
// missed.
assert(status.succeeded() && "listener state was not checked");
}
DiagnosedSilenceableFailure
transform::ErrorCheckingTrackingListener::checkAndResetError() {
DiagnosedSilenceableFailure s = std::move(status);
status = DiagnosedSilenceableFailure::success();
errorCounter = 0;
return s;
}
bool transform::ErrorCheckingTrackingListener::failed() const {
return !status.succeeded();
}
void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
// Merge potentially existing diags and store the result in the listener.
SmallVector<Diagnostic> diags;
diag.takeDiagnostics(diags);
if (!status.succeeded())
status.takeDiagnostics(diags);
status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
// Report more details.
status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
for (auto &&[index, value] : llvm::enumerate(values))
status.attachNote(value.getLoc())
<< "[" << errorCounter << "] replacement value " << index;
++errorCounter;
}
std::string
transform::ErrorCheckingTrackingListener::getLatestMatchFailureMessage() {
if (!matchFailure) {
return "";
}
return matchFailure->str();
}
void transform::ErrorCheckingTrackingListener::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
matchFailure = std::move(diag);
}
//===----------------------------------------------------------------------===//
// TransformRewriter
//===----------------------------------------------------------------------===//
transform::TransformRewriter::TransformRewriter(
MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
: RewriterBase(ctx), listener(listener) {
setListener(listener);
}
bool transform::TransformRewriter::hasTrackingFailures() const {
return listener->failed();
}
/// Silence all tracking failures that have been encountered so far.
void transform::TransformRewriter::silenceTrackingFailure() {
if (hasTrackingFailures()) {
DiagnosedSilenceableFailure status = listener->checkAndResetError();
(void)status.silence();
}
}
LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced(
Operation *op, Operation *replacement) {
return listener->replacePayloadOp(op, replacement);
}
//===----------------------------------------------------------------------===//
// Utilities for TransformEachOpTrait.
//===----------------------------------------------------------------------===//
LogicalResult
transform::detail::checkNestedConsumption(Location loc,
ArrayRef<Operation *> targets) {
for (auto &&[position, parent] : llvm::enumerate(targets)) {
for (Operation *child : targets.drop_front(position + 1)) {
if (parent->isAncestor(child)) {
InFlightDiagnostic diag =
emitError(loc)
<< "transform operation consumes a handle pointing to an ancestor "
"payload operation before its descendant";
diag.attachNote()
<< "the ancestor is likely erased or rewritten before the "
"descendant is accessed, leading to undefined behavior";
diag.attachNote(parent->getLoc()) << "ancestor payload op";
diag.attachNote(child->getLoc()) << "descendant payload op";
return diag;
}
}
}
return success();
}
LogicalResult
transform::detail::checkApplyToOne(Operation *transformOp,
Location payloadOpLoc,
const ApplyToEachResultList &partialResult) {
Location transformOpLoc = transformOp->getLoc();
StringRef transformOpName = transformOp->getName().getStringRef();
unsigned expectedNumResults = transformOp->getNumResults();
// Reuse the emission of the diagnostic note.
auto emitDiag = [&]() {
auto diag = mlir::emitError(transformOpLoc);
diag.attachNote(payloadOpLoc) << "when applied to this op";
return diag;
};
if (partialResult.size() != expectedNumResults) {
auto diag = emitDiag() << "application of " << transformOpName
<< " expected to produce " << expectedNumResults
<< " results (actually produced "
<< partialResult.size() << ").";
diag.attachNote(transformOpLoc)
<< "if you need variadic results, consider a generic `apply` "
<< "instead of the specialized `applyToOne`.";
return failure();
}
// Check that the right kind of value was produced.
for (const auto &[ptr, res] :
llvm::zip(partialResult, transformOp->getResults())) {
if (ptr.isNull())
continue;
if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
!isa<Operation *>(ptr)) {
return emitDiag() << "application of " << transformOpName
<< " expected to produce an Operation * for result #"
<< res.getResultNumber();
}
if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
!isa<Attribute>(ptr)) {
return emitDiag() << "application of " << transformOpName
<< " expected to produce an Attribute for result #"
<< res.getResultNumber();
}
if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
!isa<Value>(ptr)) {
return emitDiag() << "application of " << transformOpName
<< " expected to produce a Value for result #"
<< res.getResultNumber();
}
}
return success();
}
template <typename T>
static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
return llvm::to_vector(llvm::map_range(
range, [](transform::MappedValue value) { return cast<T>(value); }));
}
void transform::detail::setApplyToOneResults(
Operation *transformOp, TransformResults &transformResults,
ArrayRef<ApplyToEachResultList> results) {
SmallVector<SmallVector<MappedValue>> transposed;
transposed.resize(transformOp->getNumResults());
for (const ApplyToEachResultList &partialResults : results) {
if (llvm::any_of(partialResults,
[](MappedValue value) { return value.isNull(); }))
continue;
assert(transformOp->getNumResults() == partialResults.size() &&
"expected as many partial results as op as results");
for (auto [i, value] : llvm::enumerate(partialResults))
transposed[i].push_back(value);
}
for (OpResult r : transformOp->getResults()) {
unsigned position = r.getResultNumber();
if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
transformResults.setParams(r,
castVector<Attribute>(transposed[position]));
} else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
transformResults.setValues(r, castVector<Value>(transposed[position]));
} else {
transformResults.set(r, castVector<Operation *>(transposed[position]));
}
}
}
//===----------------------------------------------------------------------===//
// Utilities for implementing transform ops with regions.
//===----------------------------------------------------------------------===//
LogicalResult transform::detail::appendValueMappings(
MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
ValueRange values, const transform::TransformState &state, bool flatten) {
assert(mappings.size() == values.size() && "mismatching number of mappings");
for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
size_t mappedSize = mapped.size();
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
llvm::append_range(mapped, state.getPayloadOps(operand));
} else if (llvm::isa<TransformValueHandleTypeInterface>(
operand.getType())) {
llvm::append_range(mapped, state.getPayloadValues(operand));
} else {
assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
"unsupported kind of transform dialect value");
llvm::append_range(mapped, state.getParams(operand));
}
if (mapped.size() - mappedSize != 1 && !flatten)
return failure();
}
return success();
}
void transform::detail::prepareValueMappings(
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
ValueRange values, const transform::TransformState &state) {
mappings.resize(mappings.size() + values.size());
(void)appendValueMappings(
MutableArrayRef<SmallVector<transform::MappedValue>>(mappings).take_back(
values.size()),
values, state);
}
void transform::detail::forwardTerminatorOperands(
Block *block, transform::TransformState &state,
transform::TransformResults &results) {
for (auto &&[terminatorOperand, result] :
llvm::zip(block->getTerminator()->getOperands(),
block->getParentOp()->getOpResults())) {
if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
results.set(result, state.getPayloadOps(terminatorOperand));
} else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
result.getType())) {
results.setValues(result, state.getPayloadValues(terminatorOperand));
} else {
assert(
llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
"unhandled transform type interface");
results.setParams(result, state.getParams(terminatorOperand));
}
}
}
transform::TransformState
transform::detail::makeTransformStateForTesting(Region *region,
Operation *payloadRoot) {
return TransformState(region, payloadRoot);
}
//===----------------------------------------------------------------------===//
// Utilities for PossibleTopLevelTransformOpTrait.
//===----------------------------------------------------------------------===//
/// Appends to `effects` the memory effect instances on `target` with the same
/// resource and effect as the ones the operation `iface` having on `source`.
static void
remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
OpOperand *target,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
iface.getEffectsOnValue(source, nestedEffects);
for (const auto &effect : nestedEffects)
effects.emplace_back(effect.getEffect(), target, effect.getResource());
}
/// Appends to `effects` the same effects as the operations of `block` have on
/// block arguments but associated with `operands.`
static void
remapArgumentEffects(Block &block, MutableArrayRef<OpOperand> operands,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (Operation &op : block) {
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
if (!iface)
continue;
for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
remapEffects(iface, source, &target, effects);
}
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
iface.getEffectsOnResource(transform::PayloadIRResource::get(),
nestedEffects);
llvm::append_range(effects, nestedEffects);
}
}
void transform::detail::getPotentialTopLevelEffects(
Operation *operation, Value root, Block &body,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(operation->getOpOperands(), effects);
transform::producesHandle(operation->getOpResults(), effects);
if (!root) {
for (Operation &op : body) {
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
if (!iface)
continue;
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
iface.getEffects(effects);
}
return;
}
// Carry over all effects on arguments of the entry block as those on the
// operands, this is the same value just remapped.
remapArgumentEffects(body, operation->getOpOperands(), effects);
}
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
TransformState &state, Operation *op, Region &region) {
SmallVector<Operation *> targets;
SmallVector<SmallVector<MappedValue>> extraMappings;
if (op->getNumOperands() != 0) {
llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
} else {
if (state.getNumTopLevelMappings() !=
region.front().getNumArguments() - 1) {
return emitError(op->getLoc())
<< "operation expects " << region.front().getNumArguments() - 1
<< " extra value bindings, but " << state.getNumTopLevelMappings()
<< " were provided to the interpreter";
}
targets.push_back(state.getTopLevel());
for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
}
if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
return failure();
for (BlockArgument argument : region.front().getArguments().drop_front()) {
if (failed(state.mapBlockArgument(
argument, extraMappings[argument.getArgNumber() - 1])))
return failure();
}
return success();
}
LogicalResult
transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
// Attaching this trait without the interface is a misuse of the API, but it
// cannot be caught via a static_assert because interface registration is
// dynamic.
assert(isa<TransformOpInterface>(op) &&
"should implement TransformOpInterface to have "
"PossibleTopLevelTransformOpTrait");
if (op->getNumRegions() < 1)
return op->emitOpError() << "expects at least one region";
Region *bodyRegion = &op->getRegion(0);
if (!llvm::hasNItems(*bodyRegion, 1))
return op->emitOpError() << "expects a single-block region";
Block *body = &bodyRegion->front();
if (body->getNumArguments() == 0) {
return op->emitOpError()
<< "expects the entry block to have at least one argument";
}
if (!llvm::isa<TransformHandleTypeInterface>(
body->getArgument(0).getType())) {
return op->emitOpError()
<< "expects the first entry block argument to be of type "
"implementing TransformHandleTypeInterface";
}
BlockArgument arg = body->getArgument(0);
if (op->getNumOperands() != 0) {
if (arg.getType() != op->getOperand(0).getType()) {
return op->emitOpError()
<< "expects the type of the block argument to match "
"the type of the operand";
}
}
for (BlockArgument arg : body->getArguments().drop_front()) {
if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
TransformValueHandleTypeInterface>(arg.getType()))
continue;
InFlightDiagnostic diag =
op->emitOpError()
<< "expects trailing entry block arguments to be of type implementing "
"TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
"TransformParamTypeInterface";
diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
return diag;
}
if (auto *parent =
op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
if (op->getNumOperands() != body->getNumArguments()) {
InFlightDiagnostic diag =
op->emitOpError()
<< "expects operands to be provided for a nested op";
diag.attachNote(parent->getLoc())
<< "nested in another possible top-level op";
return diag;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// Utilities for ParamProducedTransformOpTrait.
//===----------------------------------------------------------------------===//
void transform::detail::getParamProducerTransformOpTraitEffects(
Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
producesHandle(op->getResults(), effects);
bool hasPayloadOperands = false;
for (OpOperand &operand : op->getOpOperands()) {
onlyReadsHandle(operand, effects);
if (llvm::isa<TransformHandleTypeInterface,
TransformValueHandleTypeInterface>(operand.get().getType()))
hasPayloadOperands = true;
}
if (hasPayloadOperands)
onlyReadsPayload(effects);
}
LogicalResult
transform::detail::verifyParamProducerTransformOpTrait(Operation *op) {
// Interfaces can be attached dynamically, so this cannot be a static
// assert.
if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
llvm::report_fatal_error(
Twine("ParamProducerTransformOpTrait must be attached to an op that "
"implements MemoryEffectsOpInterface, found on ") +
op->getName().getStringRef());
}
for (Value result : op->getResults()) {
if (llvm::isa<TransformParamTypeInterface>(result.getType()))
continue;
return op->emitOpError()
<< "ParamProducerTransformOpTrait attached to this op expects "
"result types to implement TransformParamTypeInterface";
}
return success();
}
//===----------------------------------------------------------------------===//
// Memory effects.
//===----------------------------------------------------------------------===//
void transform::consumesHandle(
MutableArrayRef<OpOperand> handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (OpOperand &handle : handles) {
effects.emplace_back(MemoryEffects::Read::get(), &handle,
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Free::get(), &handle,
TransformMappingResource::get());
}
}
/// Returns `true` if the given list of effects instances contains an instance
/// with the effect type specified as template parameter.
template <typename EffectTy, typename ResourceTy, typename Range>
static bool hasEffect(Range &&effects) {
return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
return isa<EffectTy>(effect.getEffect()) &&
isa<ResourceTy>(effect.getResource());
});
}
bool transform::isHandleConsumed(Value handle,
transform::TransformOpInterface transform) {
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffectsOnValue(handle, effects);
return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
}
void transform::producesHandle(
ResultRange handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (OpResult handle : handles) {
effects.emplace_back(MemoryEffects::Allocate::get(), handle,
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Write::get(), handle,
TransformMappingResource::get());
}
}
void transform::producesHandle(
MutableArrayRef<BlockArgument> handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (BlockArgument handle : handles) {
effects.emplace_back(MemoryEffects::Allocate::get(), handle,
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Write::get(), handle,
TransformMappingResource::get());
}
}
void transform::onlyReadsHandle(
MutableArrayRef<OpOperand> handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (OpOperand &handle : handles) {
effects.emplace_back(MemoryEffects::Read::get(), &handle,
TransformMappingResource::get());
}
}
void transform::modifiesPayload(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
}
void transform::onlyReadsPayload(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
}
bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffects(effects);
return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
}
bool transform::doesReadPayload(transform::TransformOpInterface transform) {
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffects(effects);
return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
}
void transform::getConsumedBlockArguments(
Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
SmallVector<MemoryEffects::EffectInstance> effects;
for (Operation &nested : block) {
auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
if (!iface)
continue;
effects.clear();
iface.getEffects(effects);
for (const MemoryEffects::EffectInstance &effect : effects) {
BlockArgument argument =
dyn_cast_or_null<BlockArgument>(effect.getValue());
if (!argument || argument.getOwner() != &block ||
!isa<MemoryEffects::Free>(effect.getEffect()) ||
effect.getResource() != transform::TransformMappingResource::get()) {
continue;
}
consumedArguments.insert(argument.getArgNumber());
}
}
}
//===----------------------------------------------------------------------===//
// Utilities for TransformOpInterface.
//===----------------------------------------------------------------------===//
SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
TransformOpInterface transformOp) {
SmallVector<OpOperand *> consumedOperands;
consumedOperands.reserve(transformOp->getNumOperands());
auto memEffectInterface =
cast<MemoryEffectOpInterface>(transformOp.getOperation());
SmallVector<MemoryEffects::EffectInstance, 2> effects;
for (OpOperand &target : transformOp->getOpOperands()) {
effects.clear();
memEffectInterface.getEffectsOnValue(target.get(), effects);
if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
return isa<transform::TransformMappingResource>(
effect.getResource()) &&
isa<MemoryEffects::Free>(effect.getEffect());
})) {
consumedOperands.push_back(&target);
}
}
return consumedOperands;
}
LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
auto iface = cast<MemoryEffectOpInterface>(op);
SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffects(effects);
auto effectsOn = [&](Value value) {
return llvm::make_filter_range(
effects, [value](const MemoryEffects::EffectInstance &instance) {
return instance.getValue() == value;
});
};
std::optional<unsigned> firstConsumedOperand;
for (OpOperand &operand : op->getOpOperands()) {
auto range = effectsOn(operand.get());
if (range.empty()) {
InFlightDiagnostic diag =
op->emitError() << "TransformOpInterface requires memory effects "
"on operands to be specified";
diag.attachNote() << "no effects specified for operand #"
<< operand.getOperandNumber();
return diag;
}
if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
InFlightDiagnostic diag = op->emitError()
<< "TransformOpInterface did not expect "
"'allocate' memory effect on an operand";
diag.attachNote() << "specified for operand #"
<< operand.getOperandNumber();
return diag;
}
if (!firstConsumedOperand &&
::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
firstConsumedOperand = operand.getOperandNumber();
}
}
if (firstConsumedOperand &&
!::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
InFlightDiagnostic diag =
op->emitError()
<< "TransformOpInterface expects ops consuming operands to have a "
"'write' effect on the payload resource";
diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
return diag;
}
for (OpResult result : op->getResults()) {
auto range = effectsOn(result);
if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
range)) {
InFlightDiagnostic diag =
op->emitError() << "TransformOpInterface requires 'allocate' memory "
"effect to be specified for results";
diag.attachNote() << "no 'allocate' effect specified for result #"
<< result.getResultNumber();
return diag;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// Entry point.
//===----------------------------------------------------------------------===//
LogicalResult transform::applyTransforms(
Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
const TransformOptions &options, bool enforceToplevelTransformOp,
function_ref<void(TransformState &)> stateInitializer,
function_ref<LogicalResult(TransformState &)> stateExporter) {
if (enforceToplevelTransformOp) {
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
return transform->emitError()
<< "expected transform to start at the top-level transform op";
}
} else if (failed(
detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
return failure();
}
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
if (stateInitializer)
stateInitializer(state);
if (state.applyTransform(transform).checkAndReport().failed())
return failure();
if (stateExporter)
return stateExporter(state);
return success();
}
//===----------------------------------------------------------------------===//
// Generated interface implementation.
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"