For 1:N type conversion, there is a 1:N relationship between the original operands and the converted operands. The same is true for the results. The previous design passed an instance of a "mapping" class into each pattern that helped with handling this 1:N correspondance. However, this was still rather manual and, in particular, it required the use of magic constants for the indices of the different operands. This commits uses the generated GenericAdaptor class that is generated for each op class in order to simplify this relationship further. The GenericAdaptor allows to wrap around a list of arbitrary types for each operand (via templating); for 1:N type conversion, this allows the operand accessors of the adaptor class to return a ValueRange that corresponds to the N values in the converted types. Patterns can thus use the named accessors instead of magic constants, which eliminates a common class of errors. This commit further simplifies the API that patterns need to implement by making the operand and result type mappings part of the adaptor. Since many patterns only need one of the two (or even neither), this reduces the number of unnecessary arguments in many cases. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D147225
131 lines
4.5 KiB
C++
131 lines
4.5 KiB
C++
//===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===//
|
|
//
|
|
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// The patterns in this file are heavily inspired (and copied from)
|
|
// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the
|
|
// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N
|
|
// type conversions.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Transforms/OneToNTypeConversion.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::func;
|
|
|
|
namespace {
|
|
|
|
class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
|
|
public:
|
|
using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CallOp op, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
Location loc = op->getLoc();
|
|
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
|
|
|
|
// Nothing to do if the op doesn't have any non-identity conversions for its
|
|
// operands or results.
|
|
if (!adaptor.getOperandMapping().hasNonIdentityConversion() &&
|
|
!resultMapping.hasNonIdentityConversion())
|
|
return failure();
|
|
|
|
// Create new CallOp.
|
|
auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
|
|
adaptor.getFlatOperands());
|
|
newOp->setAttrs(op->getAttrs());
|
|
|
|
rewriter.replaceOp(op, newOp->getResults(), resultMapping);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
|
|
public:
|
|
using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(FuncOp op, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
|
|
|
|
// Construct mapping for function arguments.
|
|
OneToNTypeMapping argumentMapping(op.getArgumentTypes());
|
|
if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(),
|
|
argumentMapping)))
|
|
return failure();
|
|
|
|
// Construct mapping for function results.
|
|
OneToNTypeMapping funcResultMapping(op.getResultTypes());
|
|
if (failed(typeConverter->computeTypeMapping(op.getResultTypes(),
|
|
funcResultMapping)))
|
|
return failure();
|
|
|
|
// Nothing to do if the op doesn't have any non-identity conversions for its
|
|
// operands or results.
|
|
if (!argumentMapping.hasNonIdentityConversion() &&
|
|
!funcResultMapping.hasNonIdentityConversion())
|
|
return failure();
|
|
|
|
// Update the function signature in-place.
|
|
auto newType = FunctionType::get(rewriter.getContext(),
|
|
argumentMapping.getConvertedTypes(),
|
|
funcResultMapping.getConvertedTypes());
|
|
rewriter.updateRootInPlace(op, [&] { op.setType(newType); });
|
|
|
|
// Update block signatures.
|
|
if (!op.isExternal()) {
|
|
Region *region = &op.getBody();
|
|
Block *block = ®ion->front();
|
|
rewriter.applySignatureConversion(block, argumentMapping);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
|
|
public:
|
|
using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
// Nothing to do if there is no non-identity conversion.
|
|
if (!adaptor.getOperandMapping().hasNonIdentityConversion())
|
|
return failure();
|
|
|
|
// Convert operands.
|
|
rewriter.updateRootInPlace(
|
|
op, [&] { op->setOperands(adaptor.getFlatOperands()); });
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
|
|
void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<
|
|
// clang-format off
|
|
ConvertTypesInFuncCallOp,
|
|
ConvertTypesInFuncFuncOp,
|
|
ConvertTypesInFuncReturnOp
|
|
// clang-format on
|
|
>(typeConverter, patterns.getContext());
|
|
}
|
|
|
|
} // namespace mlir
|