Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:
// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa<T>, cast<T> and the llvm::dyn_cast<T>
I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.
2031 lines
79 KiB
C++
2031 lines
79 KiB
C++
//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/CastInterfaces.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/ScopeExit.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
#define DEBUG_TYPE "transform-dialect"
|
|
#define DEBUG_TYPE_FULL "transform-dialect-full"
|
|
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
|
|
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
|
|
#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
|
|
#define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helper functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
|
|
/// properly dominates `b` and `b` is not inside `a`.
|
|
static bool happensBefore(Operation *a, Operation *b) {
|
|
do {
|
|
if (a->isProperAncestor(b))
|
|
return false;
|
|
if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
|
|
return a->isBeforeInBlock(bAncestor);
|
|
}
|
|
} while ((a = a->getParentOp()));
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformState
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
constexpr const Value transform::TransformState::kTopLevelValue;
|
|
|
|
transform::TransformState::TransformState(
|
|
Region *region, Operation *payloadRoot,
|
|
const RaggedArray<MappedValue> &extraMappings,
|
|
const TransformOptions &options)
|
|
: topLevel(payloadRoot), options(options) {
|
|
topLevelMappedValues.reserve(extraMappings.size());
|
|
for (ArrayRef<MappedValue> mapping : extraMappings)
|
|
topLevelMappedValues.push_back(mapping);
|
|
if (region) {
|
|
RegionScope *scope = new RegionScope(*this, *region);
|
|
topLevelRegionScope.reset(scope);
|
|
}
|
|
}
|
|
|
|
Operation *transform::TransformState::getTopLevel() const { return topLevel; }
|
|
|
|
ArrayRef<Operation *>
|
|
transform::TransformState::getPayloadOpsView(Value value) const {
|
|
const TransformOpMapping &operationMapping = getMapping(value).direct;
|
|
auto iter = operationMapping.find(value);
|
|
assert(iter != operationMapping.end() &&
|
|
"cannot find mapping for payload handle (param/value handle "
|
|
"provided?)");
|
|
return iter->getSecond();
|
|
}
|
|
|
|
ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
|
|
const ParamMapping &mapping = getMapping(value).params;
|
|
auto iter = mapping.find(value);
|
|
assert(iter != mapping.end() && "cannot find mapping for param handle "
|
|
"(operation/value handle provided?)");
|
|
return iter->getSecond();
|
|
}
|
|
|
|
ArrayRef<Value>
|
|
transform::TransformState::getPayloadValuesView(Value handleValue) const {
|
|
const ValueMapping &mapping = getMapping(handleValue).values;
|
|
auto iter = mapping.find(handleValue);
|
|
assert(iter != mapping.end() && "cannot find mapping for value handle "
|
|
"(param/operation handle provided?)");
|
|
return iter->getSecond();
|
|
}
|
|
|
|
LogicalResult transform::TransformState::getHandlesForPayloadOp(
|
|
Operation *op, SmallVectorImpl<Value> &handles,
|
|
bool includeOutOfScope) const {
|
|
bool found = false;
|
|
for (const auto &[region, mapping] : llvm::reverse(mappings)) {
|
|
auto iterator = mapping->reverse.find(op);
|
|
if (iterator != mapping->reverse.end()) {
|
|
llvm::append_range(handles, iterator->getSecond());
|
|
found = true;
|
|
}
|
|
// Stop looking when reaching a region that is isolated from above.
|
|
if (!includeOutOfScope &&
|
|
region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
|
|
break;
|
|
}
|
|
|
|
return success(found);
|
|
}
|
|
|
|
LogicalResult transform::TransformState::getHandlesForPayloadValue(
|
|
Value payloadValue, SmallVectorImpl<Value> &handles,
|
|
bool includeOutOfScope) const {
|
|
bool found = false;
|
|
for (const auto &[region, mapping] : llvm::reverse(mappings)) {
|
|
auto iterator = mapping->reverseValues.find(payloadValue);
|
|
if (iterator != mapping->reverseValues.end()) {
|
|
llvm::append_range(handles, iterator->getSecond());
|
|
found = true;
|
|
}
|
|
// Stop looking when reaching a region that is isolated from above.
|
|
if (!includeOutOfScope &&
|
|
region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
|
|
break;
|
|
}
|
|
|
|
return success(found);
|
|
}
|
|
|
|
/// Given a list of MappedValues, cast them to the value kind implied by the
|
|
/// interface of the handle type, and dispatch to one of the callbacks.
|
|
static DiagnosedSilenceableFailure dispatchMappedValues(
|
|
Value handle, ArrayRef<transform::MappedValue> values,
|
|
function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
|
|
function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
|
|
function_ref<LogicalResult(ValueRange)> valuesFn) {
|
|
if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
|
|
SmallVector<Operation *> operations;
|
|
operations.reserve(values.size());
|
|
for (transform::MappedValue value : values) {
|
|
if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
|
|
operations.push_back(op);
|
|
continue;
|
|
}
|
|
return emitSilenceableFailure(handle.getLoc())
|
|
<< "wrong kind of value provided for top-level operation handle";
|
|
}
|
|
if (failed(operationsFn(operations)))
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
if (llvm::isa<transform::TransformValueHandleTypeInterface>(
|
|
handle.getType())) {
|
|
SmallVector<Value> payloadValues;
|
|
payloadValues.reserve(values.size());
|
|
for (transform::MappedValue value : values) {
|
|
if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
|
|
payloadValues.push_back(v);
|
|
continue;
|
|
}
|
|
return emitSilenceableFailure(handle.getLoc())
|
|
<< "wrong kind of value provided for the top-level value handle";
|
|
}
|
|
if (failed(valuesFn(payloadValues)))
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
|
|
"unsupported kind of block argument");
|
|
SmallVector<transform::Param> parameters;
|
|
parameters.reserve(values.size());
|
|
for (transform::MappedValue value : values) {
|
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
|
|
parameters.push_back(attr);
|
|
continue;
|
|
}
|
|
return emitSilenceableFailure(handle.getLoc())
|
|
<< "wrong kind of value provided for top-level parameter";
|
|
}
|
|
if (failed(paramsFn(parameters)))
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
LogicalResult
|
|
transform::TransformState::mapBlockArgument(BlockArgument argument,
|
|
ArrayRef<MappedValue> values) {
|
|
return dispatchMappedValues(
|
|
argument, values,
|
|
[&](ArrayRef<Operation *> operations) {
|
|
return setPayloadOps(argument, operations);
|
|
},
|
|
[&](ArrayRef<Param> params) {
|
|
return setParams(argument, params);
|
|
},
|
|
[&](ValueRange payloadValues) {
|
|
return setPayloadValues(argument, payloadValues);
|
|
})
|
|
.checkAndReport();
|
|
}
|
|
|
|
LogicalResult transform::TransformState::mapBlockArguments(
|
|
Block::BlockArgListType arguments,
|
|
ArrayRef<SmallVector<MappedValue>> mapping) {
|
|
for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
|
|
if (failed(mapBlockArgument(argument, values)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
transform::TransformState::setPayloadOps(Value value,
|
|
ArrayRef<Operation *> targets) {
|
|
assert(value != kTopLevelValue &&
|
|
"attempting to reset the transformation root");
|
|
assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
|
|
"wrong handle type");
|
|
|
|
for (Operation *target : targets) {
|
|
if (target)
|
|
continue;
|
|
return emitError(value.getLoc())
|
|
<< "attempting to assign a null payload op to this transform value";
|
|
}
|
|
|
|
auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
|
|
DiagnosedSilenceableFailure result =
|
|
iface.checkPayload(value.getLoc(), targets);
|
|
if (failed(result.checkAndReport()))
|
|
return failure();
|
|
|
|
// Setting new payload for the value without cleaning it first is a misuse of
|
|
// the API, assert here.
|
|
SmallVector<Operation *> storedTargets(targets);
|
|
Mappings &mappings = getMapping(value);
|
|
bool inserted =
|
|
mappings.direct.insert({value, std::move(storedTargets)}).second;
|
|
assert(inserted && "value is already associated with another list");
|
|
(void)inserted;
|
|
|
|
for (Operation *op : targets)
|
|
mappings.reverse[op].push_back(value);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
transform::TransformState::setPayloadValues(Value handle,
|
|
ValueRange payloadValues) {
|
|
assert(handle != nullptr && "attempting to set params for a null value");
|
|
assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
|
|
"wrong handle type");
|
|
|
|
for (Value payload : payloadValues) {
|
|
if (payload)
|
|
continue;
|
|
return emitError(handle.getLoc()) << "attempting to assign a null payload "
|
|
"value to this transform handle";
|
|
}
|
|
|
|
auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
|
|
SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
|
|
DiagnosedSilenceableFailure result =
|
|
iface.checkPayload(handle.getLoc(), payloadValueVector);
|
|
if (failed(result.checkAndReport()))
|
|
return failure();
|
|
|
|
Mappings &mappings = getMapping(handle);
|
|
bool inserted =
|
|
mappings.values.insert({handle, std::move(payloadValueVector)}).second;
|
|
assert(
|
|
inserted &&
|
|
"value handle is already associated with another list of payload values");
|
|
(void)inserted;
|
|
|
|
for (Value payload : payloadValues)
|
|
mappings.reverseValues[payload].push_back(handle);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult transform::TransformState::setParams(Value value,
|
|
ArrayRef<Param> params) {
|
|
assert(value != nullptr && "attempting to set params for a null value");
|
|
|
|
for (Attribute attr : params) {
|
|
if (attr)
|
|
continue;
|
|
return emitError(value.getLoc())
|
|
<< "attempting to assign a null parameter to this transform value";
|
|
}
|
|
|
|
auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
|
|
assert(value &&
|
|
"cannot associate parameter with a value of non-parameter type");
|
|
DiagnosedSilenceableFailure result =
|
|
valueType.checkPayload(value.getLoc(), params);
|
|
if (failed(result.checkAndReport()))
|
|
return failure();
|
|
|
|
Mappings &mappings = getMapping(value);
|
|
bool inserted =
|
|
mappings.params.insert({value, llvm::to_vector(params)}).second;
|
|
assert(inserted && "value is already associated with another list of params");
|
|
(void)inserted;
|
|
return success();
|
|
}
|
|
|
|
template <typename Mapping, typename Key, typename Mapped>
|
|
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
|
|
auto it = mapping.find(key);
|
|
if (it == mapping.end())
|
|
return;
|
|
|
|
llvm::erase(it->getSecond(), mapped);
|
|
if (it->getSecond().empty())
|
|
mapping.erase(it);
|
|
}
|
|
|
|
void transform::TransformState::forgetMapping(Value opHandle,
|
|
ValueRange origOpFlatResults,
|
|
bool allowOutOfScope) {
|
|
Mappings &mappings = getMapping(opHandle, allowOutOfScope);
|
|
for (Operation *op : mappings.direct[opHandle])
|
|
dropMappingEntry(mappings.reverse, op, opHandle);
|
|
mappings.direct.erase(opHandle);
|
|
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
// Payload IR is removed from the mapping. This invalidates the respective
|
|
// iterators.
|
|
mappings.incrementTimestamp(opHandle);
|
|
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
|
|
for (Value opResult : origOpFlatResults) {
|
|
SmallVector<Value> resultHandles;
|
|
(void)getHandlesForPayloadValue(opResult, resultHandles);
|
|
for (Value resultHandle : resultHandles) {
|
|
Mappings &localMappings = getMapping(resultHandle);
|
|
dropMappingEntry(localMappings.values, resultHandle, opResult);
|
|
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
// Payload IR is removed from the mapping. This invalidates the respective
|
|
// iterators.
|
|
mappings.incrementTimestamp(resultHandle);
|
|
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
|
|
}
|
|
}
|
|
}
|
|
|
|
void transform::TransformState::forgetValueMapping(
|
|
Value valueHandle, ArrayRef<Operation *> payloadOperations) {
|
|
Mappings &mappings = getMapping(valueHandle);
|
|
for (Value payloadValue : mappings.reverseValues[valueHandle])
|
|
dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
|
|
mappings.values.erase(valueHandle);
|
|
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
// Payload IR is removed from the mapping. This invalidates the respective
|
|
// iterators.
|
|
mappings.incrementTimestamp(valueHandle);
|
|
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
|
|
for (Operation *payloadOp : payloadOperations) {
|
|
SmallVector<Value> opHandles;
|
|
(void)getHandlesForPayloadOp(payloadOp, opHandles);
|
|
for (Value opHandle : opHandles) {
|
|
Mappings &localMappings = getMapping(opHandle);
|
|
dropMappingEntry(localMappings.direct, opHandle, payloadOp);
|
|
dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
|
|
|
|
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
// Payload IR is removed from the mapping. This invalidates the respective
|
|
// iterators.
|
|
localMappings.incrementTimestamp(opHandle);
|
|
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
}
|
|
}
|
|
}
|
|
|
|
LogicalResult
|
|
transform::TransformState::replacePayloadOp(Operation *op,
|
|
Operation *replacement) {
|
|
// TODO: consider invalidating the handles to nested objects here.
|
|
|
|
#ifndef NDEBUG
|
|
for (Value opResult : op->getResults()) {
|
|
SmallVector<Value> valueHandles;
|
|
(void)getHandlesForPayloadValue(opResult, valueHandles,
|
|
/*includeOutOfScope=*/true);
|
|
assert(valueHandles.empty() && "expected no mapping to old results");
|
|
}
|
|
#endif // NDEBUG
|
|
|
|
// Drop the mapping between the op and all handles that point to it. Fail if
|
|
// there are no handles.
|
|
SmallVector<Value> opHandles;
|
|
if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
|
|
return failure();
|
|
for (Value handle : opHandles) {
|
|
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
|
|
dropMappingEntry(mappings.reverse, op, handle);
|
|
}
|
|
|
|
// Replace the pointed-to object of all handles with the replacement object.
|
|
// In case a payload op was erased (replacement object is nullptr), a nullptr
|
|
// is stored in the mapping. These nullptrs are removed after each transform.
|
|
// Furthermore, nullptrs are not enumerated by payload op iterators. The
|
|
// relative order of ops is preserved.
|
|
//
|
|
// Removing an op from the mapping would be problematic because removing an
|
|
// element from an array invalidates iterators; merely changing the value of
|
|
// elements does not.
|
|
for (Value handle : opHandles) {
|
|
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
|
|
auto it = mappings.direct.find(handle);
|
|
if (it == mappings.direct.end())
|
|
continue;
|
|
|
|
SmallVector<Operation *, 2> &association = it->getSecond();
|
|
// Note that an operation may be associated with the handle more than once.
|
|
for (Operation *&mapped : association) {
|
|
if (mapped == op)
|
|
mapped = replacement;
|
|
}
|
|
|
|
if (replacement) {
|
|
mappings.reverse[replacement].push_back(handle);
|
|
} else {
|
|
opHandlesToCompact.insert(handle);
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
transform::TransformState::replacePayloadValue(Value value, Value replacement) {
|
|
SmallVector<Value> valueHandles;
|
|
if (failed(getHandlesForPayloadValue(value, valueHandles,
|
|
/*includeOutOfScope=*/true)))
|
|
return failure();
|
|
|
|
for (Value handle : valueHandles) {
|
|
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
|
|
dropMappingEntry(mappings.reverseValues, value, handle);
|
|
|
|
// If replacing with null, that is erasing the mapping, drop the mapping
|
|
// between the handles and the IR objects
|
|
if (!replacement) {
|
|
dropMappingEntry(mappings.values, handle, value);
|
|
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
// Payload IR is removed from the mapping. This invalidates the respective
|
|
// iterators.
|
|
mappings.incrementTimestamp(handle);
|
|
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
} else {
|
|
auto it = mappings.values.find(handle);
|
|
if (it == mappings.values.end())
|
|
continue;
|
|
|
|
SmallVector<Value> &association = it->getSecond();
|
|
for (Value &mapped : association) {
|
|
if (mapped == value)
|
|
mapped = replacement;
|
|
}
|
|
mappings.reverseValues[replacement].push_back(handle);
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
void transform::TransformState::recordOpHandleInvalidationOne(
|
|
OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
|
|
Operation *payloadOp, Value otherHandle, Value throughValue,
|
|
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
|
|
// If the op is associated with invalidated handle, skip the check as it
|
|
// may be reading invalid IR. This also ensures we report the first
|
|
// invalidation and not the last one.
|
|
if (invalidatedHandles.count(otherHandle) ||
|
|
newlyInvalidated.count(otherHandle))
|
|
return;
|
|
|
|
FULL_LDBG("--recordOpHandleInvalidationOne\n");
|
|
DEBUG_WITH_TYPE(
|
|
DEBUG_TYPE_FULL,
|
|
llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
|
|
[](Operation *op) { llvm::dbgs() << *op; });
|
|
llvm::dbgs() << "\n");
|
|
|
|
Operation *owner = consumingHandle.getOwner();
|
|
unsigned operandNo = consumingHandle.getOperandNumber();
|
|
for (Operation *ancestor : potentialAncestors) {
|
|
// clang-format off
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
|
|
{ (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
|
|
{ (DBGS() << "----of payload with name: "
|
|
<< payloadOp->getName().getIdentifier() << "\n"); });
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
|
|
{ (DBGS() << "----of payload: " << *payloadOp << "\n"); });
|
|
// clang-format on
|
|
if (!ancestor->isAncestor(payloadOp))
|
|
continue;
|
|
|
|
// Make sure the error-reporting lambda doesn't capture anything
|
|
// by-reference because it will go out of scope. Additionally, extract
|
|
// location from Payload IR ops because the ops themselves may be
|
|
// deleted before the lambda gets called.
|
|
Location ancestorLoc = ancestor->getLoc();
|
|
Location opLoc = payloadOp->getLoc();
|
|
std::optional<Location> throughValueLoc =
|
|
throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
|
|
newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
|
|
otherHandle,
|
|
throughValueLoc](Location currentLoc) {
|
|
InFlightDiagnostic diag = emitError(currentLoc)
|
|
<< "op uses a handle invalidated by a "
|
|
"previously executed transform op";
|
|
diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
|
|
diag.attachNote(owner->getLoc())
|
|
<< "invalidated by this transform op that consumes its operand #"
|
|
<< operandNo
|
|
<< " and invalidates all handles to payload IR entities associated "
|
|
"with this operand and entities nested in them";
|
|
diag.attachNote(ancestorLoc) << "ancestor payload op";
|
|
diag.attachNote(opLoc) << "nested payload op";
|
|
if (throughValueLoc) {
|
|
diag.attachNote(*throughValueLoc)
|
|
<< "consumed handle points to this payload value";
|
|
}
|
|
};
|
|
}
|
|
}
|
|
|
|
void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
|
|
OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
|
|
Value payloadValue, Value valueHandle,
|
|
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
|
|
// If the op is associated with invalidated handle, skip the check as it
|
|
// may be reading invalid IR. This also ensures we report the first
|
|
// invalidation and not the last one.
|
|
if (invalidatedHandles.count(valueHandle) ||
|
|
newlyInvalidated.count(valueHandle))
|
|
return;
|
|
|
|
for (Operation *ancestor : potentialAncestors) {
|
|
Operation *definingOp;
|
|
std::optional<unsigned> resultNo;
|
|
unsigned argumentNo = std::numeric_limits<unsigned>::max();
|
|
unsigned blockNo = std::numeric_limits<unsigned>::max();
|
|
unsigned regionNo = std::numeric_limits<unsigned>::max();
|
|
if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
|
|
definingOp = opResult.getOwner();
|
|
resultNo = opResult.getResultNumber();
|
|
} else {
|
|
auto arg = llvm::cast<BlockArgument>(payloadValue);
|
|
definingOp = arg.getParentBlock()->getParentOp();
|
|
argumentNo = arg.getArgNumber();
|
|
blockNo = std::distance(arg.getOwner()->getParent()->begin(),
|
|
arg.getOwner()->getIterator());
|
|
regionNo = arg.getOwner()->getParent()->getRegionNumber();
|
|
}
|
|
assert(definingOp && "expected the value to be defined by an op as result "
|
|
"or block argument");
|
|
if (!ancestor->isAncestor(definingOp))
|
|
continue;
|
|
|
|
Operation *owner = opHandle.getOwner();
|
|
unsigned operandNo = opHandle.getOperandNumber();
|
|
Location ancestorLoc = ancestor->getLoc();
|
|
Location opLoc = definingOp->getLoc();
|
|
Location valueLoc = payloadValue.getLoc();
|
|
newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
|
|
argumentNo, blockNo, regionNo, ancestorLoc,
|
|
opLoc, valueLoc](Location currentLoc) {
|
|
InFlightDiagnostic diag = emitError(currentLoc)
|
|
<< "op uses a handle invalidated by a "
|
|
"previously executed transform op";
|
|
diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
|
|
diag.attachNote(owner->getLoc())
|
|
<< "invalidated by this transform op that consumes its operand #"
|
|
<< operandNo
|
|
<< " and invalidates all handles to payload IR entities "
|
|
"associated with this operand and entities nested in them";
|
|
diag.attachNote(ancestorLoc)
|
|
<< "ancestor op associated with the consumed handle";
|
|
if (resultNo) {
|
|
diag.attachNote(opLoc)
|
|
<< "op defining the value as result #" << *resultNo;
|
|
} else {
|
|
diag.attachNote(opLoc)
|
|
<< "op defining the value as block argument #" << argumentNo
|
|
<< " of block #" << blockNo << " in region #" << regionNo;
|
|
}
|
|
diag.attachNote(valueLoc) << "payload value";
|
|
};
|
|
}
|
|
}
|
|
|
|
void transform::TransformState::recordOpHandleInvalidation(
|
|
OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
|
|
Value throughValue,
|
|
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
|
|
|
|
if (potentialAncestors.empty()) {
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
|
|
(DBGS() << "----recording invalidation for empty handle: " << handle.get()
|
|
<< "\n");
|
|
});
|
|
|
|
Operation *owner = handle.getOwner();
|
|
unsigned operandNo = handle.getOperandNumber();
|
|
newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
|
|
InFlightDiagnostic diag = emitError(currentLoc)
|
|
<< "op uses a handle associated with empty "
|
|
"payload and invalidated by a "
|
|
"previously executed transform op";
|
|
diag.attachNote(owner->getLoc())
|
|
<< "invalidated by this transform op that consumes its operand #"
|
|
<< operandNo;
|
|
};
|
|
return;
|
|
}
|
|
|
|
// Iterate over the mapping and invalidate aliasing handles. This is quite
|
|
// expensive and only necessary for error reporting in case of transform
|
|
// dialect misuse with dangling handles. Iteration over the handles is based
|
|
// on the assumption that the number of handles is significantly less than the
|
|
// number of IR objects (operations and values). Alternatively, we could walk
|
|
// the IR nested in each payload op associated with the given handle and look
|
|
// for handles associated with each operation and value.
|
|
for (const auto &[region, mapping] : llvm::reverse(mappings)) {
|
|
// Go over all op handle mappings and mark as invalidated any handle
|
|
// pointing to any of the payload ops associated with the given handle or
|
|
// any op nested in them.
|
|
for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
|
|
for (Value otherHandle : otherHandles)
|
|
recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
|
|
otherHandle, throughValue,
|
|
newlyInvalidated);
|
|
}
|
|
// Go over all value handle mappings and mark as invalidated any handle
|
|
// pointing to any result of the payload op associated with the given handle
|
|
// or any op nested in them. Similarly invalidate handles to argument of
|
|
// blocks belonging to any region of any payload op associated with the
|
|
// given handle or any op nested in them.
|
|
for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
|
|
for (Value valueHandle : valueHandles)
|
|
recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
|
|
payloadValue, valueHandle,
|
|
newlyInvalidated);
|
|
}
|
|
|
|
// Stop lookup when reaching a region that is isolated from above.
|
|
if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
|
|
break;
|
|
}
|
|
}
|
|
|
|
void transform::TransformState::recordValueHandleInvalidation(
|
|
OpOperand &valueHandle,
|
|
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
|
|
// Invalidate other handles to the same value.
|
|
for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
|
|
SmallVector<Value> otherValueHandles;
|
|
(void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
|
|
for (Value otherHandle : otherValueHandles) {
|
|
Operation *owner = valueHandle.getOwner();
|
|
unsigned operandNo = valueHandle.getOperandNumber();
|
|
Location valueLoc = payloadValue.getLoc();
|
|
newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
|
|
valueLoc](Location currentLoc) {
|
|
InFlightDiagnostic diag = emitError(currentLoc)
|
|
<< "op uses a handle invalidated by a "
|
|
"previously executed transform op";
|
|
diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
|
|
diag.attachNote(owner->getLoc())
|
|
<< "invalidated by this transform op that consumes its operand #"
|
|
<< operandNo
|
|
<< " and invalidates handles to the same values as associated with "
|
|
"it";
|
|
diag.attachNote(valueLoc) << "payload value";
|
|
};
|
|
}
|
|
|
|
if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
|
|
Operation *payloadOp = opResult.getOwner();
|
|
recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
|
|
newlyInvalidated);
|
|
} else {
|
|
auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
|
|
for (Operation &payloadOp : *arg.getOwner())
|
|
recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
|
|
newlyInvalidated);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Checks that the operation does not use invalidated handles as operands.
|
|
/// Reports errors and returns failure if it does. Otherwise, invalidates the
|
|
/// handles consumed by the operation as well as any handles pointing to payload
|
|
/// IR operations nested in the operations associated with the consumed handles.
|
|
LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
|
|
transform::TransformOpInterface transform,
|
|
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
|
|
FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
|
|
auto memoryEffectsIface =
|
|
cast<MemoryEffectOpInterface>(transform.getOperation());
|
|
SmallVector<MemoryEffects::EffectInstance> effects;
|
|
memoryEffectsIface.getEffectsOnResource(
|
|
transform::TransformMappingResource::get(), effects);
|
|
|
|
for (OpOperand &target : transform->getOpOperands()) {
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
|
|
(DBGS() << "----iterate on handle: " << target.get() << "\n");
|
|
});
|
|
// If the operand uses an invalidated handle, report it. If the operation
|
|
// allows handles to point to repeated payload operations, only report
|
|
// pre-existing invalidation errors. Otherwise, also report invalidations
|
|
// caused by the current transform operation affecting its other operands.
|
|
auto it = invalidatedHandles.find(target.get());
|
|
auto nit = newlyInvalidated.find(target.get());
|
|
if (it != invalidatedHandles.end()) {
|
|
FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
|
|
"invalidated -> FAILURE\n");
|
|
return it->getSecond()(transform->getLoc()), failure();
|
|
}
|
|
if (!transform.allowsRepeatedHandleOperands() &&
|
|
nit != newlyInvalidated.end()) {
|
|
FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
|
|
"invalidated (by this op) -> FAILURE\n");
|
|
return nit->getSecond()(transform->getLoc()), failure();
|
|
}
|
|
|
|
// Invalidate handles pointing to the operations nested in the operation
|
|
// associated with the handle consumed by this operation.
|
|
auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
|
|
return isa<MemoryEffects::Free>(effect.getEffect()) &&
|
|
effect.getValue() == target.get();
|
|
};
|
|
if (llvm::any_of(effects, consumesTarget)) {
|
|
FULL_LDBG("----found consume effect\n");
|
|
if (llvm::isa<transform::TransformHandleTypeInterface>(
|
|
target.get().getType())) {
|
|
FULL_LDBG("----recordOpHandleInvalidation\n");
|
|
SmallVector<Operation *> payloadOps =
|
|
llvm::to_vector(getPayloadOps(target.get()));
|
|
recordOpHandleInvalidation(target, payloadOps, nullptr,
|
|
newlyInvalidated);
|
|
} else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
|
|
target.get().getType())) {
|
|
FULL_LDBG("----recordValueHandleInvalidation\n");
|
|
recordValueHandleInvalidation(target, newlyInvalidated);
|
|
} else {
|
|
FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
|
|
}
|
|
} else {
|
|
FULL_LDBG("----no consume effect -> SKIP\n");
|
|
}
|
|
}
|
|
|
|
FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
|
|
return success();
|
|
}
|
|
|
|
LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
|
|
transform::TransformOpInterface transform) {
|
|
InvalidatedHandleMap newlyInvalidated;
|
|
LogicalResult checkResult =
|
|
checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
|
|
invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
|
|
std::make_move_iterator(newlyInvalidated.end()));
|
|
return checkResult;
|
|
}
|
|
|
|
template <typename T>
|
|
DiagnosedSilenceableFailure
|
|
checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
|
|
transform::TransformOpInterface transform,
|
|
unsigned operandNumber) {
|
|
DenseSet<T> seen;
|
|
for (T p : payload) {
|
|
if (!seen.insert(p).second) {
|
|
DiagnosedSilenceableFailure diag =
|
|
transform.emitSilenceableError()
|
|
<< "a handle passed as operand #" << operandNumber
|
|
<< " and consumed by this operation points to a payload "
|
|
"entity more than once";
|
|
if constexpr (std::is_pointer_v<T>)
|
|
diag.attachNote(p->getLoc()) << "repeated target op";
|
|
else
|
|
diag.attachNote(p.getLoc()) << "repeated target value";
|
|
return diag;
|
|
}
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::TransformState::compactOpHandles() {
|
|
for (Value handle : opHandlesToCompact) {
|
|
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
|
|
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
if (llvm::find(mappings.direct[handle], nullptr) !=
|
|
mappings.direct[handle].end())
|
|
// Payload IR is removed from the mapping. This invalidates the respective
|
|
// iterators.
|
|
mappings.incrementTimestamp(handle);
|
|
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
llvm::erase(mappings.direct[handle], nullptr);
|
|
}
|
|
opHandlesToCompact.clear();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TransformState::applyTransform(TransformOpInterface transform) {
|
|
LLVM_DEBUG({
|
|
DBGS() << "applying: ";
|
|
transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
|
|
llvm::dbgs() << "\n";
|
|
});
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
|
|
DBGS() << "Top-level payload before application:\n"
|
|
<< *getTopLevel() << "\n");
|
|
auto printOnFailureRAII = llvm::make_scope_exit([this] {
|
|
(void)this;
|
|
LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
|
|
llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
|
|
});
|
|
|
|
// Set current transform op.
|
|
regionStack.back()->currentTransform = transform;
|
|
|
|
// Expensive checks to detect invalid transform IR.
|
|
if (options.getExpensiveChecksEnabled()) {
|
|
FULL_LDBG("ExpensiveChecksEnabled\n");
|
|
if (failed(checkAndRecordHandleInvalidation(transform)))
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
|
|
for (OpOperand &operand : transform->getOpOperands()) {
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
|
|
(DBGS() << "iterate on handle: " << operand.get() << "\n");
|
|
});
|
|
if (!isHandleConsumed(operand.get(), transform)) {
|
|
FULL_LDBG("--handle not consumed -> SKIP\n");
|
|
continue;
|
|
}
|
|
if (transform.allowsRepeatedHandleOperands()) {
|
|
FULL_LDBG("--op allows repeated handles -> SKIP\n");
|
|
continue;
|
|
}
|
|
FULL_LDBG("--handle is consumed\n");
|
|
|
|
Type operandType = operand.get().getType();
|
|
if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
|
|
FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
|
|
DiagnosedSilenceableFailure check =
|
|
checkRepeatedConsumptionInOperand<Operation *>(
|
|
getPayloadOpsView(operand.get()), transform,
|
|
operand.getOperandNumber());
|
|
if (!check.succeeded()) {
|
|
FULL_LDBG("----FAILED\n");
|
|
return check;
|
|
}
|
|
} else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
|
|
FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
|
|
DiagnosedSilenceableFailure check =
|
|
checkRepeatedConsumptionInOperand<Value>(
|
|
getPayloadValuesView(operand.get()), transform,
|
|
operand.getOperandNumber());
|
|
if (!check.succeeded()) {
|
|
FULL_LDBG("----FAILED\n");
|
|
return check;
|
|
}
|
|
} else {
|
|
FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
|
|
}
|
|
}
|
|
}
|
|
|
|
// Find which operands are consumed.
|
|
SmallVector<OpOperand *> consumedOperands =
|
|
transform.getConsumedHandleOpOperands();
|
|
|
|
// Remember the results of the payload ops associated with the consumed
|
|
// op handles or the ops defining the value handles so we can drop the
|
|
// association with them later. This must happen here because the
|
|
// transformation may destroy or mutate them so we cannot traverse the payload
|
|
// IR after that.
|
|
SmallVector<Value> origOpFlatResults;
|
|
SmallVector<Operation *> origAssociatedOps;
|
|
for (OpOperand *opOperand : consumedOperands) {
|
|
Value operand = opOperand->get();
|
|
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
|
|
for (Operation *payloadOp : getPayloadOps(operand)) {
|
|
llvm::append_range(origOpFlatResults, payloadOp->getResults());
|
|
}
|
|
continue;
|
|
}
|
|
if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
|
|
for (Value payloadValue : getPayloadValuesView(operand)) {
|
|
if (llvm::isa<OpResult>(payloadValue)) {
|
|
origAssociatedOps.push_back(payloadValue.getDefiningOp());
|
|
continue;
|
|
}
|
|
llvm::append_range(
|
|
origAssociatedOps,
|
|
llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
|
|
[](Operation &op) { return &op; }));
|
|
}
|
|
continue;
|
|
}
|
|
DiagnosedDefiniteFailure diag =
|
|
emitDefiniteFailure(transform->getLoc())
|
|
<< "unexpectedly consumed a value that is not a handle as operand #"
|
|
<< opOperand->getOperandNumber();
|
|
diag.attachNote(operand.getLoc())
|
|
<< "value defined here with type " << operand.getType();
|
|
return diag;
|
|
}
|
|
|
|
// Prepare rewriter and listener.
|
|
TrackingListenerConfig config;
|
|
config.skipHandleFn = [&](Value handle) {
|
|
// Skip handle if it is dead.
|
|
auto scopeIt =
|
|
llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
|
|
return handle.getParentRegion() == scope->region;
|
|
});
|
|
assert(scopeIt != regionStack.rend() &&
|
|
"could not find region scope for handle");
|
|
RegionScope *scope = *scopeIt;
|
|
return llvm::all_of(handle.getUsers(), [&](Operation *user) {
|
|
return user == scope->currentTransform ||
|
|
happensBefore(user, scope->currentTransform);
|
|
});
|
|
};
|
|
transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
|
|
config);
|
|
transform::TransformRewriter rewriter(transform->getContext(),
|
|
&trackingListener);
|
|
|
|
// Compute the result but do not short-circuit the silenceable failure case as
|
|
// we still want the handles to propagate properly so the "suppress" mode can
|
|
// proceed on a best effort basis.
|
|
transform::TransformResults results(transform->getNumResults());
|
|
DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
|
|
compactOpHandles();
|
|
|
|
// Error handling: fail if transform or listener failed.
|
|
DiagnosedSilenceableFailure trackingFailure =
|
|
trackingListener.checkAndResetError();
|
|
if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
|
|
transform->hasAttr(FindPayloadReplacementOpInterface::
|
|
kSilenceTrackingFailuresAttrName)) {
|
|
// Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
|
|
// do not report failures if the above mentioned attribute is set.
|
|
if (trackingFailure.isSilenceableFailure())
|
|
(void)trackingFailure.silence();
|
|
trackingFailure = DiagnosedSilenceableFailure::success();
|
|
}
|
|
if (!trackingFailure.succeeded()) {
|
|
if (result.succeeded()) {
|
|
result = std::move(trackingFailure);
|
|
} else {
|
|
// Transform op errors have precedence, report those first.
|
|
if (result.isSilenceableFailure())
|
|
result.attachNote() << "tracking listener also failed: "
|
|
<< trackingFailure.getMessage();
|
|
(void)trackingFailure.silence();
|
|
}
|
|
}
|
|
if (result.isDefiniteFailure())
|
|
return result;
|
|
|
|
// If a silenceable failure was produced, some results may be unset, set them
|
|
// to empty lists.
|
|
if (result.isSilenceableFailure())
|
|
results.setRemainingToEmpty(transform);
|
|
|
|
// Remove the mapping for the operand if it is consumed by the operation. This
|
|
// allows us to catch use-after-free with assertions later on.
|
|
for (OpOperand *opOperand : consumedOperands) {
|
|
Value operand = opOperand->get();
|
|
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
|
|
forgetMapping(operand, origOpFlatResults);
|
|
} else if (llvm::isa<TransformValueHandleTypeInterface>(
|
|
operand.getType())) {
|
|
forgetValueMapping(operand, origAssociatedOps);
|
|
}
|
|
}
|
|
|
|
if (failed(updateStateFromResults(results, transform->getResults())))
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
|
|
printOnFailureRAII.release();
|
|
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
|
|
DBGS() << "Top-level payload:\n";
|
|
getTopLevel()->print(llvm::dbgs());
|
|
});
|
|
return result;
|
|
}
|
|
|
|
LogicalResult transform::TransformState::updateStateFromResults(
|
|
const TransformResults &results, ResultRange opResults) {
|
|
for (OpResult result : opResults) {
|
|
if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
|
|
assert(results.isParam(result.getResultNumber()) &&
|
|
"expected parameters for the parameter-typed result");
|
|
if (failed(
|
|
setParams(result, results.getParams(result.getResultNumber())))) {
|
|
return failure();
|
|
}
|
|
} else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
|
|
assert(results.isValue(result.getResultNumber()) &&
|
|
"expected values for value-type-result");
|
|
if (failed(setPayloadValues(
|
|
result, results.getValues(result.getResultNumber())))) {
|
|
return failure();
|
|
}
|
|
} else {
|
|
assert(!results.isParam(result.getResultNumber()) &&
|
|
"expected payload ops for the non-parameter typed result");
|
|
if (failed(
|
|
setPayloadOps(result, results.get(result.getResultNumber())))) {
|
|
return failure();
|
|
}
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformState::Extension
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
transform::TransformState::Extension::~Extension() = default;
|
|
|
|
LogicalResult
|
|
transform::TransformState::Extension::replacePayloadOp(Operation *op,
|
|
Operation *replacement) {
|
|
// TODO: we may need to invalidate handles to operations and values nested in
|
|
// the operation being replaced.
|
|
return state.replacePayloadOp(op, replacement);
|
|
}
|
|
|
|
LogicalResult
|
|
transform::TransformState::Extension::replacePayloadValue(Value value,
|
|
Value replacement) {
|
|
return state.replacePayloadValue(value, replacement);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformState::RegionScope
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
transform::TransformState::RegionScope::~RegionScope() {
|
|
// Remove handle invalidation notices as handles are going out of scope.
|
|
// The same region may be re-entered leading to incorrect invalidation
|
|
// errors.
|
|
for (Block &block : *region) {
|
|
for (Value handle : block.getArguments()) {
|
|
state.invalidatedHandles.erase(handle);
|
|
}
|
|
for (Operation &op : block) {
|
|
for (Value handle : op.getResults()) {
|
|
state.invalidatedHandles.erase(handle);
|
|
}
|
|
}
|
|
}
|
|
|
|
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
// Remember pointers to payload ops referenced by the handles going out of
|
|
// scope.
|
|
SmallVector<Operation *> referencedOps =
|
|
llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
|
|
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
|
|
state.mappings.erase(region);
|
|
state.regionStack.pop_back();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformResults
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
transform::TransformResults::TransformResults(unsigned numSegments) {
|
|
operations.appendEmptyRows(numSegments);
|
|
params.appendEmptyRows(numSegments);
|
|
values.appendEmptyRows(numSegments);
|
|
}
|
|
|
|
void transform::TransformResults::setParams(
|
|
OpResult value, ArrayRef<transform::TransformState::Param> params) {
|
|
int64_t position = value.getResultNumber();
|
|
assert(position < static_cast<int64_t>(this->params.size()) &&
|
|
"setting params for a non-existent handle");
|
|
assert(this->params[position].data() == nullptr && "params already set");
|
|
assert(operations[position].data() == nullptr &&
|
|
"another kind of results already set");
|
|
assert(values[position].data() == nullptr &&
|
|
"another kind of results already set");
|
|
this->params.replace(position, params);
|
|
}
|
|
|
|
void transform::TransformResults::setMappedValues(
|
|
OpResult handle, ArrayRef<MappedValue> values) {
|
|
DiagnosedSilenceableFailure diag = dispatchMappedValues(
|
|
handle, values,
|
|
[&](ArrayRef<Operation *> operations) {
|
|
return set(handle, operations), success();
|
|
},
|
|
[&](ArrayRef<Param> params) {
|
|
return setParams(handle, params), success();
|
|
},
|
|
[&](ValueRange payloadValues) {
|
|
return setValues(handle, payloadValues), success();
|
|
});
|
|
#ifndef NDEBUG
|
|
if (!diag.succeeded())
|
|
llvm::dbgs() << diag.getStatusString() << "\n";
|
|
assert(diag.succeeded() && "incorrect mapping");
|
|
#endif // NDEBUG
|
|
(void)diag.silence();
|
|
}
|
|
|
|
void transform::TransformResults::setRemainingToEmpty(
|
|
transform::TransformOpInterface transform) {
|
|
for (OpResult opResult : transform->getResults()) {
|
|
if (!isSet(opResult.getResultNumber()))
|
|
setMappedValues(opResult, {});
|
|
}
|
|
}
|
|
|
|
ArrayRef<Operation *>
|
|
transform::TransformResults::get(unsigned resultNumber) const {
|
|
assert(resultNumber < operations.size() &&
|
|
"querying results for a non-existent handle");
|
|
assert(operations[resultNumber].data() != nullptr &&
|
|
"querying unset results (values or params expected?)");
|
|
return operations[resultNumber];
|
|
}
|
|
|
|
ArrayRef<transform::TransformState::Param>
|
|
transform::TransformResults::getParams(unsigned resultNumber) const {
|
|
assert(resultNumber < params.size() &&
|
|
"querying params for a non-existent handle");
|
|
assert(params[resultNumber].data() != nullptr &&
|
|
"querying unset params (ops or values expected?)");
|
|
return params[resultNumber];
|
|
}
|
|
|
|
ArrayRef<Value>
|
|
transform::TransformResults::getValues(unsigned resultNumber) const {
|
|
assert(resultNumber < values.size() &&
|
|
"querying values for a non-existent handle");
|
|
assert(values[resultNumber].data() != nullptr &&
|
|
"querying unset values (ops or params expected?)");
|
|
return values[resultNumber];
|
|
}
|
|
|
|
bool transform::TransformResults::isParam(unsigned resultNumber) const {
|
|
assert(resultNumber < params.size() &&
|
|
"querying association for a non-existent handle");
|
|
return params[resultNumber].data() != nullptr;
|
|
}
|
|
|
|
bool transform::TransformResults::isValue(unsigned resultNumber) const {
|
|
assert(resultNumber < values.size() &&
|
|
"querying association for a non-existent handle");
|
|
return values[resultNumber].data() != nullptr;
|
|
}
|
|
|
|
bool transform::TransformResults::isSet(unsigned resultNumber) const {
|
|
assert(resultNumber < params.size() &&
|
|
"querying association for a non-existent handle");
|
|
return params[resultNumber].data() != nullptr ||
|
|
operations[resultNumber].data() != nullptr ||
|
|
values[resultNumber].data() != nullptr;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TrackingListener
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
transform::TrackingListener::TrackingListener(TransformState &state,
|
|
TransformOpInterface op,
|
|
TrackingListenerConfig config)
|
|
: TransformState::Extension(state), transformOp(op), config(config) {
|
|
if (op) {
|
|
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
|
|
consumedHandles.insert(opOperand->get());
|
|
}
|
|
}
|
|
}
|
|
|
|
Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
|
|
Operation *defOp = nullptr;
|
|
for (Value v : values) {
|
|
// Skip empty values.
|
|
if (!v)
|
|
continue;
|
|
if (!defOp) {
|
|
defOp = v.getDefiningOp();
|
|
continue;
|
|
}
|
|
if (defOp != v.getDefiningOp())
|
|
return nullptr;
|
|
}
|
|
return defOp;
|
|
}
|
|
|
|
DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
|
|
Operation *&result, Operation *op, ValueRange newValues) const {
|
|
assert(op->getNumResults() == newValues.size() &&
|
|
"invalid number of replacement values");
|
|
SmallVector<Value> values(newValues.begin(), newValues.end());
|
|
|
|
DiagnosedSilenceableFailure diag = emitSilenceableFailure(
|
|
getTransformOp(), "tracking listener failed to find replacement op "
|
|
"during application of this transform op");
|
|
|
|
do {
|
|
// If the replacement values belong to different ops, drop the mapping.
|
|
Operation *defOp = getCommonDefiningOp(values);
|
|
if (!defOp) {
|
|
diag.attachNote() << "replacement values belong to different ops";
|
|
return diag;
|
|
}
|
|
|
|
// Skip through ops that implement CastOpInterface.
|
|
if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
|
|
values.clear();
|
|
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
|
|
diag.attachNote(defOp->getLoc())
|
|
<< "using output of 'CastOpInterface' op";
|
|
continue;
|
|
}
|
|
|
|
// If the defining op has the same name or we do not care about the name of
|
|
// op replacements at all, we take it as a replacement.
|
|
if (!config.requireMatchingReplacementOpName ||
|
|
op->getName() == defOp->getName()) {
|
|
result = defOp;
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
// Replacing an op with a constant-like equivalent is a common
|
|
// canonicalization.
|
|
if (defOp->hasTrait<OpTrait::ConstantLike>()) {
|
|
result = defOp;
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
values.clear();
|
|
|
|
// Skip through ops that implement FindPayloadReplacementOpInterface.
|
|
if (auto findReplacementOpInterface =
|
|
dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
|
|
values.assign(findReplacementOpInterface.getNextOperands());
|
|
diag.attachNote(defOp->getLoc()) << "using operands provided by "
|
|
"'FindPayloadReplacementOpInterface'";
|
|
continue;
|
|
}
|
|
} while (!values.empty());
|
|
|
|
diag.attachNote() << "ran out of suitable replacement values";
|
|
return diag;
|
|
}
|
|
|
|
void transform::TrackingListener::notifyMatchFailure(
|
|
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
|
|
LLVM_DEBUG({
|
|
Diagnostic diag(loc, DiagnosticSeverity::Remark);
|
|
reasonCallback(diag);
|
|
DBGS() << "Match Failure : " << diag.str() << "\n";
|
|
});
|
|
}
|
|
|
|
void transform::TrackingListener::notifyOperationErased(Operation *op) {
|
|
// Remove mappings for result values.
|
|
for (OpResult value : op->getResults())
|
|
(void)replacePayloadValue(value, nullptr);
|
|
// Remove mapping for op.
|
|
(void)replacePayloadOp(op, nullptr);
|
|
}
|
|
|
|
void transform::TrackingListener::notifyOperationReplaced(
|
|
Operation *op, ValueRange newValues) {
|
|
assert(op->getNumResults() == newValues.size() &&
|
|
"invalid number of replacement values");
|
|
|
|
// Replace value handles.
|
|
for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
|
|
(void)replacePayloadValue(oldValue, newValue);
|
|
|
|
// Replace op handle.
|
|
SmallVector<Value> opHandles;
|
|
if (failed(getTransformState().getHandlesForPayloadOp(
|
|
op, opHandles, /*includeOutOfScope=*/true))) {
|
|
// Op is not tracked.
|
|
return;
|
|
}
|
|
|
|
// Helper function to check if the current transform op consumes any handle
|
|
// that is mapped to `op`.
|
|
//
|
|
// Note: If a handle was consumed, there shouldn't be any alive users, so it
|
|
// is not really necessary to check for consumed handles. However, in case
|
|
// there are indeed alive handles that were consumed (which is undefined
|
|
// behavior) and a replacement op could not be found, we want to fail with a
|
|
// nicer error message: "op uses a handle invalidated..." instead of "could
|
|
// not find replacement op". This nicer error is produced later.
|
|
auto handleWasConsumed = [&] {
|
|
return llvm::any_of(opHandles,
|
|
[&](Value h) { return consumedHandles.contains(h); });
|
|
};
|
|
|
|
// Check if there are any handles that must be updated.
|
|
Value aliveHandle;
|
|
if (config.skipHandleFn) {
|
|
auto it = llvm::find_if(opHandles,
|
|
[&](Value v) { return !config.skipHandleFn(v); });
|
|
if (it != opHandles.end())
|
|
aliveHandle = *it;
|
|
} else if (!opHandles.empty()) {
|
|
aliveHandle = opHandles.front();
|
|
}
|
|
if (!aliveHandle || handleWasConsumed()) {
|
|
// The op is tracked but the corresponding handles are dead or were
|
|
// consumed. Drop the op form the mapping.
|
|
(void)replacePayloadOp(op, nullptr);
|
|
return;
|
|
}
|
|
|
|
Operation *replacement;
|
|
DiagnosedSilenceableFailure diag =
|
|
findReplacementOp(replacement, op, newValues);
|
|
// If the op is tracked but no replacement op was found, send a
|
|
// notification.
|
|
if (!diag.succeeded()) {
|
|
diag.attachNote(aliveHandle.getLoc())
|
|
<< "replacement is required because this handle must be updated";
|
|
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
|
|
(void)replacePayloadOp(op, nullptr);
|
|
return;
|
|
}
|
|
|
|
(void)replacePayloadOp(op, replacement);
|
|
}
|
|
|
|
transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
|
|
// The state of the ErrorCheckingTrackingListener must be checked and reset
|
|
// if there was an error. This is to prevent errors from accidentally being
|
|
// missed.
|
|
assert(status.succeeded() && "listener state was not checked");
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::ErrorCheckingTrackingListener::checkAndResetError() {
|
|
DiagnosedSilenceableFailure s = std::move(status);
|
|
status = DiagnosedSilenceableFailure::success();
|
|
errorCounter = 0;
|
|
return s;
|
|
}
|
|
|
|
bool transform::ErrorCheckingTrackingListener::failed() const {
|
|
return !status.succeeded();
|
|
}
|
|
|
|
void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
|
|
Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
|
|
|
|
// Merge potentially existing diags and store the result in the listener.
|
|
SmallVector<Diagnostic> diags;
|
|
diag.takeDiagnostics(diags);
|
|
if (!status.succeeded())
|
|
status.takeDiagnostics(diags);
|
|
status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
|
|
|
|
// Report more details.
|
|
status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
|
|
for (auto &&[index, value] : llvm::enumerate(values))
|
|
status.attachNote(value.getLoc())
|
|
<< "[" << errorCounter << "] replacement value " << index;
|
|
++errorCounter;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformRewriter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
transform::TransformRewriter::TransformRewriter(
|
|
MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
|
|
: RewriterBase(ctx), listener(listener) {
|
|
setListener(listener);
|
|
}
|
|
|
|
bool transform::TransformRewriter::hasTrackingFailures() const {
|
|
return listener->failed();
|
|
}
|
|
|
|
/// Silence all tracking failures that have been encountered so far.
|
|
void transform::TransformRewriter::silenceTrackingFailure() {
|
|
if (hasTrackingFailures()) {
|
|
DiagnosedSilenceableFailure status = listener->checkAndResetError();
|
|
(void)status.silence();
|
|
}
|
|
}
|
|
|
|
LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced(
|
|
Operation *op, Operation *replacement) {
|
|
return listener->replacePayloadOp(op, replacement);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utilities for TransformEachOpTrait.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
transform::detail::checkNestedConsumption(Location loc,
|
|
ArrayRef<Operation *> targets) {
|
|
for (auto &&[position, parent] : llvm::enumerate(targets)) {
|
|
for (Operation *child : targets.drop_front(position + 1)) {
|
|
if (parent->isAncestor(child)) {
|
|
InFlightDiagnostic diag =
|
|
emitError(loc)
|
|
<< "transform operation consumes a handle pointing to an ancestor "
|
|
"payload operation before its descendant";
|
|
diag.attachNote()
|
|
<< "the ancestor is likely erased or rewritten before the "
|
|
"descendant is accessed, leading to undefined behavior";
|
|
diag.attachNote(parent->getLoc()) << "ancestor payload op";
|
|
diag.attachNote(child->getLoc()) << "descendant payload op";
|
|
return diag;
|
|
}
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
transform::detail::checkApplyToOne(Operation *transformOp,
|
|
Location payloadOpLoc,
|
|
const ApplyToEachResultList &partialResult) {
|
|
Location transformOpLoc = transformOp->getLoc();
|
|
StringRef transformOpName = transformOp->getName().getStringRef();
|
|
unsigned expectedNumResults = transformOp->getNumResults();
|
|
|
|
// Reuse the emission of the diagnostic note.
|
|
auto emitDiag = [&]() {
|
|
auto diag = mlir::emitError(transformOpLoc);
|
|
diag.attachNote(payloadOpLoc) << "when applied to this op";
|
|
return diag;
|
|
};
|
|
|
|
if (partialResult.size() != expectedNumResults) {
|
|
auto diag = emitDiag() << "application of " << transformOpName
|
|
<< " expected to produce " << expectedNumResults
|
|
<< " results (actually produced "
|
|
<< partialResult.size() << ").";
|
|
diag.attachNote(transformOpLoc)
|
|
<< "if you need variadic results, consider a generic `apply` "
|
|
<< "instead of the specialized `applyToOne`.";
|
|
return failure();
|
|
}
|
|
|
|
// Check that the right kind of value was produced.
|
|
for (const auto &[ptr, res] :
|
|
llvm::zip(partialResult, transformOp->getResults())) {
|
|
if (ptr.isNull())
|
|
continue;
|
|
if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
|
|
!isa<Operation *>(ptr)) {
|
|
return emitDiag() << "application of " << transformOpName
|
|
<< " expected to produce an Operation * for result #"
|
|
<< res.getResultNumber();
|
|
}
|
|
if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
|
|
!isa<Attribute>(ptr)) {
|
|
return emitDiag() << "application of " << transformOpName
|
|
<< " expected to produce an Attribute for result #"
|
|
<< res.getResultNumber();
|
|
}
|
|
if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
|
|
!isa<Value>(ptr)) {
|
|
return emitDiag() << "application of " << transformOpName
|
|
<< " expected to produce a Value for result #"
|
|
<< res.getResultNumber();
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
template <typename T>
|
|
static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
|
|
return llvm::to_vector(llvm::map_range(
|
|
range, [](transform::MappedValue value) { return cast<T>(value); }));
|
|
}
|
|
|
|
void transform::detail::setApplyToOneResults(
|
|
Operation *transformOp, TransformResults &transformResults,
|
|
ArrayRef<ApplyToEachResultList> results) {
|
|
SmallVector<SmallVector<MappedValue>> transposed;
|
|
transposed.resize(transformOp->getNumResults());
|
|
for (const ApplyToEachResultList &partialResults : results) {
|
|
if (llvm::any_of(partialResults,
|
|
[](MappedValue value) { return value.isNull(); }))
|
|
continue;
|
|
assert(transformOp->getNumResults() == partialResults.size() &&
|
|
"expected as many partial results as op as results");
|
|
for (auto [i, value] : llvm::enumerate(partialResults))
|
|
transposed[i].push_back(value);
|
|
}
|
|
|
|
for (OpResult r : transformOp->getResults()) {
|
|
unsigned position = r.getResultNumber();
|
|
if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
|
|
transformResults.setParams(r,
|
|
castVector<Attribute>(transposed[position]));
|
|
} else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
|
|
transformResults.setValues(r, castVector<Value>(transposed[position]));
|
|
} else {
|
|
transformResults.set(r, castVector<Operation *>(transposed[position]));
|
|
}
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utilities for implementing transform ops with regions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult transform::detail::appendValueMappings(
|
|
MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
|
|
ValueRange values, const transform::TransformState &state, bool flatten) {
|
|
assert(mappings.size() == values.size() && "mismatching number of mappings");
|
|
for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
|
|
size_t mappedSize = mapped.size();
|
|
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
|
|
llvm::append_range(mapped, state.getPayloadOps(operand));
|
|
} else if (llvm::isa<TransformValueHandleTypeInterface>(
|
|
operand.getType())) {
|
|
llvm::append_range(mapped, state.getPayloadValues(operand));
|
|
} else {
|
|
assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
|
|
"unsupported kind of transform dialect value");
|
|
llvm::append_range(mapped, state.getParams(operand));
|
|
}
|
|
|
|
if (mapped.size() - mappedSize != 1 && !flatten)
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void transform::detail::prepareValueMappings(
|
|
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
|
|
ValueRange values, const transform::TransformState &state) {
|
|
mappings.resize(mappings.size() + values.size());
|
|
(void)appendValueMappings(
|
|
MutableArrayRef<SmallVector<transform::MappedValue>>(mappings).take_back(
|
|
values.size()),
|
|
values, state);
|
|
}
|
|
|
|
void transform::detail::forwardTerminatorOperands(
|
|
Block *block, transform::TransformState &state,
|
|
transform::TransformResults &results) {
|
|
for (auto &&[terminatorOperand, result] :
|
|
llvm::zip(block->getTerminator()->getOperands(),
|
|
block->getParentOp()->getOpResults())) {
|
|
if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
|
|
results.set(result, state.getPayloadOps(terminatorOperand));
|
|
} else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
|
|
result.getType())) {
|
|
results.setValues(result, state.getPayloadValues(terminatorOperand));
|
|
} else {
|
|
assert(
|
|
llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
|
|
"unhandled transform type interface");
|
|
results.setParams(result, state.getParams(terminatorOperand));
|
|
}
|
|
}
|
|
}
|
|
|
|
transform::TransformState
|
|
transform::detail::makeTransformStateForTesting(Region *region,
|
|
Operation *payloadRoot) {
|
|
return TransformState(region, payloadRoot);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utilities for PossibleTopLevelTransformOpTrait.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Appends to `effects` the memory effect instances on `target` with the same
|
|
/// resource and effect as the ones the operation `iface` having on `source`.
|
|
static void
|
|
remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
|
|
OpOperand *target,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
|
|
iface.getEffectsOnValue(source, nestedEffects);
|
|
for (const auto &effect : nestedEffects)
|
|
effects.emplace_back(effect.getEffect(), target, effect.getResource());
|
|
}
|
|
|
|
/// Appends to `effects` the same effects as the operations of `block` have on
|
|
/// block arguments but associated with `operands.`
|
|
static void
|
|
remapArgumentEffects(Block &block, MutableArrayRef<OpOperand> operands,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
for (Operation &op : block) {
|
|
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
|
|
if (!iface)
|
|
continue;
|
|
|
|
for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
|
|
remapEffects(iface, source, &target, effects);
|
|
}
|
|
|
|
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
|
|
iface.getEffectsOnResource(transform::PayloadIRResource::get(),
|
|
nestedEffects);
|
|
llvm::append_range(effects, nestedEffects);
|
|
}
|
|
}
|
|
|
|
void transform::detail::getPotentialTopLevelEffects(
|
|
Operation *operation, Value root, Block &body,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(operation->getOpOperands(), effects);
|
|
transform::producesHandle(operation->getOpResults(), effects);
|
|
|
|
if (!root) {
|
|
for (Operation &op : body) {
|
|
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
|
|
if (!iface)
|
|
continue;
|
|
|
|
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
|
|
iface.getEffects(effects);
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Carry over all effects on arguments of the entry block as those on the
|
|
// operands, this is the same value just remapped.
|
|
remapArgumentEffects(body, operation->getOpOperands(), effects);
|
|
}
|
|
|
|
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
|
|
TransformState &state, Operation *op, Region ®ion) {
|
|
SmallVector<Operation *> targets;
|
|
SmallVector<SmallVector<MappedValue>> extraMappings;
|
|
if (op->getNumOperands() != 0) {
|
|
llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
|
|
prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
|
|
} else {
|
|
if (state.getNumTopLevelMappings() !=
|
|
region.front().getNumArguments() - 1) {
|
|
return emitError(op->getLoc())
|
|
<< "operation expects " << region.front().getNumArguments() - 1
|
|
<< " extra value bindings, but " << state.getNumTopLevelMappings()
|
|
<< " were provided to the interpreter";
|
|
}
|
|
|
|
targets.push_back(state.getTopLevel());
|
|
|
|
for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
|
|
extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
|
|
}
|
|
|
|
if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
|
|
return failure();
|
|
|
|
for (BlockArgument argument : region.front().getArguments().drop_front()) {
|
|
if (failed(state.mapBlockArgument(
|
|
argument, extraMappings[argument.getArgNumber() - 1])))
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
|
|
// Attaching this trait without the interface is a misuse of the API, but it
|
|
// cannot be caught via a static_assert because interface registration is
|
|
// dynamic.
|
|
assert(isa<TransformOpInterface>(op) &&
|
|
"should implement TransformOpInterface to have "
|
|
"PossibleTopLevelTransformOpTrait");
|
|
|
|
if (op->getNumRegions() < 1)
|
|
return op->emitOpError() << "expects at least one region";
|
|
|
|
Region *bodyRegion = &op->getRegion(0);
|
|
if (!llvm::hasNItems(*bodyRegion, 1))
|
|
return op->emitOpError() << "expects a single-block region";
|
|
|
|
Block *body = &bodyRegion->front();
|
|
if (body->getNumArguments() == 0) {
|
|
return op->emitOpError()
|
|
<< "expects the entry block to have at least one argument";
|
|
}
|
|
if (!llvm::isa<TransformHandleTypeInterface>(
|
|
body->getArgument(0).getType())) {
|
|
return op->emitOpError()
|
|
<< "expects the first entry block argument to be of type "
|
|
"implementing TransformHandleTypeInterface";
|
|
}
|
|
BlockArgument arg = body->getArgument(0);
|
|
if (op->getNumOperands() != 0) {
|
|
if (arg.getType() != op->getOperand(0).getType()) {
|
|
return op->emitOpError()
|
|
<< "expects the type of the block argument to match "
|
|
"the type of the operand";
|
|
}
|
|
}
|
|
for (BlockArgument arg : body->getArguments().drop_front()) {
|
|
if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
|
|
TransformValueHandleTypeInterface>(arg.getType()))
|
|
continue;
|
|
|
|
InFlightDiagnostic diag =
|
|
op->emitOpError()
|
|
<< "expects trailing entry block arguments to be of type implementing "
|
|
"TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
|
|
"TransformParamTypeInterface";
|
|
diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
|
|
return diag;
|
|
}
|
|
|
|
if (auto *parent =
|
|
op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
|
|
if (op->getNumOperands() != body->getNumArguments()) {
|
|
InFlightDiagnostic diag =
|
|
op->emitOpError()
|
|
<< "expects operands to be provided for a nested op";
|
|
diag.attachNote(parent->getLoc())
|
|
<< "nested in another possible top-level op";
|
|
return diag;
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utilities for ParamProducedTransformOpTrait.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void transform::detail::getParamProducerTransformOpTraitEffects(
|
|
Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
producesHandle(op->getResults(), effects);
|
|
bool hasPayloadOperands = false;
|
|
for (OpOperand &operand : op->getOpOperands()) {
|
|
onlyReadsHandle(operand, effects);
|
|
if (llvm::isa<TransformHandleTypeInterface,
|
|
TransformValueHandleTypeInterface>(operand.get().getType()))
|
|
hasPayloadOperands = true;
|
|
}
|
|
if (hasPayloadOperands)
|
|
onlyReadsPayload(effects);
|
|
}
|
|
|
|
LogicalResult
|
|
transform::detail::verifyParamProducerTransformOpTrait(Operation *op) {
|
|
// Interfaces can be attached dynamically, so this cannot be a static
|
|
// assert.
|
|
if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
|
|
llvm::report_fatal_error(
|
|
Twine("ParamProducerTransformOpTrait must be attached to an op that "
|
|
"implements MemoryEffectsOpInterface, found on ") +
|
|
op->getName().getStringRef());
|
|
}
|
|
for (Value result : op->getResults()) {
|
|
if (llvm::isa<TransformParamTypeInterface>(result.getType()))
|
|
continue;
|
|
return op->emitOpError()
|
|
<< "ParamProducerTransformOpTrait attached to this op expects "
|
|
"result types to implement TransformParamTypeInterface";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Memory effects.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void transform::consumesHandle(
|
|
MutableArrayRef<OpOperand> handles,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
for (OpOperand &handle : handles) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), &handle,
|
|
TransformMappingResource::get());
|
|
effects.emplace_back(MemoryEffects::Free::get(), &handle,
|
|
TransformMappingResource::get());
|
|
}
|
|
}
|
|
|
|
/// Returns `true` if the given list of effects instances contains an instance
|
|
/// with the effect type specified as template parameter.
|
|
template <typename EffectTy, typename ResourceTy, typename Range>
|
|
static bool hasEffect(Range &&effects) {
|
|
return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
|
|
return isa<EffectTy>(effect.getEffect()) &&
|
|
isa<ResourceTy>(effect.getResource());
|
|
});
|
|
}
|
|
|
|
bool transform::isHandleConsumed(Value handle,
|
|
transform::TransformOpInterface transform) {
|
|
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
|
|
SmallVector<MemoryEffects::EffectInstance> effects;
|
|
iface.getEffectsOnValue(handle, effects);
|
|
return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
|
|
::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
|
|
}
|
|
|
|
void transform::producesHandle(
|
|
ResultRange handles,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
for (OpResult handle : handles) {
|
|
effects.emplace_back(MemoryEffects::Allocate::get(), handle,
|
|
TransformMappingResource::get());
|
|
effects.emplace_back(MemoryEffects::Write::get(), handle,
|
|
TransformMappingResource::get());
|
|
}
|
|
}
|
|
|
|
void transform::producesHandle(
|
|
MutableArrayRef<BlockArgument> handles,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
for (BlockArgument handle : handles) {
|
|
effects.emplace_back(MemoryEffects::Allocate::get(), handle,
|
|
TransformMappingResource::get());
|
|
effects.emplace_back(MemoryEffects::Write::get(), handle,
|
|
TransformMappingResource::get());
|
|
}
|
|
}
|
|
|
|
void transform::onlyReadsHandle(
|
|
MutableArrayRef<OpOperand> handles,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
for (OpOperand &handle : handles) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), &handle,
|
|
TransformMappingResource::get());
|
|
}
|
|
}
|
|
|
|
void transform::modifiesPayload(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
|
|
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
|
|
}
|
|
|
|
void transform::onlyReadsPayload(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
|
|
}
|
|
|
|
bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
|
|
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
|
|
SmallVector<MemoryEffects::EffectInstance> effects;
|
|
iface.getEffects(effects);
|
|
return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
|
|
}
|
|
|
|
bool transform::doesReadPayload(transform::TransformOpInterface transform) {
|
|
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
|
|
SmallVector<MemoryEffects::EffectInstance> effects;
|
|
iface.getEffects(effects);
|
|
return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
|
|
}
|
|
|
|
void transform::getConsumedBlockArguments(
|
|
Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
|
|
SmallVector<MemoryEffects::EffectInstance> effects;
|
|
for (Operation &nested : block) {
|
|
auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
|
|
if (!iface)
|
|
continue;
|
|
|
|
effects.clear();
|
|
iface.getEffects(effects);
|
|
for (const MemoryEffects::EffectInstance &effect : effects) {
|
|
BlockArgument argument =
|
|
dyn_cast_or_null<BlockArgument>(effect.getValue());
|
|
if (!argument || argument.getOwner() != &block ||
|
|
!isa<MemoryEffects::Free>(effect.getEffect()) ||
|
|
effect.getResource() != transform::TransformMappingResource::get()) {
|
|
continue;
|
|
}
|
|
consumedArguments.insert(argument.getArgNumber());
|
|
}
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utilities for TransformOpInterface.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
|
|
TransformOpInterface transformOp) {
|
|
SmallVector<OpOperand *> consumedOperands;
|
|
consumedOperands.reserve(transformOp->getNumOperands());
|
|
auto memEffectInterface =
|
|
cast<MemoryEffectOpInterface>(transformOp.getOperation());
|
|
SmallVector<MemoryEffects::EffectInstance, 2> effects;
|
|
for (OpOperand &target : transformOp->getOpOperands()) {
|
|
effects.clear();
|
|
memEffectInterface.getEffectsOnValue(target.get(), effects);
|
|
if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
|
|
return isa<transform::TransformMappingResource>(
|
|
effect.getResource()) &&
|
|
isa<MemoryEffects::Free>(effect.getEffect());
|
|
})) {
|
|
consumedOperands.push_back(&target);
|
|
}
|
|
}
|
|
return consumedOperands;
|
|
}
|
|
|
|
LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
|
|
auto iface = cast<MemoryEffectOpInterface>(op);
|
|
SmallVector<MemoryEffects::EffectInstance> effects;
|
|
iface.getEffects(effects);
|
|
|
|
auto effectsOn = [&](Value value) {
|
|
return llvm::make_filter_range(
|
|
effects, [value](const MemoryEffects::EffectInstance &instance) {
|
|
return instance.getValue() == value;
|
|
});
|
|
};
|
|
|
|
std::optional<unsigned> firstConsumedOperand;
|
|
for (OpOperand &operand : op->getOpOperands()) {
|
|
auto range = effectsOn(operand.get());
|
|
if (range.empty()) {
|
|
InFlightDiagnostic diag =
|
|
op->emitError() << "TransformOpInterface requires memory effects "
|
|
"on operands to be specified";
|
|
diag.attachNote() << "no effects specified for operand #"
|
|
<< operand.getOperandNumber();
|
|
return diag;
|
|
}
|
|
if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
|
|
InFlightDiagnostic diag = op->emitError()
|
|
<< "TransformOpInterface did not expect "
|
|
"'allocate' memory effect on an operand";
|
|
diag.attachNote() << "specified for operand #"
|
|
<< operand.getOperandNumber();
|
|
return diag;
|
|
}
|
|
if (!firstConsumedOperand &&
|
|
::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
|
|
firstConsumedOperand = operand.getOperandNumber();
|
|
}
|
|
}
|
|
|
|
if (firstConsumedOperand &&
|
|
!::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
|
|
InFlightDiagnostic diag =
|
|
op->emitError()
|
|
<< "TransformOpInterface expects ops consuming operands to have a "
|
|
"'write' effect on the payload resource";
|
|
diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
|
|
return diag;
|
|
}
|
|
|
|
for (OpResult result : op->getResults()) {
|
|
auto range = effectsOn(result);
|
|
if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
|
|
range)) {
|
|
InFlightDiagnostic diag =
|
|
op->emitError() << "TransformOpInterface requires 'allocate' memory "
|
|
"effect to be specified for results";
|
|
diag.attachNote() << "no 'allocate' effect specified for result #"
|
|
<< result.getResultNumber();
|
|
return diag;
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Entry point.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult transform::applyTransforms(
|
|
Operation *payloadRoot, TransformOpInterface transform,
|
|
const RaggedArray<MappedValue> &extraMapping,
|
|
const TransformOptions &options, bool enforceToplevelTransformOp,
|
|
function_ref<void(TransformState &)> stateInitializer,
|
|
function_ref<LogicalResult(TransformState &)> stateExporter) {
|
|
if (enforceToplevelTransformOp) {
|
|
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
|
|
transform->getNumOperands() != 0) {
|
|
return transform->emitError()
|
|
<< "expected transform to start at the top-level transform op";
|
|
}
|
|
} else if (failed(
|
|
detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
|
|
return failure();
|
|
}
|
|
|
|
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
|
|
options);
|
|
if (stateInitializer)
|
|
stateInitializer(state);
|
|
if (state.applyTransform(transform).checkAndReport().failed())
|
|
return failure();
|
|
if (stateExporter)
|
|
return stateExporter(state);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Generated interface implementation.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
|