Transform interfaces are implemented, direction or via extensions, in libraries belonging to multiple other dialects. Those dialects don't need to depend on the non-interface part of the transform dialect, which includes the growing number of ops and transitive dependency footprint. Split out the interfaces into a separate library. This in turn requires flipping the dependency from the interface on the dialect that has crept in because both co-existed in one library. The interface shouldn't depend on the transform dialect either. As a consequence of splitting, the capability of the interpreter to automatically walk the payload IR to identify payload ops of a certain kind based on the type used for the entry point symbol argument is disabled. This is a good move by itself as it simplifies the interpreter logic. This functionality can be trivially replaced by a `transform.structured.match` operation.
258 lines
9.6 KiB
C++
258 lines
9.6 KiB
C++
//===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
|
|
|
|
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Apply...ConversionPatternsOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns) {
|
|
populateFuncToLLVMConversionPatterns(
|
|
static_cast<LLVMTypeConverter &>(typeConverter), patterns);
|
|
}
|
|
|
|
LogicalResult
|
|
transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
|
|
transform::TypeConverterBuilderOpInterface builder) {
|
|
if (builder.getTypeConverterType() != "LLVMTypeConverter")
|
|
return emitOpError("expected LLVMTypeConverter");
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CastAndCallOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
SmallVector<Value> inputs;
|
|
if (getInputs())
|
|
llvm::append_range(inputs, state.getPayloadValues(getInputs()));
|
|
|
|
SetVector<Value> outputs;
|
|
if (getOutputs()) {
|
|
for (auto output : state.getPayloadValues(getOutputs()))
|
|
outputs.insert(output);
|
|
|
|
// Verify that the set of output values to be replaced is unique.
|
|
if (outputs.size() !=
|
|
llvm::range_size(state.getPayloadValues(getOutputs()))) {
|
|
return emitSilenceableFailure(getLoc())
|
|
<< "cast and call output values must be unique";
|
|
}
|
|
}
|
|
|
|
// Get the insertion point for the call.
|
|
auto insertionOps = state.getPayloadOps(getInsertionPoint());
|
|
if (!llvm::hasSingleElement(insertionOps)) {
|
|
return emitSilenceableFailure(getLoc())
|
|
<< "Only one op can be specified as an insertion point";
|
|
}
|
|
bool insertAfter = getInsertAfter();
|
|
Operation *insertionPoint = *insertionOps.begin();
|
|
|
|
// Check that all inputs dominate the insertion point, and the insertion
|
|
// point dominates all users of the outputs.
|
|
DominanceInfo dom(insertionPoint);
|
|
for (Value output : outputs) {
|
|
for (Operation *user : output.getUsers()) {
|
|
// If we are inserting after the insertion point operation, the
|
|
// insertion point operation must properly dominate the user. Otherwise
|
|
// basic dominance is enough.
|
|
bool doesDominate = insertAfter
|
|
? dom.properlyDominates(insertionPoint, user)
|
|
: dom.dominates(insertionPoint, user);
|
|
if (!doesDominate) {
|
|
return emitDefiniteFailure()
|
|
<< "User " << user << " is not dominated by insertion point "
|
|
<< insertionPoint;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (Value input : inputs) {
|
|
// If we are inserting before the insertion point operation, the
|
|
// input must properly dominate the insertion point operation. Otherwise
|
|
// basic dominance is enough.
|
|
bool doesDominate = insertAfter
|
|
? dom.dominates(input, insertionPoint)
|
|
: dom.properlyDominates(input, insertionPoint);
|
|
if (!doesDominate) {
|
|
return emitDefiniteFailure()
|
|
<< "input " << input << " does not dominate insertion point "
|
|
<< insertionPoint;
|
|
}
|
|
}
|
|
|
|
// Get the function to call. This can either be specified by symbol or as a
|
|
// transform handle.
|
|
func::FuncOp targetFunction = nullptr;
|
|
if (getFunctionName()) {
|
|
targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
|
|
insertionPoint, *getFunctionName());
|
|
if (!targetFunction) {
|
|
return emitDefiniteFailure()
|
|
<< "unresolved symbol " << *getFunctionName();
|
|
}
|
|
} else if (getFunction()) {
|
|
auto payloadOps = state.getPayloadOps(getFunction());
|
|
if (!llvm::hasSingleElement(payloadOps)) {
|
|
return emitDefiniteFailure() << "requires a single function to call";
|
|
}
|
|
targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
|
|
if (!targetFunction) {
|
|
return emitDefiniteFailure() << "invalid non-function callee";
|
|
}
|
|
} else {
|
|
llvm_unreachable("Invalid CastAndCall op without a function to call");
|
|
return emitDefiniteFailure();
|
|
}
|
|
|
|
// Verify that the function argument and result lengths match the inputs and
|
|
// outputs given to this op.
|
|
if (targetFunction.getNumArguments() != inputs.size()) {
|
|
return emitSilenceableFailure(targetFunction.getLoc())
|
|
<< "mismatch between number of function arguments "
|
|
<< targetFunction.getNumArguments() << " and number of inputs "
|
|
<< inputs.size();
|
|
}
|
|
if (targetFunction.getNumResults() != outputs.size()) {
|
|
return emitSilenceableFailure(targetFunction.getLoc())
|
|
<< "mismatch between number of function results "
|
|
<< targetFunction->getNumResults() << " and number of outputs "
|
|
<< outputs.size();
|
|
}
|
|
|
|
// Gather all specified converters.
|
|
mlir::TypeConverter converter;
|
|
if (!getRegion().empty()) {
|
|
for (Operation &op : getRegion().front()) {
|
|
cast<transform::TypeConverterBuilderOpInterface>(&op)
|
|
.populateTypeMaterializations(converter);
|
|
}
|
|
}
|
|
|
|
if (insertAfter)
|
|
rewriter.setInsertionPointAfter(insertionPoint);
|
|
else
|
|
rewriter.setInsertionPoint(insertionPoint);
|
|
|
|
for (auto [input, type] :
|
|
llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
|
|
if (input.getType() != type) {
|
|
Value newInput = converter.materializeSourceConversion(
|
|
rewriter, input.getLoc(), type, input);
|
|
if (!newInput) {
|
|
return emitDefiniteFailure() << "Failed to materialize conversion of "
|
|
<< input << " to type " << type;
|
|
}
|
|
input = newInput;
|
|
}
|
|
}
|
|
|
|
auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
|
|
targetFunction, inputs);
|
|
|
|
// Cast the call results back to the expected types. If any conversions fail
|
|
// this is a definite failure as the call has been constructed at this point.
|
|
for (auto [output, newOutput] :
|
|
llvm::zip_equal(outputs, callOp.getResults())) {
|
|
Value convertedOutput = newOutput;
|
|
if (output.getType() != newOutput.getType()) {
|
|
convertedOutput = converter.materializeTargetConversion(
|
|
rewriter, output.getLoc(), output.getType(), newOutput);
|
|
if (!convertedOutput) {
|
|
return emitDefiniteFailure()
|
|
<< "Failed to materialize conversion of " << newOutput
|
|
<< " to type " << output.getType();
|
|
}
|
|
}
|
|
rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
|
|
}
|
|
results.set(cast<OpResult>(getResult()), {callOp});
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
LogicalResult transform::CastAndCallOp::verify() {
|
|
if (!getRegion().empty()) {
|
|
for (Operation &op : getRegion().front()) {
|
|
if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
|
|
InFlightDiagnostic diag = emitOpError()
|
|
<< "expected children ops to implement "
|
|
"TypeConverterBuilderOpInterface";
|
|
diag.attachNote(op.getLoc()) << "op without interface";
|
|
return diag;
|
|
}
|
|
}
|
|
}
|
|
if (!getFunction() && !getFunctionName()) {
|
|
return emitOpError() << "expected a function handle or name to call";
|
|
}
|
|
if (getFunction() && getFunctionName()) {
|
|
return emitOpError() << "function handle and name are mutually exclusive";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void transform::CastAndCallOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getInsertionPoint(), effects);
|
|
if (getInputs())
|
|
transform::onlyReadsHandle(getInputs(), effects);
|
|
if (getOutputs())
|
|
transform::onlyReadsHandle(getOutputs(), effects);
|
|
if (getFunction())
|
|
transform::onlyReadsHandle(getFunction(), effects);
|
|
transform::producesHandle(getResult(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Transform op registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class FuncTransformDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
FuncTransformDialectExtension> {
|
|
public:
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
declareGeneratedDialect<LLVM::LLVMDialect>();
|
|
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
|
|
|
|
void mlir::func::registerTransformDialectExtension(DialectRegistry ®istry) {
|
|
registry.addExtensions<FuncTransformDialectExtension>();
|
|
}
|