[mlir] Add Float8E5M2FNUZ and Float8E4M3FNUZ types to MLIR

Float8E5M2FNUZ and Float8E4M3FNUZ have been added to APFloat in D141863.
This change adds these types as MLIR builtin types alongside Float8E5M2
and Float8E4M3FN (added in D133823 and D138075).

Reviewed By: krzysz00

Differential Revision: https://reviews.llvm.org/D143744
This commit is contained in:
Jake Hall
2023-02-13 14:10:20 +00:00
committed by Chris Jackson
parent 7c84f6a43a
commit 96267b6b88
19 changed files with 201 additions and 3 deletions

View File

@@ -81,6 +81,20 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
/// Checks whether the given type is an f8E5M2FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
/// Creates an f8E5M2FNUZ type in the given context. The type is owned by the
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx);
/// Checks whether the given type is an f8E4M3FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
/// Creates an f8E4M3FNUZ type in the given context. The type is owned by the
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);
/// Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);

View File

@@ -62,6 +62,8 @@ public:
// Types.
FloatType getFloat8E5M2Type();
FloatType getFloat8E4M3FNType();
FloatType getFloat8E5M2FNUZType();
FloatType getFloat8E4M3FNUZType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getF32Type();

View File

@@ -47,6 +47,8 @@ public:
static FloatType getF128(MLIRContext *ctx);
static FloatType getFloat8E5M2(MLIRContext *ctx);
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
@@ -374,8 +376,9 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}
inline bool FloatType::classof(Type type) {
return type.isa<Float8E5M2Type, Float8E4M3FNType, BFloat16Type, Float16Type,
Float32Type, Float64Type, Float80Type, Float128Type>();
return type.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, BFloat16Type, Float16Type, Float32Type,
Float64Type, Float80Type, Float128Type>();
}
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -386,6 +389,14 @@ inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
return Float8E4M3FNType::get(ctx);
}
inline FloatType FloatType::getFloat8E5M2FNUZ(MLIRContext *ctx) {
return Float8E5M2FNUZType::get(ctx);
}
inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) {
return Float8E4M3FNUZType::get(ctx);
}
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}

View File

@@ -118,6 +118,50 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> {
}];
}
//===----------------------------------------------------------------------===//
// Float8E5M2FNUZType
def Builtin_Float8E5M2FNUZ : Builtin_FloatType<"Float8E5M2FNUZ"> {
let summary = "8-bit floating point with 2 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
mantissa. This is not a standard type as defined by IEEE-754, but it follows
similar conventions, with the exception that there are no infinity values,
no negative zero, and only one NaN representation. This type has the
following characteristics:
* bit encoding: S1E5M2
* exponent bias: 16
* infinities: Not supported
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
* denormals when exponent is 0
Described in: https://arxiv.org/abs/2206.02915
}];
}
//===----------------------------------------------------------------------===//
// Float8E4M3FNUZType
def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ"> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
mantissa. This is not a standard type as defined by IEEE-754, but it follows
similar conventions, with the exception that there are no infinity values,
no negative zero, and only one NaN representation. This type has the
following characteristics:
* bit encoding: S1E4M3
* exponent bias: 8
* infinities: Not supported
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
* denormals when exponent is 0
Described in: https://arxiv.org/abs/2209.05433
}];
}
//===----------------------------------------------------------------------===//
// BFloat16Type

View File

@@ -488,6 +488,10 @@ def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
BuildableType<"$_builder.getFloat8E4M3FNType()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
BuildableType<"$_builder.getFloat8E5M2Type()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
"complex-type", "::mlir::ComplexType">;

View File

@@ -122,6 +122,8 @@ public:
bool isIndex() const;
bool isFloat8E5M2() const;
bool isFloat8E4M3FN() const;
bool isFloat8E5M2FNUZ() const;
bool isFloat8E4M3FNUZ() const;
bool isBF16() const;
bool isF16() const;
bool isF32() const;

View File

@@ -95,6 +95,8 @@ TOK_KEYWORD(f64)
TOK_KEYWORD(f80)
TOK_KEYWORD(f8E5M2)
TOK_KEYWORD(f8E4M3FN)
TOK_KEYWORD(f8E5M2FNUZ)
TOK_KEYWORD(f8E4M3FNUZ)
TOK_KEYWORD(f128)
TOK_KEYWORD(false)
TOK_KEYWORD(floordiv)

View File

@@ -33,6 +33,8 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::inttype:
case Token::kw_f8E5M2:
case Token::kw_f8E4M3FN:
case Token::kw_f8E5M2FNUZ:
case Token::kw_f8E4M3FNUZ:
case Token::kw_bf16:
case Token::kw_f16:
case Token::kw_f32:
@@ -295,6 +297,12 @@ Type Parser::parseNonFunctionType() {
case Token::kw_f8E4M3FN:
consumeToken(Token::kw_f8E4M3FN);
return builder.getFloat8E4M3FNType();
case Token::kw_f8E5M2FNUZ:
consumeToken(Token::kw_f8E5M2FNUZ);
return builder.getFloat8E5M2FNUZType();
case Token::kw_f8E4M3FNUZ:
consumeToken(Token::kw_f8E4M3FNUZ);
return builder.getFloat8E4M3FNUZType();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();

View File

@@ -139,6 +139,42 @@ public:
}
};
/// Floating Point Type subclass - Float8E4M3FNUZ.
class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
static constexpr const char *pyClassName = "Float8E4M3FNUZType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
return PyFloat8E4M3FNUZType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e4m3fnuz type.");
}
};
/// Floating Point Type subclass - Float8E5M2FNUZ.
class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
static constexpr const char *pyClassName = "Float8E5M2FNUZType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
return PyFloat8E5M2FNUZType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e5m2fnuz type.");
}
};
/// Floating Point Type subclass - BF16Type.
class PyBF16Type : public PyConcreteType<PyBF16Type> {
public:
@@ -700,6 +736,8 @@ void mlir::python::populateIRTypes(py::module &m) {
PyIndexType::bind(m);
PyFloat8E4M3FNType::bind(m);
PyFloat8E5M2Type::bind(m);
PyFloat8E4M3FNUZType::bind(m);
PyFloat8E5M2FNUZType::bind(m);
PyBF16Type::bind(m);
PyF16Type::bind(m);
PyF32Type::bind(m);

View File

@@ -84,6 +84,22 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
}
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
return unwrap(type).isFloat8E5M2FNUZ();
}
MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx)));
}
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
return unwrap(type).isFloat8E4M3FNUZ();
}
MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
}
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
MlirType mlirBF16TypeGet(MlirContext ctx) {

View File

@@ -2410,6 +2410,8 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Case<IndexType>([&](Type) { os << "index"; })
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
.Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
.Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
.Case<Float16Type>([&](Type) { os << "f16"; })
.Case<Float32Type>([&](Type) { os << "f32"; })

View File

@@ -41,6 +41,14 @@ FloatType Builder::getFloat8E4M3FNType() {
return FloatType::getFloat8E4M3FN(context);
}
FloatType Builder::getFloat8E5M2FNUZType() {
return FloatType::getFloat8E5M2FNUZ(context);
}
FloatType Builder::getFloat8E4M3FNUZType() {
return FloatType::getFloat8E4M3FNUZ(context);
}
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
FloatType Builder::getF16Type() { return FloatType::getF16(context); }

View File

@@ -88,7 +88,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
//===----------------------------------------------------------------------===//
unsigned FloatType::getWidth() {
if (isa<Float8E5M2Type, Float8E4M3FNType>())
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType>())
return 8;
if (isa<Float16Type, BFloat16Type>())
return 16;
@@ -109,6 +110,10 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
return APFloat::Float8E5M2();
if (isa<Float8E4M3FNType>())
return APFloat::Float8E4M3FN();
if (isa<Float8E5M2FNUZType>())
return APFloat::Float8E5M2FNUZ();
if (isa<Float8E4M3FNUZType>())
return APFloat::Float8E4M3FNUZ();
if (isa<BFloat16Type>())
return APFloat::BFloat();
if (isa<Float16Type>())

View File

@@ -209,6 +209,8 @@ public:
/// Cached Type Instances.
Float8E5M2Type f8E5M2Ty;
Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
Float32Type f32Ty;
@@ -281,6 +283,8 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
/// Floating-point Types.
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@@ -870,6 +874,12 @@ Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNTy;
}
Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E5M2FNUZTy;
}
Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNUZTy;
}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}

View File

@@ -36,6 +36,8 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }
bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
bool Type::isBF16() const { return isa<BFloat16Type>(); }
bool Type::isF16() const { return isa<Float16Type>(); }
bool Type::isF32() const { return isa<Float32Type>(); }

View File

@@ -52,6 +52,8 @@ __all__ = [
"DictAttr",
"Float8E4M3FNType",
"Float8E5M2Type",
"Float8E4M3FNUZType",
"Float8E5M2FNUZType",
"F16Type",
"F32Type",
"F64Type",
@@ -593,6 +595,20 @@ class Float8E5M2Type(Type):
@staticmethod
def isinstance(arg: Any) -> bool: ...
class Float8E4M3FNUZType(Type):
def __init__(self, cast_from_type: Type) -> None: ...
@staticmethod
def get(*args, **kwargs) -> Float8E4M3FNUZType: ...
@staticmethod
def isinstance(arg: Any) -> bool: ...
class Float8E5M2FNUZType(Type):
def __init__(self, cast_from_type: Type) -> None: ...
@staticmethod
def get(*args, **kwargs) -> Float8E5M2FNUZType: ...
@staticmethod
def isinstance(arg: Any) -> bool: ...
# TODO: Auto-generated. Audit and fix.
class F16Type(Type):
def __init__(self, cast_from_type: Type) -> None: ...

View File

@@ -44,6 +44,14 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FN
float_attr = 2. : f8E4M3FN
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ
float_attr = 2. : f8E5M2FNUZ
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
float_attr = 2. : f8E4M3FNUZ
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f16
float_attr = 2. : f16

View File

@@ -197,6 +197,10 @@ def testFloatType():
print("float:", Float8E4M3FNType.get())
# CHECK: float: f8E5M2
print("float:", Float8E5M2Type.get())
# CHECK: float: f8E5M2FNUZ
print("float:", Float8E5M2FNUZType.get())
# CHECK: float: f8E4M3FNUZ
print("float:", Float8E4M3FNUZType.get())
# CHECK: float: bf16
print("float:", BF16Type.get())
# CHECK: float: f16

View File

@@ -52,6 +52,8 @@ builtin_attr_type_mnemonics = {
"mlir::UnknownLoc": '"loc(unknown)"',
"mlir::Float8E5M2Type": '"f8E5M2"',
"mlir::Float8E4M3FNType": '"f8E4M3FN"',
"mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"',
"mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
"mlir::BFloat16Type": '"bf16"',
"mlir::Float16Type": '"f16"',
"mlir::Float32Type": '"f32"',