[mlir] Add Float8E5M2FNUZ and Float8E4M3FNUZ types to MLIR
Float8E5M2FNUZ and Float8E4M3FNUZ have been added to APFloat in D141863. This change adds these types as MLIR builtin types alongside Float8E5M2 and Float8E4M3FN (added in D133823 and D138075). Reviewed By: krzysz00 Differential Revision: https://reviews.llvm.org/D143744
This commit is contained in:
@@ -81,6 +81,20 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
|
||||
|
||||
/// Checks whether the given type is an f8E5M2FNUZ type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
|
||||
|
||||
/// Creates an f8E5M2FNUZ type in the given context. The type is owned by the
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx);
|
||||
|
||||
/// Checks whether the given type is an f8E4M3FNUZ type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
|
||||
|
||||
/// Creates an f8E4M3FNUZ type in the given context. The type is owned by the
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);
|
||||
|
||||
/// Checks whether the given type is a bf16 type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
|
||||
|
||||
|
||||
@@ -62,6 +62,8 @@ public:
|
||||
// Types.
|
||||
FloatType getFloat8E5M2Type();
|
||||
FloatType getFloat8E4M3FNType();
|
||||
FloatType getFloat8E5M2FNUZType();
|
||||
FloatType getFloat8E4M3FNUZType();
|
||||
FloatType getBF16Type();
|
||||
FloatType getF16Type();
|
||||
FloatType getF32Type();
|
||||
|
||||
@@ -47,6 +47,8 @@ public:
|
||||
static FloatType getF128(MLIRContext *ctx);
|
||||
static FloatType getFloat8E5M2(MLIRContext *ctx);
|
||||
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
|
||||
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
|
||||
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(Type type);
|
||||
@@ -374,8 +376,9 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
|
||||
}
|
||||
|
||||
inline bool FloatType::classof(Type type) {
|
||||
return type.isa<Float8E5M2Type, Float8E4M3FNType, BFloat16Type, Float16Type,
|
||||
Float32Type, Float64Type, Float80Type, Float128Type>();
|
||||
return type.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||
Float8E4M3FNUZType, BFloat16Type, Float16Type, Float32Type,
|
||||
Float64Type, Float80Type, Float128Type>();
|
||||
}
|
||||
|
||||
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
|
||||
@@ -386,6 +389,14 @@ inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
|
||||
return Float8E4M3FNType::get(ctx);
|
||||
}
|
||||
|
||||
inline FloatType FloatType::getFloat8E5M2FNUZ(MLIRContext *ctx) {
|
||||
return Float8E5M2FNUZType::get(ctx);
|
||||
}
|
||||
|
||||
inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) {
|
||||
return Float8E4M3FNUZType::get(ctx);
|
||||
}
|
||||
|
||||
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
|
||||
return BFloat16Type::get(ctx);
|
||||
}
|
||||
|
||||
@@ -118,6 +118,50 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Float8E5M2FNUZType
|
||||
|
||||
def Builtin_Float8E5M2FNUZ : Builtin_FloatType<"Float8E5M2FNUZ"> {
|
||||
let summary = "8-bit floating point with 2 bit mantissa";
|
||||
let description = [{
|
||||
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
|
||||
mantissa. This is not a standard type as defined by IEEE-754, but it follows
|
||||
similar conventions, with the exception that there are no infinity values,
|
||||
no negative zero, and only one NaN representation. This type has the
|
||||
following characteristics:
|
||||
|
||||
* bit encoding: S1E5M2
|
||||
* exponent bias: 16
|
||||
* infinities: Not supported
|
||||
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
|
||||
* denormals when exponent is 0
|
||||
|
||||
Described in: https://arxiv.org/abs/2206.02915
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Float8E4M3FNUZType
|
||||
|
||||
def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ"> {
|
||||
let summary = "8-bit floating point with 3 bit mantissa";
|
||||
let description = [{
|
||||
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
|
||||
mantissa. This is not a standard type as defined by IEEE-754, but it follows
|
||||
similar conventions, with the exception that there are no infinity values,
|
||||
no negative zero, and only one NaN representation. This type has the
|
||||
following characteristics:
|
||||
|
||||
* bit encoding: S1E4M3
|
||||
* exponent bias: 8
|
||||
* infinities: Not supported
|
||||
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
|
||||
* denormals when exponent is 0
|
||||
|
||||
Described in: https://arxiv.org/abs/2209.05433
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BFloat16Type
|
||||
|
||||
|
||||
@@ -488,6 +488,10 @@ def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
|
||||
BuildableType<"$_builder.getFloat8E4M3FNType()">;
|
||||
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
|
||||
BuildableType<"$_builder.getFloat8E5M2Type()">;
|
||||
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
|
||||
BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
|
||||
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
|
||||
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
|
||||
|
||||
def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
|
||||
"complex-type", "::mlir::ComplexType">;
|
||||
|
||||
@@ -122,6 +122,8 @@ public:
|
||||
bool isIndex() const;
|
||||
bool isFloat8E5M2() const;
|
||||
bool isFloat8E4M3FN() const;
|
||||
bool isFloat8E5M2FNUZ() const;
|
||||
bool isFloat8E4M3FNUZ() const;
|
||||
bool isBF16() const;
|
||||
bool isF16() const;
|
||||
bool isF32() const;
|
||||
|
||||
@@ -95,6 +95,8 @@ TOK_KEYWORD(f64)
|
||||
TOK_KEYWORD(f80)
|
||||
TOK_KEYWORD(f8E5M2)
|
||||
TOK_KEYWORD(f8E4M3FN)
|
||||
TOK_KEYWORD(f8E5M2FNUZ)
|
||||
TOK_KEYWORD(f8E4M3FNUZ)
|
||||
TOK_KEYWORD(f128)
|
||||
TOK_KEYWORD(false)
|
||||
TOK_KEYWORD(floordiv)
|
||||
|
||||
@@ -33,6 +33,8 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
|
||||
case Token::inttype:
|
||||
case Token::kw_f8E5M2:
|
||||
case Token::kw_f8E4M3FN:
|
||||
case Token::kw_f8E5M2FNUZ:
|
||||
case Token::kw_f8E4M3FNUZ:
|
||||
case Token::kw_bf16:
|
||||
case Token::kw_f16:
|
||||
case Token::kw_f32:
|
||||
@@ -295,6 +297,12 @@ Type Parser::parseNonFunctionType() {
|
||||
case Token::kw_f8E4M3FN:
|
||||
consumeToken(Token::kw_f8E4M3FN);
|
||||
return builder.getFloat8E4M3FNType();
|
||||
case Token::kw_f8E5M2FNUZ:
|
||||
consumeToken(Token::kw_f8E5M2FNUZ);
|
||||
return builder.getFloat8E5M2FNUZType();
|
||||
case Token::kw_f8E4M3FNUZ:
|
||||
consumeToken(Token::kw_f8E4M3FNUZ);
|
||||
return builder.getFloat8E4M3FNUZType();
|
||||
case Token::kw_bf16:
|
||||
consumeToken(Token::kw_bf16);
|
||||
return builder.getBF16Type();
|
||||
|
||||
@@ -139,6 +139,42 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Floating Point Type subclass - Float8E4M3FNUZ.
|
||||
class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
|
||||
static constexpr const char *pyClassName = "Float8E4M3FNUZType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get",
|
||||
[](DefaultingPyMlirContext context) {
|
||||
MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
|
||||
return PyFloat8E4M3FNUZType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e4m3fnuz type.");
|
||||
}
|
||||
};
|
||||
|
||||
/// Floating Point Type subclass - Float8E5M2FNUZ.
|
||||
class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
|
||||
static constexpr const char *pyClassName = "Float8E5M2FNUZType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get",
|
||||
[](DefaultingPyMlirContext context) {
|
||||
MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
|
||||
return PyFloat8E5M2FNUZType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e5m2fnuz type.");
|
||||
}
|
||||
};
|
||||
|
||||
/// Floating Point Type subclass - BF16Type.
|
||||
class PyBF16Type : public PyConcreteType<PyBF16Type> {
|
||||
public:
|
||||
@@ -700,6 +736,8 @@ void mlir::python::populateIRTypes(py::module &m) {
|
||||
PyIndexType::bind(m);
|
||||
PyFloat8E4M3FNType::bind(m);
|
||||
PyFloat8E5M2Type::bind(m);
|
||||
PyFloat8E4M3FNUZType::bind(m);
|
||||
PyFloat8E5M2FNUZType::bind(m);
|
||||
PyBF16Type::bind(m);
|
||||
PyF16Type::bind(m);
|
||||
PyF32Type::bind(m);
|
||||
|
||||
@@ -84,6 +84,22 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
|
||||
}
|
||||
|
||||
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
|
||||
return unwrap(type).isFloat8E5M2FNUZ();
|
||||
}
|
||||
|
||||
MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx)));
|
||||
}
|
||||
|
||||
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
|
||||
return unwrap(type).isFloat8E4M3FNUZ();
|
||||
}
|
||||
|
||||
MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
|
||||
}
|
||||
|
||||
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
|
||||
|
||||
MlirType mlirBF16TypeGet(MlirContext ctx) {
|
||||
|
||||
@@ -2410,6 +2410,8 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
|
||||
.Case<IndexType>([&](Type) { os << "index"; })
|
||||
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
|
||||
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
|
||||
.Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
|
||||
.Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
|
||||
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
|
||||
.Case<Float16Type>([&](Type) { os << "f16"; })
|
||||
.Case<Float32Type>([&](Type) { os << "f32"; })
|
||||
|
||||
@@ -41,6 +41,14 @@ FloatType Builder::getFloat8E4M3FNType() {
|
||||
return FloatType::getFloat8E4M3FN(context);
|
||||
}
|
||||
|
||||
FloatType Builder::getFloat8E5M2FNUZType() {
|
||||
return FloatType::getFloat8E5M2FNUZ(context);
|
||||
}
|
||||
|
||||
FloatType Builder::getFloat8E4M3FNUZType() {
|
||||
return FloatType::getFloat8E4M3FNUZ(context);
|
||||
}
|
||||
|
||||
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
|
||||
|
||||
FloatType Builder::getF16Type() { return FloatType::getF16(context); }
|
||||
|
||||
@@ -88,7 +88,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
unsigned FloatType::getWidth() {
|
||||
if (isa<Float8E5M2Type, Float8E4M3FNType>())
|
||||
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||
Float8E4M3FNUZType>())
|
||||
return 8;
|
||||
if (isa<Float16Type, BFloat16Type>())
|
||||
return 16;
|
||||
@@ -109,6 +110,10 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
|
||||
return APFloat::Float8E5M2();
|
||||
if (isa<Float8E4M3FNType>())
|
||||
return APFloat::Float8E4M3FN();
|
||||
if (isa<Float8E5M2FNUZType>())
|
||||
return APFloat::Float8E5M2FNUZ();
|
||||
if (isa<Float8E4M3FNUZType>())
|
||||
return APFloat::Float8E4M3FNUZ();
|
||||
if (isa<BFloat16Type>())
|
||||
return APFloat::BFloat();
|
||||
if (isa<Float16Type>())
|
||||
|
||||
@@ -209,6 +209,8 @@ public:
|
||||
/// Cached Type Instances.
|
||||
Float8E5M2Type f8E5M2Ty;
|
||||
Float8E4M3FNType f8E4M3FNTy;
|
||||
Float8E5M2FNUZType f8E5M2FNUZTy;
|
||||
Float8E4M3FNUZType f8E4M3FNUZTy;
|
||||
BFloat16Type bf16Ty;
|
||||
Float16Type f16Ty;
|
||||
Float32Type f32Ty;
|
||||
@@ -281,6 +283,8 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
|
||||
/// Floating-point Types.
|
||||
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
|
||||
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
|
||||
impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
|
||||
impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
|
||||
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
|
||||
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
|
||||
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
|
||||
@@ -870,6 +874,12 @@ Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
|
||||
Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
|
||||
return context->getImpl().f8E4M3FNTy;
|
||||
}
|
||||
Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
|
||||
return context->getImpl().f8E5M2FNUZTy;
|
||||
}
|
||||
Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
|
||||
return context->getImpl().f8E4M3FNUZTy;
|
||||
}
|
||||
BFloat16Type BFloat16Type::get(MLIRContext *context) {
|
||||
return context->getImpl().bf16Ty;
|
||||
}
|
||||
|
||||
@@ -36,6 +36,8 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }
|
||||
|
||||
bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
|
||||
bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
|
||||
bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
|
||||
bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
|
||||
bool Type::isBF16() const { return isa<BFloat16Type>(); }
|
||||
bool Type::isF16() const { return isa<Float16Type>(); }
|
||||
bool Type::isF32() const { return isa<Float32Type>(); }
|
||||
|
||||
@@ -52,6 +52,8 @@ __all__ = [
|
||||
"DictAttr",
|
||||
"Float8E4M3FNType",
|
||||
"Float8E5M2Type",
|
||||
"Float8E4M3FNUZType",
|
||||
"Float8E5M2FNUZType",
|
||||
"F16Type",
|
||||
"F32Type",
|
||||
"F64Type",
|
||||
@@ -593,6 +595,20 @@ class Float8E5M2Type(Type):
|
||||
@staticmethod
|
||||
def isinstance(arg: Any) -> bool: ...
|
||||
|
||||
class Float8E4M3FNUZType(Type):
|
||||
def __init__(self, cast_from_type: Type) -> None: ...
|
||||
@staticmethod
|
||||
def get(*args, **kwargs) -> Float8E4M3FNUZType: ...
|
||||
@staticmethod
|
||||
def isinstance(arg: Any) -> bool: ...
|
||||
|
||||
class Float8E5M2FNUZType(Type):
|
||||
def __init__(self, cast_from_type: Type) -> None: ...
|
||||
@staticmethod
|
||||
def get(*args, **kwargs) -> Float8E5M2FNUZType: ...
|
||||
@staticmethod
|
||||
def isinstance(arg: Any) -> bool: ...
|
||||
|
||||
# TODO: Auto-generated. Audit and fix.
|
||||
class F16Type(Type):
|
||||
def __init__(self, cast_from_type: Type) -> None: ...
|
||||
|
||||
@@ -44,6 +44,14 @@ func.func @float_attrs_pass() {
|
||||
// CHECK: float_attr = 2.000000e+00 : f8E4M3FN
|
||||
float_attr = 2. : f8E4M3FN
|
||||
} : () -> ()
|
||||
"test.float_attrs"() {
|
||||
// CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ
|
||||
float_attr = 2. : f8E5M2FNUZ
|
||||
} : () -> ()
|
||||
"test.float_attrs"() {
|
||||
// CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
|
||||
float_attr = 2. : f8E4M3FNUZ
|
||||
} : () -> ()
|
||||
"test.float_attrs"() {
|
||||
// CHECK: float_attr = 2.000000e+00 : f16
|
||||
float_attr = 2. : f16
|
||||
|
||||
@@ -197,6 +197,10 @@ def testFloatType():
|
||||
print("float:", Float8E4M3FNType.get())
|
||||
# CHECK: float: f8E5M2
|
||||
print("float:", Float8E5M2Type.get())
|
||||
# CHECK: float: f8E5M2FNUZ
|
||||
print("float:", Float8E5M2FNUZType.get())
|
||||
# CHECK: float: f8E4M3FNUZ
|
||||
print("float:", Float8E4M3FNUZType.get())
|
||||
# CHECK: float: bf16
|
||||
print("float:", BF16Type.get())
|
||||
# CHECK: float: f16
|
||||
|
||||
@@ -52,6 +52,8 @@ builtin_attr_type_mnemonics = {
|
||||
"mlir::UnknownLoc": '"loc(unknown)"',
|
||||
"mlir::Float8E5M2Type": '"f8E5M2"',
|
||||
"mlir::Float8E4M3FNType": '"f8E4M3FN"',
|
||||
"mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"',
|
||||
"mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
|
||||
"mlir::BFloat16Type": '"bf16"',
|
||||
"mlir::Float16Type": '"f16"',
|
||||
"mlir::Float32Type": '"f32"',
|
||||
|
||||
Reference in New Issue
Block a user