Files
clang-p2996/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
Krzysztof Drewniak 05e85e4fc5 [mlir][Math] Add pass to legalize math functions to f32-or-higher (#78361)
Since most of the operations in the `math` dialect don't have
low-precision implementations, add the -math-legalize-to-f32 pass that
goes through and brackets low-precision math funcitons (like `math.sin
%0 : f16`) with `arith.extf` and `arith.truncf`. This preserves the
original semantics of the math operation but allows lowering to proceed.

Versions of this lowering are already implicitly present in some passes,
like ConvertGPUToROCDL. However, because those are implicit rewrites,
they hide the floating-point extension and truncation, preventing anyone
from writing passes that operate on those implitic extf/truncf pairs.

Exposing this legalization explicitly is needed to allow lowening 8-bit
floats on AMD GPUs, as the implementation of extf and truncf on that
platform requires the complex logic found in ArithToAMDGPU, which runs
before the GPU to ROCDL lowering.
2024-01-18 09:37:43 -06:00

119 lines
4.5 KiB
C++

//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements legalizing math operations on small floating-point
// types through arith.extf and arith.truncf.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir::math {
#define GEN_PASS_DEF_MATHLEGALIZETOF32
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
} // namespace mlir::math
using namespace mlir;
namespace {
struct LegalizeToF32RewritePattern final : ConversionPattern {
LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
: ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
struct LegalizeToF32Pass final
: mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
void runOnOperation() override;
};
} // namespace
void mlir::math::populateLegalizeToF32TypeConverter(
TypeConverter &typeConverter) {
typeConverter.addConversion(
[](Type type) -> std::optional<Type> { return type; });
typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
if (type.getWidth() < 32)
return Float32Type::get(type.getContext());
return std::nullopt;
});
typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
return type.clone(Float32Type::get(type.getContext()));
return std::nullopt;
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
return b.create<arith::ExtFOp>(loc, target, input);
});
}
void mlir::math::populateLegalizeToF32ConversionTarget(
ConversionTarget &target, TypeConverter &typeConverter) {
target.addDynamicallyLegalDialect<MathDialect>(
[&typeConverter](Operation *op) -> bool {
return typeConverter.isLegal(op);
});
target.addLegalOp<FmaOp>();
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
}
LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
const TypeConverter *converter = getTypeConverter();
if (converter->isLegal(op))
return rewriter.notifyMatchFailure(loc, "op already legal");
OperationState newOp(loc, op->getName());
newOp.addOperands(operands);
SmallVector<Type> newResultTypes;
if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
newOp.addTypes(newResultTypes);
newOp.addAttributes(op->getAttrs());
Operation *legalized = rewriter.create(newOp);
SmallVector<Value> results = legalized->getResults();
for (auto [result, newType, origType] :
llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
if (newType != origType)
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
}
rewriter.replaceOp(op, results);
return success();
}
void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<LegalizeToF32RewritePattern>(typeConverter,
patterns.getContext());
}
void LegalizeToF32Pass::runOnOperation() {
Operation *op = getOperation();
MLIRContext &ctx = getContext();
TypeConverter typeConverter;
math::populateLegalizeToF32TypeConverter(typeConverter);
ConversionTarget target(ctx);
math::populateLegalizeToF32ConversionTarget(target, typeConverter);
RewritePatternSet patterns(&ctx);
math::populateLegalizeToF32Patterns(patterns, typeConverter);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
return signalPassFailure();
}