[mlir][tosa] Change Transpose perms operand to attribute (#128115)

This patch changes the perms operand for Tosa Transpose operator to an
i32 array attribute

Signed-off-by: Tai Ly <tai.ly@arm.com>
This commit is contained in:
Tai Ly
2025-02-25 12:00:26 -06:00
committed by GitHub
parent d2d469eb79
commit 48db4e8377
20 changed files with 245 additions and 504 deletions

View File

@@ -2023,7 +2023,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Int32Tensor:$perms
DenseI32ArrayAttr:$perms
);
let results = (
@@ -2035,10 +2035,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
let extraClassDeclaration = [{
LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;

View File

@@ -329,13 +329,11 @@ public:
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);
weightPermAttr);
}
}
@@ -353,13 +351,11 @@ public:
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);
weightPermAttr);
}
// Extract the attributes for convolution.
@@ -1003,9 +999,7 @@ public:
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
SmallVector<int32_t> constantPerms;
if (failed(op.getConstantPerms(constantPerms)))
return failure();
const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
Location loc = op.getLoc();
// The verifier should have made sure we have a valid TOSA permutation

View File

@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
return rewriter.notifyMatchFailure(transposeOp,
"input must be transpose operation");
SmallVector<int32_t> transposePerms, innerTransposePerms;
if (transposeOp.getConstantPerms(transposePerms).failed())
return rewriter.notifyMatchFailure(transposeOp,
"transpose perms must be constant");
if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
return rewriter.notifyMatchFailure(
transposeOp, "inner transpose perms must be constant");
const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
const llvm::ArrayRef<int32_t> innerTransposePerms =
innerTranspose.getPerms();
if (transposePerms.size() != innerTransposePerms.size())
return rewriter.notifyMatchFailure(
transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
for (int i = 0, s = transposePerms.size(); i < s; ++i)
perms[i] = innerTransposePerms[transposePerms[i]];
auto permsTy =
RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
permsTy, permsAttr);
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
innerTranspose.getInput1(), permsValue);
innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
return success();
}
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
DenseIntElementsAttr permAttr;
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
return rewriter.notifyMatchFailure(op, "Non-constant permutation");
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
return rewriter.notifyMatchFailure(
op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
if (numDynDims > 1)
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
SmallVector<int64_t> permValues = llvm::to_vector<6>(
llvm::map_range(permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
const llvm::ArrayRef<int32_t> permValues = op.getPerms();
SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
@@ -1176,9 +1161,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
}
// Transpose is not the identity transpose.
SmallVector<int32_t> perms;
if (getConstantPerms(perms).failed())
return {};
const llvm::ArrayRef<int32_t> perms = getPerms();
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};

View File

@@ -1372,54 +1372,37 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
// Perms must be constants.
DenseIntElementsAttr permsAttr;
if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
return failure();
perms.clear();
for (auto v : permsAttr.getValues<APInt>())
perms.push_back(v.getSExtValue());
return success();
}
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
ShapeAdaptor permsShape(adaptor.getPerms().getType());
// We cannot infer anything from a rank-0 "permutation" tensor.
if (permsShape.hasRank() && permsShape.getRank() == 0)
return failure();
// If input rank and permutation length is unknown, the output rank is
// unknown.
if (!inputShape.hasRank() || !permsShape.hasRank() ||
permsShape.isDynamicDim(0)) {
if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}
const auto inputRank = inputShape.getRank();
// This would imply the number of permutations does not match the rank of
// the input which is illegal.
if (permsShape.getDimSize(0) != inputShape.getRank()) {
if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
return failure();
}
SmallVector<int64_t> outputShape;
// Rank-0 means no permutations matter.
if (inputShape.getRank() == 0) {
if (inputRank == 0) {
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
// Check whether the input dimensions are all the same.
bool allTheSame = true;
for (int i = 1, s = inputShape.getRank(); i < s; i++) {
for (int i = 1, s = inputRank; i < s; i++) {
if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
allTheSame = false;
break;
@@ -1429,34 +1412,21 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
// If all of the input dimensions are the same we don't care about the
// permutation.
if (allTheSame) {
outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
outputShape.resize(inputRank, inputShape.getDimSize(0));
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
// If the permuations are a constant we can directly determine the output
// shape.
DenseIntElementsAttr attr;
if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
attr.getType().getRank() == 1) {
ShapeAdaptor permShape = attr;
// Constant permutation must be the same length as the input rank.
if (inputShape.getRank() != permShape.getRank())
return emitOptionalError(location,
"constant permutation must be the same length"
" as the input rank");
outputShape.resize(inputRank, ShapedType::kDynamic);
// Constant permutation values must be within the input rank.
for (int i = 0, e = inputShape.getRank(); i < e; i++) {
if (inputShape.getRank() <= permShape.getDimSize(i))
return failure();
}
// Constant permutation values must be within the input rank.
if (llvm::any_of(adaptor.getPerms(),
[inputRank](const auto i) { return i >= inputRank; }))
return failure();
outputShape.reserve(inputShape.getRank());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
}
outputShape.reserve(inputRank);
for (int i = 0, s = inputRank; i < s; i++) {
outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
}
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1465,75 +1435,60 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
LogicalResult tosa::TransposeOp::verify() {
TensorType inputType = getInput1().getType();
TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
if (permType.hasRank() && permType.getRank() != 1)
return emitOpError()
<< "expected permutation tensor to be rank 1 but got rank "
<< permType.getRank();
if (inputType.hasRank() && permType.hasRank())
if (!permType.isDynamicDim(0) &&
permType.getDimSize(0) != inputType.getRank())
return emitOpError() << "expected permutation tensor dim 0 to have size "
<< inputType.getRank()
<< " (input rank) but got size "
<< permType.getDimSize(0);
if (inputType.hasRank() &&
constantPerms.size() != static_cast<size_t>(inputType.getRank()))
return emitOpError() << "expected perms attribute to have size "
<< inputType.getRank() << " (input rank) but got size "
<< constantPerms.size();
if (inputType.hasRank() && outputType.hasRank() &&
inputType.getRank() != outputType.getRank())
return emitOpError()
<< "expected input tensor rank to equal result tensor rank";
if (outputType.hasRank() && permType.hasRank())
if (!permType.isDynamicDim(0) &&
permType.getDimSize(0) != outputType.getRank())
return emitOpError() << "expected permutation tensor dim 0 to have size "
<< outputType.getRank()
<< " (output rank) but got size "
<< permType.getDimSize(0);
if (outputType.hasRank() &&
constantPerms.size() != static_cast<size_t>(outputType.getRank()))
return emitOpError() << "expected perms attribute to have size "
<< outputType.getRank()
<< " (output rank) but got size "
<< constantPerms.size();
SmallVector<int32_t> constantPerms;
if (succeeded(getConstantPerms(constantPerms))) {
// Assert that the permutation tensor has a rank, which means that the
// rank has been verified above.
assert(permType.hasRank() &&
"Unexpectedly found permutation tensor without rank");
if (!llvm::all_of(constantPerms,
[&constantPerms](int32_t s) {
return s >= 0 &&
static_cast<size_t>(s) < constantPerms.size();
}) ||
!isPermutationVector(llvm::to_vector(llvm::map_range(
constantPerms, [](int32_t v) -> int64_t { return v; }))))
return emitOpError() << "expected valid permutation tensor";
if (!llvm::all_of(constantPerms,
[&constantPerms](int32_t s) {
return s >= 0 &&
static_cast<size_t>(s) < constantPerms.size();
}) ||
!isPermutationVector(llvm::to_vector(llvm::map_range(
constantPerms, [](int32_t v) -> int64_t { return v; }))))
return emitOpError() << "expected valid permutation indices";
// Verify that the types of the input and output tensors are properly
// permuted.
if (inputType.hasRank() && outputType.hasRank()) {
assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
inputType.getRank() == outputType.getRank());
// Verify that the types of the input and output tensors are properly
// permuted.
if (inputType.hasRank() && outputType.hasRank()) {
assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
inputType.getRank() == outputType.getRank());
for (auto i = 0; i < outputType.getRank(); i++) {
if (inputType.isDynamicDim(constantPerms[i]) ||
outputType.isDynamicDim(i))
continue;
for (auto i = 0; i < outputType.getRank(); i++) {
if (inputType.isDynamicDim(constantPerms[i]) ||
outputType.isDynamicDim(i))
continue;
if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
return emitOpError()
<< "expected output tensor dim " << i << " to match "
<< "input dim " << constantPerms[i] << " with value of "
<< inputType.getDimSize(constantPerms[i]);
}
if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
return emitOpError()
<< "expected output tensor dim " << i << " to match "
<< "input dim " << constantPerms[i] << " with value of "
<< inputType.getDimSize(constantPerms[i]);
}
}
return success();
}
LogicalResult TransposeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
SmallVector<int32_t> transposePerms;
if (getConstantPerms(transposePerms).failed())
return failure();
const llvm::ArrayRef<int32_t> transposePerms = getPerms();
Value input = getInput1();
auto inputType = cast<TensorType>(input.getType());

View File

@@ -166,13 +166,9 @@ public:
getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
transposeWeightVal);
rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
// Collapse the strides and output channels into a single dimension.
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ public:
convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
transposeConvVal);
rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {

View File

@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
return failure();
DenseIntElementsAttr permAttr;
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
return failure();
auto permValues = llvm::map_to_vector(
// TOSA allows both 32- and 64-bit integer tensors here.
permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); });
op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
auto inputType = cast<ShapedType>(op.getInput1().getType());

View File

@@ -367,9 +367,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
SmallVector<int32_t> perms;
if (failed(transposeOp.getConstantPerms(perms)) ||
!areInvolutionTransposes(hoistedPerms, perms))
if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
return std::nullopt;
return transposeOp.getInput1();
}
@@ -506,14 +504,11 @@ bool TosaReduceTransposes::dependenciesAreValid(
// replaced.
Operation *user = use.getOwner();
if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
SmallVector<int32_t> otherPerms;
// Can later think about cases where transpose -> transpose
// or reshape -> transpose, where the transposes are not necessarily
// the same perms as the hoisted, if implementing a more general
// transform. These could be permitted.
if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
!llvm::equal(perms, otherPerms))
if (!llvm::equal(perms, otherTranspose.getPerms()))
return false;
} else if (userNotContainedInValidTransposeDependencies(
user, validTransposes, transposeInfo)) {
@@ -607,9 +602,8 @@ void TosaReduceTransposes::runOnOperation() {
!llvm::isa<RankedTensorType>(output.getType()))
return;
// No transformation when transpose permutation non-constant.
if (failed(transposeOp.getConstantPerms(perms)))
return;
llvm::for_each(transposeOp.getPerms(),
[&perms](const auto i) { perms.emplace_back(i); });
// We let --canonicalize deal with identity transpose.
if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))

View File

@@ -56,15 +56,6 @@ static LogicalResult checkConstantOperandPad(Operation *op) {
return success();
}
static LogicalResult checkConstantOperandTranspose(Operation *op) {
if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
DenseElementsAttr perms;
if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
return op->emitOpError("perms of transpose is not constant");
}
return success();
}
struct TosaLevel {
int32_t MAX_RANK = 0;
int32_t MAX_KERNEL = 0;
@@ -118,7 +109,6 @@ public:
private:
void populateConstantOperandChecks() {
constCheckers.emplace_back(checkConstantOperandPad);
constCheckers.emplace_back(checkConstantOperandTranspose);
}
bool levelCheckKernel(Operation *op, int32_t v,

View File

@@ -463,7 +463,6 @@ func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tens
// CHECK-LABEL: @conv2d_i8
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
@@ -485,7 +484,6 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
// CHECK-LABEL: @conv2d_f32
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
@@ -786,7 +784,6 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
// CHECK-LABEL: @conv3d_f32
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
@@ -824,7 +821,6 @@ func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: t
// CHECK-LABEL: @conv3d_i8
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
@@ -852,10 +848,9 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
// CHECK-LABEL: @test_transpose
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
%0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x1xi32>
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x2x3xi32>) outs(%[[INIT]] : tensor<2x3x1xi32>) permutation = [1, 2, 0]
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<2x3x1xi32>
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 2, 0>}: (tensor<1x2x3xi32>) -> tensor<2x3x1xi32>
return
}
@@ -864,12 +859,11 @@ func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
// CHECK-LABEL: @test_transpose_dyn
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>)
func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
%0 = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32>
// CHECK: %[[C1:.+]] = arith.constant 1
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor<?x4x1x3xi32>
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs(%[[INIT]] : tensor<?x4x1x3xi32>) permutation = [1, 3, 0, 2]
%1 = tosa.transpose %arg0, %0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x4x1x3xi32>
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 3, 0, 2>}: (tensor<1x?x3x4xi32>) -> tensor<?x4x1x3xi32>
return
}
@@ -878,14 +872,13 @@ func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
// CHECK-LABEL: @test_transpose_dyn_multiple_2d
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]])
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[INIT]] : tensor<?x?xf32>) permutation = [1, 0]
%1 = tosa.transpose %arg0, %0 : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<?x?xf32>) -> tensor<?x?xf32>
return
}
@@ -894,7 +887,6 @@ func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
// CHECK-LABEL: @test_transpose_dyn_multiple_3d
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?x?xf32>)
func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
%0 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -903,7 +895,7 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM2]], %[[DIM0]], %[[DIM1]]) : tensor<?x?x?xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?xf32>) permutation = [2, 0, 1]
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
%1 = "tosa.transpose"(%arg0) {perms = array<i32: 2, 0, 1>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return
}

View File

@@ -34,8 +34,7 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
// CHECK-NEXT: tensor.dim %[[arg]], %[[c2]]
// CHECK-NEXT: return
func.func @dynamic_dim_of_transpose_op(%arg0: tensor<1x2x?x8xi8>) -> index {
%0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x?x8xi8>, tensor<4xi32>) -> tensor<1x8x2x?xi8>
%1 = tosa.transpose %arg0 { perms = array<i32: 0, 3, 1, 2> }: (tensor<1x2x?x8xi8>) -> tensor<1x8x2x?xi8>
%c3 = arith.constant 3 : index
%dim = tensor.dim %1, %c3 : tensor<1x8x2x?xi8>
return %dim : index
@@ -47,8 +46,7 @@ func.func @dynamic_dim_of_transpose_op(%arg0: tensor<1x2x?x8xi8>) -> index {
// CHECK: arith.constant 100 : index
// CHECK: return
func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
%0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<1x100x?x8xi8>, tensor<4xi32>) -> tensor<1x8x100x?xi8>
%1 = tosa.transpose %arg0 { perms = array<i32: 0, 3, 1, 2> }: (tensor<1x100x?x8xi8>) -> tensor<1x8x100x?xi8>
%c2 = arith.constant 2 : index
%dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
return %dim : index

View File

@@ -553,10 +553,9 @@ func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
// -----
// CHECK-LABEL: transpose
func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
%0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
%1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>}: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
return %1 : tensor<3x13x21xf32>
}

View File

@@ -631,12 +631,11 @@ func.func @reshape_canonicalize_quant_nofold() -> (tensor<1x3x!quant.uniform<i8:
// CHECK-LABEL: @transpose_canonicalize_strip_quant
func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>) {
// CHECK-DAG: tosa.const_shape {value = dense<[2, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK-DAG: "tosa.const"() <{value = dense<0> : tensor<1x2x3xi8>}> : () -> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
// CHECK: tosa.reshape %0, %1 : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, !tosa.shape<3>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
%perms = "tosa.const"() {value = dense<[1, 0, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[2, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK-DAG: %[[CONST:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x2x3xi8>}> : () -> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
// CHECK: tosa.reshape %[[CONST]], %[[SHAPE]] : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, !tosa.shape<3>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
%0 = "tosa.const"() {value = dense<0> : tensor<1x2x3xi8>} : ()-> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
%1 = tosa.transpose %0, %perms : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, tensor<3xi32>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
%1 = tosa.transpose %0 { perms = array<i32: 1, 0, 2> }: (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
return %1 : tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
}
@@ -688,8 +687,7 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.transpose
%perms = "tosa.const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %arg0, %perms : (tensor<3x4x5x6xf32>, tensor<4xi32>) -> tensor<3x4x5x6xf32>
%1 = tosa.transpose %arg0 { perms = array<i32: 0, 1, 2, 3> }: (tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32>
return %1 : tensor<3x4x5x6xf32>
}
@@ -699,13 +697,22 @@ func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> {
// CHECK: %[[CONST0:.+]] = tosa.const_shape {value = dense<[1, 4, 1, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: tosa.reshape %arg0, %[[CONST0]]
%perms = "tosa.const"() <{value = dense<[3, 1, 0, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms : (tensor<1x4x5x1xf32>, tensor<4xi32>) -> tensor<1x4x1x5xf32>
%0 = tosa.transpose %arg0 { perms = array<i32: 3, 1, 0, 2> }: (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32>
return %0 : tensor<1x4x1x5xf32>
}
// -----
// CHECK-LABEL: @transpose_is_reshape_unknown_dim
func.func @transpose_is_reshape_unknown_dim(%arg0: tensor<1x4x?x1xf32>) -> tensor<1x4x1x?xf32> {
// CHECK: %[[CONST0:.+]] = tosa.const_shape {value = dense<[1, 4, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: tosa.reshape %arg0, %[[CONST0]]
%0 = tosa.transpose %arg0 { perms = array<i32: 3, 1, 0, 2> }: (tensor<1x4x?x1xf32>) -> tensor<1x4x1x?xf32>
return %0 : tensor<1x4x1x?xf32>
}
// -----
// CHECK-LABEL: @single_bit_reshape
// https://github.com/llvm/llvm-project/issues/55440
func.func @single_bit_reshape() -> tensor<1xi1> {

View File

@@ -20,34 +20,30 @@ func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor<?x1x3xf32>) ->
// CHECK-LABEL: @transpose_fold
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
%0 = arith.constant dense<[0, 1]> : tensor<2xi32>
%1 = tosa.transpose %arg0, %0 { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32>
%1 = tosa.transpose %arg0 { perms = array<i32: 0, 1> }: (tensor<3x4xf32>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
}
// CHECK-LABEL: @transpose_nofold
func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
// CHECK: tosa.transpose
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
%1 = tosa.transpose %arg0, %0 { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
%1 = tosa.transpose %arg0 { perms = array<i32: 1, 0> }: (tensor<3x3xf32>) -> tensor<3x3xf32>
return %1 : tensor<3x3xf32>
}
// CHECK-LABEL: @transpose_nofold_shape
func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
// CHECK: tosa.transpose
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
%1 = tosa.transpose %arg0, %0 { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
%1 = tosa.transpose %arg0 { perms = array<i32: 1, 0> }: (tensor<3x4xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @transpose_fold_splat
func.func @transpose_fold_splat() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"() <{
// CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
%1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
%1 = tosa.transpose %input { perms = array<i32: 1, 0> }: (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
@@ -55,10 +51,9 @@ func.func @transpose_fold_splat() -> tensor<3x2xf32> {
// CHECK-LABEL: @transpose_fold_2d_float
func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"() <{
// CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
%1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
%1 = tosa.transpose %input { perms = array<i32: 1, 0> }: (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
@@ -66,10 +61,9 @@ func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
// CHECK-LABEL: @transpose_fold_2d_bool
func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> {
%input = "tosa.const"() {value = dense<[[true, false, false], [false, false, true]]> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"() <{
// CHECK-SAME{LITERAL}: value = dense<[[true, false], [false, false], [false, true]]> : tensor<3x2xi1>
%1 = tosa.transpose %input, %perms : (tensor<2x3xi1>, tensor<2xi32>) -> tensor<3x2xi1>
%1 = tosa.transpose %input { perms = array<i32: 1, 0> }: (tensor<2x3xi1>) -> tensor<3x2xi1>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xi1>
}
@@ -80,50 +74,46 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
%perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: %[[CST:.+]] = "tosa.const"() <{
// CHECK-SAME{LITERAL}: value = dense<[
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
// CHECK-SAME{LITERAL}: ]>
%1 = tosa.transpose %input, %perms : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<3x1x4x2xi32>
%1 = tosa.transpose %input { perms = array<i32: 2, 0, 3, 1> }: (tensor<1x2x3x4xi32>) -> tensor<3x1x4x2xi32>
// CHECK: return %[[CST]]
return %1 : tensor<3x1x4x2xi32>
}
// CHECK-LABEL: @transpose_nofold_non_cst_input
func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tosa.transpose
%1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_nofold_non_cst_perms
func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
// CHECK: tosa.transpose
%1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
%1 = tosa.transpose %input { perms = array<i32: 1, 0> }: (tensor<2x3xf32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_nofold_multi_users
func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tosa.transpose
%1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
%1 = tosa.transpose %input { perms = array<i32: 1, 0> }: (tensor<2x3xf32>) -> tensor<3x2xf32>
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
}
// CHECK-LABEL: @transpose_nofold_quantized_types
func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
%input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
// CHECK: tosa.transpose
%0 = tosa.transpose %input { perms = array<i32: 1, 2, 3, 0> }: (tensor<2x1x1x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
}
// CHECK-LABEL: @transpose_nofold_dense_resource
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
%0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
%1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK: tosa.transpose
%2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32>
%2 = tosa.transpose %0 { perms = array<i32: 1, 0> }: (tensor<2x2xf32>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
}
{-#

View File

@@ -235,97 +235,74 @@ func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor
// -----
func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
// expected-error@+1 {{'tosa.transpose' op perms of transpose is not constant}}
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
return %0 : tensor<3x13x21xf32>
}
// -----
func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> {
// expected-error@+1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}}
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21x1xf32>
%0 = tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>}: (tensor<13x21x3xf32>) -> tensor<3x13x21x1xf32>
return %0 : tensor<3x13x21x1xf32>
}
// -----
func.func @test_transpose_invalid_perms_rank(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<3x13x21xf32> {
// expected-error@+1 {{'tosa.transpose' op expected permutation tensor to be rank 1 but got rank 2}}
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<3x13x21xf32>
return %0 : tensor<3x13x21xf32>
}
// -----
func.func @test_transpose_rank0_perms() {
%14 = tensor.empty() : tensor<5x27xi64>
%cst = tensor.empty() : tensor<i32>
// expected-error@+1 {{'tosa.transpose' op expected permutation tensor to be rank 1 but got rank 0}}
%72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64>
// expected-error@+1 {{'tosa.transpose' op expected perms attribute to have size 2 (input rank) but got size 0}}
%72 = tosa.transpose %14 {perms = array<i32> }: (tensor<5x27xi64>) -> tensor<?x?xi64>
return
}
// -----
func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7xi32>) -> tensor<3x13x21xf32> {
// expected-error@+1 {{'tosa.transpose' op expected permutation tensor dim 0 to have size 3 (input rank) but got size 7}}
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<7xi32>) -> tensor<3x13x21xf32>
func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
// expected-error@+1 {{'tosa.transpose' op expected perms attribute to have size 3 (input rank) but got size 7}}
%0 = tosa.transpose %arg0 {perms = array<i32: 6, 5, 4, 3, 2, 1, 0> }: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
return %0 : tensor<3x13x21xf32>
}
// -----
func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor<?x?x?xf32> {
%perms = arith.constant dense<[2, 0, 0]> : tensor<3xi32>
// expected-error@+1 {{'tosa.transpose' op expected valid permutation tensor}}
%0 = tosa.transpose %arg0, %perms : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
%0 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 0> }: (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// -----
func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
%perms = "tosa.const"() {value = dense<[-1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// expected-error@+1 {{'tosa.transpose' op expected valid permutation tensor}}
%1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
%1 = tosa.transpose %arg0 {perms = array<i32: -1, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
%perms = "tosa.const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// expected-error@+1 {{'tosa.transpose' op expected valid permutation tensor}}
%1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// expected-error@+1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
%1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x4xi32>
return %1 : tensor<3x4xi32>
}
// -----
func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// expected-error@+1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
%1 = tosa.transpose %arg0, %perms : (tensor<2x?xi32>, tensor<2xi32>) -> tensor<3x4xi32>
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<2x?xi32>) -> tensor<3x4xi32>
return %1 : tensor<3x4xi32>
}
// -----
func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
%1 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xf32>
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>} : (tensor<2x3xi32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
@@ -674,10 +651,9 @@ func.func @test_tile_io_rank_mismatch() {
// CHECK-LABEL: @test_invalid_constant_permutation
func.func @test_invalid_constant_permutation() {
// expected-error@+3 {{'tosa.transpose' op expected valid permutation tensor}}
%0 = tensor.empty() : tensor<3x4x5xi32>
%1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
%2 = tosa.transpose %0, %1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
%2 = tosa.transpose %0 {perms = array<i32: 3, 0, 1>}: (tensor<3x4x5xi32>) -> tensor<3x4x5xi32>
return
}
@@ -685,11 +661,10 @@ func.func @test_invalid_constant_permutation() {
// CHECK-LABEL: test_rank_size_constant_permutation
func.func @test_rank_size_constant_permutation() {
// expected-error@+4 {{'tosa.transpose' op expected valid permutation tensor}}
%0 = arith.constant 6 : index
%1 = arith.constant dense<[0, 2]> : tensor<2xi32>
%2 = tensor.empty(%0) : tensor<?x27xi64>
%3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
%3 = tosa.transpose %2 {perms = array<i32: 0, 2>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
return
}
@@ -697,11 +672,10 @@ func.func @test_rank_size_constant_permutation() {
// CHECK-LABEL: test_large_constant_permutation
func.func @test_large_constant_permutation() {
// expected-error@+4 {{'tosa.transpose' op expected valid permutation tensor}}
%0 = arith.constant 6 : index
%1 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
%2 = tensor.empty(%0) : tensor<?x27xi64>
%3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
%3 = tosa.transpose %2 {perms = array<i32: 1185677355, 332462212>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
return
}

View File

@@ -105,9 +105,8 @@ func.func @test_tile(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21
// -----
func.func @test_transpose(%arg0: tensor<13x21x3x1x1x1x1xf32>) -> tensor<3x13x21x1x1x1x1xf32> {
%0 = "tosa.const"() {value = dense<[2, 0, 1, 3, 4, 5, 6]> : tensor<7xi32>} : () -> tensor<7xi32>
// expected-error@+1 {{'tosa.transpose' op failed level check: operand rank(shape) <= MAX_RANK}}
%1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3x1x1x1x1xf32>, tensor<7xi32>) -> tensor<3x13x21x1x1x1x1xf32>
%1 = "tosa.transpose"(%arg0) {perms = array<i32: 2, 0, 1, 3, 4, 5, 6>} : (tensor<13x21x3x1x1x1x1xf32>) -> tensor<3x13x21x1x1x1x1xf32>
return %1 : tensor<3x13x21x1x1x1x1xf32>
}

View File

@@ -640,24 +640,21 @@ func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
// -----
// CHECK-LABEL: transpose
func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
%0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
return %1 : tensor<3x13x21xf32>
}
// -----
// CHECK-LABEL: transpose_dynamic_dim
func.func @test_transpose_dynamic_dim(%arg0: tensor<13x?x3xf32>) -> tensor<3x13x?xf32> {
%0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<13x?x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x?x3xf32>) -> tensor<3x13x?xf32>
return %1 : tensor<3x13x?xf32>
}
// -----
// CHECK-LABEL: transpose_half_dynamic_dim
func.func @test_transpose_half_dynamic_dim(%arg0: tensor<13x3x3xf32>) -> tensor<3x13x?xf32> {
%0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<13x3x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x3x3xf32>) -> tensor<3x13x?xf32>
return %1 : tensor<3x13x?xf32>
}

View File

@@ -54,11 +54,10 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1:
func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
// Manipulate the weight matrix to handle striding.
// CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]]
// CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[5, 2, 2, 2, 3, 3]> : tensor<6xindex>}
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]], %[[CONST1]]
// CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]]
// CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]] {perms = array<i32: 2, 4, 0, 1, 3, 5>}
// CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[30, 2, 2, 3]> : tensor<4xindex>}
// CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]], %[[CONST3]]
// CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[RESW2]] {axis = 1 : i32}
@@ -68,7 +67,6 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// Pad out the input matrix to handle the transpose conv.
// CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]]
// Manipulate the final shape.
@@ -76,7 +74,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
// CHECK-DAG: %[[CONST6:.+]] = tosa.const_shape {value = dense<[2, 18, 16, 2, 3, 5]> : tensor<6xindex>}
// CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]], %[[CONST6]]
// CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
// CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]] {perms = array<i32: 0, 1, 3, 2, 4, 5>}
// CHECK-DAG: %[[CONST8:.+]] = tosa.const_shape {value = dense<[2, 36, 48, 5]> : tensor<4xindex>
// CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]], %[[CONST8]]
// CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
@@ -95,11 +93,10 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
// Manipulate the weight matrix to handle striding.
// CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {input_zp = 42 : i32}
// CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[5, 2, 2, 2, 3, 3]> : tensor<6xindex>}
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]], %[[CONST1]]
// CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]]
// CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]] {perms = array<i32: 2, 4, 0, 1, 3, 5>}
// CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[30, 2, 2, 3]> : tensor<4xindex>}
// CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]], %[[CONST3]]
// CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[RESW2]] {axis = 1 : i32}
@@ -109,7 +106,6 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// Pad out the input matrix to handle the transpose conv.
// CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {input_zp = -22 : i32}
// Manipulate the final shape.
@@ -119,7 +115,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
// CHECK-DAG: %[[CONV_NEW_SHAPE:.*]] = tosa.const_shape {value = dense<[2, 18, 16, 2, 3, 5]> : tensor<6xindex>}
// CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]], %[[CONV_NEW_SHAPE]]
// CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
// CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]] {perms = array<i32: 0, 1, 3, 2, 4, 5>}
// CHECK-DAG: %[[TRANS_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[2, 36, 48, 5]> : tensor<4xindex>}
// CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]], %[[TRANS_NEW_SHAPE]]
// CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
@@ -138,12 +134,10 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
// CHECK-DAG: %[[WEIGHT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xindex>}
// CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[1, 2, 1, 1, 2, 1]> : tensor<6xindex>}
// CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[2, 2, 1, 1]> : tensor<4xindex>}
// CHECK-DAG: %[[INPUT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>}
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
// CHECK-DAG: %[[CONST6:.+]] = tosa.const_shape {value = dense<[1, 17, 1, 1, 2, 1]> : tensor<6xindex>}
// CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[CONST8:.+]] = tosa.const_shape {value = dense<[1, 17, 2, 1]> : tensor<4xindex>}
// CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>}
// CHECK-DAG: %[[CONST10:.+]] = tosa.const_shape {value = dense<1> : tensor<4xindex>}
@@ -151,13 +145,13 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
// CHECK-DAG: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{value = dense<93> : tensor<1xi8>}>
// CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {input_zp = 93 : i32}
// CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]], %[[CONST1]]
// CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
// CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]] {perms = array<i32: 2, 4, 0, 1, 3, 5>}
// CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]], %[[CONST3]]
// CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32}
// CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {input_zp = -103 : i32}
// CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
// CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]], %[[CONST6]]
// CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]]
// CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]] {perms = array<i32: 0, 1, 3, 2, 4, 5>}
// CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]], %[[CONST8]]
// CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]
// CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[CONST10]]

View File

@@ -577,29 +577,10 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
// -----
// CHECK-LABEL: @test_transpose_same
func.func @test_transpose_same(%arg0 : tensor<4x4x4xi32>, %arg1 : tensor<3xi32>) -> () {
// CHECK: tosa.transpose %arg0, %arg1 : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<4x4x4xi32>
%0 = tosa.transpose %arg0, %arg1 : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
return
}
// -----
// CHECK-LABEL: @test_transpose_perm_unknown
func.func @test_transpose_perm_unknown(%arg0 : tensor<4x4x5xi32>, %arg1 : tensor<3xi32>) -> () {
// CHECK: tosa.transpose %arg0, %arg1 : (tensor<4x4x5xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
%0 = tosa.transpose %arg0, %arg1 : (tensor<4x4x5xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
return
}
// -----
// CHECK-LABEL: @test_transpose_static
func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
%0 = arith.constant dense<[2, 1, 0]> : tensor<3xi32>
// CHECK: tosa.transpose %arg0, %cst : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<5x4x3xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
// CHECK: tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>
%1 = tosa.transpose %arg0 { perms = array<i32: 2, 1, 0> }: (tensor<3x4x5xi32>) -> tensor<?x?x?xi32>
return
}

View File

@@ -4,11 +4,9 @@
// CHECK-NEXT: %[[RESULT:.*]] = tosa.ceil %arg0
// CHECK-NEXT: return %[[RESULT]]
func.func @test_transpose_tracks_to_nullifying_single_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%ceil = tosa.ceil %0 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %ceil, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%1 = tosa.transpose %ceil {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %1 : tensor<1x2x3x4xi32>
}
@@ -20,13 +18,11 @@ func.func @test_transpose_tracks_to_nullifying_single_step(%arg0: tensor<1x2x3x4
// CHECK-NEXT: %[[NOT:.*]] = tosa.bitwise_not %[[ABS]]
// CHECK-NEXT: return %[[NOT]]
func.func @test_transpose_tracks_to_nullifying_multi_unary_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %0 {max_val = 1 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %bitwise_not, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%1 = tosa.transpose %bitwise_not {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %1 : tensor<1x2x3x4xi32>
}
@@ -38,14 +34,12 @@ func.func @test_transpose_tracks_to_nullifying_multi_unary_step(%arg0: tensor<1x
// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
// CHECK-NEXT: return %[[ADD]]
func.func @test_transpose_tracks_to_nullifying_diverging_binary(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%transpose1 = tosa.transpose %arg1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %transpose0 {max_val = 1 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%result = tosa.transpose %add {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
@@ -58,14 +52,12 @@ func.func @test_transpose_tracks_to_nullifying_diverging_binary(%arg0: tensor<1x
// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
// CHECK-NEXT: return %[[ADD]]
func.func @test_transpose_tracks_to_nullifying_diverging_binary_with_broadcasting(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x1x4xi32>) -> tensor<1x2x3x4xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x1x4xi32>, tensor<4xi32>) -> tensor<1x1x4x2xi32>
%transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%transpose1 = tosa.transpose %arg1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x1x4xi32>) -> tensor<1x1x4x2xi32>
%clamp = tosa.clamp %transpose0 {max_val = 1 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %transpose1 : (tensor<1x1x4x2xi32>) -> tensor<1x1x4x2xi32>
%add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x1x4x2xi32>) -> tensor<1x3x4x2xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%result = tosa.transpose %add {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
@@ -75,11 +67,9 @@ func.func @test_transpose_tracks_to_nullifying_diverging_binary_with_broadcastin
// CHECK-NEXT: %[[RESULT:.*]] = tosa.add %arg0, %arg0
// CHECK-NEXT: return %[[RESULT]]
func.func @test_transpose_tracks_to_nullifying__converging_binary(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%clamp = tosa.add %0, %0 : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %1 : tensor<1x2x3x4xi32>
}
@@ -102,20 +92,20 @@ func.func @test_torch_conv2d_with_elementwise_in_between(%arg0: tensor<3x3x10x10
%5 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>}> : () -> tensor<3xf32>
%6 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
%7 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
%8 = tosa.transpose %arg0, %6 : (tensor<3x3x10x10xf32>, tensor<4xi32>) -> tensor<3x10x10x3xf32>
%9 = tosa.transpose %4, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
%8 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x10x10xf32>) -> tensor<3x10x10x3xf32>
%9 = tosa.transpose %4 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x2x2xf32>) -> tensor<3x2x2x3xf32>
%10 = tosa.conv2d %8, %9, %5 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x10x10x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x9x9x3xf32>
%11 = tosa.transpose %10, %7 : (tensor<3x9x9x3xf32>, tensor<4xi32>) -> tensor<3x3x9x9xf32>
%11 = tosa.transpose %10 {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x9x9x3xf32>) -> tensor<3x3x9x9xf32>
%12 = tosa.ceil %11 : (tensor<3x3x9x9xf32>) -> tensor<3x3x9x9xf32>
%13 = tosa.transpose %12, %6 : (tensor<3x3x9x9xf32>, tensor<4xi32>) -> tensor<3x9x9x3xf32>
%14 = tosa.transpose %3, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
%13 = tosa.transpose %12 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x9x9xf32>) -> tensor<3x9x9x3xf32>
%14 = tosa.transpose %3 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x2x2xf32>) -> tensor<3x2x2x3xf32>
%15 = tosa.conv2d %13, %14, %2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x9x9x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x8x8x3xf32>
%16 = tosa.transpose %15, %7 : (tensor<3x8x8x3xf32>, tensor<4xi32>) -> tensor<3x3x8x8xf32>
%16 = tosa.transpose %15 {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x8x8x3xf32>) -> tensor<3x3x8x8xf32>
%17 = tosa.floor %16 : (tensor<3x3x8x8xf32>) -> tensor<3x3x8x8xf32>
%18 = tosa.transpose %17, %6 : (tensor<3x3x8x8xf32>, tensor<4xi32>) -> tensor<3x8x8x3xf32>
%19 = tosa.transpose %1, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
%18 = tosa.transpose %17 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x8x8xf32>) -> tensor<3x8x8x3xf32>
%19 = tosa.transpose %1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x2x2xf32>) -> tensor<3x2x2x3xf32>
%20 = tosa.conv2d %18, %19, %0 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x8x8x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x7x7x3xf32>
%21 = tosa.transpose %20, %7 : (tensor<3x7x7x3xf32>, tensor<4xi32>) -> tensor<3x3x7x7xf32>
%21 = tosa.transpose %20 {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x7x7x3xf32>) -> tensor<3x3x7x7xf32>
return %21 : tensor<3x3x7x7xf32>
}
@@ -126,13 +116,11 @@ func.func @test_torch_conv2d_with_elementwise_in_between(%arg0: tensor<3x3x10x10
// CHECK-NEXT: %[[RES:.*]] = tosa.mul %arg0, %arg1, %[[SHIFT]]
// CHECK-NEXT: return %[[RES]]
func.func @test_mulop_conversion(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%transpose1 = tosa.transpose %arg1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%mul = tosa.mul %transpose0, %transpose1, %shift : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>, tensor<1xi8>) -> tensor<1x3x4x2xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%result = tosa.transpose %mul, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%result = tosa.transpose %mul {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
@@ -141,15 +129,14 @@ func.func @test_mulop_conversion(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3
// COM: this case is a reshape we don't convert, since can't fold the transpose into it.
// COM: a transform actually occurs underneath the hood, but it results in identical IR.
// CHECK-LABEL: @test_basic_non_broadcasting_reshape
// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[1, 3, 2]> : tensor<3xindex>}
// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}>
// CHECK: %[[VAL_3:.*]] = tosa.reshape %arg0, %[[VAL_1]] : (tensor<2x3xi32>, !tosa.shape<3>) -> tensor<1x3x2xi32>
// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_3]], %[[VAL_2]] : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
// CHECK: %[[SHAPE:.+]] = tosa.const_shape {value = dense<[1, 3, 2]> : tensor<3xindex>}
// CHECK: %[[RESHAPED:.+]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<2x3xi32>, !tosa.shape<3>) -> tensor<1x3x2xi32>
// CHECK: tosa.transpose %[[RESHAPED]] {perms = array<i32: 0, 2, 1>} : (tensor<1x3x2xi32>) -> tensor<1x2x3xi32>
func.func @test_basic_non_broadcasting_reshape(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> {
%shape = tosa.const_shape {value = dense<[1, 3, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
%perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = tosa.reshape %arg0, %shape : (tensor<2x3xi32>, !tosa.shape<3>) -> tensor<1x3x2xi32>
%2 = tosa.transpose %1, %perms : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
%2 = tosa.transpose %1 {perms = array<i32: 0, 2, 1>}: (tensor<1x3x2xi32>) -> tensor<1x2x3xi32>
return %2 : tensor<1x2x3xi32>
}
@@ -163,7 +150,7 @@ func.func @test_dynamic_broadcasting_reshape(%arg0: tensor<?xi32>) -> tensor<1x1
%shape = tosa.const_shape {value = dense<[1, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = tosa.reshape %arg0, %shape : (tensor<?xi32>, !tosa.shape<3>) -> tensor<1x?x1xi32>
%2 = tosa.transpose %1, %perms : (tensor<1x?x1xi32>, tensor<3xi32>) -> tensor<1x1x?xi32>
%2 = tosa.transpose %1 {perms = array<i32: 0, 2, 1>}: (tensor<1x?x1xi32>) -> tensor<1x1x?xi32>
return %2 : tensor<1x1x?xi32>
}
@@ -179,10 +166,9 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x
%0 = "tosa.const"() {value = dense<[1,2,3,4]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.const_shape {value = dense<[1, 1, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
%reshape = tosa.reshape %0, %1 : (tensor<4xi32>, !tosa.shape<3>) -> tensor<1x1x4xi32>
%perms0 = "tosa.const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
%transpose0 = tosa.transpose %arg0, %perms0 : (tensor<4x3x2xi32>, tensor<3xi32>) -> tensor<2x3x4xi32>
%transpose0 = tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>}: (tensor<4x3x2xi32>) -> tensor<2x3x4xi32>
%add = tosa.add %transpose0, %reshape : (tensor<2x3x4xi32>, tensor<1x1x4xi32>) -> tensor<2x3x4xi32>
%transpose1 = tosa.transpose %add, %perms0 : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x3x2xi32>
%transpose1 = tosa.transpose %add {perms = array<i32: 2, 1, 0>}: (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
return %transpose1 : tensor<4x3x2xi32>
}
@@ -223,7 +209,7 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32
%64 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
%69 = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
%70 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
%75 = tosa.transpose %74, %64 : (tensor<1x112x112x64xf32>, tensor<4xi32>) -> tensor<1x64x112x112xf32>
%75 = tosa.transpose %74 {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x112x112x64xf32>) -> tensor<1x64x112x112xf32>
%76 = tosa.add %arg1, %69 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
%77 = tosa.pow %76, %70 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
%78 = tosa.reciprocal %77 : (tensor<64xf32>) -> tensor<64xf32>
@@ -236,54 +222,45 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32
%85 = tosa.reshape %59, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32>
%86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
%87 = tosa.clamp %86 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>
%88 = tosa.transpose %87, %63 : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
%88 = tosa.transpose %87 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x64x112x112xf32>) -> tensor<1x112x112x64xf32>
return %88 : tensor<1x112x112x64xf32>
}
// -----
// CHECK-LABEL: @test_back_to_back_nullifiers
// CHECK: %[[PERMS:.*]] = "tosa.const"
// CHECK: %[[RES:.*]] = tosa.transpose %arg0, %[[PERMS]]
// CHECK: %[[RES:.*]] = tosa.transpose %arg0 {perms = array<i32: 1, 0>}
// CHECK: return %[[RES]]
func.func @test_back_to_back_nullifiers(%arg0: tensor<2x3xi32>) -> tensor<3x2xi32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
%1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
%2 = tosa.transpose %1, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<3x2xi32>) -> tensor<2x3xi32>
%2 = tosa.transpose %1 {perms = array<i32: 1, 0>}: (tensor<2x3xi32>) -> tensor<3x2xi32>
return %2 : tensor<3x2xi32>
}
// -----
// CHECK-LABEL: @test_back_to_back_nullifiers_different_transposes
// CHECK: %[[PERMS:.*]] = "tosa.const"
// CHECK: %[[RES:.*]] = tosa.transpose %arg0, %[[PERMS]]
// CHECK: %[[RES:.*]] = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}
// CHECK: return %[[RES]]
func.func @test_back_to_back_nullifiers_different_transposes(%arg0: tensor<2x3x4x5xi32>) -> tensor<2x4x5x3xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<2x4x5x3xi32>
%1 = tosa.transpose %0, %perms1 : (tensor<2x4x5x3xi32>, tensor<4xi32>) -> tensor<2x3x4x5xi32>
%2 = tosa.transpose %1, %perms0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<2x4x5x3xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<2x3x4x5xi32>) -> tensor<2x4x5x3xi32>
%1 = tosa.transpose %0 {perms = array<i32: 0, 3, 1, 2>}: (tensor<2x4x5x3xi32>) -> tensor<2x3x4x5xi32>
%2 = tosa.transpose %1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<2x3x4x5xi32>) -> tensor<2x4x5x3xi32>
return %2 : tensor<2x4x5x3xi32>
}
// -----
// CHECK-LABEL: @test_no_transform_if_outside_fan_in_cone
// CHECK: tosa.const
// CHECK: %[[CLAMP_IN:.*]] = tosa.transpose
// CHECK: %[[RES2:.*]] = tosa.clamp %[[CLAMP_IN]]
// CHECK: tosa.const
// CHECK: %[[RES1:.*]] = tosa.transpose
// CHECK: return %[[RES1]], %[[RES2]]
func.func @test_no_transform_if_outside_fan_in_cone(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
return %1, %clamp : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
}
@@ -298,29 +275,25 @@ func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: t
%shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%1 = tosa.reshape %arg0, %shape : (tensor<64xf32>, !tosa.shape<3>) -> tensor<1x64x1xf32>
%2 = tosa.clamp %1 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32>
%3 = tosa.transpose %1, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
%4 = tosa.transpose %2, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
%3 = tosa.transpose %1 {perms = array<i32: 0, 2, 1>}: (tensor<1x64x1xf32>) -> tensor<1x1x64xf32>
%4 = tosa.transpose %2 {perms = array<i32: 0, 2, 1>}: (tensor<1x64x1xf32>) -> tensor<1x1x64xf32>
return %3, %4 : tensor<1x1x64xf32>, tensor<1x1x64xf32>
}
// -----
// CHECK-LABEL: @test_two_different_downstream_converge_to_reshape_different_perms
// CHECK-DAG: tosa.const
// CHECK-DAG: tosa.const
// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape
// CHECK-DAG: %[[CLAMP:.*]] = tosa.clamp %[[RESHAPE]]
// CHECK-DAG: %[[RET1:.*]] = tosa.transpose
// CHECK-DAG: %[[RET2:.*]] = tosa.transpose
// CHECK-DAG: return %[[RET1]], %[[RET2]]
func.func @test_two_different_downstream_converge_to_reshape_different_perms(%arg0: tensor<64xf32>) -> (tensor<1x1x64xf32>, tensor<64x1x1xf32>) {
%0 = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
%1 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
%shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%2 = tosa.reshape %arg0, %shape : (tensor<64xf32>, !tosa.shape<3>) -> tensor<1x64x1xf32>
%3 = tosa.clamp %2 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32>
%4 = tosa.transpose %2, %1 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
%5 = tosa.transpose %3, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<64x1x1xf32>
%4 = tosa.transpose %2 {perms = array<i32: 0, 2, 1>}: (tensor<1x64x1xf32>) -> tensor<1x1x64xf32>
%5 = tosa.transpose %3 {perms = array<i32: 1, 2, 0>}: (tensor<1x64x1xf32>) -> tensor<64x1x1xf32>
return %4, %5 : tensor<1x1x64xf32>, tensor<64x1x1xf32>
}
@@ -328,16 +301,15 @@ func.func @test_two_different_downstream_converge_to_reshape_different_perms(%ar
// COM: no transform
// CHECK-LABEL: @test_outside_perms_usage_of_fan_in
// CHECK: tosa.const
// CHECK: tosa.transpose
// CHECK: tosa.clamp
// CHECK: %[[RES1:.*]] = tosa.transpose
// CHECK: %[[RES2:.*]] = tosa.add
// CHECK: return %[[RES1]], %[[RES2]]
func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> (tensor<2x3xf32>, tensor<3x2xf32>) { %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> (tensor<2x3xf32>, tensor<3x2xf32>) {
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<2x3xf32>) -> tensor<3x2xf32>
%2 = tosa.clamp %1 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<3x2xf32>) -> tensor<3x2xf32>
%3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
%3 = tosa.transpose %2 {perms = array<i32: 1, 0>}: (tensor<3x2xf32>) -> tensor<2x3xf32>
%4 = tosa.add %arg1, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
return %3, %4: tensor<2x3xf32>, tensor<3x2xf32>
}
@@ -351,13 +323,12 @@ func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: ten
// CHECK-DAG: %[[NEW_ADD:.*]] = tosa.add %arg1, %[[NEW_CLAMP]]
// CHECK: return %[[NEW_CLAMP]], %[[NEW_ADD]]
func.func @test_use_present_in_another_valid_perms_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
%0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<2x3xf32>) -> tensor<3x2xf32>
%2 = tosa.clamp %1 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<3x2xf32>) -> tensor<3x2xf32>
%3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
%4 = tosa.transpose %arg1, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
%3 = tosa.transpose %2 {perms = array<i32: 1, 0>}: (tensor<3x2xf32>) -> tensor<2x3xf32>
%4 = tosa.transpose %arg1 {perms = array<i32: 1, 0>}: (tensor<2x3xf32>) -> tensor<3x2xf32>
%5 = tosa.add %4, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
%6 = tosa.transpose %5, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
%6 = tosa.transpose %5 {perms = array<i32: 1, 0>}: (tensor<3x2xf32>) -> tensor<2x3xf32>
return %3, %6: tensor<2x3xf32>, tensor<2x3xf32>
}
@@ -365,7 +336,6 @@ func.func @test_use_present_in_another_valid_perms_fan_in(%arg0: tensor<2x3xf32>
// COM: no transform, since we would get duplicates
// CHECK-LABEL: @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents
// CHECK: tosa.const
// CHECK: tosa.transpose
// CHECK: %[[CEIL:.*]] = tosa.ceil
// CHECK: %[[ADD:.*]] = tosa.add %[[CEIL]]
@@ -373,12 +343,11 @@ func.func @test_use_present_in_another_valid_perms_fan_in(%arg0: tensor<2x3xf32>
// CHECK: %[[RES2:.*]] = tosa.transpose %[[ADD]]
// CHECK: return %[[RES1]], %[[RES2]]
func.func @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents(%arg0: tensor<2x3xi32>, %arg1: tensor<3x2xi32>) -> (tensor<2x3xi32>, tensor<2x3xi32>) {
%0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<2x3xi32>) -> tensor<3x2xi32>
%2 = tosa.ceil %1 : (tensor<3x2xi32>) -> tensor<3x2xi32>
%3 = tosa.add %2, %arg1 : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
%4 = tosa.transpose %2, %0 : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
%5 = tosa.transpose %3, %0 : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
%4 = tosa.transpose %2 {perms = array<i32: 1, 0>}: (tensor<3x2xi32>) -> tensor<2x3xi32>
%5 = tosa.transpose %3 {perms = array<i32: 1, 0>}: (tensor<3x2xi32>) -> tensor<2x3xi32>
return %4, %5 : tensor<2x3xi32>, tensor<2x3xi32>
}
@@ -388,12 +357,10 @@ func.func @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents(%arg0: t
// CHECK-NEXT: %[[RES:.*]] = tosa.clamp %arg0
// CHECK-NEXT: return %[[RES]], %[[RES]]
func.func @test_direct_use_in_other_transpose_with_same_perms(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%2 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
}
@@ -405,8 +372,7 @@ func.func @test_direct_use_in_other_transpose_with_same_perms(%arg0: tensor<3x3x
// CHECK: return %[[NEW]]
func.func @test_const_transpose() -> tensor<2x3xi32> {
%0 = "tosa.const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
%1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<3x2xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
}
@@ -420,8 +386,7 @@ func.func @test_const_transpose() -> tensor<2x3xi32> {
func.func @test_transpose_tracks_to_const_single_step() -> tensor<1x2x3x4xi32> {
%0 = "tosa.const"() {value = dense<0> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %1 : tensor<1x2x3x4xi32>
}
@@ -434,13 +399,11 @@ func.func @test_transpose_tracks_to_const_single_step() -> tensor<1x2x3x4xi32> {
// CHECK: %[[NEW_NOT:.*]] = tosa.bitwise_not %[[NEW_ABS]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
// CHECK: return %[[NEW_NOT]]
func.func @test_static_unary_path_to_const() -> tensor<1x2x3x4xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = "tosa.const"() {value = dense<1> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %bitwise_not, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%1 = tosa.transpose %bitwise_not {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %1 : tensor<1x2x3x4xi32>
}
@@ -455,16 +418,14 @@ func.func @test_static_unary_path_to_const() -> tensor<1x2x3x4xi32> {
// CHECK: %[[NEW_ADD:.*]] = tosa.add %[[NEW_ABS]], %[[NEW_CLAMP]] : (tensor<1x2x3x4xi32>, tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
// CHECK: return %[[NEW_ADD]]
func.func @test_static_diverges_to_non_splat_const_and_nullifying(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%const = "tosa.const"() {value = dense<[[[[1, 2], [3, 4], [5, 6], [7, 8]],
[[9, 10], [11, 12], [13, 14], [15, 16]],
[[17, 18], [19, 20], [21, 22], [23, 24]]]]> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %transpose0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %const : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%add = tosa.add %abs, %clamp : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%result = tosa.transpose %add, %perms2 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%result = tosa.transpose %add {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
@@ -474,12 +435,10 @@ func.func @test_static_diverges_to_non_splat_const_and_nullifying(%arg0: tensor<
// CHECK-NEXT: %[[RES:.*]] = tosa.clamp %arg0
// CHECK-NEXT: return %[[RES]], %[[RES]]
func.func @test_multi_downstream_both_nullify(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%2 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
}
@@ -487,19 +446,15 @@ func.func @test_multi_downstream_both_nullify(%arg0: tensor<3x3x3x3xi32>) -> (te
// COM: we don't perform this transformation intentionally, since we would then get duplicates
// CHECK-LABEL: @test_multi_downstream_one_nullifies_upstream_other_does_not
// CHECK: tosa.const
// CHECK: tosa.transpose
// CHECK: tosa.clamp
// CHECK: tosa.const
// CHECK: tosa.transpose
// CHECK: tosa.transpose
func.func @test_multi_downstream_one_nullifies_upstream_other_does_not(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%2 = tosa.transpose %clamp, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
%1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%2 = tosa.transpose %clamp {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
}
@@ -508,9 +463,8 @@ func.func @test_multi_downstream_one_nullifies_upstream_other_does_not(%arg0: te
// CHECK-LABEL: @test_unknown_dim_inner_replacement_matches
// CHECK-NEXT: return %arg0
func.func @test_unknown_dim_inner_replacement_matches(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<?x3xi32>
%1 = tosa.transpose %0, %perms : (tensor<?x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<3x2xi32>) -> tensor<?x3xi32>
%1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<?x3xi32>) -> tensor<3x2xi32>
return %1 : tensor<3x2xi32>
}
@@ -520,9 +474,8 @@ func.func @test_unknown_dim_inner_replacement_matches(%arg0: tensor<3x2xi32>) ->
// CHECK-LABEL: @test_unknown_dim_outer_replacement_matches
// CHECK-NEXT: return %arg0
func.func @test_unknown_dim_outer_replacement_matches(%arg0: tensor<3x?xi32>) -> tensor<3x?xi32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<2x3xi32>
%1 = tosa.transpose %0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x?xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<3x?xi32>) -> tensor<2x3xi32>
%1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<2x3xi32>) -> tensor<3x?xi32>
return %1 : tensor<3x?xi32>
}
@@ -534,47 +487,28 @@ func.func @test_unknown_dim_outer_replacement_matches(%arg0: tensor<3x?xi32>) ->
// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
// CHECK-NEXT: return %[[ADD]]
func.func @test_transpose_tracks_to_nullifying_diverging_binary_unknown_dim_replacements_match(%arg0: tensor<1x?x3x4xi32>, %arg1: tensor<1x2x?x4xi32>) -> tensor<1x2x3x4xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x3x4x?xi32>
%transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x?x4xi32>, tensor<4xi32>) -> tensor<1x?x?x2xi32>
%transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x?x3x4xi32>) -> tensor<?x3x4x?xi32>
%transpose1 = tosa.transpose %arg1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x?x4xi32>) -> tensor<1x?x?x2xi32>
%clamp = tosa.clamp %transpose0 {min_val = 0 : i32, max_val = 1 : i32} : (tensor<?x3x4x?xi32>) -> tensor<?x3x4x?xi32>
%abs = tosa.abs %transpose1 : (tensor<1x?x?x2xi32>) -> tensor<1x?x?x2xi32>
%add = tosa.add %clamp, %abs : (tensor<?x3x4x?xi32>, tensor<1x?x?x2xi32>) -> tensor<1x3x4x2xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%result = tosa.transpose %add {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
// -----
// COM: we cannot do anything to the transpose in this case.
// CHECK-LABEL: @test_unimplemented_non_const_perms
// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_non_const_perms(%perms: tensor<2xi32>) -> tensor<?x?xi32> {
%0 = "tosa.const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
// -----
// COM: due to tracking back to a non-nullifying transpose, we can't get rid of the transposes entirely.
// COM: later editions of the pass may wish to fold these into a single transpose.
// CHECK-LABEL: @test_unimplemented_transpose_tracks_to_non_nullifying_transpose_single_step
// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.clamp
// CHECK-NEXT: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_transpose_tracks_to_non_nullifying_transpose_single_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x4x3xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x4x3x2xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 0, 3, 2, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x4x3x2xi32>
%clamp = tosa.clamp %0 {min_val = 0 : i32, max_val = 1 : i32} : (tensor<1x4x3x2xi32>) -> tensor<1x4x3x2xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %clamp, %perms1 : (tensor<1x4x3x2xi32>, tensor<4xi32>) -> tensor<1x2x4x3xi32>
%1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x4x3x2xi32>) -> tensor<1x2x4x3xi32>
return %1 : tensor<1x2x4x3xi32>
}
@@ -582,28 +516,24 @@ func.func @test_unimplemented_transpose_tracks_to_non_nullifying_transpose_singl
// COM: we don't deal with this case. resolution of shapes required.
// CHECK-LABEL: @test_unimplemented_unknown_dim_input_nullifying_pair
// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_unknown_dim_input_nullifying_pair(%arg0: tensor<3x?xi32>) -> tensor<3x2xi32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<2x3xi32>
%1 = tosa.transpose %0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<3x?xi32>) -> tensor<2x3xi32>
%1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<2x3xi32>) -> tensor<3x2xi32>
return %1 : tensor<3x2xi32>
}
// -----
// CHECK-LABEL: @test_unimplemented_unknown_dim_replacement_does_not_match
// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_unknown_dim_replacement_does_not_match(%arg0: tensor<3x?xi32>) -> tensor<?x?xi32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<?x3xi32>
%1 = tosa.transpose %0, %perms : (tensor<?x3xi32>, tensor<2xi32>) -> tensor<?x?xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<3x?xi32>) -> tensor<?x3xi32>
%1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<?x3xi32>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
@@ -611,53 +541,43 @@ func.func @test_unimplemented_unknown_dim_replacement_does_not_match(%arg0: tens
// COM: this would be able to be converted if --tosa-infer-shapes was run beforehand
// CHECK-LABEL: @test_unimplemented_unranked_tensors_present
// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_unranked_tensors_present(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
%perms = "tosa.const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
%1 = tosa.transpose %0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 0, 1>}: (tensor<3x2xi32>) -> tensor<*xi32>
%1 = tosa.transpose %0 {perms = array<i32: 0, 1>}: (tensor<*xi32>) -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
// CHECK-LABEL: @test_unimplemented_unranked_everything
// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_unranked_everything(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = tosa.transpose %arg0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
%1 = tosa.transpose %0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
%0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<*xi32>) -> tensor<*xi32>
%1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<*xi32>) -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
// CHECK-LABEL: @test_unimplemented_static_diverges_to_one_nullifying_one_non_nullifying
// CHECK: tosa.const
// CHECK-NEXT: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.clamp
// CHECK-NEXT: tosa.abs
// CHECK-NEXT: tosa.add
// CHECK-NEXT: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_static_diverges_to_one_nullifying_one_non_nullifying(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x4x3xi32>) -> tensor<1x2x3x4xi32> {
%perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%perms1 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%transpose1 = tosa.transpose %arg1, %perms1 : (tensor<1x2x4x3xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
%transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%transpose1 = tosa.transpose %arg1 {perms = array<i32: 0, 3, 2, 1>}: (tensor<1x2x4x3xi32>) -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %transpose0 {min_val = 0 : i32, max_val = 1 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%result = tosa.transpose %add, %perms2 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
%result = tosa.transpose %add {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}

View File

@@ -6,10 +6,8 @@
// CHECK: }
func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
%0 = "tosa.const"() {value = dense<[1, 2, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<2x3x1xi32>
%2 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%3 = tosa.transpose %1, %2 : (tensor<2x3x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
%1 = tosa.transpose %arg0 { perms = array<i32: 1, 2, 0> }: (tensor<1x2x3xi32>) -> tensor<2x3x1xi32>
%3 = tosa.transpose %1 { perms = array<i32: 2, 0, 1> }: (tensor<2x3x1xi32>) -> tensor<1x2x3xi32>
return %3 : tensor<1x2x3xi32>
}
@@ -21,8 +19,7 @@ func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<
// CHECK: }
func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
%0 = "tosa.const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
%1 = tosa.transpose %arg0 { perms = array<i32: 0, 1, 2> }: (tensor<1x2x3xi32>) -> tensor<1x2x3xi32>
return %1 : tensor<1x2x3xi32>
}
@@ -30,16 +27,13 @@ func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1
// CHECK-LABEL: func.func @test_do_not_cancel_different_transpose(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32> {
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[3, 2, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]] {perms = array<i32: 3, 2, 1, 0>} : (tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32>
// CHECK: return %[[VAL_2]] : tensor<5x4x3x2xi32>
// CHECK: }
func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) -> (tensor<5x4x3x2xi32>) {
%0 = "tosa.const"() {value = dense<[1, 2, 0, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<3x4x2x5xi32>
%2 = "tosa.const"() {value = dense<[3, 1, 0, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%3 = tosa.transpose %1, %2 : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
%1 = tosa.transpose %arg0 { perms = array<i32: 1, 2, 0, 3> }: (tensor<2x3x4x5xi32>) -> tensor<3x4x2x5xi32>
%3 = tosa.transpose %1 { perms = array<i32: 3, 1, 0, 2> }: (tensor<3x4x2x5xi32>) -> tensor<5x4x3x2xi32>
return %3 : tensor<5x4x3x2xi32>
}
@@ -47,15 +41,12 @@ func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) ->
// CHECK-LABEL: func.func @test_prefer_compose_transpose(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32> {
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[3, 2, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]] {perms = array<i32: 3, 2, 1, 0>} : (tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32>
// CHECK: return %[[VAL_2]] : tensor<4x3x2x1xi32>
// CHECK: }
func.func @test_prefer_compose_transpose(%arg0: tensor<1x2x3x4xi32>) -> (tensor<4x3x2x1xi32>) {
%0 = "tosa.const"() {value = dense<[1, 2, 0, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<2x3x1x4xi32>
%2 = "tosa.const"() {value = dense<[3, 1, 0, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%3 = tosa.transpose %1, %2 : (tensor<2x3x1x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
%1 = tosa.transpose %arg0 { perms = array<i32: 1, 2, 0, 3> }: (tensor<1x2x3x4xi32>) -> tensor<2x3x1x4xi32>
%3 = tosa.transpose %1 { perms = array<i32: 3, 1, 0, 2> }: (tensor<2x3x1x4xi32>) -> tensor<4x3x2x1xi32>
return %3 : tensor<4x3x2x1xi32>
}