[TOSA] Use attributes for unsigned rescale (#118075)

Unsigned integer types are uncommon enough in MLIR that there is no
operation to cast a scalar from signless to unsigned and vice versa.
Currently tosa.rescale uses builtin.unrealized_conversion_cast which
does not lower. Instead, this commit introduces optional attributes to
indicate unsigned input or output, named similarly to those in the TOSA
specification. This is more in line with the rest of MLIR where specific
operations rather than values are signed/unsigned.
This commit is contained in:
Thomas Preud'homme
2024-12-04 09:17:55 +00:00
committed by GitHub
parent bba2507c19
commit 720864907d
3 changed files with 63 additions and 48 deletions

View File

@@ -1869,22 +1869,23 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
let description = [{
Rescale quantized values into a new domain. Supported rescalings are:
| Mode | Input | Output |
|------------------------|-------|--------|
| signed 8 to 8 | int8 | int8 |
| signed 8 to 16 | int8 | int16 |
| signed 8 to 32 | int8 | int32 |
| signed 16 to 8 | int16 | int8 |
| signed 16 to 16 | int16 | int16 |
| signed 16 to 32 | int16 | int32 |
| signed 32 to 8 | int32 | int8 |
| signed 32 to 16 | int32 | int16 |
| signed 32 to 32 | int32 | int32 |
| signed 48 to 8 | int48 | int8 |
| signed 48 to 16 | int48 | int16 |
| signed 48 to 32 | int48 | int32 |
| unsigned 8 to signed 8 | uint8 | int8 |
| signed 8 to unsigned 8 | int8 | uint8 |
| Mode | Input | Output | Unsigned | Unsigned |
| | | | input | output |
|------------------------|-------|--------|----------|----------|
| signed 8 to 8 | int8 | int8 | false | false |
| signed 8 to 16 | int8 | int16 | false | false |
| signed 8 to 32 | int8 | int32 | false | false |
| signed 16 to 8 | int16 | int8 | false | false |
| signed 16 to 16 | int16 | int16 | false | false |
| signed 16 to 32 | int16 | int32 | false | false |
| signed 32 to 8 | int32 | int8 | false | false |
| signed 32 to 16 | int32 | int16 | false | false |
| signed 32 to 32 | int32 | int32 | false | false |
| signed 48 to 8 | int48 | int8 | false | false |
| signed 48 to 16 | int48 | int16 | false | false |
| signed 48 to 32 | int48 | int32 | false | false |
| unsigned 8 to signed 8 | uint8 | int8 | true | false |
| signed 8 to unsigned 8 | int8 | uint8 | false | true |
}];
let arguments = (ins
@@ -1895,13 +1896,33 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
DenseI8ArrayAttr:$shift,
BoolAttr:$scale32,
BoolAttr:$double_round,
BoolAttr:$per_channel
BoolAttr:$per_channel,
DefaultValuedOptionalAttr<BoolAttr, "false">:$input_unsigned,
DefaultValuedOptionalAttr<BoolAttr, "false">:$output_unsigned
);
let results = (outs
Tosa_Tensor:$output
);
// Custom builder that does not require optional input_unsigned and
// output_unsigned.
let builders = [
OpBuilder<(ins "::mlir::Type":$output,
"::mlir::Value":$input,
"::mlir::IntegerAttr":$input_zp,
"::mlir::IntegerAttr":$output_zp,
"::mlir::DenseI32ArrayAttr":$multiplier,
"::mlir::DenseI8ArrayAttr":$shift,
"::mlir::BoolAttr":$scale32,
"::mlir::BoolAttr":$double_round,
"::mlir::BoolAttr":$per_channel), [{
auto FalseAttr = BoolAttr::get($_builder.getContext(), false);
build($_builder, $_state, output, input, input_zp, output_zp, multiplier,
shift, scale32, double_round, per_channel, FalseAttr, FalseAttr);
}]>
];
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}

View File

@@ -1261,14 +1261,7 @@ public:
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
if (valueTy.getIntOrFloatBitWidth() < 32) {
if (valueTy.isUnsignedInteger()) {
value = nestedBuilder
.create<UnrealizedConversionCastOp>(
nestedLoc,
nestedBuilder.getIntegerType(
valueTy.getIntOrFloatBitWidth()),
value)
.getResult(0);
if (op.getInputUnsigned()) {
value = nestedBuilder.create<arith::ExtUIOp>(
nestedLoc, nestedBuilder.getI32Type(), value);
} else {
@@ -1297,7 +1290,7 @@ public:
int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
// Unsigned integers have a difference output value.
if (outIntType.isUnsignedInteger()) {
if (op.getOutputUnsigned()) {
intMin = 0;
intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
}
@@ -1314,13 +1307,6 @@ public:
value = nestedBuilder.create<arith::TruncIOp>(
nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
value);
if (outIntType.isUnsignedInteger()) {
value = nestedBuilder
.create<UnrealizedConversionCastOp>(nestedLoc,
outIntType, value)
.getResult(0);
}
}
nestedBuilder.create<linalg::YieldOp>(loc, value);

View File

@@ -1132,11 +1132,21 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: linalg.yield [[TRUNC]]
%0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
// CHECK: return
return
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @rescale_i8_unsigned_output
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[C0:%.+]] = arith.constant 19689
// CHECK: [[C1:%.+]] = arith.constant 15
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>)
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8):
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
// CHECK: [[C17:%.+]] = arith.constant 17
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
@@ -1148,9 +1158,8 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
// CHECK: linalg.yield [[CAST]]
%1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xui8>
// CHECK: linalg.yield [[TRUNC]]
%1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, output_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1171,9 +1180,9 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xui8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xui8>)
%1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xui8>
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
%1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, output_unsigned = true} : (tensor<?x2xi8>) -> tensor<?x2xi8>
return
}
@@ -1199,18 +1208,17 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @rescale_ui8
// CHECK-LABEL: @rescale_i8_unsigned_input
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[C0:%.+]] = arith.constant 19689
// CHECK: [[C1:%.+]] = arith.constant 15
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>)
// CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8):
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
// CHECK: [[C17:%.+]] = arith.constant 17
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[CAST]]
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
@@ -1220,7 +1228,7 @@ func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK: linalg.yield [[TRUNC]]
%0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>) -> tensor<2xi8>
%0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
return
}