[tosa] Change VariableOp to align with spec (#142240)

This fixes Tosa VariableOp to align with spec 1.0
  - add var_shape attribute to store shape of variable type
  - change type attribute to store element type of variable type
  - add a builder so previous construction calls still work
- fix up level check of rank to be on variable type instead of initial
value which is optional
  - add level check of size for variable type
  - add lit tests for variable op's without initial values
  - add lit test for variable op with fixed rank but unknown dimension
  - add invalid lit test for variable op with unranked type

Signed-off-by: Tai Ly <tai.ly@arm.com>
This commit is contained in:
Tai Ly
2025-06-03 11:41:33 -05:00
committed by GitHub
parent 4d42c8e184
commit 04b63ac1ab
11 changed files with 266 additions and 69 deletions

View File

@@ -197,6 +197,16 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
input, paddings);
}]>;
// This builder is called on the TOSA variable operator with a variable type
// and optional initial value. The builder will extract var_shape and element type
// attributes from variable type.
def Tosa_VariableOpBuilder : OpBuilder<
(ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
[{
buildVariableOp($_builder, $_state, name, variable_type, initial_value);
}]>;
// Wrapper over base I32EnumAttr to set common fields.
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {

View File

@@ -44,10 +44,14 @@ class PatternRewriter;
namespace tosa {
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &attr);
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr);
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser,
DenseElementsAttr &varShapeAttr,
TypeAttr &typeAttr,
Attribute &initialValueAttr);
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op,
DenseElementsAttr varShapeAttr,
TypeAttr typeAttr,
Attribute initialValueAttr);
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
@@ -172,6 +176,9 @@ std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
int32_t val = 0);
// returns type of variable op
RankedTensorType getVariableType(VariableOp variableOp);
} // namespace tosa
} // namespace mlir

View File

@@ -92,6 +92,7 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
let arguments = (ins
SymbolNameAttr:$name,
IndexElementsAttr:$var_shape,
TypeAttr:$type,
OptionalAttr<AnyAttr>:$initial_value
);
@@ -101,12 +102,16 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
Extension<[Tosa_EXT_VARIABLE]>,
];
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$name
attr-dict
custom<TypeOrAttr>($type, $initial_value)
custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
}];
let builders = [Tosa_VariableOpBuilder];
let hasVerifier = 1;
}

View File

@@ -26,8 +26,9 @@ public:
LogicalResult matchAndRewrite(tosa::VariableOp op,
PatternRewriter &rewriter) const final {
auto variableType = tosa::getVariableType(op);
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
newVariable.setPrivate();
rewriter.replaceOp(op, newVariable);

View File

@@ -131,6 +131,24 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
return {&getBodyGraph()};
}
//===----------------------------------------------------------------------===//
// TOSA variable operator support.
//===----------------------------------------------------------------------===//
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return dim == -1 ? ShapedType::kDynamic : dim;
}));
}
// returns type of variable op
RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
Type elementType = variableOp.getType();
DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
return RankedTensorType::get(shape, elementType);
}
//===----------------------------------------------------------------------===//
// Tosa dialect initialization.
//===----------------------------------------------------------------------===//
@@ -177,42 +195,80 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Parsers and printers
//===----------------------------------------------------------------------===//
ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &attr) {
namespace {
ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
DenseElementsAttr &varShapeAttr,
TypeAttr &typeAttr) {
if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
if (!shapedType.hasRank())
return parser.emitError(parser.getCurrentLocation())
<< "expected ranked type";
auto elementType = shapedType.getElementType();
typeAttr = TypeAttr::get(elementType);
ArrayRef<int64_t> shape = shapedType.getShape();
Builder builder(parser.getContext());
varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
return success();
}
return parser.emitError(parser.getCurrentLocation())
<< "expected shaped type";
}
} // namespace
// parses the optional initial value or type for a tosa variable
// with initial value:
// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
//
// without initial value:
// tosa.variable @name : tensor<1x8xf32>
ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue(
OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
Attribute &initialValueAttr) {
if (succeeded(parser.parseOptionalEqual())) {
if (failed(parser.parseAttribute(attr))) {
if (failed(parser.parseAttribute(initialValueAttr))) {
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
typeAttr = TypeAttr::get(typedAttr.getType());
if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
typeAttr);
}
return success();
return parser.emitError(parser.getCurrentLocation())
<< "expected Typed attr";
}
Type type;
if (failed(parser.parseColonType(type))) {
return parser.emitError(parser.getCurrentLocation()) << "expected type";
initialValueAttr = nullptr;
Type parsedType;
if (failed(parser.parseColonType(parsedType))) {
return parser.emitError(parser.getCurrentLocation())
<< "expected type after colon";
}
typeAttr = TypeAttr::get(type);
return success();
return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
}
void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr) {
void mlir::tosa::printVariableOpTypeOrInitialValue(
OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
TypeAttr typeAttr, Attribute initialValueAttr) {
bool needsSpace = false;
auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
if (!typedAttr || typedAttr.getType() != type.getValue()) {
if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
auto shape =
convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
Type elementType = typeAttr.getValue();
RankedTensorType tensorType =
RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
auto tensorTypeAttr = TypeAttr::get(tensorType);
p << ": ";
p.printAttribute(type);
p.printAttribute(tensorTypeAttr);
needsSpace = true; // subsequent attr value needs a space separator
}
if (attr) {
if (initialValueAttr) {
if (needsSpace)
p << ' ';
p << "= ";
p.printAttribute(attr);
p.printAttribute(initialValueAttr);
}
}
@@ -657,8 +713,9 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
<< symName << "' has not been declared by 'tosa.variable'";
// Verify type and shape
Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
auto variableType = getVariableType(varOp.value());
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
"the input tensor")
.failed())
return failure();
@@ -1103,6 +1160,33 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
result.types.push_back(outputType);
}
static void buildVariableOp(OpBuilder &builder, OperationState &result,
StringRef name, Type variableType,
Attribute initialValue) {
const Location loc{result.location};
auto nameAttr = builder.getStringAttr(name);
auto shapedType = dyn_cast<ShapedType>(variableType);
if (!shapedType) {
(void)emitError(loc, "variable type must be a shaped type");
return;
}
if (!shapedType.hasRank()) {
(void)emitError(loc, "variable type must be a ranked type");
return;
}
auto elementType = shapedType.getElementType();
auto elementTypeAttr = TypeAttr::get(elementType);
ArrayRef<int64_t> shape = shapedType.getShape();
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
result.addAttribute("name", nameAttr);
result.addAttribute("var_shape", varShapeAttr);
result.addAttribute("type", elementTypeAttr);
result.addAttribute("initial_value", initialValue);
}
//===----------------------------------------------------------------------===//
// TOSA Operator Return Type Inference.
//===----------------------------------------------------------------------===//
@@ -1676,12 +1760,6 @@ LogicalResult tosa::PadOp::verify() {
return success();
}
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return dim == -1 ? ShapedType::kDynamic : dim;
}));
}
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,

View File

@@ -215,15 +215,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
::mlir::Attribute attr = op.getInitialValueAttr();
if (attr == nullptr)
return failure();
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
addType(getElementTypeOrSelf(typedAttr));
return success();
}
return failure();
addType(op.getType());
return success();
}
template <>

View File

@@ -238,10 +238,10 @@ private:
return true;
}
template <typename T>
bool levelCheckRank(Operation *op, const T &v,
// Perform the Level Rank check on the tensor type.
bool levelCheckRank(Operation *op, const Type typeToCheck,
const StringRef operandOrResult, int32_t highest_rank) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
if (!type.hasRank()) {
op->emitOpError() << "failed level check: unranked tensor";
return false;
@@ -255,10 +255,22 @@ private:
return true;
}
// Perform the Level tensor size check on the input tensor.
bool levelCheckSize(Operation *op, const Value &v,
// Perform the Level Rank check on the tensor value.
bool levelCheckRank(Operation *op, const Value &v,
const StringRef operandOrResult, int32_t highest_rank) {
return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
}
// Perform the Level tensor size check on the tensor type.
bool levelCheckSize(Operation *op, const Type &typeToCheck,
const StringRef operandOrResult);
// Perform the Level tensor size check on the tensor value.
bool levelCheckSize(Operation *op, const Value &v,
const StringRef operandOrResult) {
return levelCheckSize(op, v.getType(), operandOrResult);
}
// Level check sizes of all operands and results of the operation.
template <typename T>
bool levelCheckSizes(T tosaOp) {
@@ -284,15 +296,6 @@ private:
return false;
}
if (!op->getAttrs().empty()) {
for (NamedAttribute attr : op->getAttrs()) {
if (auto elemAttr = dyn_cast<ElementsAttr>(attr.getValue())) {
if (!levelCheckRank(op, elemAttr, "attribute", tosaLevel.MAX_RANK))
return false;
}
}
}
for (auto v : op->getResults()) {
if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
return false;
@@ -596,6 +599,26 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
return true;
}
template <>
bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
auto op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
return false;
return true;
}
template <>
bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
auto op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
if (!levelCheckSize(op, variableType, "variable type"))
return false;
return true;
}
bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
#define CHECK_RANKS_AND_SIZES(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
@@ -714,10 +737,10 @@ bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
return true;
}
// Perform the Level tensor size check
bool TosaValidation::levelCheckSize(Operation *op, const Value &v,
// Perform the Level tensor size check on the tensor type.
bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
const StringRef operandOrResult) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
if (!type.hasRank()) {
op->emitOpError() << "failed level check: unranked tensor";
return false;
@@ -800,18 +823,21 @@ inline bool CompatibleTypes(const mlir::Type &type,
}
bool TosaValidation::CheckVariable(Operation *op) {
if (isa<mlir::tosa::VariableOp>(op)) {
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
mlir::StringAttr nameAttr = variableOp.getNameAttr();
if (variablesMap.count(nameAttr)) {
op->emitOpError() << "name has already been declared";
return false;
}
auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
mlir::Type type = typeAttr.getValue();
auto elementType = variableOp.getType();
DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
RankedTensorType variableType =
RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
variablesMap[nameAttr] = type;
variablesMap[nameAttr] = variableType;
}
return true;

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s
// RUN: mlir-opt --tosa-to-mlprogram %s -split-input-file -o -| FileCheck %s
module {
// CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
@@ -10,4 +10,18 @@ module {
%0 = tosa.variable_read @var_x : tensor<1xf32>
return %0 : tensor<1xf32>
}
}
// -----
module {
// CHECK: ml_program.global private mutable @var_x : tensor<f32>
tosa.variable @var_x : tensor<f32>
func.func @test_stateful_ops(%arg0: tensor<f32>) -> (tensor<f32>) {
// CHECK: ml_program.global_store @var_x = %arg0 : tensor<f32>
tosa.variable_write @var_x, %arg0 : tensor<f32>
// CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<f32>
%0 = tosa.variable_read @var_x : tensor<f32>
return %0 : tensor<f32>
}
}

View File

@@ -564,6 +564,23 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
// -----
func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var : tensor<*xi8>
// expected-error@+1 {{custom op 'tosa.variable' expected ranked type}}
return
}
// -----
func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
// expected-error@+1 {{elements literal type must have static shape}}
tosa.variable @stored_var = dense<0> : tensor<*xi8>
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
return
}
// -----
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}

View File

@@ -443,7 +443,7 @@ func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tenso
// -----
func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
// expected-error@+1 {{'tosa.const' op failed level check: attribute rank(shape) <= MAX_RANK}}
// expected-error@+1 {{'tosa.const' op failed level check: result rank(shape) <= MAX_RANK}}
%0 = "tosa.const"() {values = dense<0> : tensor<1x1x1x1x1x1x1xi32>} : () -> tensor<1x1x1x1x1x1x1xi32>
return %0: tensor<1x1x1x1x1x1x1xi32>
}
@@ -1089,7 +1089,8 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %
// -----
func.func @test_variable_read_write_tensor_size_invalid() -> () {
tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32>
// expected-error@+1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
tosa.variable @stored_var : tensor<536870912xf32>
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
@@ -1156,8 +1157,8 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
// -----
func.func @test_variable_read_write_rank_invalid() -> () {
// expected-error@+1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}}
tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error@+1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}}
tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}

View File

@@ -31,3 +31,48 @@ func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
return
}
// -----
// CHECK-LABEL: @test_variable_scalar_no_initial_value(
// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
func.func @test_variable_scalar_no_initial_value(%arg0: tensor<f32>) -> () {
// CHECK: tosa.variable @stored_var : tensor<f32>
tosa.variable @stored_var : tensor<f32>
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
%0 = tosa.variable_read @stored_var : tensor<f32>
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
%1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
tosa.variable_write @stored_var, %1 : tensor<f32>
return
}
// -----
// CHECK-LABEL: @test_variable_tensor_no_initial_value(
// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () {
// CHECK: tosa.variable @stored_var : tensor<2x4x8xi32>
tosa.variable @stored_var : tensor<2x4x8xi32>
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
%1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
return
}
// -----
// CHECK-LABEL: @test_variable_tensor_with_unknowns(
// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
func.func @test_variable_tensor_with_unknowns(%arg0: tensor<2x4x8xi32>) -> () {
// CHECK: tosa.variable @stored_var : tensor<2x?x8xi32>
tosa.variable @stored_var : tensor<2x?x8xi32>
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
%1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
return
}