[MLIR][TOSA] add additional verification to TOSA (#108133)
---------- Motivation: ---------- Spec conformance. Allows assumptions to be made in TOSA code. ------------ Changes Made: ------------ Add full permutation tensor verification to tosa.TRANSPOSE. Priorly would not verify that permuted values were between 0 - (rank - 1). Update tosa.TRANSPOSE perms data type to be strictly i32. Verify input/output shapes for tosa.TRANSPOSE. Add verifier to tosa.CONST, with consideration for quantization. Fix TOSA conformance of tensor type to disallow dimensions with size 0 for ranked tensors, per spec. This is not the same as rank 0 tensors. Here is an example of a disallowed tensor: tensor<3x0xi32>. Naturally, this means that the number of elements in a TOSA tensor will always be greater than 0. Signed-off-by: Arteen Abrishami <arteen.abrishami@arm.com>
This commit is contained in:
@@ -3,8 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
|
||||
add_mlir_interface(TosaInterfaces)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TosaOps.td)
|
||||
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs)
|
||||
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
|
||||
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
|
||||
add_public_tablegen_target(MLIRTosaAttributesIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
|
||||
|
||||
@@ -73,7 +73,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
|
||||
|
||||
let arguments = (ins
|
||||
Tosa_Tensor4D:$input,
|
||||
|
||||
Tosa_IntArrayAttr2:$kernel,
|
||||
Tosa_IntArrayAttr2:$stride,
|
||||
Tosa_IntArrayAttr4:$pad,
|
||||
@@ -102,9 +101,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
|
||||
|
||||
let arguments = (ins
|
||||
Tosa_Tensor4D:$input,
|
||||
4DTensorOf<[Tosa_Weight]>:$weight,
|
||||
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
|
||||
Tosa_Tensor1D:$bias,
|
||||
|
||||
Tosa_IntArrayAttr4:$pad,
|
||||
Tosa_IntArrayAttr2:$stride,
|
||||
Tosa_IntArrayAttr2:$dilation,
|
||||
@@ -132,9 +130,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
|
||||
|
||||
let arguments = (ins
|
||||
Tosa_Tensor5D:$input,
|
||||
TensorRankOf<[Tosa_Weight], [5]>:$weight,
|
||||
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
|
||||
Tosa_Tensor1D:$bias,
|
||||
|
||||
Tosa_IntArrayAttr6:$pad,
|
||||
Tosa_IntArrayAttr3:$stride,
|
||||
Tosa_IntArrayAttr3:$dilation,
|
||||
@@ -163,9 +160,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
|
||||
|
||||
let arguments = (ins
|
||||
Tosa_Tensor4D:$input,
|
||||
4DTensorOf<[Tosa_Weight]>:$weight,
|
||||
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
|
||||
Tosa_Tensor1D:$bias,
|
||||
|
||||
Tosa_IntArrayAttr4:$pad,
|
||||
Tosa_IntArrayAttr2:$stride,
|
||||
Tosa_IntArrayAttr2:$dilation,
|
||||
@@ -232,7 +228,7 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
|
||||
|
||||
let arguments = (ins
|
||||
Tosa_Tensor2D:$input,
|
||||
2DTensorOf<[Tosa_Weight]>:$weight,
|
||||
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
|
||||
Tosa_Tensor1D:$bias,
|
||||
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
|
||||
);
|
||||
@@ -347,9 +343,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
|
||||
|
||||
let arguments = (ins
|
||||
Tosa_Tensor4D:$input,
|
||||
4DTensorOf<[Tosa_Weight]>:$filter,
|
||||
TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
|
||||
Tosa_Tensor1D:$bias,
|
||||
|
||||
Tosa_IntArrayAttr4:$out_pad,
|
||||
Tosa_IntArrayAttr2:$stride,
|
||||
Tosa_IntArrayAttrUpto4:$out_shape,
|
||||
@@ -641,12 +636,12 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I1Tensor:$input1,
|
||||
I1Tensor:$input2
|
||||
Tosa_I1Tensor:$input1,
|
||||
Tosa_I1Tensor:$input2
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I1Tensor:$z
|
||||
Tosa_I1Tensor:$z
|
||||
);
|
||||
}
|
||||
|
||||
@@ -708,12 +703,12 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I1Tensor:$input1,
|
||||
I1Tensor:$input2
|
||||
Tosa_I1Tensor:$input1,
|
||||
Tosa_I1Tensor:$input2
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I1Tensor:$z
|
||||
Tosa_I1Tensor:$z
|
||||
);
|
||||
}
|
||||
|
||||
@@ -731,12 +726,12 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I1Tensor:$input1,
|
||||
I1Tensor:$input2
|
||||
Tosa_I1Tensor:$input1,
|
||||
Tosa_I1Tensor:$input2
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I1Tensor:$z
|
||||
Tosa_I1Tensor:$z
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1085,11 +1080,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I1Tensor:$input1
|
||||
Tosa_I1Tensor:$input1
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I1Tensor:$output
|
||||
Tosa_I1Tensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1208,7 +1203,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I1Tensor:$pred,
|
||||
Tosa_I1Tensor:$pred,
|
||||
Tosa_Tensor:$on_true,
|
||||
Tosa_Tensor:$on_false
|
||||
);
|
||||
@@ -1249,7 +1244,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I1Tensor:$output
|
||||
Tosa_I1Tensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
@@ -1277,7 +1272,7 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I1Tensor:$output
|
||||
Tosa_I1Tensor:$output
|
||||
);
|
||||
|
||||
let hasFolder = 1;
|
||||
@@ -1300,7 +1295,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I1Tensor:$output
|
||||
Tosa_I1Tensor:$output
|
||||
);
|
||||
|
||||
let hasFolder = 1;
|
||||
@@ -1721,7 +1716,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
|
||||
|
||||
let arguments = (ins
|
||||
Tosa_Tensor:$input1,
|
||||
Tosa_Int32Or64Tensor:$perms
|
||||
Tosa_Int32Tensor:$perms
|
||||
);
|
||||
|
||||
let results = (
|
||||
@@ -1729,7 +1724,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
|
||||
LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
@@ -1755,7 +1750,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
|
||||
|
||||
let arguments = (ins
|
||||
Tosa_Tensor3D:$values,
|
||||
2DTensorOf<[Tosa_Int32]>:$indices
|
||||
TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -1776,7 +1771,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
|
||||
|
||||
let arguments = (ins
|
||||
Tosa_Tensor3D:$values_in,
|
||||
2DTensorOf<[Tosa_Int32]>:$indices,
|
||||
TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
|
||||
Tosa_Tensor3D:$input
|
||||
);
|
||||
|
||||
@@ -1947,10 +1942,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
|
||||
TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
|
||||
);
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -2054,7 +2050,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I1Tensor:$cond,
|
||||
Tosa_I1Tensor:$cond,
|
||||
Variadic<Tosa_Tensor>:$inputs
|
||||
);
|
||||
|
||||
|
||||
@@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
|
||||
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
|
||||
Tosa_QuantizedInt, AnyFloat]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TOSA Tensor Conformance
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HasNo0Dimensions : And<[
|
||||
IsRankedTensorTypePred,
|
||||
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
|
||||
|
||||
class TosaTensorOf<
|
||||
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
|
||||
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
|
||||
|
||||
class TosaRankedTensorOf<
|
||||
list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant ranked tensor">
|
||||
: RankedTensorOf<allowedTypes, !listconcat([HasNo0Dimensions], preds), summary>;
|
||||
|
||||
class TosaUnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant unranked tensor">
|
||||
: UnrankedTensorOf<allowedTypes, preds, summary>;
|
||||
|
||||
class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
|
||||
: TosaRankedTensorOf<allowedTypes,
|
||||
[HasAnyRankOfPred<ranks>],
|
||||
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
|
||||
def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
|
||||
def Tosa_I1Tensor : TosaTensorOf<[I1]>;
|
||||
def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
|
||||
def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
|
||||
|
||||
def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
|
||||
def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
|
||||
|
||||
// Either ranked or unranked tensor of TOSA supported element types.
|
||||
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
|
||||
def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
|
||||
|
||||
// Must be ranked but no further constraints
|
||||
def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
|
||||
def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
|
||||
|
||||
// Any tensor element type allowed in Tosa ops.
|
||||
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
|
||||
AnyFloat.predicate]>, "tosa.dtype">;
|
||||
|
||||
class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
|
||||
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
|
||||
AnyTypeOf<[TosaTensorOf<allowedTypes>, NoneType], description>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor types with constrained ranks.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Rank-0 (scalar) tensor
|
||||
def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
|
||||
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
|
||||
|
||||
// We include unranked tensors as a supported type for all possible tosa
|
||||
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
|
||||
// they should be shape propagate used Tosa's shape inference pass and verified
|
||||
// to not include any remaining unranked tensors.
|
||||
def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
|
||||
def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
|
||||
|
||||
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
|
||||
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
|
||||
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
|
||||
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
|
||||
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
|
||||
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
|
||||
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
|
||||
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
|
||||
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
|
||||
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
|
||||
|
||||
// Ranked tensors up to given rank.
|
||||
def Tosa_Tensor1Dto4D : AnyTypeOf<[
|
||||
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
|
||||
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
|
||||
def Tosa_Tensor1Dto6D : AnyTypeOf<[
|
||||
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
|
||||
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
|
||||
|
||||
def Tosa_TensorUpto4D : AnyTypeOf<[
|
||||
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
|
||||
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
|
||||
|
||||
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
|
||||
Tosa_UnrankedTensor, TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
|
||||
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Generic scalar, vector, or tensor of a particular type.
|
||||
@@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
|
||||
class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
|
||||
AnyTypeOf<types>.predicate,
|
||||
VectorOf<types>.predicate,
|
||||
TensorOf<types>.predicate]>,
|
||||
TosaTensorOf<types>.predicate]>,
|
||||
description>;
|
||||
|
||||
def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
|
||||
|
||||
@@ -216,6 +216,19 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
|
||||
return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
|
||||
}
|
||||
|
||||
// Apply an int32_t permutation to some input, that should be of the same
|
||||
// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
|
||||
template <typename T>
|
||||
SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
|
||||
ArrayRef<int32_t> perms) {
|
||||
SmallVector<T> permuted;
|
||||
size_t N = input.size();
|
||||
permuted.resize_for_overwrite(N);
|
||||
for (size_t i = 0; i < N; i++)
|
||||
permuted[i] = input[perms[i]];
|
||||
return permuted;
|
||||
}
|
||||
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -313,7 +313,7 @@ public:
|
||||
// convolution operation.
|
||||
// TODO(suderman): See if this can be efficiently folded - check whether
|
||||
// the input is used anywhere else, if not fold the constant.
|
||||
SmallVector<int64_t> weightPerm;
|
||||
SmallVector<int32_t> weightPerm;
|
||||
for (int i = 1; i < resultTy.getRank(); i++)
|
||||
weightPerm.push_back(i);
|
||||
weightPerm.push_back(0);
|
||||
@@ -321,7 +321,7 @@ public:
|
||||
SmallVector<int64_t> newWeightShape;
|
||||
for (auto dim : weightPerm)
|
||||
newWeightShape.push_back(weightShape[dim]);
|
||||
auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
|
||||
auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
|
||||
Value weightPermValue =
|
||||
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
|
||||
Type newWeightTy =
|
||||
@@ -337,7 +337,7 @@ public:
|
||||
if (5 == inputTy.getRank()) {
|
||||
// TODO(suderman): See if this can be efficiently folded - check whether
|
||||
// the input is used anywhere else, if not fold the constant.
|
||||
SmallVector<int64_t> weightPerm;
|
||||
SmallVector<int32_t> weightPerm;
|
||||
for (int i = 1; i < resultTy.getRank(); i++)
|
||||
weightPerm.push_back(i);
|
||||
weightPerm.push_back(0);
|
||||
@@ -345,7 +345,7 @@ public:
|
||||
SmallVector<int64_t> newWeightShape;
|
||||
for (auto dim : weightPerm)
|
||||
newWeightShape.push_back(weightShape[dim]);
|
||||
auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
|
||||
auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
|
||||
Value weightPermValue =
|
||||
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
|
||||
Type newWeightTy =
|
||||
@@ -1040,22 +1040,25 @@ public:
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
SmallVector<int64_t> constantPerms;
|
||||
SmallVector<int32_t> constantPerms;
|
||||
if (failed(op.getConstantPerms(constantPerms)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
// The verifier should have made sure we have a valid permutation tensor.
|
||||
assert(isPermutationVector(constantPerms) && "Expected valid permutation");
|
||||
// The verifier should have made sure we have a valid TOSA permutation
|
||||
// tensor. isPermutationVector doesn't actually check the TOSA perms we
|
||||
// expect.
|
||||
SmallVector<OpFoldResult> inputSizes =
|
||||
tensor::getMixedSizes(rewriter, loc, op.getInput1());
|
||||
auto permutedSizes =
|
||||
applyPermutation<OpFoldResult>(inputSizes, constantPerms);
|
||||
applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
|
||||
|
||||
auto permutedInit = rewriter.create<tensor::EmptyOp>(
|
||||
loc, permutedSizes, op.getInput1().getType().getElementType());
|
||||
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
|
||||
op, op.getInput1(), permutedInit, constantPerms);
|
||||
op, op.getInput1(), permutedInit,
|
||||
llvm::to_vector(llvm::map_range(
|
||||
constantPerms, [](int32_t v) -> int64_t { return v; })));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -88,7 +88,7 @@ struct ConsolidateTransposeOptimization
|
||||
return rewriter.notifyMatchFailure(transposeOp,
|
||||
"input must be transpose operation");
|
||||
|
||||
SmallVector<int64_t> transposePerms, innerTransposePerms;
|
||||
SmallVector<int32_t> transposePerms, innerTransposePerms;
|
||||
if (transposeOp.getConstantPerms(transposePerms).failed())
|
||||
return rewriter.notifyMatchFailure(transposeOp,
|
||||
"transpose perms must be constant");
|
||||
@@ -497,8 +497,10 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
|
||||
auto resultETy = resultTy.getElementType();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
auto lhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
|
||||
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
||||
return getInput1();
|
||||
@@ -536,8 +538,10 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
|
||||
|
||||
// IntDivOp inputs must be integer type, no need to check for quantized type
|
||||
auto resultETy = resultTy.getElementType();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
auto lhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
if (lhsAttr && lhsAttr.isSplat()) {
|
||||
if (llvm::isa<IntegerType>(resultETy) &&
|
||||
lhsAttr.getSplatValue<APInt>().isZero())
|
||||
@@ -605,10 +609,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
|
||||
auto resultETy = resultTy.getElementType();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
auto lhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
|
||||
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
|
||||
|
||||
if (rhsTy == resultTy) {
|
||||
if (isSplatZero(resultETy, lhsAttr))
|
||||
return lhsAttr.resizeSplat(resultTy);
|
||||
@@ -638,8 +645,10 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
|
||||
auto resultETy = resultTy.getElementType();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
auto lhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
|
||||
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
||||
return getInput1();
|
||||
@@ -681,8 +690,10 @@ struct APIntFoldGreaterEqual {
|
||||
|
||||
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
auto lhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
|
||||
if (!lhsAttr || !rhsAttr)
|
||||
return {};
|
||||
@@ -693,8 +704,10 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
|
||||
|
||||
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
auto lhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
|
||||
if (!lhsAttr || !rhsAttr)
|
||||
return {};
|
||||
@@ -706,8 +719,10 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
|
||||
|
||||
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
auto lhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
Value lhs = getInput1();
|
||||
Value rhs = getInput2();
|
||||
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
|
||||
@@ -838,14 +853,16 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
|
||||
// reshape(const(x)) -> const(reshape-attr(x))
|
||||
if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
|
||||
if (auto operand =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
|
||||
// Constants must have static shape.
|
||||
if (!outputTy.hasStaticShape())
|
||||
return {};
|
||||
|
||||
// Okay to duplicate splat constants.
|
||||
if (operand.isSplat())
|
||||
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
||||
return SplatElementsAttr::get(outputTy,
|
||||
operand.getSplatValue<Attribute>());
|
||||
|
||||
// Don't duplicate other constants.
|
||||
if (!getInput1().hasOneUse())
|
||||
@@ -905,7 +922,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
|
||||
auto operand = getInput();
|
||||
auto operandTy = llvm::cast<ShapedType>(operand.getType());
|
||||
auto axis = getAxis();
|
||||
auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
|
||||
auto operandAttr =
|
||||
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
|
||||
if (operandAttr)
|
||||
return operandAttr;
|
||||
|
||||
@@ -954,7 +972,8 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
|
||||
if (getOnTrue() == getOnFalse())
|
||||
return getOnTrue();
|
||||
|
||||
auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
|
||||
auto predicate =
|
||||
llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
|
||||
if (!predicate)
|
||||
return {};
|
||||
|
||||
@@ -975,7 +994,8 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultTy = llvm::cast<ShapedType>(getType());
|
||||
|
||||
// Transposing splat values just means reshaping.
|
||||
if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
|
||||
if (auto input =
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
|
||||
if (input.isSplat() && resultTy.hasStaticShape() &&
|
||||
input.getType().getElementType() == resultTy.getElementType())
|
||||
return input.reshape(resultTy);
|
||||
@@ -986,11 +1006,11 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
|
||||
// Transpose is not the identity transpose.
|
||||
SmallVector<int64_t> perms;
|
||||
SmallVector<int32_t> perms;
|
||||
if (getConstantPerms(perms).failed())
|
||||
return {};
|
||||
|
||||
if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
|
||||
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
|
||||
return {};
|
||||
|
||||
return getInput1();
|
||||
|
||||
@@ -204,22 +204,6 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
|
||||
// TOSA Operator Verifiers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static bool hasZeroDimension(ShapedType shapedType) {
|
||||
if (!shapedType.hasRank())
|
||||
return false;
|
||||
|
||||
auto rank = shapedType.getRank();
|
||||
|
||||
for (int i = 0; i < rank; i++) {
|
||||
if (shapedType.isDynamicDim(i))
|
||||
continue;
|
||||
if (shapedType.getDimSize(i) == 0)
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static LogicalResult verifyConvOp(T op) {
|
||||
// All TOSA conv ops have an input() and weight().
|
||||
@@ -236,10 +220,6 @@ static LogicalResult verifyConvOp(T op) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (hasZeroDimension(inputType))
|
||||
return op.emitOpError() << "tensor has a dimension with size zero. Each "
|
||||
"dimension of a tensor must have size >= 1";
|
||||
|
||||
auto inputEType = inputType.getElementType();
|
||||
auto weightEType = weightType.getElementType();
|
||||
|
||||
@@ -262,6 +242,29 @@ static LogicalResult verifyConvOp(T op) {
|
||||
"allowed for float type");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult tosa::ConstOp::verify() {
|
||||
|
||||
auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
|
||||
auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
|
||||
|
||||
if (!attrType || !outputType) {
|
||||
emitOpError("expected tensors for attr/result type");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
|
||||
outputType.getElementType())) {
|
||||
if (result.getStorageType() == attrType.getElementType())
|
||||
return success();
|
||||
}
|
||||
|
||||
if (attrType.getElementType() != outputType.getElementType()) {
|
||||
emitOpError("expected same attr/result element types");
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -283,9 +286,6 @@ LogicalResult tosa::ArgMaxOp::verify() {
|
||||
|
||||
LogicalResult tosa::AvgPool2dOp::verify() {
|
||||
auto inputType = llvm::cast<ShapedType>(getInput().getType());
|
||||
if (hasZeroDimension(inputType))
|
||||
return emitOpError() << "tensor has a dimension with size zero. Each "
|
||||
"dimension of a tensor must have size >= 1";
|
||||
|
||||
auto inputETy = inputType.getElementType();
|
||||
auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
|
||||
@@ -341,9 +341,9 @@ LogicalResult tosa::ClampOp::verify() {
|
||||
if (inputETy != outputETy)
|
||||
return emitOpError("input/output element types are incompatible.");
|
||||
|
||||
// if input datatype is float, check that the two min/max_fp attributes share
|
||||
// the same type and that their type is either the same of the input's
|
||||
// datatype, or a float type whose bitwidth > input datatype bitwidth
|
||||
// If input datatype is float, check that the two min/max_fp attributes
|
||||
// share the same type and that their type is either the same of the input's
|
||||
// datatype, or a float type whose bitwidth > input datatype bitwidth.
|
||||
if (!inputETy.isInteger(dataTypeBitWidth)) {
|
||||
if (((maxFpType != minFpType) ||
|
||||
(maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
|
||||
@@ -383,7 +383,8 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
|
||||
/// Handles tosa.transpose_conv2d which has outpad and output shape
|
||||
/// attributes.
|
||||
static void buildTransConvOpWithQuantInfo(
|
||||
OpBuilder &builder, OperationState &result, Type outputType, Value input,
|
||||
Value weight, Value bias, DenseI64ArrayAttr outpad,
|
||||
@@ -420,9 +421,9 @@ static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
||||
}
|
||||
}
|
||||
|
||||
/// The tosa.matmul op is also intended to be generated where a fully_connected
|
||||
/// op must be constructed where the weight is not a constant. In this case,
|
||||
/// the fully_connected op must be expressed using matmul.
|
||||
/// The tosa.matmul op is also intended to be generated where a
|
||||
/// fully_connected op must be constructed where the weight is not a constant.
|
||||
/// In this case, the fully_connected op must be expressed using matmul.
|
||||
/// TODO: Add link to the leglization document explaining this.
|
||||
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
|
||||
OperationState &result, Type outputType,
|
||||
@@ -457,9 +458,9 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
|
||||
}
|
||||
}
|
||||
|
||||
/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
|
||||
/// but avg_pool operator has its own builder as it has additional parameters
|
||||
/// not part of the unary ops.
|
||||
/// Both the tosa.avg_pool2d and unary ops use the same
|
||||
/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
|
||||
/// has additional parameters not part of the unary ops.
|
||||
static void
|
||||
buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
||||
Type outputType, Value input,
|
||||
@@ -526,8 +527,8 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
|
||||
for (int i = 0, e = operands.size(); i != e; ++i) {
|
||||
auto shape = operands.getShape(i);
|
||||
if (!shape.hasRank()) {
|
||||
// TODO(jennik): Update function to have better case handling for invalid
|
||||
// operands and for ranked tensors.
|
||||
// TODO(jennik): Update function to have better case handling for
|
||||
// invalid operands and for ranked tensors.
|
||||
return failure();
|
||||
}
|
||||
outRank = std::max<int64_t>(outRank, shape.getRank());
|
||||
@@ -776,8 +777,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
|
||||
return success();
|
||||
}
|
||||
|
||||
// If the input rank is unknown we can info the output rank using the padding
|
||||
// shape's first dim.
|
||||
// If the input rank is unknown we can info the output rank using the
|
||||
// padding shape's first dim.
|
||||
if (!inputShape.hasRank()) {
|
||||
if (paddingShape.isDynamicDim(0)) {
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents());
|
||||
@@ -1000,10 +1001,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
|
||||
TensorType inputType = getInput1().getType();
|
||||
RankedTensorType outputType = getType();
|
||||
|
||||
if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
|
||||
return emitOpError() << "tensor has a dimension with size zero. Each "
|
||||
"dimension of a tensor must have size >= 1";
|
||||
|
||||
if ((int64_t)getNewShape().size() != outputType.getRank())
|
||||
return emitOpError() << "new shape does not match result rank";
|
||||
|
||||
@@ -1034,16 +1031,15 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
|
||||
LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
|
||||
// Perms must be constants.
|
||||
DenseIntElementsAttr permsAttr;
|
||||
if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
|
||||
return failure();
|
||||
|
||||
// Transpose is not the identity transpose.
|
||||
perms = llvm::to_vector(
|
||||
llvm::map_range(permsAttr.getValues<APInt>(),
|
||||
[](const APInt &val) { return val.getSExtValue(); }));
|
||||
perms.clear();
|
||||
for (auto v : permsAttr.getValues<APInt>())
|
||||
perms.push_back(v.getSExtValue());
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -1067,8 +1063,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
|
||||
return success();
|
||||
}
|
||||
|
||||
// This would imply the number of permutations does not match the rank of the
|
||||
// input which is illegal.
|
||||
// This would imply the number of permutations does not match the rank of
|
||||
// the input which is illegal.
|
||||
if (permsShape.getDimSize(0) != inputShape.getRank()) {
|
||||
return failure();
|
||||
}
|
||||
@@ -1154,19 +1150,38 @@ LogicalResult tosa::TransposeOp::verify() {
|
||||
<< " (output rank) but got size "
|
||||
<< permType.getDimSize(0);
|
||||
|
||||
SmallVector<int64_t> constantPerms;
|
||||
SmallVector<int32_t> constantPerms;
|
||||
if (succeeded(getConstantPerms(constantPerms))) {
|
||||
// Assert that the permutation tensor has a rank, which means that the rank
|
||||
// has been verified above.
|
||||
// Assert that the permutation tensor has a rank, which means that the
|
||||
// rank has been verified above.
|
||||
assert(permType.hasRank() &&
|
||||
"Unexpectedly found permutation tensor without rank");
|
||||
if (!isPermutationVector(constantPerms))
|
||||
if (!llvm::all_of(constantPerms,
|
||||
[&constantPerms](int32_t s) {
|
||||
return s >= 0 &&
|
||||
static_cast<size_t>(s) < constantPerms.size();
|
||||
}) ||
|
||||
!isPermutationVector(llvm::to_vector(llvm::map_range(
|
||||
constantPerms, [](int32_t v) -> int64_t { return v; }))))
|
||||
return emitOpError() << "expected valid permutation tensor";
|
||||
|
||||
if (inputType.hasRank() && !llvm::all_of(constantPerms, [&](int64_t s) {
|
||||
return s < inputType.getRank();
|
||||
})) {
|
||||
return emitOpError() << "permutation must be within input bounds";
|
||||
// Verify that the types of the input and output tensors are properly
|
||||
// permuted.
|
||||
if (inputType.hasRank() && outputType.hasRank()) {
|
||||
assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
|
||||
inputType.getRank() == outputType.getRank());
|
||||
|
||||
for (auto i = 0; i < outputType.getRank(); i++) {
|
||||
if (inputType.isDynamicDim(constantPerms[i]) ||
|
||||
outputType.isDynamicDim(i))
|
||||
continue;
|
||||
|
||||
if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
|
||||
return emitOpError()
|
||||
<< "expected output tensor dim " << i << " to match "
|
||||
<< "input dim " << constantPerms[i] << " with value of "
|
||||
<< inputType.getDimSize(constantPerms[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
@@ -1175,7 +1190,7 @@ LogicalResult tosa::TransposeOp::verify() {
|
||||
LogicalResult TransposeOp::reifyResultShapes(
|
||||
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
||||
|
||||
SmallVector<int64_t> transposePerms;
|
||||
SmallVector<int32_t> transposePerms;
|
||||
if (getConstantPerms(transposePerms).failed())
|
||||
return failure();
|
||||
|
||||
@@ -1184,7 +1199,7 @@ LogicalResult TransposeOp::reifyResultShapes(
|
||||
|
||||
SmallVector<OpFoldResult> returnedDims(inputType.getRank());
|
||||
for (auto dim : transposePerms) {
|
||||
int64_t dimInInput = transposePerms[dim];
|
||||
int32_t dimInInput = transposePerms[dim];
|
||||
if (inputType.isDynamicDim(dimInInput))
|
||||
returnedDims[dim] =
|
||||
builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
|
||||
@@ -1378,8 +1393,8 @@ static LogicalResult verifyReduceOp(T op) {
|
||||
<< ")";
|
||||
return failure();
|
||||
}
|
||||
// We can only verify the reduced dimension size to be 1 if this is not the
|
||||
// special case of output rank == 0.
|
||||
// We can only verify the reduced dimension size to be 1 if this is not
|
||||
// the special case of output rank == 0.
|
||||
if (outputRank != 0) {
|
||||
auto outputShape = outputType.getShape();
|
||||
if (!outputType.isDynamicDim(reduceAxis) &&
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
|
||||
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
|
||||
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
|
||||
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
|
||||
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
|
||||
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
|
||||
|
||||
// CHECK-LABEL: @matmul
|
||||
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
|
||||
@@ -521,7 +521,7 @@ func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tens
|
||||
|
||||
// CHECK-LABEL: @conv2d_i8
|
||||
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
|
||||
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
|
||||
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
|
||||
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
|
||||
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
|
||||
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
|
||||
@@ -542,7 +542,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
|
||||
|
||||
// CHECK-LABEL: @conv2d_f32
|
||||
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
|
||||
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
|
||||
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
|
||||
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
|
||||
|
||||
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
|
||||
|
||||
@@ -24,7 +24,7 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
|
||||
|
||||
// check that tosa verify kick in
|
||||
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
|
||||
// expected-error@+1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
|
||||
// expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
|
||||
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
|
||||
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
|
||||
return %0 : tensor<1x7x7x9xf32>
|
||||
|
||||
@@ -80,14 +80,14 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
|
||||
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
|
||||
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
|
||||
]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
|
||||
%perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: %[[CST:.+]] = "tosa.const"() <{
|
||||
// CHECK-SAME{LITERAL}: value = dense<[
|
||||
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
|
||||
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
|
||||
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
|
||||
// CHECK-SAME{LITERAL}: ]>
|
||||
%1 = tosa.transpose %input, %perms : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
|
||||
%1 = tosa.transpose %input, %perms : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<3x1x4x2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : tensor<3x1x4x2xi32>
|
||||
}
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate=strict-op-spec-alignment
|
||||
|
||||
|
||||
func.func @test_const() -> tensor<1xf32> {
|
||||
// expected-error@+1{{'tosa.const' op expected same attr/result element types}}
|
||||
%0 = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_const_non_tensor_attr() {
|
||||
// expected-error@+1{{tosa.const' op expected tensors for attr/result type}}
|
||||
%0 = "tosa.const"() {value = dense<1.0> : vector<f32>} : () -> tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
|
||||
// expected-error@+1 {{expect both input and weight to be float or not together, got 'f32' and 'i8'}}
|
||||
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
|
||||
@@ -148,6 +164,42 @@ func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
|
||||
%perms = "tosa.const"() {value = dense<[-1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// expected-error@+1 {{'tosa.transpose' op expected valid permutation tensor}}
|
||||
%1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
|
||||
%perms = "tosa.const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// expected-error@+1 {{'tosa.transpose' op expected valid permutation tensor}}
|
||||
%1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// expected-error@+1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
|
||||
%1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
|
||||
return %1 : tensor<3x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// expected-error@+1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
|
||||
%1 = tosa.transpose %arg0, %perms : (tensor<2x?xi32>, tensor<2xi32>) -> tensor<3x4xi32>
|
||||
return %1 : tensor<3x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
|
||||
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||
%1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
|
||||
@@ -269,7 +321,7 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
|
||||
// -----
|
||||
|
||||
func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
|
||||
// expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
|
||||
// expected-error@+1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<13x0x3xf32>'}}
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32>
|
||||
return
|
||||
}
|
||||
@@ -277,7 +329,7 @@ func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> ()
|
||||
// -----
|
||||
|
||||
func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
|
||||
// expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
|
||||
// expected-error@+1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<?x0x3xf32>) -> tensor<13x0x3xf32>
|
||||
return
|
||||
}
|
||||
@@ -341,7 +393,7 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
|
||||
// -----
|
||||
|
||||
func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
|
||||
// expected-error@+1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
|
||||
// expected-error@+1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}}
|
||||
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
|
||||
: (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
|
||||
return %0 : tensor<1x27x27x16xf32>
|
||||
@@ -350,8 +402,8 @@ func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1:
|
||||
// -----
|
||||
|
||||
func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
|
||||
// expected-error@+1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
|
||||
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
|
||||
// expected-error@+1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x?x0x4xf32>'}}
|
||||
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
|
||||
: (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
|
||||
return %0 : tensor<1x27x27x16xf32>
|
||||
}
|
||||
@@ -360,7 +412,7 @@ func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<
|
||||
// -----
|
||||
|
||||
func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> {
|
||||
// expected-error@+1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
|
||||
// expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x7x9xf32>'}}
|
||||
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
|
||||
: (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32>
|
||||
return %0 : tensor<1x7x7x9xf32>
|
||||
@@ -369,7 +421,7 @@ func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) ->
|
||||
// -----
|
||||
|
||||
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
|
||||
// expected-error@+1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
|
||||
// expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
|
||||
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
|
||||
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
|
||||
return %0 : tensor<1x7x7x9xf32>
|
||||
@@ -469,7 +521,7 @@ func.func @test_tile_io_rank_mismatch() {
|
||||
|
||||
// CHECK-LABEL: @test_invalid_constant_permutation
|
||||
func.func @test_invalid_constant_permutation() {
|
||||
// expected-error@+3 {{permutation must be within input bounds}}
|
||||
// expected-error@+3 {{'tosa.transpose' op expected valid permutation tensor}}
|
||||
%0 = tensor.empty() : tensor<3x4x5xi32>
|
||||
%1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
|
||||
%2 = tosa.transpose %0, %1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
|
||||
@@ -480,7 +532,7 @@ func.func @test_invalid_constant_permutation() {
|
||||
|
||||
// CHECK-LABEL: test_rank_size_constant_permutation
|
||||
func.func @test_rank_size_constant_permutation() {
|
||||
// expected-error@+4 {{permutation must be within input bounds}}
|
||||
// expected-error@+4 {{'tosa.transpose' op expected valid permutation tensor}}
|
||||
%0 = arith.constant 6 : index
|
||||
%1 = arith.constant dense<[0, 2]> : tensor<2xi32>
|
||||
%2 = tensor.empty(%0) : tensor<?x27xi64>
|
||||
@@ -492,7 +544,7 @@ func.func @test_rank_size_constant_permutation() {
|
||||
|
||||
// CHECK-LABEL: test_large_constant_permutation
|
||||
func.func @test_large_constant_permutation() {
|
||||
// expected-error@+4 {{permutation must be within input bounds}}
|
||||
// expected-error@+4 {{'tosa.transpose' op expected valid permutation tensor}}
|
||||
%0 = arith.constant 6 : index
|
||||
%1 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
|
||||
%2 = tensor.empty(%0) : tensor<?x27xi64>
|
||||
@@ -504,7 +556,7 @@ func.func @test_large_constant_permutation() {
|
||||
|
||||
// CHECK-LABEL: test_table_rank0_table
|
||||
func.func @test_table_rank0_table(%arg0: tensor<64xi16>, %arg1: tensor<i16>) {
|
||||
// expected-error@+1 {{'tosa.table' op operand #1 must be 1-d tensor, but got 'tensor<i16>'}}
|
||||
// expected-error@+1 {{'tosa.table' op operand #1 must be 1-d tosa-conformant tensor, but got 'tensor<i16>'}}
|
||||
%0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<i16>) -> tensor<64xi16>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -573,6 +573,22 @@ func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
|
||||
return %1 : tensor<3x13x21xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: transpose_dynamic_dim
|
||||
func.func @test_transpose_dynamic_dim(%arg0: tensor<13x?x3xf32>) -> tensor<3x13x?xf32> {
|
||||
%0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%1 = tosa.transpose %arg0, %0 : (tensor<13x?x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
|
||||
return %1 : tensor<3x13x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: transpose_half_dynamic_dim
|
||||
func.func @test_transpose_half_dynamic_dim(%arg0: tensor<13x3x3xf32>) -> tensor<3x13x?xf32> {
|
||||
%0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%1 = tosa.transpose %arg0, %0 : (tensor<13x3x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
|
||||
return %1 : tensor<3x13x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: gather
|
||||
func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf32> {
|
||||
|
||||
Reference in New Issue
Block a user