Adapt the implementation of TransformEachOpTrait to the existence of parameter values recently introduced into the transform dialect. In particular, allow `applyToOne` hooks to return a list containing a mix of `Operation *` that will be associated with handles and `Attribute` that will be associated with parameter values by the trait implementation of the transform interface's `apply` method. Disentangle the "transposition" of the list of per-payload op partial results to decrease its overall complexity and detemplatize the code that doesn't really need templates. This removes the poorly documented special handling for single-result ops with TransformEachOpTrait that could have assigned null pointer values to handles. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D140979
679 lines
26 KiB
C++
679 lines
26 KiB
C++
//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/ScopeExit.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
#define DEBUG_TYPE "transform-dialect"
|
|
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
|
|
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformState
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
constexpr const Value transform::TransformState::kTopLevelValue;
|
|
|
|
transform::TransformState::TransformState(Region *region,
|
|
Operation *payloadRoot,
|
|
const TransformOptions &options)
|
|
: topLevel(payloadRoot), options(options) {
|
|
auto result = mappings.try_emplace(region);
|
|
assert(result.second && "the region scope is already present");
|
|
(void)result;
|
|
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
regionStack.push_back(region);
|
|
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
}
|
|
|
|
Operation *transform::TransformState::getTopLevel() const { return topLevel; }
|
|
|
|
ArrayRef<Operation *>
|
|
transform::TransformState::getPayloadOps(Value value) const {
|
|
const TransformOpMapping &operationMapping = getMapping(value).direct;
|
|
auto iter = operationMapping.find(value);
|
|
assert(iter != operationMapping.end() &&
|
|
"cannot find mapping for payload handle (param handle provided?)");
|
|
return iter->getSecond();
|
|
}
|
|
|
|
ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
|
|
const ParamMapping &mapping = getMapping(value).params;
|
|
auto iter = mapping.find(value);
|
|
assert(iter != mapping.end() &&
|
|
"cannot find mapping for param handle (payload handle provided?)");
|
|
return iter->getSecond();
|
|
}
|
|
|
|
LogicalResult transform::TransformState::getHandlesForPayloadOp(
|
|
Operation *op, SmallVectorImpl<Value> &handles) const {
|
|
bool found = false;
|
|
for (const Mappings &mapping : llvm::make_second_range(mappings)) {
|
|
auto iterator = mapping.reverse.find(op);
|
|
if (iterator != mapping.reverse.end()) {
|
|
llvm::append_range(handles, iterator->getSecond());
|
|
found = true;
|
|
}
|
|
}
|
|
|
|
return success(found);
|
|
}
|
|
|
|
LogicalResult
|
|
transform::TransformState::setPayloadOps(Value value,
|
|
ArrayRef<Operation *> targets) {
|
|
assert(value != kTopLevelValue &&
|
|
"attempting to reset the transformation root");
|
|
assert(!value.getType().isa<TransformParamTypeInterface>() &&
|
|
"cannot associate payload ops with a value of parameter type");
|
|
|
|
auto iface = value.getType().cast<TransformHandleTypeInterface>();
|
|
DiagnosedSilenceableFailure result =
|
|
iface.checkPayload(value.getLoc(), targets);
|
|
if (failed(result.checkAndReport()))
|
|
return failure();
|
|
|
|
// Setting new payload for the value without cleaning it first is a misuse of
|
|
// the API, assert here.
|
|
SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
|
|
Mappings &mappings = getMapping(value);
|
|
bool inserted =
|
|
mappings.direct.insert({value, std::move(storedTargets)}).second;
|
|
assert(inserted && "value is already associated with another list");
|
|
(void)inserted;
|
|
|
|
for (Operation *op : targets)
|
|
mappings.reverse[op].push_back(value);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult transform::TransformState::setParams(Value value,
|
|
ArrayRef<Param> params) {
|
|
assert(value != nullptr && "attempting to set params for a null value");
|
|
|
|
auto valueType = value.getType().dyn_cast<TransformParamTypeInterface>();
|
|
assert(value &&
|
|
"cannot associate parameter with a value of non-parameter type");
|
|
DiagnosedSilenceableFailure result =
|
|
valueType.checkPayload(value.getLoc(), params);
|
|
if (failed(result.checkAndReport()))
|
|
return failure();
|
|
|
|
Mappings &mappings = getMapping(value);
|
|
bool inserted =
|
|
mappings.params.insert({value, llvm::to_vector(params)}).second;
|
|
assert(inserted && "value is already associated with another list of params");
|
|
(void)inserted;
|
|
return success();
|
|
}
|
|
|
|
void transform::TransformState::dropReverseMapping(Mappings &mappings,
|
|
Operation *op, Value value) {
|
|
auto it = mappings.reverse.find(op);
|
|
if (it == mappings.reverse.end())
|
|
return;
|
|
|
|
llvm::erase_value(it->getSecond(), value);
|
|
if (it->getSecond().empty())
|
|
mappings.reverse.erase(it);
|
|
}
|
|
|
|
void transform::TransformState::removePayloadOps(Value value) {
|
|
Mappings &mappings = getMapping(value);
|
|
for (Operation *op : mappings.direct[value])
|
|
dropReverseMapping(mappings, op, value);
|
|
mappings.direct.erase(value);
|
|
}
|
|
|
|
LogicalResult transform::TransformState::updatePayloadOps(
|
|
Value value, function_ref<Operation *(Operation *)> callback) {
|
|
Mappings &mappings = getMapping(value);
|
|
auto it = mappings.direct.find(value);
|
|
assert(it != mappings.direct.end() && "unknown handle");
|
|
SmallVector<Operation *, 2> &association = it->getSecond();
|
|
SmallVector<Operation *, 2> updated;
|
|
updated.reserve(association.size());
|
|
|
|
for (Operation *op : association) {
|
|
dropReverseMapping(mappings, op, value);
|
|
if (Operation *updatedOp = callback(op)) {
|
|
updated.push_back(updatedOp);
|
|
mappings.reverse[updatedOp].push_back(value);
|
|
}
|
|
}
|
|
|
|
auto iface = value.getType().cast<TransformHandleTypeInterface>();
|
|
DiagnosedSilenceableFailure result =
|
|
iface.checkPayload(value.getLoc(), updated);
|
|
if (failed(result.checkAndReport()))
|
|
return failure();
|
|
|
|
it->second = updated;
|
|
return success();
|
|
}
|
|
|
|
void transform::TransformState::recordHandleInvalidationOne(
|
|
OpOperand &handle, Operation *payloadOp, Value otherHandle) {
|
|
ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
|
|
// If the op is associated with invalidated handle, skip the check as it
|
|
// may be reading invalid IR.
|
|
if (invalidatedHandles.count(otherHandle))
|
|
return;
|
|
|
|
for (Operation *ancestor : potentialAncestors) {
|
|
if (!ancestor->isAncestor(payloadOp))
|
|
continue;
|
|
|
|
// Make sure the error-reporting lambda doesn't capture anything
|
|
// by-reference because it will go out of scope. Additionally, extract
|
|
// location from Payload IR ops because the ops themselves may be
|
|
// deleted before the lambda gets called.
|
|
Location ancestorLoc = ancestor->getLoc();
|
|
Location opLoc = payloadOp->getLoc();
|
|
Operation *owner = handle.getOwner();
|
|
unsigned operandNo = handle.getOperandNumber();
|
|
invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
|
|
otherHandle](Location currentLoc) {
|
|
InFlightDiagnostic diag = emitError(currentLoc)
|
|
<< "op uses a handle invalidated by a "
|
|
"previously executed transform op";
|
|
diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
|
|
diag.attachNote(owner->getLoc())
|
|
<< "invalidated by this transform op that consumes its operand #"
|
|
<< operandNo
|
|
<< " and invalidates handles to payload ops nested in payload "
|
|
"ops associated with the consumed handle";
|
|
diag.attachNote(ancestorLoc) << "ancestor payload op";
|
|
diag.attachNote(opLoc) << "nested payload op";
|
|
};
|
|
}
|
|
}
|
|
|
|
void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
|
|
for (const Mappings &mapping : llvm::make_second_range(mappings))
|
|
for (const auto &[payloadOp, otherHandles] : mapping.reverse)
|
|
for (Value otherHandle : otherHandles)
|
|
recordHandleInvalidationOne(handle, payloadOp, otherHandle);
|
|
}
|
|
|
|
LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
|
|
TransformOpInterface transform) {
|
|
auto memoryEffectsIface =
|
|
cast<MemoryEffectOpInterface>(transform.getOperation());
|
|
SmallVector<MemoryEffects::EffectInstance> effects;
|
|
memoryEffectsIface.getEffectsOnResource(
|
|
transform::TransformMappingResource::get(), effects);
|
|
|
|
for (OpOperand &target : transform->getOpOperands()) {
|
|
// If the operand uses an invalidated handle, report it.
|
|
auto it = invalidatedHandles.find(target.get());
|
|
if (!transform.allowsRepeatedHandleOperands() &&
|
|
it != invalidatedHandles.end())
|
|
return it->getSecond()(transform->getLoc()), failure();
|
|
|
|
// Invalidate handles pointing to the operations nested in the operation
|
|
// associated with the handle consumed by this operation.
|
|
auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
|
|
return isa<MemoryEffects::Free>(effect.getEffect()) &&
|
|
effect.getValue() == target.get();
|
|
};
|
|
if (llvm::any_of(effects, consumesTarget))
|
|
recordHandleInvalidation(target);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TransformState::applyTransform(TransformOpInterface transform) {
|
|
LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
|
|
auto printOnFailureRAII = llvm::make_scope_exit([this] {
|
|
(void)this;
|
|
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
|
|
DBGS() << "Top-level payload:\n";
|
|
getTopLevel()->print(llvm::dbgs(),
|
|
mlir::OpPrintingFlags().printGenericOpForm());
|
|
});
|
|
});
|
|
if (options.getExpensiveChecksEnabled()) {
|
|
if (failed(checkAndRecordHandleInvalidation(transform)))
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
|
|
for (OpOperand &operand : transform->getOpOperands()) {
|
|
if (!isHandleConsumed(operand.get(), transform))
|
|
continue;
|
|
|
|
DenseSet<Operation *> seen;
|
|
for (Operation *op : getPayloadOps(operand.get())) {
|
|
if (!seen.insert(op).second) {
|
|
DiagnosedSilenceableFailure diag =
|
|
transform.emitSilenceableError()
|
|
<< "a handle passed as operand #" << operand.getOperandNumber()
|
|
<< " and consumed by this operation points to a payload "
|
|
"operation more than once";
|
|
diag.attachNote(op->getLoc()) << "repeated target op";
|
|
return diag;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
transform::TransformResults results(transform->getNumResults());
|
|
// Compute the result but do not short-circuit the silenceable failure case as
|
|
// we still want the handles to propagate properly so the "suppress" mode can
|
|
// proceed on a best effort basis.
|
|
DiagnosedSilenceableFailure result(transform.apply(results, *this));
|
|
if (result.isDefiniteFailure())
|
|
return result;
|
|
|
|
// Remove the mapping for the operand if it is consumed by the operation. This
|
|
// allows us to catch use-after-free with assertions later on.
|
|
auto memEffectInterface =
|
|
cast<MemoryEffectOpInterface>(transform.getOperation());
|
|
SmallVector<MemoryEffects::EffectInstance, 2> effects;
|
|
for (OpOperand &target : transform->getOpOperands()) {
|
|
effects.clear();
|
|
memEffectInterface.getEffectsOnValue(target.get(), effects);
|
|
if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
|
|
return isa<transform::TransformMappingResource>(
|
|
effect.getResource()) &&
|
|
isa<MemoryEffects::Free>(effect.getEffect());
|
|
})) {
|
|
removePayloadOps(target.get());
|
|
}
|
|
}
|
|
|
|
for (OpResult result : transform->getResults()) {
|
|
assert(result.getDefiningOp() == transform.getOperation() &&
|
|
"payload IR association for a value other than the result of the "
|
|
"current transform op");
|
|
if (result.getType().isa<TransformParamTypeInterface>()) {
|
|
assert(results.isParam(result.getResultNumber()) &&
|
|
"expected parameters for the parameter-typed result");
|
|
if (failed(
|
|
setParams(result, results.getParams(result.getResultNumber())))) {
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
} else {
|
|
assert(!results.isParam(result.getResultNumber()) &&
|
|
"expected payload ops for the non-parameter typed result");
|
|
if (failed(
|
|
setPayloadOps(result, results.get(result.getResultNumber())))) {
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
}
|
|
}
|
|
|
|
printOnFailureRAII.release();
|
|
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
|
|
DBGS() << "Top-level payload:\n";
|
|
getTopLevel()->print(llvm::dbgs());
|
|
});
|
|
return result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformState::Extension
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
transform::TransformState::Extension::~Extension() = default;
|
|
|
|
LogicalResult
|
|
transform::TransformState::Extension::replacePayloadOp(Operation *op,
|
|
Operation *replacement) {
|
|
SmallVector<Value> handles;
|
|
if (failed(state.getHandlesForPayloadOp(op, handles)))
|
|
return failure();
|
|
|
|
for (Value handle : handles) {
|
|
LogicalResult result =
|
|
state.updatePayloadOps(handle, [&](Operation *current) {
|
|
return current == op ? replacement : current;
|
|
});
|
|
if (failed(result))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformResults
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
transform::TransformResults::TransformResults(unsigned numSegments) {
|
|
segments.resize(numSegments,
|
|
ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
|
|
paramSegments.resize(numSegments, ArrayRef<TransformState::Param>(
|
|
nullptr, static_cast<size_t>(0)));
|
|
}
|
|
|
|
void transform::TransformResults::set(OpResult value,
|
|
ArrayRef<Operation *> ops) {
|
|
int64_t position = value.getResultNumber();
|
|
assert(position < static_cast<int64_t>(segments.size()) &&
|
|
"setting results for a non-existent handle");
|
|
assert(segments[position].data() == nullptr && "results already set");
|
|
int64_t start = operations.size();
|
|
llvm::append_range(operations, ops);
|
|
segments[position] = makeArrayRef(operations).drop_front(start);
|
|
}
|
|
|
|
void transform::TransformResults::setParams(
|
|
OpResult value, ArrayRef<transform::TransformState::Param> params) {
|
|
int64_t position = value.getResultNumber();
|
|
assert(position < static_cast<int64_t>(paramSegments.size()) &&
|
|
"setting params for a non-existent handle");
|
|
assert(paramSegments[position].data() == nullptr && "params already set");
|
|
size_t start = this->params.size();
|
|
llvm::append_range(this->params, params);
|
|
paramSegments[position] = makeArrayRef(this->params).drop_front(start);
|
|
}
|
|
|
|
ArrayRef<Operation *>
|
|
transform::TransformResults::get(unsigned resultNumber) const {
|
|
assert(resultNumber < segments.size() &&
|
|
"querying results for a non-existent handle");
|
|
assert(segments[resultNumber].data() != nullptr &&
|
|
"querying unset results (param expected?)");
|
|
return segments[resultNumber];
|
|
}
|
|
|
|
ArrayRef<transform::TransformState::Param>
|
|
transform::TransformResults::getParams(unsigned resultNumber) const {
|
|
assert(resultNumber < paramSegments.size() &&
|
|
"querying params for a non-existent handle");
|
|
assert(paramSegments[resultNumber].data() != nullptr &&
|
|
"querying unset params (payload ops expected?)");
|
|
return paramSegments[resultNumber];
|
|
}
|
|
|
|
bool transform::TransformResults::isParam(unsigned resultNumber) const {
|
|
assert(resultNumber < paramSegments.size() &&
|
|
"querying association for a non-existent handle");
|
|
return paramSegments[resultNumber].data() != nullptr;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utilities for TransformEachOpTrait.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
transform::detail::checkApplyToOne(Operation *transformOp,
|
|
Location payloadOpLoc,
|
|
const ApplyToEachResultList &partialResult) {
|
|
Location transformOpLoc = transformOp->getLoc();
|
|
StringRef transformOpName = transformOp->getName().getStringRef();
|
|
unsigned expectedNumResults = transformOp->getNumResults();
|
|
// TODO: encode this implicit must always produce `expectedNumResults`
|
|
// and nullptr is fine with a proper trait.
|
|
if (partialResult.size() != expectedNumResults) {
|
|
auto diag = mlir::emitError(transformOpLoc, "applications of ")
|
|
<< transformOpName << " expected to produce "
|
|
<< expectedNumResults << " results (actually produced "
|
|
<< partialResult.size() << ").";
|
|
diag.attachNote(transformOpLoc)
|
|
<< "If you need variadic results, consider a generic `apply` "
|
|
<< "instead of the specialized `applyToOne`.";
|
|
diag.attachNote(transformOpLoc)
|
|
<< "Producing " << expectedNumResults << " null results is "
|
|
<< "allowed if the use case warrants it.";
|
|
diag.attachNote(payloadOpLoc) << "when applied to this op";
|
|
return failure();
|
|
}
|
|
|
|
// Check that all is null or none is null
|
|
// TODO: relax this behavior and encode with a proper trait.
|
|
if (llvm::any_of(
|
|
partialResult,
|
|
[](llvm::PointerUnion<Operation *, Attribute> ptr) { return ptr; }) &&
|
|
llvm::any_of(partialResult,
|
|
[](llvm::PointerUnion<Operation *, Attribute> ptr) {
|
|
return !ptr;
|
|
})) {
|
|
auto diag = mlir::emitError(transformOpLoc, "unexpected application of ")
|
|
<< transformOpName
|
|
<< " produces both null and non null results.";
|
|
diag.attachNote(payloadOpLoc) << "when applied to this op";
|
|
return failure();
|
|
}
|
|
|
|
// Check that the right kind of value was produced.
|
|
for (const auto &[ptr, res] :
|
|
llvm::zip(partialResult, transformOp->getResults())) {
|
|
if (ptr.is<Operation *>() &&
|
|
!res.getType().template isa<TransformHandleTypeInterface>()) {
|
|
mlir::emitError(transformOpLoc)
|
|
<< "applications of " << transformOpName
|
|
<< " expected to produce an Attribute for result #"
|
|
<< res.getResultNumber();
|
|
return failure();
|
|
}
|
|
if (ptr.is<Attribute>() &&
|
|
!res.getType().template isa<TransformParamTypeInterface>()) {
|
|
mlir::emitError(transformOpLoc)
|
|
<< "applications of " << transformOpName
|
|
<< " expected to produce an Operation * for result #"
|
|
<< res.getResultNumber();
|
|
return failure();
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void transform::detail::setApplyToOneResults(
|
|
Operation *transformOp, TransformResults &transformResults,
|
|
ArrayRef<ApplyToEachResultList> results) {
|
|
for (OpResult r : transformOp->getResults()) {
|
|
if (r.getType().isa<TransformParamTypeInterface>()) {
|
|
auto params = llvm::to_vector(
|
|
llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
|
|
return oneResult[r.getResultNumber()].get<Attribute>();
|
|
}));
|
|
transformResults.setParams(r, params);
|
|
} else {
|
|
auto payloads = llvm::to_vector(
|
|
llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
|
|
return oneResult[r.getResultNumber()].get<Operation *>();
|
|
}));
|
|
transformResults.set(r, payloads);
|
|
}
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utilities for PossibleTopLevelTransformOpTrait.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
|
|
TransformState &state, Operation *op, Region ®ion) {
|
|
SmallVector<Operation *> targets;
|
|
if (op->getNumOperands() != 0)
|
|
llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
|
|
else
|
|
targets.push_back(state.getTopLevel());
|
|
|
|
return state.mapBlockArguments(region.front().getArgument(0), targets);
|
|
}
|
|
|
|
LogicalResult
|
|
transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
|
|
// Attaching this trait without the interface is a misuse of the API, but it
|
|
// cannot be caught via a static_assert because interface registration is
|
|
// dynamic.
|
|
assert(isa<TransformOpInterface>(op) &&
|
|
"should implement TransformOpInterface to have "
|
|
"PossibleTopLevelTransformOpTrait");
|
|
|
|
if (op->getNumRegions() < 1)
|
|
return op->emitOpError() << "expects at least one region";
|
|
|
|
Region *bodyRegion = &op->getRegion(0);
|
|
if (!llvm::hasNItems(*bodyRegion, 1))
|
|
return op->emitOpError() << "expects a single-block region";
|
|
|
|
Block *body = &bodyRegion->front();
|
|
if (body->getNumArguments() != 1 ||
|
|
!body->getArgumentTypes()[0].isa<TransformHandleTypeInterface>()) {
|
|
return op->emitOpError()
|
|
<< "expects the entry block to have one argument "
|
|
"of type implementing TransformHandleTypeInterface";
|
|
}
|
|
|
|
if (auto *parent =
|
|
op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
|
|
if (op->getNumOperands() == 0) {
|
|
InFlightDiagnostic diag =
|
|
op->emitOpError()
|
|
<< "expects the root operation to be provided for a nested op";
|
|
diag.attachNote(parent->getLoc())
|
|
<< "nested in another possible top-level op";
|
|
return diag;
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utilities for ParamProducedTransformOpTrait.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void transform::detail::getParamProducerTransformOpTraitEffects(
|
|
Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
producesHandle(op->getResults(), effects);
|
|
bool hasPayloadOperands = false;
|
|
for (Value operand : op->getOperands()) {
|
|
onlyReadsHandle(operand, effects);
|
|
if (operand.getType().isa<TransformHandleTypeInterface>())
|
|
hasPayloadOperands = true;
|
|
}
|
|
if (hasPayloadOperands)
|
|
onlyReadsPayload(effects);
|
|
}
|
|
|
|
LogicalResult
|
|
transform::detail::verifyParamProducerTransformOpTrait(Operation *op) {
|
|
// Interfaces can be attached dynamically, so this cannot be a static
|
|
// assert.
|
|
if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
|
|
llvm::report_fatal_error(
|
|
Twine("ParamProducerTransformOpTrait must be attached to an op that "
|
|
"implements MemoryEffectsOpInterface, found on ") +
|
|
op->getName().getStringRef());
|
|
}
|
|
for (Value result : op->getResults()) {
|
|
if (result.getType().isa<TransformParamTypeInterface>())
|
|
continue;
|
|
return op->emitOpError()
|
|
<< "ParamProducerTransformOpTrait attached to this op expects "
|
|
"result types to implement TransformParamTypeInterface";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Memory effects.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void transform::consumesHandle(
|
|
ValueRange handles,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
for (Value handle : handles) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), handle,
|
|
TransformMappingResource::get());
|
|
effects.emplace_back(MemoryEffects::Free::get(), handle,
|
|
TransformMappingResource::get());
|
|
}
|
|
}
|
|
|
|
/// Returns `true` if the given list of effects instances contains an instance
|
|
/// with the effect type specified as template parameter.
|
|
template <typename EffectTy, typename ResourceTy = SideEffects::DefaultResource>
|
|
static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
|
|
return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
|
|
return isa<EffectTy>(effect.getEffect()) &&
|
|
isa<ResourceTy>(effect.getResource());
|
|
});
|
|
}
|
|
|
|
bool transform::isHandleConsumed(Value handle,
|
|
transform::TransformOpInterface transform) {
|
|
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
|
|
SmallVector<MemoryEffects::EffectInstance> effects;
|
|
iface.getEffectsOnValue(handle, effects);
|
|
return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
|
|
::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
|
|
}
|
|
|
|
void transform::producesHandle(
|
|
ValueRange handles,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
for (Value handle : handles) {
|
|
effects.emplace_back(MemoryEffects::Allocate::get(), handle,
|
|
TransformMappingResource::get());
|
|
effects.emplace_back(MemoryEffects::Write::get(), handle,
|
|
TransformMappingResource::get());
|
|
}
|
|
}
|
|
|
|
void transform::onlyReadsHandle(
|
|
ValueRange handles,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
for (Value handle : handles) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), handle,
|
|
TransformMappingResource::get());
|
|
}
|
|
}
|
|
|
|
void transform::modifiesPayload(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
|
|
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
|
|
}
|
|
|
|
void transform::onlyReadsPayload(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Entry point.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult transform::applyTransforms(Operation *payloadRoot,
|
|
TransformOpInterface transform,
|
|
const TransformOptions &options) {
|
|
#ifndef NDEBUG
|
|
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
|
|
transform->getNumOperands() != 0) {
|
|
transform->emitError()
|
|
<< "expected transform to start at the top-level transform op";
|
|
llvm::report_fatal_error("could not run transforms",
|
|
/*gen_crash_diag=*/false);
|
|
}
|
|
#endif // NDEBUG
|
|
|
|
TransformState state(transform->getParentRegion(), payloadRoot, options);
|
|
return state.applyTransform(transform).checkAndReport();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Generated interface implementation.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
|