Introduce arith.scaling_extf and arith.scaling_truncf (#141965)
This PR adds `arith.scaling_truncf` and `arith.scaling_extf` operations
which supports the block quantization following OCP MXFP specs listed
here
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
OCP MXFP Spec comes with reference implementation here
https://github.com/microsoft/microxcaling/tree/main
Interesting piece of reference code is this method `_quantize_mx`
7bc41952de/mx/mx_ops.py (L173).
Both `arith.scaling_truncf` and `arith.scaling_extf` are designed to be
an elementwise operation. Please see description about them in
`ArithOps.td` file for more details.
Internally,
`arith.scaling_truncf` does the
`arith.truncf(arith.divf(input/(2^scale)))`. `scale` should have
necessary broadcast, clamping, normalization and NaN propagation done
before callling into `arith.scaling_truncf`.
`arith.scaling_extf` does the `arith.mulf(2^scale, input)` after taking
care of necessary data type conversions.
CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar
@tgymnich
---------
Co-authored-by: Prashant Kumar <pk5561@gmail.com>
Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
This commit is contained in:
@@ -1215,6 +1215,58 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
|
||||
attr-dict `:` type($in) `to` type($out) }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Scaling ExtFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Arith_ScalingExtFOp
|
||||
: Arith_Op<
|
||||
"scaling_extf", [Pure, SameInputOutputTensorDims,
|
||||
DeclareOpInterfaceMethods<ArithFastMathInterface>,
|
||||
DeclareOpInterfaceMethods<CastOpInterface>]>,
|
||||
Arguments<(ins FloatLike:$in, FloatLike:$scale,
|
||||
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
|
||||
Results<(outs FloatLike:$out)> {
|
||||
let summary = "Upcasts input floats using provided scales values following "
|
||||
"OCP MXFP Spec";
|
||||
let description = [{
|
||||
This operation upcasts input floating-point values using provided scale
|
||||
values. It expects both scales and the input operand to be of the same shape,
|
||||
making the operation elementwise. Scales are usually calculated per block
|
||||
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
|
||||
|
||||
If scales are calculated per block where blockSize != 1, then scales may
|
||||
require broadcasting to make this operation elementwise. For example, let's
|
||||
say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
|
||||
assuming quantization happens on the last axis, the input can be reshaped to
|
||||
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
|
||||
per block on the last axis. Therefore, scales will be of shape
|
||||
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
|
||||
shape as long as it is broadcast compatible with the input, e.g.,
|
||||
`<1 x 1 x ... (dimN/blockSize) x 1>`.
|
||||
|
||||
In this example, before calling into `arith.scaling_extf`, scales must be
|
||||
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
|
||||
that there could be multiple quantization axes. Internally,
|
||||
`arith.scaling_extf` would perform the following:
|
||||
|
||||
```
|
||||
resultTy = get_type(result)
|
||||
scaleTy = get_type(scale)
|
||||
inputTy = get_type(input)
|
||||
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
|
||||
scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
|
||||
input.extf = arith.extf(input) : inputTy to resultTy
|
||||
result = arith.mulf(scale.extf, input.extf)
|
||||
```
|
||||
It propagates NaN values. Therefore, if either scale or the input element
|
||||
contains NaN, then the output element value will also be a NaN.
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat =
|
||||
[{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:`
|
||||
type($in) `,` type($scale) `to` type($out)}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TruncIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1280,6 +1332,63 @@ def Arith_TruncFOp :
|
||||
attr-dict `:` type($in) `to` type($out) }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Scaling TruncFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Arith_ScalingTruncFOp
|
||||
: Arith_Op<"scaling_truncf",
|
||||
[Pure, SameInputOutputTensorDims,
|
||||
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
|
||||
DeclareOpInterfaceMethods<ArithFastMathInterface>,
|
||||
DeclareOpInterfaceMethods<CastOpInterface>]>,
|
||||
Arguments<(ins FloatLike:$in, FloatLike:$scale,
|
||||
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
|
||||
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
|
||||
Results<(outs FloatLike:$out)> {
|
||||
let summary = "Downcasts input floating point values using provided scales "
|
||||
"values following OCP MXFP Spec";
|
||||
let description = [{
|
||||
This operation downcasts input using the provided scale values. It expects
|
||||
both scales and the input operand to be of the same shape and, therefore,
|
||||
makes the operation elementwise. Scales are usually calculated per block
|
||||
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
|
||||
Users are required to normalize and clamp the scales as necessary before calling
|
||||
passing them to this operation. OCP MXFP spec also does the flushing of denorms
|
||||
on the input operand, which should be handled during lowering by passing appropriate
|
||||
fastMath flag to this operation.
|
||||
|
||||
If scales are calculated per block where blockSize != 1, scales may require
|
||||
broadcasting to make this operation elementwise. For example, let's say the
|
||||
input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
|
||||
assuming quantization happens on the last axis, the input can be reshaped to
|
||||
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
|
||||
per block on the last axis. Therefore, scales will be of shape
|
||||
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
|
||||
shape as long as it is broadcast compatible with the input, e.g.,
|
||||
`<1 x 1 x ... (dimN/blockSize) x 1>`.
|
||||
|
||||
In this example, before calling into `arith.scaling_truncf`, scales must be
|
||||
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
|
||||
that there could be multiple quantization axes. Internally,
|
||||
`arith.scaling_truncf` would perform the following:
|
||||
|
||||
```
|
||||
scaleTy = get_type(scale)
|
||||
inputTy = get_type(input)
|
||||
resultTy = get_type(result)
|
||||
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
|
||||
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
|
||||
result = arith.divf(input, scale.extf)
|
||||
result.cast = arith.truncf(result, resultTy)
|
||||
```
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat =
|
||||
[{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:`
|
||||
type($in) `,` type($scale) `to` type($out)}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UIToFPOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
|
||||
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
|
||||
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
|
||||
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Add patterns to expand Arith ops.
|
||||
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ public:
|
||||
Attribute metadata = Attribute());
|
||||
|
||||
// Types.
|
||||
FloatType getF8E8M0Type();
|
||||
FloatType getBF16Type();
|
||||
FloatType getF16Type();
|
||||
FloatType getTF32Type();
|
||||
|
||||
@@ -1451,6 +1451,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
|
||||
LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScalingExtFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
|
||||
TypeRange outputs) {
|
||||
return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
|
||||
}
|
||||
|
||||
LogicalResult arith::ScalingExtFOp::verify() {
|
||||
return verifyExtOp<FloatType>(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TruncIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1565,6 +1578,19 @@ LogicalResult arith::TruncFOp::verify() {
|
||||
return verifyTruncateOp<FloatType>(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScalingTruncFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
|
||||
TypeRange outputs) {
|
||||
return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
|
||||
}
|
||||
|
||||
LogicalResult arith::ScalingTruncFOp::verify() {
|
||||
return verifyTruncateOp<FloatType>(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AndIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value,
|
||||
return rewriter.create<arith::ConstantOp>(
|
||||
loc, DenseElementsAttr::get(shapedTy, attr));
|
||||
}
|
||||
|
||||
return rewriter.create<arith::ConstantOp>(loc, attr);
|
||||
}
|
||||
|
||||
@@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
|
||||
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
|
||||
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
|
||||
if (resultETy.getIntOrFloatBitWidth() < 32) {
|
||||
result = b.create<arith::TruncFOp>(resultTy, result);
|
||||
result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
|
||||
op.getFastmathAttr());
|
||||
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
|
||||
result = b.create<arith::ExtFOp>(resultTy, result);
|
||||
result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
@@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
|
||||
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
|
||||
|
||||
if (operandETy.getIntOrFloatBitWidth() < 32) {
|
||||
operand = b.create<arith::ExtFOp>(f32Ty, operand);
|
||||
operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
|
||||
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
|
||||
operand = b.create<arith::TruncFOp>(f32Ty, operand);
|
||||
operand = b.create<arith::TruncFOp>(
|
||||
f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
|
||||
}
|
||||
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
|
||||
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
|
||||
@@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
Value inputOperand = op.getIn();
|
||||
Value scaleOperand = op.getScale();
|
||||
Type scaleTy = scaleOperand.getType();
|
||||
Type scaleETy = getElementTypeOrSelf(scaleOperand);
|
||||
// allow implicit exponent extraction from 16/32 bits floats
|
||||
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
|
||||
scaleETy = b.getF8E8M0Type();
|
||||
scaleTy = cloneToShapedType(scaleTy, scaleETy);
|
||||
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
|
||||
op.getFastmathAttr());
|
||||
}
|
||||
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "scaling_extf is using scales of type which can not be converted "
|
||||
"to f8E8M0FNU");
|
||||
}
|
||||
Type resultTy = op.getType();
|
||||
// extf on scale will essentially create floating point number
|
||||
// of type resulTy that is 2^scale and will also propagate NaNs
|
||||
Value scaleExt =
|
||||
b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
|
||||
Value inputExt =
|
||||
b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
|
||||
Value result =
|
||||
b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
Expands arith.ScalingTruncFOp(in, scale) into
|
||||
scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
|
||||
result = arith.truncf(in / (2^scale))
|
||||
*/
|
||||
struct ScalingTruncFOpConverter
|
||||
: public OpRewritePattern<arith::ScalingTruncFOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
Value inputOperand = op.getIn();
|
||||
Value scaleOperand = op.getScale();
|
||||
Type scaleTy = scaleOperand.getType();
|
||||
Type scaleETy = getElementTypeOrSelf(scaleOperand);
|
||||
// allow implicit exponent extraction from 16/32 bits floats
|
||||
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
|
||||
scaleETy = b.getF8E8M0Type();
|
||||
scaleTy = cloneToShapedType(scaleTy, scaleETy);
|
||||
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
|
||||
op.getFastmathAttr());
|
||||
}
|
||||
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "scaling_truncf is using scales type which can not be converted "
|
||||
"to f8E8M0FNU");
|
||||
}
|
||||
Type resultTy = op.getType();
|
||||
Type inputTy = inputOperand.getType();
|
||||
// this will create a floating point number of type
|
||||
// inputTy that is 2^scale and will also propagate NaNs
|
||||
scaleOperand =
|
||||
b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
|
||||
Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
|
||||
op.getFastmathAttr());
|
||||
Value resultCast = b.create<arith::TruncFOp>(
|
||||
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
|
||||
rewriter.replaceOp(op, resultCast);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ArithExpandOpsPass
|
||||
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
|
||||
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
|
||||
@@ -432,7 +510,9 @@ struct ArithExpandOpsPass
|
||||
arith::MaximumFOp,
|
||||
arith::MinimumFOp,
|
||||
arith::MaxNumFOp,
|
||||
arith::MinNumFOp
|
||||
arith::MinNumFOp,
|
||||
arith::ScalingExtFOp,
|
||||
arith::ScalingTruncFOp
|
||||
>();
|
||||
|
||||
if (includeBf16) {
|
||||
@@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::arith::populateExpandScalingExtTruncPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
|
||||
populateCeilFloorDivExpandOpsPatterns(patterns);
|
||||
populateExpandScalingExtTruncPatterns(patterns);
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
|
||||
@@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
|
||||
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
|
||||
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
|
||||
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
|
||||
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
|
||||
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
|
||||
>(patterns.getContext());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
|
||||
// Types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }
|
||||
|
||||
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
|
||||
|
||||
FloatType Builder::getF16Type() { return Float16Type::get(context); }
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
|
||||
|
||||
// Test ceil divide with signed integer
|
||||
// CHECK-LABEL: func @ceildivi
|
||||
@@ -253,7 +254,7 @@ func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
|
||||
%0 = arith.truncf %arg0 : f32 to f8E8M0FNU
|
||||
return %0 : f8E8M0FNU
|
||||
}
|
||||
// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
|
||||
// CHECK-LABEL: @truncf_f32_to_f8E8M0FNU
|
||||
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
|
||||
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
|
||||
// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
|
||||
@@ -267,7 +268,7 @@ func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
|
||||
%0 = arith.truncf %arg0 : f16 to f8E8M0FNU
|
||||
return %0 : f8E8M0FNU
|
||||
}
|
||||
// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
|
||||
// CHECK-LABEL: @truncf_f16_to_f8E8M0FNU
|
||||
// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
|
||||
// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
|
||||
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
|
||||
@@ -305,9 +306,76 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf
|
||||
|
||||
// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
|
||||
// CHECK-NOT: arith.truncf
|
||||
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
|
||||
%0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
|
||||
return %0 : f4E2M1FN
|
||||
}
|
||||
|
||||
// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
|
||||
// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
|
||||
// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
|
||||
// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
|
||||
// SCHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
|
||||
%0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
|
||||
return %0 : vector<4xf6E3M2FN>
|
||||
}
|
||||
|
||||
// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
|
||||
// SCHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
|
||||
// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
|
||||
// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
|
||||
// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
|
||||
%0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
|
||||
return %0 : vector<4xf6E3M2FN>
|
||||
}
|
||||
// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
|
||||
// SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
|
||||
// SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
|
||||
// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
|
||||
// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
|
||||
// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
|
||||
%0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
|
||||
return %0 : f4E2M1FN
|
||||
}
|
||||
// SCHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
|
||||
// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
|
||||
// SCHECK: return
|
||||
|
||||
// -----
|
||||
func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
|
||||
%0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
|
||||
return %0 : vector<4xf4E2M1FN>
|
||||
}
|
||||
// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
|
||||
// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
|
||||
// SCHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN {
|
||||
// expected-error@+1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}}
|
||||
%0 = arith.scaling_truncf %arg0, %arg1 : f16, f8E5M2FNUZ to f4E2M1FN
|
||||
return %0 : f4E2M1FN
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
|
||||
%0 = arith.extf %arg0 : f8E8M0FNU to f32
|
||||
return %0 : f32
|
||||
@@ -332,7 +400,7 @@ func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
|
||||
return %0 : f16
|
||||
}
|
||||
|
||||
// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
|
||||
// CHECK-LABEL: @extf_f8E8M0FNU_to_f16
|
||||
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
|
||||
// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
|
||||
// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
|
||||
@@ -374,7 +442,109 @@ func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<
|
||||
|
||||
// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
|
||||
// CHECK-NOT: arith.extf
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
|
||||
%0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// SCHECK-LABEL: @scaling_extf_to_f32
|
||||
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
|
||||
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
|
||||
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
|
||||
// SCHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
|
||||
%0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// SCHECK-LABEL: @scaling_extf_to_f32_using_f16_scales
|
||||
// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU
|
||||
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32
|
||||
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
|
||||
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
|
||||
// SCHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 {
|
||||
// expected-error@+1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
|
||||
%0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
|
||||
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// SCHECK-LABEL: @scaling_extf_vector_to_f32
|
||||
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
|
||||
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
|
||||
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
|
||||
// SCHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
|
||||
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
|
||||
return %0 : vector<4xf16>
|
||||
}
|
||||
|
||||
// SCHECK-LABEL: @scaling_extf_vector_to_f16
|
||||
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
|
||||
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
|
||||
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16>
|
||||
// SCHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
|
||||
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
|
||||
return %0 : vector<4xbf16>
|
||||
}
|
||||
|
||||
// SCHECK-LABEL: @scaling_extf_vector_to_bf16
|
||||
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
|
||||
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
|
||||
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16>
|
||||
// SCHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
|
||||
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
|
||||
// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
|
||||
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
|
||||
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
|
||||
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
|
||||
// SCHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
|
||||
%0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
|
||||
// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
|
||||
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
|
||||
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
|
||||
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
|
||||
// SCHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
Reference in New Issue
Block a user