This fixes all the places in MLIR that hit the new assertion added in #106524, in preparation for enabling it by default. That is, cases where the value passed to the APInt constructor is not an N-bit signed/unsigned integer, where N is the bit width and signedness is determined by the isSigned flag. The fixes either set the correct value for isSigned, or set the implicitTrunc flag to retain the old behavior. I've left TODOs for the latter case in some places, where I think that it may be worthwhile to stop doing implicit truncation in the future. Note that the assertion is currently still disabled by default, so this patch is mostly NFC. This is just the MLIR changes split off from https://github.com/llvm/llvm-project/pull/80309.
253 lines
9.8 KiB
C++
253 lines
9.8 KiB
C++
//===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// These rewriters lower from the Tosa to the Arith dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace tosa;
|
|
|
|
namespace {
|
|
|
|
class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ConstOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
Type matchContainerType(Type element, Type container) {
|
|
if (auto shapedTy = dyn_cast<ShapedType>(container))
|
|
return shapedTy.clone(element);
|
|
|
|
return element;
|
|
}
|
|
|
|
TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
|
|
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
|
|
Type eTy = shapedTy.getElementType();
|
|
APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
|
|
return DenseIntElementsAttr::get(shapedTy, valueInt);
|
|
}
|
|
|
|
return rewriter.getIntegerAttr(type, value);
|
|
}
|
|
|
|
Value getConstantValue(Location loc, Type type, int64_t value,
|
|
PatternRewriter &rewriter) {
|
|
return rewriter.create<arith::ConstantOp>(
|
|
loc, getConstantAttr(type, value, rewriter));
|
|
}
|
|
|
|
// This converts the TOSA ApplyScale operator to a set of arithmetic ops,
|
|
// using 64-bit operations to perform the necessary multiply, bias, and shift.
|
|
class ApplyScaleGenericOpConverter
|
|
: public OpRewritePattern<tosa::ApplyScaleOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
Value value = op.getValue();
|
|
Value multiplier32 = op.getMultiplier();
|
|
|
|
Type resultTy = op.getType();
|
|
Type valueTy = value.getType();
|
|
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
|
|
Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
|
|
|
|
Value zero = getConstantValue(loc, valueTy, 0, rewriter);
|
|
Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
|
|
Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
|
|
|
|
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
|
|
|
|
// Compute the multiplication in 64-bits then select the high / low parts.
|
|
Value value64 = value;
|
|
if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
|
|
value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
|
|
Value multiplier64 =
|
|
rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
|
|
Value multiply64 =
|
|
rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
|
|
|
|
// Apply normal rounding.
|
|
Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
|
|
Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
|
|
round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
|
|
multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
|
|
|
|
// Apply double rounding if necessary.
|
|
if (op.getDoubleRound()) {
|
|
int64_t roundInt = 1 << 30;
|
|
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
|
|
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
|
|
Value positive = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::sge, value, zero);
|
|
Value dir =
|
|
rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
|
|
Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
|
|
Value valid = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
|
|
multiply64 =
|
|
rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
|
|
}
|
|
|
|
Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
|
|
Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
|
|
|
|
rewriter.replaceOp(op, result32);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
|
|
Type resultTy = op.getType();
|
|
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
|
|
|
|
Value value = op.getValue();
|
|
if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
|
|
return failure();
|
|
}
|
|
|
|
Value value32 = op.getValue();
|
|
Value multiplier32 = op.getMultiplier();
|
|
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
|
|
|
|
// Constants used during the scaling operation.
|
|
Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
|
|
Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
|
|
Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
|
|
Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
|
|
Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
|
|
|
|
// Compute the multiplication in 64-bits then select the high / low parts.
|
|
// Grab out the high/low of the computation
|
|
auto value64 =
|
|
rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
|
|
Value low32 = value64.getLow();
|
|
Value high32 = value64.getHigh();
|
|
|
|
// Determine the direction and amount to shift the high bits.
|
|
Value shiftOver32 = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
|
|
Value roundHighBits = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
|
|
|
|
Value shiftHighL =
|
|
rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
|
|
Value shiftHighR =
|
|
rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
|
|
|
|
shiftHighL =
|
|
rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
|
|
shiftHighR =
|
|
rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
|
|
|
|
// Conditionally perform our double round.
|
|
if (op.getDoubleRound()) {
|
|
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
|
|
Value valuePositive = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::sge, value32, zero32);
|
|
|
|
Value roundDir =
|
|
rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
|
|
roundDir =
|
|
rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
|
|
|
|
Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
|
|
Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
|
|
Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
|
|
|
|
Value shiftRound =
|
|
rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
|
|
|
|
low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
|
|
high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
|
|
}
|
|
|
|
// Conditionally apply rounding in the low bits.
|
|
{
|
|
Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
|
|
Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
|
|
roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
|
|
roundBit);
|
|
|
|
Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
|
|
Value wasRounded = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::ugt, low32, newLow32);
|
|
low32 = newLow32;
|
|
|
|
Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
|
|
high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
|
|
}
|
|
|
|
// Conditionally apply rounding in the high bits.
|
|
{
|
|
Value shiftSubOne =
|
|
rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
|
|
Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
|
|
roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
|
|
zero32);
|
|
high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
|
|
}
|
|
|
|
// Combine the correct high/low bits into the final rescale result.
|
|
high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
|
|
high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
|
|
low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
|
|
low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
|
|
|
|
// Apply the rounding behavior and shift to the final alignment.
|
|
Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
|
|
|
|
// Truncate if necessary.
|
|
if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
|
|
result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
|
|
}
|
|
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tosa::populateTosaToArithConversionPatterns(
|
|
RewritePatternSet *patterns) {
|
|
patterns->add<ConstOpConverter>(patterns->getContext());
|
|
}
|
|
|
|
void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
|
|
RewritePatternSet *patterns, bool include32Bit) {
|
|
patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
|
|
if (include32Bit) {
|
|
patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
|
|
}
|
|
}
|