Introduce arith.scaling_extf and arith.scaling_truncf (#141965)

This PR adds `arith.scaling_truncf` and `arith.scaling_extf` operations
which supports the block quantization following OCP MXFP specs listed
here
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

OCP MXFP Spec comes with reference implementation here
https://github.com/microsoft/microxcaling/tree/main

Interesting piece of reference code is this method `_quantize_mx`
7bc41952de/mx/mx_ops.py (L173).

Both `arith.scaling_truncf` and `arith.scaling_extf` are designed to be
an elementwise operation. Please see description about them in
`ArithOps.td` file for more details.
 
Internally, 

`arith.scaling_truncf` does the
`arith.truncf(arith.divf(input/(2^scale)))`. `scale` should have
necessary broadcast, clamping, normalization and NaN propagation done
before callling into `arith.scaling_truncf`.

`arith.scaling_extf` does the `arith.mulf(2^scale, input)` after taking
care of necessary data type conversions.


CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar
@tgymnich

---------

Co-authored-by: Prashant Kumar <pk5561@gmail.com>
Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
This commit is contained in:
Umang Yadav
2025-06-09 14:13:31 -04:00
committed by GitHub
parent 5d6218d311
commit 7f08503a3b
7 changed files with 412 additions and 14 deletions

View File

@@ -1215,6 +1215,58 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
attr-dict `:` type($in) `to` type($out) }];
}
//===----------------------------------------------------------------------===//
// Scaling ExtFOp
//===----------------------------------------------------------------------===//
def Arith_ScalingExtFOp
: Arith_Op<
"scaling_extf", [Pure, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in, FloatLike:$scale,
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
Results<(outs FloatLike:$out)> {
let summary = "Upcasts input floats using provided scales values following "
"OCP MXFP Spec";
let description = [{
This operation upcasts input floating-point values using provided scale
values. It expects both scales and the input operand to be of the same shape,
making the operation elementwise. Scales are usually calculated per block
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
If scales are calculated per block where blockSize != 1, then scales may
require broadcasting to make this operation elementwise. For example, let's
say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
assuming quantization happens on the last axis, the input can be reshaped to
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
per block on the last axis. Therefore, scales will be of shape
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
shape as long as it is broadcast compatible with the input, e.g.,
`<1 x 1 x ... (dimN/blockSize) x 1>`.
In this example, before calling into `arith.scaling_extf`, scales must be
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
that there could be multiple quantization axes. Internally,
`arith.scaling_extf` would perform the following:
```
resultTy = get_type(result)
scaleTy = get_type(scale)
inputTy = get_type(input)
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
input.extf = arith.extf(input) : inputTy to resultTy
result = arith.mulf(scale.extf, input.extf)
```
It propagates NaN values. Therefore, if either scale or the input element
contains NaN, then the output element value will also be a NaN.
}];
let hasVerifier = 1;
let assemblyFormat =
[{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:`
type($in) `,` type($scale) `to` type($out)}];
}
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
@@ -1280,6 +1332,63 @@ def Arith_TruncFOp :
attr-dict `:` type($in) `to` type($out) }];
}
//===----------------------------------------------------------------------===//
// Scaling TruncFOp
//===----------------------------------------------------------------------===//
def Arith_ScalingTruncFOp
: Arith_Op<"scaling_truncf",
[Pure, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in, FloatLike:$scale,
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
Results<(outs FloatLike:$out)> {
let summary = "Downcasts input floating point values using provided scales "
"values following OCP MXFP Spec";
let description = [{
This operation downcasts input using the provided scale values. It expects
both scales and the input operand to be of the same shape and, therefore,
makes the operation elementwise. Scales are usually calculated per block
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
Users are required to normalize and clamp the scales as necessary before calling
passing them to this operation. OCP MXFP spec also does the flushing of denorms
on the input operand, which should be handled during lowering by passing appropriate
fastMath flag to this operation.
If scales are calculated per block where blockSize != 1, scales may require
broadcasting to make this operation elementwise. For example, let's say the
input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
assuming quantization happens on the last axis, the input can be reshaped to
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
per block on the last axis. Therefore, scales will be of shape
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
shape as long as it is broadcast compatible with the input, e.g.,
`<1 x 1 x ... (dimN/blockSize) x 1>`.
In this example, before calling into `arith.scaling_truncf`, scales must be
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
that there could be multiple quantization axes. Internally,
`arith.scaling_truncf` would perform the following:
```
scaleTy = get_type(scale)
inputTy = get_type(input)
resultTy = get_type(result)
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
result = arith.divf(input, scale.extf)
result.cast = arith.truncf(result, resultTy)
```
}];
let hasVerifier = 1;
let assemblyFormat =
[{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:`
type($in) `,` type($scale) `to` type($out)}];
}
//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//

View File

@@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);

View File

@@ -60,6 +60,7 @@ public:
Attribute metadata = Attribute());
// Types.
FloatType getF8E8M0Type();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();

View File

@@ -1451,6 +1451,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
//===----------------------------------------------------------------------===//
// ScalingExtFOp
//===----------------------------------------------------------------------===//
bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
TypeRange outputs) {
return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
}
LogicalResult arith::ScalingExtFOp::verify() {
return verifyExtOp<FloatType>(*this);
}
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
@@ -1565,6 +1578,19 @@ LogicalResult arith::TruncFOp::verify() {
return verifyTruncateOp<FloatType>(*this);
}
//===----------------------------------------------------------------------===//
// ScalingTruncFOp
//===----------------------------------------------------------------------===//
bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
TypeRange outputs) {
return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
}
LogicalResult arith::ScalingTruncFOp::verify() {
return verifyTruncateOp<FloatType>(*this);
}
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//

View File

@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(shapedTy, attr));
}
return rewriter.create<arith::ConstantOp>(loc, attr);
}
@@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
if (resultETy.getIntOrFloatBitWidth() < 32) {
result = b.create<arith::TruncFOp>(resultTy, result);
result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
op.getFastmathAttr());
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
result = b.create<arith::ExtFOp>(resultTy, result);
result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
}
rewriter.replaceOp(op, result);
return success();
@@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
if (operandETy.getIntOrFloatBitWidth() < 32) {
operand = b.create<arith::ExtFOp>(f32Ty, operand);
operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
operand = b.create<arith::TruncFOp>(f32Ty, operand);
operand = b.create<arith::TruncFOp>(
f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
}
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
@@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value inputOperand = op.getIn();
Value scaleOperand = op.getScale();
Type scaleTy = scaleOperand.getType();
Type scaleETy = getElementTypeOrSelf(scaleOperand);
// allow implicit exponent extraction from 16/32 bits floats
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
scaleETy = b.getF8E8M0Type();
scaleTy = cloneToShapedType(scaleTy, scaleETy);
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
op.getFastmathAttr());
}
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
return rewriter.notifyMatchFailure(
op, "scaling_extf is using scales of type which can not be converted "
"to f8E8M0FNU");
}
Type resultTy = op.getType();
// extf on scale will essentially create floating point number
// of type resulTy that is 2^scale and will also propagate NaNs
Value scaleExt =
b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
Value inputExt =
b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
Value result =
b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
rewriter.replaceOp(op, result);
return success();
}
};
/*
Expands arith.ScalingTruncFOp(in, scale) into
scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
result = arith.truncf(in / (2^scale))
*/
struct ScalingTruncFOpConverter
: public OpRewritePattern<arith::ScalingTruncFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value inputOperand = op.getIn();
Value scaleOperand = op.getScale();
Type scaleTy = scaleOperand.getType();
Type scaleETy = getElementTypeOrSelf(scaleOperand);
// allow implicit exponent extraction from 16/32 bits floats
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
scaleETy = b.getF8E8M0Type();
scaleTy = cloneToShapedType(scaleTy, scaleETy);
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
op.getFastmathAttr());
}
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
return rewriter.notifyMatchFailure(
op, "scaling_truncf is using scales type which can not be converted "
"to f8E8M0FNU");
}
Type resultTy = op.getType();
Type inputTy = inputOperand.getType();
// this will create a floating point number of type
// inputTy that is 2^scale and will also propagate NaNs
scaleOperand =
b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
op.getFastmathAttr());
Value resultCast = b.create<arith::TruncFOp>(
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
rewriter.replaceOp(op, resultCast);
return success();
}
};
struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -432,7 +510,9 @@ struct ArithExpandOpsPass
arith::MaximumFOp,
arith::MinimumFOp,
arith::MaxNumFOp,
arith::MinNumFOp
arith::MinNumFOp,
arith::ScalingExtFOp,
arith::ScalingTruncFOp
>();
if (includeBf16) {
@@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}
void mlir::arith::populateExpandScalingExtTruncPatterns(
RewritePatternSet &patterns) {
patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
patterns.getContext());
}
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
populateExpandScalingExtTruncPatterns(patterns);
// clang-format off
patterns.add<
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
@@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
>(patterns.getContext());
// clang-format on
}

View File

@@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//
FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
FloatType Builder::getF16Type() { return Float16Type::get(context); }

View File

@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s
// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
// Test ceil divide with signed integer
// CHECK-LABEL: func @ceildivi
@@ -253,7 +254,7 @@ func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
%0 = arith.truncf %arg0 : f32 to f8E8M0FNU
return %0 : f8E8M0FNU
}
// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
// CHECK-LABEL: @truncf_f32_to_f8E8M0FNU
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
@@ -267,7 +268,7 @@ func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
%0 = arith.truncf %arg0 : f16 to f8E8M0FNU
return %0 : f8E8M0FNU
}
// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
// CHECK-LABEL: @truncf_f16_to_f8E8M0FNU
// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
@@ -305,9 +306,76 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf
// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
// CHECK-NOT: arith.truncf
// CHECK: return
// -----
func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
%0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
return %0 : f4E2M1FN
}
// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
// SCHECK: return %[[RESULT]]
// -----
func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
%0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
return %0 : vector<4xf6E3M2FN>
}
// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
// SCHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
// -----
func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
%0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
return %0 : vector<4xf6E3M2FN>
}
// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
// SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
// SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
// -----
func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
%0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
return %0 : f4E2M1FN
}
// SCHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
// SCHECK: return
// -----
func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
%0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
return %0 : vector<4xf4E2M1FN>
}
// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
// SCHECK: return
// -----
func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN {
// expected-error@+1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}}
%0 = arith.scaling_truncf %arg0, %arg1 : f16, f8E5M2FNUZ to f4E2M1FN
return %0 : f4E2M1FN
}
// -----
func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
%0 = arith.extf %arg0 : f8E8M0FNU to f32
return %0 : f32
@@ -332,7 +400,7 @@ func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
return %0 : f16
}
// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
// CHECK-LABEL: @extf_f8E8M0FNU_to_f16
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
@@ -374,7 +442,109 @@ func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<
// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
// CHECK-NOT: arith.extf
// CHECK: return
// -----
func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
%0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
return %0 : f32
}
// SCHECK-LABEL: @scaling_extf_to_f32
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
// SCHECK: return %[[RESULT]]
// -----
func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
%0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
return %0 : f32
}
// SCHECK-LABEL: @scaling_extf_to_f32_using_f16_scales
// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
// SCHECK: return %[[RESULT]]
// -----
func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 {
// expected-error@+1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
%0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32
return %0 : f32
}
// -----
func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
return %0 : vector<4xf32>
}
// SCHECK-LABEL: @scaling_extf_vector_to_f32
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
// SCHECK: return %[[RESULT]]
// -----
func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
return %0 : vector<4xf16>
}
// SCHECK-LABEL: @scaling_extf_vector_to_f16
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16>
// SCHECK: return %[[RESULT]]
// -----
func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
return %0 : vector<4xbf16>
}
// SCHECK-LABEL: @scaling_extf_vector_to_bf16
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16>
// SCHECK: return %[[RESULT]]
// -----
func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
return %0 : vector<4xf32>
}
// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
// SCHECK: return %[[RESULT]]
// -----
func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
%0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
return %0 : vector<4xf32>
}
// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
// SCHECK: return %[[RESULT]]
// -----