Files
clang-p2996/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
Rik Huijzer c836b4ad6c [mlir] Verify TestBuiltinAttributeInterfaces eltype (#69878)
Fixes #61871 and fixes #60581.

This PR fixes two small things. First and foremost, it throws a clear
error in the `-test-elements-attr-interface` when those tests are called
on elements which are not an integer. I've looked through the
introduction of the attribute interface
(https://reviews.llvm.org/D109190) and later commits and see no evidence
that the interface (`attr.tryGetValues<T>()`) is expected to handle
mismatching types.

For example, the case which is given in #61871 is:
```mlir
arith.constant sparse<[[0, 0, 5]],  -2.0> : vector<1x1x10xf16>
```
So, a sparse vector containing `f16` elements. This will crash at
various locations when called in the test because the test introduces
integer types (`int64_t`, `uint64_t`, `APInt`, `IntegerAttr`), but as I
said in the previous paragraph: I see no reason to believe that the
implementation of the interface is wrong here. The interface just
assumes that clients don't do things like `attr.tryGetValues<APInt>()`
on a floating point `attr`.

Also I've added a test for the implementation of this interface by the
`sparse` dialect. There were no problems there. Still, probably good to
increase code coverage on that one.
2023-10-23 17:02:52 +02:00

78 lines
2.5 KiB
C++

//===- TestBuiltinAttributeInterfaces.cpp ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
using namespace test;
// Helper to print one scalar value, force int8_t to print as integer instead of
// char.
template <typename T>
static void printOneElement(InFlightDiagnostic &os, T value) {
os << llvm::formatv("{0}", value).str();
}
namespace {
struct TestElementsAttrInterface
: public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestElementsAttrInterface)
StringRef getArgument() const final { return "test-elements-attr-interface"; }
StringRef getDescription() const final {
return "Test ElementsAttr interface support.";
}
void runOnOperation() override {
getOperation().walk([&](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) {
auto elementsAttr = dyn_cast<ElementsAttr>(attr.getValue());
if (!elementsAttr)
continue;
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr");
}
});
}
template <typename T>
void testElementsAttrIteration(Operation *op, ElementsAttr attr,
StringRef type) {
InFlightDiagnostic diag = op->emitError()
<< "Test iterating `" << type << "`: ";
if (!attr.getElementType().isa<mlir::IntegerType>()) {
diag << "expected element type to be an integer type";
return;
}
auto values = attr.tryGetValues<T>();
if (!values) {
diag << "unable to iterate type";
return;
}
llvm::interleaveComma(*values, diag,
[&](T value) { printOneElement(diag, value); });
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestBuiltinAttributeInterfaces() {
PassRegistration<TestElementsAttrInterface>();
}
} // namespace test
} // namespace mlir