This PR adds `f8E3M4` type to mlir.
`f8E3M4` type follows IEEE 754 convention
```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs
Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6)
```
Related PRs:
- [PR-99698](https://github.com/llvm/llvm-project/pull/99698) [APFloat]
Add support for f8E3M4 IEEE 754 type
- [PR-97118](https://github.com/llvm/llvm-project/pull/97118) [MLIR] Add
f8E4M3 IEEE 754 type
196 lines
7.8 KiB
C++
196 lines
7.8 KiB
C++
//===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
// This pass promotes small floats (of some unsupported types T) to a supported
|
|
// type U by wrapping all float operations on Ts with expansion to and
|
|
// truncation from U, then operating on U.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Location.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
#include <optional>
|
|
|
|
namespace mlir::arith {
|
|
#define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
|
|
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
|
|
} // namespace mlir::arith
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
struct EmulateUnsupportedFloatsPass
|
|
: arith::impl::ArithEmulateUnsupportedFloatsBase<
|
|
EmulateUnsupportedFloatsPass> {
|
|
using arith::impl::ArithEmulateUnsupportedFloatsBase<
|
|
EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
|
|
|
|
void runOnOperation() override;
|
|
};
|
|
|
|
struct EmulateFloatPattern final : ConversionPattern {
|
|
EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx)
|
|
: ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
|
|
|
|
LogicalResult match(Operation *op) const override;
|
|
void rewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
} // end namespace
|
|
|
|
/// Map strings to float types. This function is here because no one else needs
|
|
/// it yet, feel free to abstract it out.
|
|
static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
|
|
StringRef name) {
|
|
Builder b(ctx);
|
|
return llvm::StringSwitch<std::optional<FloatType>>(name)
|
|
.Case("f8E5M2", b.getFloat8E5M2Type())
|
|
.Case("f8E4M3", b.getFloat8E4M3Type())
|
|
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
|
|
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
|
|
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
|
|
.Case("f8E3M4", b.getFloat8E3M4Type())
|
|
.Case("bf16", b.getBF16Type())
|
|
.Case("f16", b.getF16Type())
|
|
.Case("f32", b.getF32Type())
|
|
.Case("f64", b.getF64Type())
|
|
.Case("f80", b.getF80Type())
|
|
.Case("f128", b.getF128Type())
|
|
.Default(std::nullopt);
|
|
}
|
|
|
|
LogicalResult EmulateFloatPattern::match(Operation *op) const {
|
|
if (getTypeConverter()->isLegal(op))
|
|
return failure();
|
|
// The rewrite doesn't handle cloning regions.
|
|
if (op->getNumRegions() != 0)
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = op->getLoc();
|
|
const TypeConverter *converter = getTypeConverter();
|
|
SmallVector<Type> resultTypes;
|
|
if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
|
|
// Note to anyone looking for this error message: this is a "can't happen".
|
|
// If you're seeing it, there's a bug.
|
|
op->emitOpError("type conversion failed in float emulation");
|
|
return;
|
|
}
|
|
Operation *expandedOp =
|
|
rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
|
|
op->getAttrs(), op->getSuccessors(), /*regions=*/{});
|
|
SmallVector<Value> newResults(expandedOp->getResults());
|
|
for (auto [res, oldType, newType] : llvm::zip_equal(
|
|
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
|
|
if (oldType != newType) {
|
|
auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
|
|
truncFOp.setFastmath(arith::FastMathFlags::contract);
|
|
res = truncFOp.getResult();
|
|
}
|
|
}
|
|
rewriter.replaceOp(op, newResults);
|
|
}
|
|
|
|
void mlir::arith::populateEmulateUnsupportedFloatsConversions(
|
|
TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
|
|
converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
|
|
targetType](Type type) -> std::optional<Type> {
|
|
if (llvm::is_contained(sourceTypes, type))
|
|
return targetType;
|
|
if (auto shaped = dyn_cast<ShapedType>(type))
|
|
if (llvm::is_contained(sourceTypes, shaped.getElementType()))
|
|
return shaped.clone(targetType);
|
|
// All other types legal
|
|
return type;
|
|
});
|
|
converter.addTargetMaterialization(
|
|
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
|
|
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
|
|
extFOp.setFastmath(arith::FastMathFlags::contract);
|
|
return extFOp;
|
|
});
|
|
}
|
|
|
|
void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
|
|
RewritePatternSet &patterns, TypeConverter &converter) {
|
|
patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
|
|
}
|
|
|
|
void mlir::arith::populateEmulateUnsupportedFloatsLegality(
|
|
ConversionTarget &target, TypeConverter &converter) {
|
|
// Don't try to legalize functions and other ops that don't need expansion.
|
|
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
|
|
target.addDynamicallyLegalDialect<arith::ArithDialect>(
|
|
[&](Operation *op) -> std::optional<bool> {
|
|
return converter.isLegal(op);
|
|
});
|
|
// Manually mark arithmetic-performing vector instructions.
|
|
target.addDynamicallyLegalOp<
|
|
vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
|
|
vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
|
|
[&](Operation *op) { return converter.isLegal(op); });
|
|
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
|
|
arith::ConstantOp, vector::SplatOp>();
|
|
}
|
|
|
|
void EmulateUnsupportedFloatsPass::runOnOperation() {
|
|
MLIRContext *ctx = &getContext();
|
|
Operation *op = getOperation();
|
|
SmallVector<Type> sourceTypes;
|
|
Type targetType;
|
|
|
|
std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
|
|
if (!maybeTargetType) {
|
|
emitError(UnknownLoc::get(ctx), "could not map target type '" +
|
|
targetTypeStr +
|
|
"' to a known floating-point type");
|
|
return signalPassFailure();
|
|
}
|
|
targetType = *maybeTargetType;
|
|
for (StringRef sourceTypeStr : sourceTypeStrs) {
|
|
std::optional<FloatType> maybeSourceType =
|
|
parseFloatType(ctx, sourceTypeStr);
|
|
if (!maybeSourceType) {
|
|
emitError(UnknownLoc::get(ctx), "could not map source type '" +
|
|
sourceTypeStr +
|
|
"' to a known floating-point type");
|
|
return signalPassFailure();
|
|
}
|
|
sourceTypes.push_back(*maybeSourceType);
|
|
}
|
|
if (sourceTypes.empty())
|
|
(void)emitOptionalWarning(
|
|
std::nullopt,
|
|
"no source types specified, float emulation will do nothing");
|
|
|
|
if (llvm::is_contained(sourceTypes, targetType)) {
|
|
emitError(UnknownLoc::get(ctx),
|
|
"target type cannot be an unsupported source type");
|
|
return signalPassFailure();
|
|
}
|
|
TypeConverter converter;
|
|
arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
|
|
targetType);
|
|
RewritePatternSet patterns(ctx);
|
|
arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
|
|
ConversionTarget target(getContext());
|
|
arith::populateEmulateUnsupportedFloatsLegality(target, converter);
|
|
|
|
if (failed(applyPartialConversion(op, target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|