Files
clang-p2996/mlir/test/lib/Transforms/TestDialectConversion.cpp
Matthias Springer f18c3e4e73 [mlir][Transforms] Dialect Conversion: Simplify materialization fn result type (#113031)
This commit simplifies the result type of materialization functions.

Previously: `std::optional<Value>`
Now: `Value`

The previous implementation allowed 3 possible return values:
- Non-null value: The materialization function produced a valid
materialization.
- `std::nullopt`: The materialization function failed, but another
materialization can be attempted.
- `Value()`: The materialization failed and so should the dialect
conversion. (Previously: Dialect conversion can roll back.)

This commit removes the last variant. It is not particularly useful
because the dialect conversion will fail anyway if all other
materialization functions produced `std::nullopt`.

Furthermore, in contrast to type conversions, at least one
materialization callback is expected to succeed. In case of a failing
type conversion, the current dialect conversion can roll back and try a
different pattern. This also used to be the case for materializations,
but that functionality was removed with #107109: failed materializations
can no longer trigger a rollback. (They can just make the entire dialect
conversion fail without rollback.) With this in mind, it is even less
useful to have an additional error state for materialization functions.

This commit is in preparation of merging the 1:1 and 1:N type
converters. Target materializations will have to return multiple values
instead of a single one. With this commit, we can keep the API simple:
`SmallVector<Value>` instead of `std::optional<SmallVector<Value>>`.

Note for LLVM integration: All 1:1 materializations should return
`Value` instead of `std::optional<Value>`. Instead of `std::nullopt`
return `Value()`.
2024-10-23 07:29:17 -07:00

97 lines
3.3 KiB
C++

//===- TestDialectConversion.cpp - Test DialectConversion functionality ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace test;
//===----------------------------------------------------------------------===//
// Test PDLL Support
//===----------------------------------------------------------------------===//
#include "TestDialectConversionPDLLPatterns.h.inc"
namespace {
struct PDLLTypeConverter : public TypeConverter {
PDLLTypeConverter() {
addConversion(convertType);
addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
// Convert I64 to F64.
if (t.isSignlessInteger(64)) {
results.push_back(FloatType::getF64(t.getContext()));
return success();
}
// Otherwise, convert the type directly.
results.push_back(t);
return success();
}
/// Hook for materializing a conversion.
static Value materializeCast(OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
}
};
struct TestDialectConversionPDLLPass
: public PassWrapper<TestDialectConversionPDLLPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectConversionPDLLPass)
StringRef getArgument() const final { return "test-dialect-conversion-pdll"; }
StringRef getDescription() const final {
return "Test DialectConversion PDLL functionality";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
}
LogicalResult initialize(MLIRContext *ctx) override {
// Build the pattern set within the `initialize` to avoid recompiling PDL
// patterns during each `runOnOperation` invocation.
RewritePatternSet patternList(ctx);
registerConversionPDLFunctions(patternList);
populateGeneratedPDLLPatterns(patternList, PDLConversionConfig(&converter));
patterns = std::move(patternList);
return success();
}
void runOnOperation() final {
mlir::ConversionTarget target(getContext());
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
target.addDynamicallyLegalDialect<TestDialect>(
[this](Operation *op) { return converter.isLegal(op); });
if (failed(mlir::applyFullConversion(getOperation(), target, patterns)))
signalPassFailure();
}
FrozenRewritePatternSet patterns;
PDLLTypeConverter converter;
};
} // namespace
namespace mlir {
namespace test {
void registerTestDialectConversionPasses() {
PassRegistration<TestDialectConversionPDLLPass>();
}
} // namespace test
} // namespace mlir