[mlir][tosa] Require signless types in validation and add corresponding conversion pass (#144367)
Firstly, this commit requires that all types are signless in the strict mode of the validation pass. This is because signless types on operations are required by the TOSA specification. The "strict" mode in the validation pass is the final check for TOSA conformance to the specification, which can often be used for conversion to other formats. In addition, a conversion pass `--tosa-convert-integer-type-to-signless` is provided to allow a user to convert all integer types to signless. The intention is that this pass can be run before the validation pass. Following use of this pass, input/output information should be carried independently by the user.
This commit is contained in:
@@ -127,4 +127,18 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> {
|
||||
let summary = "Convert integer types to signless";
|
||||
let description = [{
|
||||
This pass converts signed or unsigned integer types to signless. It
|
||||
currently does this greedily for all operators and can also change the
|
||||
signature of the function. Should the signature of the entrypoint
|
||||
function change, it will be the responsibility of the user to carry
|
||||
signedness information of the inputs and outputs independently.
|
||||
|
||||
This can be a useful transformation for conversion to other formats
|
||||
that require strict adherence to the TOSA specification.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
add_mlir_dialect_library(MLIRTosaTransforms
|
||||
TosaConvertIntegerTypeToSignless.cpp
|
||||
TosaDecomposeTransposeConv.cpp
|
||||
TosaDecomposeDepthwise.cpp
|
||||
TosaFolders.cpp
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
//===- TosaConvertIntegerTypeToSignless.cpp
|
||||
//-------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===-------------------------------------------------------------------------------===//
|
||||
|
||||
// -----------
|
||||
// Motivation:
|
||||
// -----------
|
||||
|
||||
// The TOSA specification uses a signless type system, which means that
|
||||
// information about signedness must be encapsulated by the operations
|
||||
// themselves. For example, tosa.rescale provides the attributes
|
||||
// `input_unsigned` and `output_unsigned` to indicate whether the input/output
|
||||
// should be interpreted as unsigned or signed.
|
||||
|
||||
// The TOSA dialect, on the other hand, allows the use of signed or unsigned
|
||||
// types in addition to signless. As such, when converting from TOSA dialect to
|
||||
// other formats, we need to ensure that we conform to the TOSA specification.
|
||||
|
||||
// ---------
|
||||
// Overview:
|
||||
// ---------
|
||||
|
||||
// This pass converts signed or unsigned integer types to signless. It currently
|
||||
// does this greedily for all operators and can also change the signature of the
|
||||
// function. Should the signature of the entrypoint function change, it will be
|
||||
// the responsibility of the user to carry signedness information of the inputs
|
||||
// and outputs independently.
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tosa {
|
||||
|
||||
#define GEN_PASS_DEF_TOSACONVERTINTEGERTYPETOSIGNLESS
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
class ToSignlessTensorTypeConverter : public TypeConverter {
|
||||
static Type convertType(Type type) {
|
||||
const auto tensorType = dyn_cast<TensorType>(type);
|
||||
if (!tensorType)
|
||||
return type;
|
||||
|
||||
const auto intType = dyn_cast<IntegerType>(tensorType.getElementType());
|
||||
if (!intType ||
|
||||
intType.getSignedness() == IntegerType::SignednessSemantics::Signless)
|
||||
return type;
|
||||
|
||||
const auto signlessType = IntegerType::get(
|
||||
intType.getContext(), intType.getWidth(), IntegerType::Signless);
|
||||
return tensorType.cloneWith(std::nullopt, signlessType);
|
||||
}
|
||||
|
||||
public:
|
||||
explicit ToSignlessTensorTypeConverter() { addConversion(convertType); }
|
||||
};
|
||||
|
||||
class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
|
||||
public:
|
||||
ConvertGenericOpWithIntegerTensorType(TypeConverter &typeConverter,
|
||||
MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// Typically TOSA operators have a single result, but some have an
|
||||
// arbitrary number. 4 seems like a good balance as an optimization
|
||||
// hint for storing result types.
|
||||
constexpr unsigned int numResults = 4;
|
||||
|
||||
// Convert integer types to signless
|
||||
SmallVector<Type, numResults> resultTypes;
|
||||
if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
|
||||
return failure();
|
||||
|
||||
// Create new op with replaced operands and results
|
||||
auto *newOp = Operation::create(
|
||||
op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
|
||||
op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
|
||||
|
||||
// Handle regions in e.g. tosa.cond_if and tosa.while_loop
|
||||
for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
|
||||
Region &before = std::get<0>(regions);
|
||||
Region &parent = std::get<1>(regions);
|
||||
rewriter.inlineRegionBefore(before, parent, parent.end());
|
||||
if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Replace with rewritten op
|
||||
rewriter.insert(newOp);
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class TosaConvertIntegerTypeToSignless
|
||||
: public impl::TosaConvertIntegerTypeToSignlessBase<
|
||||
TosaConvertIntegerTypeToSignless> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
ToSignlessTensorTypeConverter typeConverter;
|
||||
|
||||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
|
||||
typeConverter.isLegal(&op.getBody());
|
||||
});
|
||||
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
|
||||
return typeConverter.isLegal(op->getOperandTypes()) &&
|
||||
typeConverter.isLegal(op->getResultTypes());
|
||||
});
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
|
||||
patterns, typeConverter);
|
||||
patterns.add<ConvertGenericOpWithIntegerTensorType>(typeConverter, context);
|
||||
|
||||
if (failed(
|
||||
applyFullConversion(getOperation(), target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
@@ -1320,13 +1320,14 @@ void TosaValidation::runOnOperation() {
|
||||
|
||||
// validate operator element types:
|
||||
// - rescale operator is allowed to have ui8/ui16/ui32
|
||||
// operands/results
|
||||
// operands/results when strictOpSpecAlignment is false
|
||||
// - perform valid element type check at the beginning to
|
||||
// protect rest of code against quantized element types
|
||||
const bool opIsRescale = isa<tosa::RescaleOp>(op);
|
||||
const bool allowUnsigned =
|
||||
!strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
|
||||
for (Value operand : op->getOperands()) {
|
||||
auto elementTy = getElementTypeOrSelf(operand);
|
||||
if (!isValidElementType(elementTy, opIsRescale)) {
|
||||
if (!isValidElementType(elementTy, allowUnsigned)) {
|
||||
op->emitOpError() << "is not profile-aligned: element type "
|
||||
<< elementTy << " is not legal";
|
||||
return signalPassFailure();
|
||||
@@ -1334,7 +1335,7 @@ void TosaValidation::runOnOperation() {
|
||||
}
|
||||
for (Type resultTy : op->getResultTypes()) {
|
||||
auto elementTy = getElementTypeOrSelf(resultTy);
|
||||
if (!isValidElementType(elementTy, opIsRescale)) {
|
||||
if (!isValidElementType(elementTy, allowUnsigned)) {
|
||||
op->emitOpError() << "is not profile-aligned: element type "
|
||||
<< elementTy << " is not legal";
|
||||
return signalPassFailure();
|
||||
|
||||
@@ -2000,6 +2000,7 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8
|
||||
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
// expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
|
||||
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
|
||||
return %r : tensor<1x1xi8>
|
||||
}
|
||||
@@ -2012,6 +2013,7 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
|
||||
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
// expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
|
||||
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
|
||||
return %r : tensor<1x1xui8>
|
||||
}
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
// RUN: mlir-opt --split-input-file --tosa-convert-integer-type-to-signless %s | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_rescale_output_unsigned
|
||||
// CHECK: %arg0: tensor<1x1xi8>
|
||||
func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
|
||||
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
// CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
|
||||
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
|
||||
// CHECK: return %[[RESCALE]] : tensor<1x1xi8>
|
||||
return %r : tensor<1x1xui8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_rescale_input_unsigned
|
||||
// CHECK: %arg0: tensor<1x1xi16>
|
||||
func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui16>) -> (tensor<1x1xi8>) {
|
||||
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%3 = "tosa.const"() <{values = dense<32768> : tensor<1xi16>}> : () -> tensor<1xi16>
|
||||
// CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
|
||||
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
|
||||
// CHECK: return %[[RESCALE]] : tensor<1x1xi8>
|
||||
return %r : tensor<1x1xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_unsigned_function_signature
|
||||
// CHECK: %arg0: tensor<1xi8>, %arg1: tensor<1xi8>
|
||||
func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<1xui8>) -> (tensor<1xui8>, tensor<1xui8>) {
|
||||
// CHECK: return %arg0, %arg1 : tensor<1xi8>, tensor<1xi8>
|
||||
return %arg0, %arg1 : tensor<1xui8>, tensor<1xui8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_no_change
|
||||
// CHECK: %arg0: tensor<13x21x3xi8>
|
||||
func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
|
||||
%0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
|
||||
// CHECK: return %0 : tensor<13x21x3xi8>
|
||||
return %0 : tensor<13x21x3xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_regions
|
||||
// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
|
||||
func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
|
||||
// CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
|
||||
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
|
||||
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
|
||||
// CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
|
||||
%1 = tosa.add %arg0, %arg1 : (tensor<ui8>, tensor<ui8>) -> tensor<ui8>
|
||||
// CHECK: tosa.yield %1 : tensor<i8>
|
||||
tosa.yield %1 : tensor<ui8>
|
||||
}, {
|
||||
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
|
||||
// CHECK: %1 = tosa.sub %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
|
||||
%1 = tosa.sub %arg0, %arg1 : (tensor<ui8>, tensor<ui8>) -> tensor<ui8>
|
||||
// CHECK: tosa.yield %1 : tensor<i8>
|
||||
tosa.yield %1 : tensor<ui8>
|
||||
}) : (tensor<i1>, tensor<ui8>, tensor<ui8>) -> tensor<ui8>
|
||||
// CHECK: return %0 : tensor<i8>
|
||||
return %0 : tensor<ui8>
|
||||
}
|
||||
31
mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
Normal file
31
mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
Normal file
@@ -0,0 +1,31 @@
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
// Test valid IR in terms of the shape and type of tensor, and the argument type of
|
||||
// operation. Excludes the profile compilance checking since it is performed earlier in the
|
||||
// validation flow.
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_rescale_input_unsigned
|
||||
func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
|
||||
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
|
||||
return %r : tensor<1x1xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_rescale_output_unsigned
|
||||
func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
|
||||
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
|
||||
return %r : tensor<1x1xui8>
|
||||
}
|
||||
Reference in New Issue
Block a user