Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as `array<array<... vector>>` and SPIR-V conversion doesn't handle them at all at the moment. Sometimes it's preferable to treat multidim vectors as linearized 1D. Add pass to do this. Only constants and simple elementwise ops are supported for now. @krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place. Also, add ConversionPattern class operating on traits.
113 lines
4.2 KiB
C++
113 lines
4.2 KiB
C++
//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements legalizing math operations on small floating-point
|
|
// types through arith.extf and arith.truncf.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
|
|
namespace mlir::math {
|
|
#define GEN_PASS_DEF_MATHLEGALIZETOF32
|
|
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
|
|
} // namespace mlir::math
|
|
|
|
using namespace mlir;
|
|
namespace {
|
|
struct LegalizeToF32RewritePattern final : ConversionPattern {
|
|
LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
|
|
: ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
struct LegalizeToF32Pass final
|
|
: mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::math::populateLegalizeToF32TypeConverter(
|
|
TypeConverter &typeConverter) {
|
|
typeConverter.addConversion(
|
|
[](Type type) -> std::optional<Type> { return type; });
|
|
typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
|
|
if (type.getWidth() < 32)
|
|
return Float32Type::get(type.getContext());
|
|
return std::nullopt;
|
|
});
|
|
typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
|
|
if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
|
|
return type.clone(Float32Type::get(type.getContext()));
|
|
return std::nullopt;
|
|
});
|
|
typeConverter.addTargetMaterialization(
|
|
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
|
|
return b.create<arith::ExtFOp>(loc, target, input);
|
|
});
|
|
}
|
|
|
|
void mlir::math::populateLegalizeToF32ConversionTarget(
|
|
ConversionTarget &target, TypeConverter &typeConverter) {
|
|
target.addDynamicallyLegalDialect<MathDialect>(
|
|
[&typeConverter](Operation *op) -> bool {
|
|
return typeConverter.isLegal(op);
|
|
});
|
|
target.addLegalOp<FmaOp>();
|
|
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
|
|
}
|
|
|
|
LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
|
|
Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = op->getLoc();
|
|
const TypeConverter *converter = getTypeConverter();
|
|
FailureOr<Operation *> legalized =
|
|
convertOpResultTypes(op, operands, *converter, rewriter);
|
|
if (failed(legalized))
|
|
return failure();
|
|
|
|
SmallVector<Value> results = (*legalized)->getResults();
|
|
for (auto [result, newType, origType] : llvm::zip_equal(
|
|
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
|
|
if (newType != origType)
|
|
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
|
|
}
|
|
rewriter.replaceOp(op, results);
|
|
return success();
|
|
}
|
|
|
|
void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
|
|
TypeConverter &typeConverter) {
|
|
patterns.add<LegalizeToF32RewritePattern>(typeConverter,
|
|
patterns.getContext());
|
|
}
|
|
|
|
void LegalizeToF32Pass::runOnOperation() {
|
|
Operation *op = getOperation();
|
|
MLIRContext &ctx = getContext();
|
|
|
|
TypeConverter typeConverter;
|
|
math::populateLegalizeToF32TypeConverter(typeConverter);
|
|
ConversionTarget target(ctx);
|
|
math::populateLegalizeToF32ConversionTarget(target, typeConverter);
|
|
RewritePatternSet patterns(&ctx);
|
|
math::populateLegalizeToF32Patterns(patterns, typeConverter);
|
|
if (failed(applyPartialConversion(op, target, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|