This revision refactors ElementsAttr into an Attribute Interface.
This enables a common interface with which to interact with
element attributes, without needing to modify the builtin
dialect. It also removes a majority (if not all?) of the need for
the current OpaqueElementsAttr, which was originally intended as
a way to opaquely represent data that was not representable by
the other builtin constructs.
The new ElementsAttr interface not only allows for users to
natively represent their data in the way that best suits them,
it also allows for efficient opaque access and iteration of the
underlying data. Attributes using the ElementsAttr interface
can directly expose support for interacting with the held
elements using any C++ data type they claim to support. For
example, DenseIntOrFpElementsAttr supports iteration using
various native C++ integer/float data types, as well as
APInt/APFloat, and more. ElementsAttr instances that refer to
DenseIntOrFpElementsAttr can use all of these data types for
iteration:
```c++
DenseIntOrFpElementsAttr intElementsAttr = ...;
ElementsAttr attr = intElementsAttr;
for (uint64_t value : attr.getValues<uint64_t>())
...;
for (APInt value : attr.getValues<APInt>())
...;
for (IntegerAttr value : attr.getValues<IntegerAttr>())
...;
```
ElementsAttr also supports failable range/iterator access,
allowing for selective code paths depending on data type
support:
```c++
ElementsAttr attr = ...;
if (auto range = attr.tryGetValues<uint64_t>()) {
for (uint64_t value : *range)
...;
}
```
Differential Revision: https://reviews.llvm.org/D109190
75 lines
2.7 KiB
C++
75 lines
2.7 KiB
C++
//===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "llvm/ADT/Sequence.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::detail;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// Tablegen Interface Definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ElementsAttr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ShapedType ElementsAttr::getType() const {
|
|
return Attribute::getType().cast<ShapedType>();
|
|
}
|
|
|
|
Type ElementsAttr::getElementType(Attribute elementsAttr) {
|
|
return elementsAttr.getType().cast<ShapedType>().getElementType();
|
|
}
|
|
|
|
int64_t ElementsAttr::getNumElements(Attribute elementsAttr) {
|
|
return elementsAttr.getType().cast<ShapedType>().getNumElements();
|
|
}
|
|
|
|
bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
|
|
// Verify that the rank of the indices matches the held type.
|
|
int64_t rank = type.getRank();
|
|
if (rank == 0 && index.size() == 1 && index[0] == 0)
|
|
return true;
|
|
if (rank != static_cast<int64_t>(index.size()))
|
|
return false;
|
|
|
|
// Verify that all of the indices are within the shape dimensions.
|
|
ArrayRef<int64_t> shape = type.getShape();
|
|
return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
|
|
int64_t dim = static_cast<int64_t>(index[i]);
|
|
return 0 <= dim && dim < shape[i];
|
|
});
|
|
}
|
|
bool ElementsAttr::isValidIndex(Attribute elementsAttr,
|
|
ArrayRef<uint64_t> index) {
|
|
return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index);
|
|
}
|
|
|
|
uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr,
|
|
ArrayRef<uint64_t> index) {
|
|
ShapedType type = elementsAttr.getType().cast<ShapedType>();
|
|
assert(isValidIndex(type, index) && "expected valid multi-dimensional index");
|
|
|
|
// Reduce the provided multidimensional index into a flattended 1D row-major
|
|
// index.
|
|
auto rank = type.getRank();
|
|
auto shape = type.getShape();
|
|
uint64_t valueIndex = 0;
|
|
uint64_t dimMultiplier = 1;
|
|
for (int i = rank - 1; i >= 0; --i) {
|
|
valueIndex += index[i] * dimMultiplier;
|
|
dimMultiplier *= shape[i];
|
|
}
|
|
return valueIndex;
|
|
}
|