Files
clang-p2996/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
Mehdi Amini 7faf75bb3e Introduce a new Dense Array attribute
This attribute is similar to DenseElementsAttr but does not support
splat. As such it has a much simpler API and does not need any smart
iterator: it exposes direct ArrayRef access.

A new syntax is introduced so that the generic printing/parsing looks
like:

  [:i64 1, -2, 3]

This attribute beings like an ArrayAttr but has a `:` token after the
opening square brace to introduce the element type (supported are I8,
I16, I32, I64, F32, F64) and the comma separated list for the data.

This is particularly convenient for attributes intended to be small,
like those referring to shapes.
For example a `transpose` operation with a `dims` attribute could be
defined as such:

  let arguments = (ins AnyTensor:$input, DenseI64ArrayAttr:$dims);
  let assemblyFormat = "$input `dims` `=` $dims attr-dict : type($input)";

And printed this way (the element type is elided in this case):

  transpose %input dims = [0, 2, 1] : tensor<2x3x4xf32>

The C++ API for dims would just directly return an ArrayRef<int64>

RFC: https://discourse.llvm.org/t/rfc-introduce-a-new-dense-array-attribute/63279

Recommit with a custom DenseArrayBaseAttrStorage class to ensure
over-alignment of the storage to the largest type.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D123774
2022-06-28 13:28:06 +00:00

100 lines
3.5 KiB
C++

//===- TestBuiltinAttributeInterfaces.cpp ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
using namespace test;
// Helper to print one scalar value, force int8_t to print as integer instead of
// char.
template <typename T>
static void printOneElement(InFlightDiagnostic &os, T value) {
os << llvm::formatv("{0}", value).str();
}
template <>
void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
}
namespace {
struct TestElementsAttrInterface
: public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestElementsAttrInterface)
StringRef getArgument() const final { return "test-elements-attr-interface"; }
StringRef getDescription() const final {
return "Test ElementsAttr interface support.";
}
void runOnOperation() override {
getOperation().walk([&](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) {
auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
if (!elementsAttr)
continue;
if (auto concreteAttr =
attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
switch (concreteAttr.getElementType()) {
case DenseArrayBaseAttr::EltType::I8:
testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
break;
case DenseArrayBaseAttr::EltType::I16:
testElementsAttrIteration<int16_t>(op, elementsAttr, "int16_t");
break;
case DenseArrayBaseAttr::EltType::I32:
testElementsAttrIteration<int32_t>(op, elementsAttr, "int32_t");
break;
case DenseArrayBaseAttr::EltType::I64:
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
break;
case DenseArrayBaseAttr::EltType::F32:
testElementsAttrIteration<float>(op, elementsAttr, "float");
break;
case DenseArrayBaseAttr::EltType::F64:
testElementsAttrIteration<double>(op, elementsAttr, "double");
break;
}
continue;
}
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr");
}
});
}
template <typename T>
void testElementsAttrIteration(Operation *op, ElementsAttr attr,
StringRef type) {
InFlightDiagnostic diag = op->emitError()
<< "Test iterating `" << type << "`: ";
auto values = attr.tryGetValues<T>();
if (!values) {
diag << "unable to iterate type";
return;
}
llvm::interleaveComma(*values, diag,
[&](T value) { printOneElement(diag, value); });
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestBuiltinAttributeInterfaces() {
PassRegistration<TestElementsAttrInterface>();
}
} // namespace test
} // namespace mlir