This commit simplifies the result type of materialization functions. Previously: `std::optional<Value>` Now: `Value` The previous implementation allowed 3 possible return values: - Non-null value: The materialization function produced a valid materialization. - `std::nullopt`: The materialization function failed, but another materialization can be attempted. - `Value()`: The materialization failed and so should the dialect conversion. (Previously: Dialect conversion can roll back.) This commit removes the last variant. It is not particularly useful because the dialect conversion will fail anyway if all other materialization functions produced `std::nullopt`. Furthermore, in contrast to type conversions, at least one materialization callback is expected to succeed. In case of a failing type conversion, the current dialect conversion can roll back and try a different pattern. This also used to be the case for materializations, but that functionality was removed with #107109: failed materializations can no longer trigger a rollback. (They can just make the entire dialect conversion fail without rollback.) With this in mind, it is even less useful to have an additional error state for materialization functions. This commit is in preparation of merging the 1:1 and 1:N type converters. Target materializations will have to return multiple values instead of a single one. With this commit, we can keep the API simple: `SmallVector<Value>` instead of `std::optional<SmallVector<Value>>`. Note for LLVM integration: All 1:1 materializations should return `Value` instead of `std::optional<Value>`. Instead of `std::nullopt` return `Value()`.
97 lines
3.3 KiB
C++
97 lines
3.3 KiB
C++
//===- TestDialectConversion.cpp - Test DialectConversion functionality ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TestDialect.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/PDL/IR/PDL.h"
|
|
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
|
|
#include "mlir/Parser/Parser.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
using namespace mlir;
|
|
using namespace test;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Test PDLL Support
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TestDialectConversionPDLLPatterns.h.inc"
|
|
|
|
namespace {
|
|
struct PDLLTypeConverter : public TypeConverter {
|
|
PDLLTypeConverter() {
|
|
addConversion(convertType);
|
|
addArgumentMaterialization(materializeCast);
|
|
addSourceMaterialization(materializeCast);
|
|
}
|
|
|
|
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
|
// Convert I64 to F64.
|
|
if (t.isSignlessInteger(64)) {
|
|
results.push_back(FloatType::getF64(t.getContext()));
|
|
return success();
|
|
}
|
|
|
|
// Otherwise, convert the type directly.
|
|
results.push_back(t);
|
|
return success();
|
|
}
|
|
/// Hook for materializing a conversion.
|
|
static Value materializeCast(OpBuilder &builder, Type resultType,
|
|
ValueRange inputs, Location loc) {
|
|
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
|
|
.getResult(0);
|
|
}
|
|
};
|
|
|
|
struct TestDialectConversionPDLLPass
|
|
: public PassWrapper<TestDialectConversionPDLLPass, OperationPass<>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectConversionPDLLPass)
|
|
|
|
StringRef getArgument() const final { return "test-dialect-conversion-pdll"; }
|
|
StringRef getDescription() const final {
|
|
return "Test DialectConversion PDLL functionality";
|
|
}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
|
|
}
|
|
LogicalResult initialize(MLIRContext *ctx) override {
|
|
// Build the pattern set within the `initialize` to avoid recompiling PDL
|
|
// patterns during each `runOnOperation` invocation.
|
|
RewritePatternSet patternList(ctx);
|
|
registerConversionPDLFunctions(patternList);
|
|
populateGeneratedPDLLPatterns(patternList, PDLConversionConfig(&converter));
|
|
patterns = std::move(patternList);
|
|
return success();
|
|
}
|
|
|
|
void runOnOperation() final {
|
|
mlir::ConversionTarget target(getContext());
|
|
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
|
|
target.addDynamicallyLegalDialect<TestDialect>(
|
|
[this](Operation *op) { return converter.isLegal(op); });
|
|
|
|
if (failed(mlir::applyFullConversion(getOperation(), target, patterns)))
|
|
signalPassFailure();
|
|
}
|
|
|
|
FrozenRewritePatternSet patterns;
|
|
PDLLTypeConverter converter;
|
|
};
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestDialectConversionPasses() {
|
|
PassRegistration<TestDialectConversionPDLLPass>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|