[mlir][tosa] Support DenseResourceElementsAttr in TOSA transpose folders (#124532)
Handle dense resource attributes in the transpose TOSA folder. Currently their interface does not align with the rest of the `ElementsAttr` when it comes to data accessing hence the special handling. Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/DialectResourceBlobManager.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
@@ -176,6 +177,28 @@ DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
|
||||
llvm::ArrayRef<ElementType>(outputValues));
|
||||
}
|
||||
|
||||
// Try to get the values of a DenseResourceElementsAttr construct
|
||||
template <typename T>
|
||||
std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
|
||||
if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
|
||||
// Check that the resource memory blob exists
|
||||
AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
|
||||
if (!blob)
|
||||
return std::nullopt;
|
||||
|
||||
// Check that the data are in a valid form
|
||||
bool isSplat = false;
|
||||
if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
|
||||
blob->getData(), isSplat)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return blob->template getDataAs<T>();
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// A type specialized transposition of an ElementsAttr.
|
||||
// This implementation tries to operate on the underlying data in its raw
|
||||
// representation when possible to avoid allocating a large number of Attribute
|
||||
@@ -183,6 +206,7 @@ DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
|
||||
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
|
||||
ShapedType outputType,
|
||||
llvm::ArrayRef<int64_t> permValues) {
|
||||
// Handle generic ElementsAttr
|
||||
if (auto data = attr.tryGetValues<bool>())
|
||||
return transposeType(*data, inputType, outputType, permValues);
|
||||
|
||||
@@ -204,6 +228,35 @@ DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
|
||||
if (auto data = attr.tryGetValues<APFloat>())
|
||||
return transposeType(*data, inputType, outputType, permValues);
|
||||
|
||||
// Handle DenseResourceElementsAttr
|
||||
if (isa<DenseResourceElementsAttr>(attr)) {
|
||||
auto elementTy = attr.getElementType();
|
||||
|
||||
if (auto data = tryGetDenseResourceValues<bool>(attr);
|
||||
data && elementTy.isInteger(1))
|
||||
return transposeType(*data, inputType, outputType, permValues);
|
||||
|
||||
if (auto data = tryGetDenseResourceValues<int8_t>(attr);
|
||||
data && elementTy.isInteger(8))
|
||||
return transposeType(*data, inputType, outputType, permValues);
|
||||
|
||||
if (auto data = tryGetDenseResourceValues<int16_t>(attr);
|
||||
data && elementTy.isInteger(16))
|
||||
return transposeType(*data, inputType, outputType, permValues);
|
||||
|
||||
if (auto data = tryGetDenseResourceValues<int32_t>(attr);
|
||||
data && elementTy.isInteger(32))
|
||||
return transposeType(*data, inputType, outputType, permValues);
|
||||
|
||||
if (auto data = tryGetDenseResourceValues<int64_t>(attr);
|
||||
data && elementTy.isInteger(64))
|
||||
return transposeType(*data, inputType, outputType, permValues);
|
||||
|
||||
if (auto data = tryGetDenseResourceValues<float>(attr);
|
||||
data && elementTy.isF32())
|
||||
return transposeType(*data, inputType, outputType, permValues);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
@@ -108,18 +108,18 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i
|
||||
return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_dense_resource
|
||||
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
|
||||
// CHECK-LABEL: @transpose_fold_dense_resource
|
||||
func.func @transpose_fold_dense_resource() -> tensor<2x2xf32> {
|
||||
%0 = "tosa.const"() <{values = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
|
||||
|
||||
// CHECK: tosa.transpose
|
||||
// CHECK-NOT: tosa.transpose
|
||||
%2 = tosa.transpose %0 { perms = array<i32: 1, 0> }: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %2 : tensor<2x2xf32>
|
||||
}
|
||||
{-#
|
||||
dialect_resources: {
|
||||
builtin: {
|
||||
resource: "0x08000000010000000000000002000000000000000300000000000000"
|
||||
resource: "0x040000003f800000400000004040000040800000"
|
||||
}
|
||||
}
|
||||
#-}
|
||||
|
||||
Reference in New Issue
Block a user