This attribute is similar to DenseElementsAttr but does not support splat. As such it has a much simpler API and does not need any smart iterator: it exposes direct ArrayRef access. A new syntax is introduced so that the generic printing/parsing looks like: [:i64 1, -2, 3] This attribute beings like an ArrayAttr but has a `:` token after the opening square brace to introduce the element type (supported are I8, I16, I32, I64, F32, F64) and the comma separated list for the data. This is particularly convenient for attributes intended to be small, like those referring to shapes. For example a `transpose` operation with a `dims` attribute could be defined as such: let arguments = (ins AnyTensor:$input, DenseI64ArrayAttr:$dims); let assemblyFormat = "$input `dims` `=` $dims attr-dict : type($input)"; And printed this way (the element type is elided in this case): transpose %input dims = [0, 2, 1] : tensor<2x3x4xf32> The C++ API for dims would just directly return an ArrayRef<int64> RFC: https://discourse.llvm.org/t/rfc-introduce-a-new-dense-array-attribute/63279 Recommit with a custom DenseArrayBaseAttrStorage class to ensure over-alignment of the storage to the largest type. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D123774
100 lines
3.5 KiB
C++
100 lines
3.5 KiB
C++
//===- TestBuiltinAttributeInterfaces.cpp ---------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TestAttributes.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
|
|
using namespace mlir;
|
|
using namespace test;
|
|
|
|
// Helper to print one scalar value, force int8_t to print as integer instead of
|
|
// char.
|
|
template <typename T>
|
|
static void printOneElement(InFlightDiagnostic &os, T value) {
|
|
os << llvm::formatv("{0}", value).str();
|
|
}
|
|
template <>
|
|
void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
|
|
os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
|
|
}
|
|
|
|
namespace {
|
|
struct TestElementsAttrInterface
|
|
: public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestElementsAttrInterface)
|
|
|
|
StringRef getArgument() const final { return "test-elements-attr-interface"; }
|
|
StringRef getDescription() const final {
|
|
return "Test ElementsAttr interface support.";
|
|
}
|
|
void runOnOperation() override {
|
|
getOperation().walk([&](Operation *op) {
|
|
for (NamedAttribute attr : op->getAttrs()) {
|
|
auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
|
|
if (!elementsAttr)
|
|
continue;
|
|
if (auto concreteAttr =
|
|
attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
|
|
switch (concreteAttr.getElementType()) {
|
|
case DenseArrayBaseAttr::EltType::I8:
|
|
testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
|
|
break;
|
|
case DenseArrayBaseAttr::EltType::I16:
|
|
testElementsAttrIteration<int16_t>(op, elementsAttr, "int16_t");
|
|
break;
|
|
case DenseArrayBaseAttr::EltType::I32:
|
|
testElementsAttrIteration<int32_t>(op, elementsAttr, "int32_t");
|
|
break;
|
|
case DenseArrayBaseAttr::EltType::I64:
|
|
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
|
|
break;
|
|
case DenseArrayBaseAttr::EltType::F32:
|
|
testElementsAttrIteration<float>(op, elementsAttr, "float");
|
|
break;
|
|
case DenseArrayBaseAttr::EltType::F64:
|
|
testElementsAttrIteration<double>(op, elementsAttr, "double");
|
|
break;
|
|
}
|
|
continue;
|
|
}
|
|
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
|
|
testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
|
|
testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
|
|
testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr");
|
|
}
|
|
});
|
|
}
|
|
|
|
template <typename T>
|
|
void testElementsAttrIteration(Operation *op, ElementsAttr attr,
|
|
StringRef type) {
|
|
InFlightDiagnostic diag = op->emitError()
|
|
<< "Test iterating `" << type << "`: ";
|
|
|
|
auto values = attr.tryGetValues<T>();
|
|
if (!values) {
|
|
diag << "unable to iterate type";
|
|
return;
|
|
}
|
|
|
|
llvm::interleaveComma(*values, diag,
|
|
[&](T value) { printOneElement(diag, value); });
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestBuiltinAttributeInterfaces() {
|
|
PassRegistration<TestElementsAttrInterface>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|