[mlir] load dialects for non-namespaced attrs (#96242)
The mlir-translate tool calls into the parser without loading registered dependent dialects, and the parser only loads attributes if the fully-namespaced attribute is present in the textual IR. This causes parsing to break when an op has an attribute that prints/parses without the namespaced attribute. Co-authored-by: Jeremy Kun <jkun@google.com>
This commit is contained in:
@@ -1465,3 +1465,14 @@ test.dialect_custom_format_fallback custom_format_fallback
|
||||
// CHECK: test.format_optional_result_d_op : f80
|
||||
test.format_optional_result_d_op : f80
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// This is a testing that a non-qualified attribute in a custom format
|
||||
// correctly preload the dialect before creating the attribute.
|
||||
#attr = #test.nested_polynomial<<1 + x**2>>
|
||||
// CHECK-lABLE: @parse_correctly
|
||||
llvm.func @parse_correctly() {
|
||||
test.containing_int_polynomial_attr #attr
|
||||
llvm.return
|
||||
}
|
||||
|
||||
@@ -16,8 +16,8 @@ mlir_tablegen(TestOpInterfaces.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(MLIRTestInterfaceIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TestOps.td)
|
||||
mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=test)
|
||||
mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=test)
|
||||
add_public_tablegen_target(MLIRTestAttrDefIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
|
||||
@@ -86,6 +86,7 @@ add_mlir_library(MLIRTestDialect
|
||||
MLIRLinalgTransforms
|
||||
MLIRLLVMDialect
|
||||
MLIRPass
|
||||
MLIRPolynomialDialect
|
||||
MLIRReduce
|
||||
MLIRTensorDialect
|
||||
MLIRTransformUtils
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
// To get the test dialect definition.
|
||||
include "TestDialect.td"
|
||||
include "TestEnumDefs.td"
|
||||
include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.td"
|
||||
include "mlir/Dialect/Utils/StructuredOpsUtils.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
||||
@@ -351,4 +352,12 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
|
||||
let mnemonic = "nested_polynomial";
|
||||
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
|
||||
let assemblyFormat = [{
|
||||
`<` $poly `>`
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TEST_ATTRDEFS
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "TestTraits.h"
|
||||
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
|
||||
@@ -232,6 +232,11 @@ def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> {
|
||||
);
|
||||
}
|
||||
|
||||
def ContainingIntPolynomialAttrOp : TEST_Op<"containing_int_polynomial_attr"> {
|
||||
let arguments = (ins NestedPolynomialAttr:$attr);
|
||||
let assemblyFormat = "$attr attr-dict";
|
||||
}
|
||||
|
||||
// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
|
||||
// This tests both matching and generating float elements attributes.
|
||||
def UpdateFloatElementsAttr : Pat<
|
||||
@@ -2215,7 +2220,7 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
|
||||
def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
|
||||
let description = [{
|
||||
Reify a bound for the given index-typed value or dimension size of a shaped
|
||||
value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set,
|
||||
value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set,
|
||||
`vscale_min` and `vscale_max` must be provided, which allows computing
|
||||
a bound in terms of "vector.vscale" for a given range of vscale.
|
||||
}];
|
||||
|
||||
@@ -164,8 +164,9 @@ static const char *const parserErrorStr =
|
||||
/// {2}: Code template for printing an error.
|
||||
/// {3}: Name of the attribute or type.
|
||||
/// {4}: C++ class of the parameter.
|
||||
/// {5}: Optional code to preload the dialect for this variable.
|
||||
static const char *const variableParser = R"(
|
||||
// Parse variable '{0}'
|
||||
// Parse variable '{0}'{5}
|
||||
_result_{0} = {1};
|
||||
if (::mlir::failed(_result_{0})) {{
|
||||
{2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
|
||||
@@ -411,9 +412,28 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
|
||||
auto customParser = param.getParser();
|
||||
auto parser =
|
||||
customParser ? *customParser : StringRef(defaultParameterParser);
|
||||
|
||||
// If the variable points to a dialect specific entity (type of attribute),
|
||||
// we force load the dialect now before trying to parse it.
|
||||
std::string dialectLoading;
|
||||
if (auto *defInit = dyn_cast<llvm::DefInit>(param.getDef())) {
|
||||
auto *dialectValue = defInit->getDef()->getValue("dialect");
|
||||
if (dialectValue) {
|
||||
if (auto *dialectInit =
|
||||
dyn_cast<llvm::DefInit>(dialectValue->getValue())) {
|
||||
Dialect dialect(dialectInit->getDef());
|
||||
auto cppNamespace = dialect.getCppNamespace();
|
||||
std::string name = dialect.getCppClassName();
|
||||
dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
|
||||
cppNamespace + "::" + name + ">();")
|
||||
.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
os << formatv(variableParser, param.getName(),
|
||||
tgfmt(parser, &ctx, param.getCppStorageType()),
|
||||
tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType());
|
||||
tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType(),
|
||||
dialectLoading);
|
||||
}
|
||||
|
||||
void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
|
||||
|
||||
Reference in New Issue
Block a user