[mlir][tosa] Support DenseResourceElementsAttr in TOSA transpose folders (#124532)

Handle dense resource attributes in the transpose TOSA folder.
Currently their interface does not align with the rest of the
`ElementsAttr` when it comes to data accessing hence the special
handling.

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
This commit is contained in:
Georgios Pinitas
2025-03-24 21:48:22 +00:00
committed by GitHub
parent 5a668bdb98
commit 3df92197bb
2 changed files with 57 additions and 4 deletions

View File

@@ -18,6 +18,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/APFloat.h"
@@ -176,6 +177,28 @@ DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
llvm::ArrayRef<ElementType>(outputValues));
}
// Try to get the values of a DenseResourceElementsAttr construct
template <typename T>
std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
// Check that the resource memory blob exists
AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
if (!blob)
return std::nullopt;
// Check that the data are in a valid form
bool isSplat = false;
if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
blob->getData(), isSplat)) {
return std::nullopt;
}
return blob->template getDataAs<T>();
}
return std::nullopt;
}
// A type specialized transposition of an ElementsAttr.
// This implementation tries to operate on the underlying data in its raw
// representation when possible to avoid allocating a large number of Attribute
@@ -183,6 +206,7 @@ DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
// Handle generic ElementsAttr
if (auto data = attr.tryGetValues<bool>())
return transposeType(*data, inputType, outputType, permValues);
@@ -204,6 +228,35 @@ DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
if (auto data = attr.tryGetValues<APFloat>())
return transposeType(*data, inputType, outputType, permValues);
// Handle DenseResourceElementsAttr
if (isa<DenseResourceElementsAttr>(attr)) {
auto elementTy = attr.getElementType();
if (auto data = tryGetDenseResourceValues<bool>(attr);
data && elementTy.isInteger(1))
return transposeType(*data, inputType, outputType, permValues);
if (auto data = tryGetDenseResourceValues<int8_t>(attr);
data && elementTy.isInteger(8))
return transposeType(*data, inputType, outputType, permValues);
if (auto data = tryGetDenseResourceValues<int16_t>(attr);
data && elementTy.isInteger(16))
return transposeType(*data, inputType, outputType, permValues);
if (auto data = tryGetDenseResourceValues<int32_t>(attr);
data && elementTy.isInteger(32))
return transposeType(*data, inputType, outputType, permValues);
if (auto data = tryGetDenseResourceValues<int64_t>(attr);
data && elementTy.isInteger(64))
return transposeType(*data, inputType, outputType, permValues);
if (auto data = tryGetDenseResourceValues<float>(attr);
data && elementTy.isF32())
return transposeType(*data, inputType, outputType, permValues);
}
return nullptr;
}

View File

@@ -108,18 +108,18 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i
return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
}
// CHECK-LABEL: @transpose_nofold_dense_resource
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
// CHECK-LABEL: @transpose_fold_dense_resource
func.func @transpose_fold_dense_resource() -> tensor<2x2xf32> {
%0 = "tosa.const"() <{values = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
// CHECK: tosa.transpose
// CHECK-NOT: tosa.transpose
%2 = tosa.transpose %0 { perms = array<i32: 1, 0> }: (tensor<2x2xf32>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
}
{-#
dialect_resources: {
builtin: {
resource: "0x08000000010000000000000002000000000000000300000000000000"
resource: "0x040000003f800000400000004040000040800000"
}
}
#-}