[tosa] Change VariableOp to align with spec (#142240)
This fixes Tosa VariableOp to align with spec 1.0 - add var_shape attribute to store shape of variable type - change type attribute to store element type of variable type - add a builder so previous construction calls still work - fix up level check of rank to be on variable type instead of initial value which is optional - add level check of size for variable type - add lit tests for variable op's without initial values - add lit test for variable op with fixed rank but unknown dimension - add invalid lit test for variable op with unranked type Signed-off-by: Tai Ly <tai.ly@arm.com>
This commit is contained in:
@@ -197,6 +197,16 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
|
||||
input, paddings);
|
||||
}]>;
|
||||
|
||||
// This builder is called on the TOSA variable operator with a variable type
|
||||
// and optional initial value. The builder will extract var_shape and element type
|
||||
// attributes from variable type.
|
||||
def Tosa_VariableOpBuilder : OpBuilder<
|
||||
(ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
|
||||
[{
|
||||
buildVariableOp($_builder, $_state, name, variable_type, initial_value);
|
||||
}]>;
|
||||
|
||||
|
||||
// Wrapper over base I32EnumAttr to set common fields.
|
||||
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
|
||||
: I32EnumAttr<name, description, cases> {
|
||||
|
||||
@@ -44,10 +44,14 @@ class PatternRewriter;
|
||||
|
||||
namespace tosa {
|
||||
|
||||
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
|
||||
Attribute &attr);
|
||||
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
|
||||
Attribute attr);
|
||||
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser,
|
||||
DenseElementsAttr &varShapeAttr,
|
||||
TypeAttr &typeAttr,
|
||||
Attribute &initialValueAttr);
|
||||
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op,
|
||||
DenseElementsAttr varShapeAttr,
|
||||
TypeAttr typeAttr,
|
||||
Attribute initialValueAttr);
|
||||
|
||||
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
|
||||
|
||||
@@ -172,6 +176,9 @@ std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
|
||||
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
|
||||
int32_t val = 0);
|
||||
|
||||
// returns type of variable op
|
||||
RankedTensorType getVariableType(VariableOp variableOp);
|
||||
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -92,6 +92,7 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
|
||||
|
||||
let arguments = (ins
|
||||
SymbolNameAttr:$name,
|
||||
IndexElementsAttr:$var_shape,
|
||||
TypeAttr:$type,
|
||||
OptionalAttr<AnyAttr>:$initial_value
|
||||
);
|
||||
@@ -101,12 +102,16 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
|
||||
Extension<[Tosa_EXT_VARIABLE]>,
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$name
|
||||
attr-dict
|
||||
custom<TypeOrAttr>($type, $initial_value)
|
||||
custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
|
||||
}];
|
||||
|
||||
let builders = [Tosa_VariableOpBuilder];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -26,8 +26,9 @@ public:
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::VariableOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto variableType = tosa::getVariableType(op);
|
||||
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
|
||||
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
|
||||
op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
|
||||
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
|
||||
newVariable.setPrivate();
|
||||
rewriter.replaceOp(op, newVariable);
|
||||
|
||||
@@ -131,6 +131,24 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
|
||||
return {&getBodyGraph()};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TOSA variable operator support.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
|
||||
return to_vector(llvm::map_range(shape, [](int64_t dim) {
|
||||
return dim == -1 ? ShapedType::kDynamic : dim;
|
||||
}));
|
||||
}
|
||||
|
||||
// returns type of variable op
|
||||
RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
|
||||
Type elementType = variableOp.getType();
|
||||
DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
|
||||
auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
|
||||
return RankedTensorType::get(shape, elementType);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tosa dialect initialization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -177,42 +195,80 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
||||
// Parsers and printers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
|
||||
Attribute &attr) {
|
||||
namespace {
|
||||
|
||||
ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
|
||||
DenseElementsAttr &varShapeAttr,
|
||||
TypeAttr &typeAttr) {
|
||||
if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
|
||||
if (!shapedType.hasRank())
|
||||
return parser.emitError(parser.getCurrentLocation())
|
||||
<< "expected ranked type";
|
||||
|
||||
auto elementType = shapedType.getElementType();
|
||||
typeAttr = TypeAttr::get(elementType);
|
||||
ArrayRef<int64_t> shape = shapedType.getShape();
|
||||
Builder builder(parser.getContext());
|
||||
varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
|
||||
return success();
|
||||
}
|
||||
return parser.emitError(parser.getCurrentLocation())
|
||||
<< "expected shaped type";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// parses the optional initial value or type for a tosa variable
|
||||
// with initial value:
|
||||
// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
|
||||
//
|
||||
// without initial value:
|
||||
// tosa.variable @name : tensor<1x8xf32>
|
||||
ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue(
|
||||
OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
|
||||
Attribute &initialValueAttr) {
|
||||
if (succeeded(parser.parseOptionalEqual())) {
|
||||
if (failed(parser.parseAttribute(attr))) {
|
||||
if (failed(parser.parseAttribute(initialValueAttr))) {
|
||||
return parser.emitError(parser.getCurrentLocation())
|
||||
<< "expected attribute";
|
||||
}
|
||||
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
|
||||
typeAttr = TypeAttr::get(typedAttr.getType());
|
||||
if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
|
||||
return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
|
||||
typeAttr);
|
||||
}
|
||||
return success();
|
||||
return parser.emitError(parser.getCurrentLocation())
|
||||
<< "expected Typed attr";
|
||||
}
|
||||
|
||||
Type type;
|
||||
if (failed(parser.parseColonType(type))) {
|
||||
return parser.emitError(parser.getCurrentLocation()) << "expected type";
|
||||
initialValueAttr = nullptr;
|
||||
Type parsedType;
|
||||
if (failed(parser.parseColonType(parsedType))) {
|
||||
return parser.emitError(parser.getCurrentLocation())
|
||||
<< "expected type after colon";
|
||||
}
|
||||
typeAttr = TypeAttr::get(type);
|
||||
|
||||
return success();
|
||||
return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
|
||||
}
|
||||
|
||||
void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
|
||||
Attribute attr) {
|
||||
void mlir::tosa::printVariableOpTypeOrInitialValue(
|
||||
OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
|
||||
TypeAttr typeAttr, Attribute initialValueAttr) {
|
||||
bool needsSpace = false;
|
||||
auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
|
||||
if (!typedAttr || typedAttr.getType() != type.getValue()) {
|
||||
if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
|
||||
auto shape =
|
||||
convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
|
||||
Type elementType = typeAttr.getValue();
|
||||
RankedTensorType tensorType =
|
||||
RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
|
||||
auto tensorTypeAttr = TypeAttr::get(tensorType);
|
||||
p << ": ";
|
||||
p.printAttribute(type);
|
||||
p.printAttribute(tensorTypeAttr);
|
||||
needsSpace = true; // subsequent attr value needs a space separator
|
||||
}
|
||||
if (attr) {
|
||||
if (initialValueAttr) {
|
||||
if (needsSpace)
|
||||
p << ' ';
|
||||
p << "= ";
|
||||
p.printAttribute(attr);
|
||||
p.printAttribute(initialValueAttr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -657,8 +713,9 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
|
||||
<< symName << "' has not been declared by 'tosa.variable'";
|
||||
|
||||
// Verify type and shape
|
||||
Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
|
||||
if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
|
||||
auto variableType = getVariableType(varOp.value());
|
||||
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
|
||||
"the input tensor")
|
||||
.failed())
|
||||
return failure();
|
||||
|
||||
@@ -1103,6 +1160,33 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
||||
result.types.push_back(outputType);
|
||||
}
|
||||
|
||||
static void buildVariableOp(OpBuilder &builder, OperationState &result,
|
||||
StringRef name, Type variableType,
|
||||
Attribute initialValue) {
|
||||
const Location loc{result.location};
|
||||
auto nameAttr = builder.getStringAttr(name);
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(variableType);
|
||||
if (!shapedType) {
|
||||
(void)emitError(loc, "variable type must be a shaped type");
|
||||
return;
|
||||
}
|
||||
if (!shapedType.hasRank()) {
|
||||
(void)emitError(loc, "variable type must be a ranked type");
|
||||
return;
|
||||
}
|
||||
|
||||
auto elementType = shapedType.getElementType();
|
||||
auto elementTypeAttr = TypeAttr::get(elementType);
|
||||
ArrayRef<int64_t> shape = shapedType.getShape();
|
||||
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
|
||||
|
||||
result.addAttribute("name", nameAttr);
|
||||
result.addAttribute("var_shape", varShapeAttr);
|
||||
result.addAttribute("type", elementTypeAttr);
|
||||
result.addAttribute("initial_value", initialValue);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TOSA Operator Return Type Inference.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1676,12 +1760,6 @@ LogicalResult tosa::PadOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
|
||||
return to_vector(llvm::map_range(shape, [](int64_t dim) {
|
||||
return dim == -1 ? ShapedType::kDynamic : dim;
|
||||
}));
|
||||
}
|
||||
|
||||
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
SliceOp::Adaptor adaptor,
|
||||
|
||||
@@ -215,15 +215,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
|
||||
|
||||
template <>
|
||||
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
|
||||
::mlir::Attribute attr = op.getInitialValueAttr();
|
||||
if (attr == nullptr)
|
||||
return failure();
|
||||
|
||||
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
|
||||
addType(getElementTypeOrSelf(typedAttr));
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
addType(op.getType());
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
|
||||
@@ -238,10 +238,10 @@ private:
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool levelCheckRank(Operation *op, const T &v,
|
||||
// Perform the Level Rank check on the tensor type.
|
||||
bool levelCheckRank(Operation *op, const Type typeToCheck,
|
||||
const StringRef operandOrResult, int32_t highest_rank) {
|
||||
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
|
||||
if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
|
||||
if (!type.hasRank()) {
|
||||
op->emitOpError() << "failed level check: unranked tensor";
|
||||
return false;
|
||||
@@ -255,10 +255,22 @@ private:
|
||||
return true;
|
||||
}
|
||||
|
||||
// Perform the Level tensor size check on the input tensor.
|
||||
bool levelCheckSize(Operation *op, const Value &v,
|
||||
// Perform the Level Rank check on the tensor value.
|
||||
bool levelCheckRank(Operation *op, const Value &v,
|
||||
const StringRef operandOrResult, int32_t highest_rank) {
|
||||
return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
|
||||
}
|
||||
|
||||
// Perform the Level tensor size check on the tensor type.
|
||||
bool levelCheckSize(Operation *op, const Type &typeToCheck,
|
||||
const StringRef operandOrResult);
|
||||
|
||||
// Perform the Level tensor size check on the tensor value.
|
||||
bool levelCheckSize(Operation *op, const Value &v,
|
||||
const StringRef operandOrResult) {
|
||||
return levelCheckSize(op, v.getType(), operandOrResult);
|
||||
}
|
||||
|
||||
// Level check sizes of all operands and results of the operation.
|
||||
template <typename T>
|
||||
bool levelCheckSizes(T tosaOp) {
|
||||
@@ -284,15 +296,6 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!op->getAttrs().empty()) {
|
||||
for (NamedAttribute attr : op->getAttrs()) {
|
||||
if (auto elemAttr = dyn_cast<ElementsAttr>(attr.getValue())) {
|
||||
if (!levelCheckRank(op, elemAttr, "attribute", tosaLevel.MAX_RANK))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto v : op->getResults()) {
|
||||
if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
|
||||
return false;
|
||||
@@ -596,6 +599,26 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
|
||||
auto op = tosaOp.getOperation();
|
||||
auto variableType = getVariableType(tosaOp);
|
||||
if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
|
||||
auto op = tosaOp.getOperation();
|
||||
auto variableType = getVariableType(tosaOp);
|
||||
if (!levelCheckSize(op, variableType, "variable type"))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
|
||||
#define CHECK_RANKS_AND_SIZES(tosaOp) \
|
||||
if (isa<tosa::tosaOp##Op>(op)) { \
|
||||
@@ -714,10 +737,10 @@ bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Perform the Level tensor size check
|
||||
bool TosaValidation::levelCheckSize(Operation *op, const Value &v,
|
||||
// Perform the Level tensor size check on the tensor type.
|
||||
bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
|
||||
const StringRef operandOrResult) {
|
||||
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
|
||||
if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
|
||||
if (!type.hasRank()) {
|
||||
op->emitOpError() << "failed level check: unranked tensor";
|
||||
return false;
|
||||
@@ -800,18 +823,21 @@ inline bool CompatibleTypes(const mlir::Type &type,
|
||||
}
|
||||
|
||||
bool TosaValidation::CheckVariable(Operation *op) {
|
||||
if (isa<mlir::tosa::VariableOp>(op)) {
|
||||
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
|
||||
if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
|
||||
mlir::StringAttr nameAttr = variableOp.getNameAttr();
|
||||
|
||||
if (variablesMap.count(nameAttr)) {
|
||||
op->emitOpError() << "name has already been declared";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
|
||||
mlir::Type type = typeAttr.getValue();
|
||||
auto elementType = variableOp.getType();
|
||||
DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
|
||||
SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
|
||||
RankedTensorType variableType =
|
||||
RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
|
||||
|
||||
variablesMap[nameAttr] = type;
|
||||
variablesMap[nameAttr] = variableType;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s
|
||||
// RUN: mlir-opt --tosa-to-mlprogram %s -split-input-file -o -| FileCheck %s
|
||||
|
||||
module {
|
||||
// CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
|
||||
@@ -10,4 +10,18 @@ module {
|
||||
%0 = tosa.variable_read @var_x : tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// CHECK: ml_program.global private mutable @var_x : tensor<f32>
|
||||
tosa.variable @var_x : tensor<f32>
|
||||
func.func @test_stateful_ops(%arg0: tensor<f32>) -> (tensor<f32>) {
|
||||
// CHECK: ml_program.global_store @var_x = %arg0 : tensor<f32>
|
||||
tosa.variable_write @var_x, %arg0 : tensor<f32>
|
||||
// CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<f32>
|
||||
%0 = tosa.variable_read @var_x : tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
}
|
||||
@@ -564,6 +564,23 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
|
||||
tosa.variable @stored_var : tensor<*xi8>
|
||||
// expected-error@+1 {{custom op 'tosa.variable' expected ranked type}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
|
||||
// expected-error@+1 {{elements literal type must have static shape}}
|
||||
tosa.variable @stored_var = dense<0> : tensor<*xi8>
|
||||
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
|
||||
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
|
||||
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
|
||||
|
||||
@@ -443,7 +443,7 @@ func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tenso
|
||||
|
||||
// -----
|
||||
func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
|
||||
// expected-error@+1 {{'tosa.const' op failed level check: attribute rank(shape) <= MAX_RANK}}
|
||||
// expected-error@+1 {{'tosa.const' op failed level check: result rank(shape) <= MAX_RANK}}
|
||||
%0 = "tosa.const"() {values = dense<0> : tensor<1x1x1x1x1x1x1xi32>} : () -> tensor<1x1x1x1x1x1x1xi32>
|
||||
return %0: tensor<1x1x1x1x1x1x1xi32>
|
||||
}
|
||||
@@ -1089,7 +1089,8 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %
|
||||
// -----
|
||||
|
||||
func.func @test_variable_read_write_tensor_size_invalid() -> () {
|
||||
tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32>
|
||||
// expected-error@+1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
|
||||
tosa.variable @stored_var : tensor<536870912xf32>
|
||||
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
|
||||
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
|
||||
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
|
||||
@@ -1156,8 +1157,8 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
|
||||
// -----
|
||||
|
||||
func.func @test_variable_read_write_rank_invalid() -> () {
|
||||
// expected-error@+1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}}
|
||||
tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
|
||||
// expected-error@+1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}}
|
||||
tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
|
||||
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
|
||||
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
|
||||
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
|
||||
|
||||
@@ -31,3 +31,48 @@ func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
|
||||
tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @test_variable_scalar_no_initial_value(
|
||||
// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
|
||||
func.func @test_variable_scalar_no_initial_value(%arg0: tensor<f32>) -> () {
|
||||
// CHECK: tosa.variable @stored_var : tensor<f32>
|
||||
tosa.variable @stored_var : tensor<f32>
|
||||
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
|
||||
%0 = tosa.variable_read @stored_var : tensor<f32>
|
||||
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
|
||||
tosa.variable_write @stored_var, %1 : tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @test_variable_tensor_no_initial_value(
|
||||
// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
|
||||
func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () {
|
||||
// CHECK: tosa.variable @stored_var : tensor<2x4x8xi32>
|
||||
tosa.variable @stored_var : tensor<2x4x8xi32>
|
||||
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
|
||||
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
|
||||
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
|
||||
%1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
|
||||
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
|
||||
tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @test_variable_tensor_with_unknowns(
|
||||
// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
|
||||
func.func @test_variable_tensor_with_unknowns(%arg0: tensor<2x4x8xi32>) -> () {
|
||||
// CHECK: tosa.variable @stored_var : tensor<2x?x8xi32>
|
||||
tosa.variable @stored_var : tensor<2x?x8xi32>
|
||||
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
|
||||
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
|
||||
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
|
||||
%1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
|
||||
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
|
||||
tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user