[mlir] load dialects for non-namespaced attrs (#96242)

The mlir-translate tool calls into the parser without loading registered
dependent dialects, and the parser only loads attributes if the
fully-namespaced attribute is present in the textual IR. This causes
parsing to break when an op has an attribute that prints/parses without
the namespaced attribute.

Co-authored-by: Jeremy Kun <jkun@google.com>
This commit is contained in:
Mehdi Amini
2024-06-21 13:23:45 +02:00
committed by GitHub
parent 739a960567
commit bc82793b30
6 changed files with 52 additions and 5 deletions

View File

@@ -1465,3 +1465,14 @@ test.dialect_custom_format_fallback custom_format_fallback
// CHECK: test.format_optional_result_d_op : f80
test.format_optional_result_d_op : f80
// -----
// This is a testing that a non-qualified attribute in a custom format
// correctly preload the dialect before creating the attribute.
#attr = #test.nested_polynomial<<1 + x**2>>
// CHECK-lABLE: @parse_correctly
llvm.func @parse_correctly() {
test.containing_int_polynomial_attr #attr
llvm.return
}

View File

@@ -16,8 +16,8 @@ mlir_tablegen(TestOpInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRTestInterfaceIncGen)
set(LLVM_TARGET_DEFINITIONS TestOps.td)
mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=test)
mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=test)
add_public_tablegen_target(MLIRTestAttrDefIncGen)
set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
@@ -86,6 +86,7 @@ add_mlir_library(MLIRTestDialect
MLIRLinalgTransforms
MLIRLLVMDialect
MLIRPass
MLIRPolynomialDialect
MLIRReduce
MLIRTensorDialect
MLIRTransformUtils

View File

@@ -16,6 +16,7 @@
// To get the test dialect definition.
include "TestDialect.td"
include "TestEnumDefs.td"
include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.td"
include "mlir/Dialect/Utils/StructuredOpsUtils.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
@@ -351,4 +352,12 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
}];
}
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
let mnemonic = "nested_polynomial";
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
let assemblyFormat = [{
`<` $poly `>`
}];
}
#endif // TEST_ATTRDEFS

View File

@@ -17,6 +17,7 @@
#include <tuple>
#include "TestTraits.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"

View File

@@ -232,6 +232,11 @@ def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> {
);
}
def ContainingIntPolynomialAttrOp : TEST_Op<"containing_int_polynomial_attr"> {
let arguments = (ins NestedPolynomialAttr:$attr);
let assemblyFormat = "$attr attr-dict";
}
// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
// This tests both matching and generating float elements attributes.
def UpdateFloatElementsAttr : Pat<
@@ -2215,7 +2220,7 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
let description = [{
Reify a bound for the given index-typed value or dimension size of a shaped
value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set,
value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set,
`vscale_min` and `vscale_max` must be provided, which allows computing
a bound in terms of "vector.vscale" for a given range of vscale.
}];

View File

@@ -164,8 +164,9 @@ static const char *const parserErrorStr =
/// {2}: Code template for printing an error.
/// {3}: Name of the attribute or type.
/// {4}: C++ class of the parameter.
/// {5}: Optional code to preload the dialect for this variable.
static const char *const variableParser = R"(
// Parse variable '{0}'
// Parse variable '{0}'{5}
_result_{0} = {1};
if (::mlir::failed(_result_{0})) {{
{2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
@@ -411,9 +412,28 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
auto customParser = param.getParser();
auto parser =
customParser ? *customParser : StringRef(defaultParameterParser);
// If the variable points to a dialect specific entity (type of attribute),
// we force load the dialect now before trying to parse it.
std::string dialectLoading;
if (auto *defInit = dyn_cast<llvm::DefInit>(param.getDef())) {
auto *dialectValue = defInit->getDef()->getValue("dialect");
if (dialectValue) {
if (auto *dialectInit =
dyn_cast<llvm::DefInit>(dialectValue->getValue())) {
Dialect dialect(dialectInit->getDef());
auto cppNamespace = dialect.getCppNamespace();
std::string name = dialect.getCppClassName();
dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
cppNamespace + "::" + name + ">();")
.str();
}
}
}
os << formatv(variableParser, param.getName(),
tgfmt(parser, &ctx, param.getCppStorageType()),
tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType());
tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType(),
dialectLoading);
}
void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,