[mlir][tblgen] Add custom parsing and printing within struct (#133939)
This PR extends the `struct` directive in tablegen to support nested
`custom` directives. Note that this assumes/verifies that that `custom`
directive has a single parameter.
This enables defining custom field parsing and printing functions if the
`struct` directive doesn't suffice. There is some existing potential
downstream usage for it:
a3c7de9242/stablehlo/dialect/StablehloOps.cpp (L3102)
This commit is contained in:
@@ -842,9 +842,9 @@ if they are not present.
|
||||
|
||||
###### `struct` Directive
|
||||
|
||||
The `struct` directive accepts a list of variables to capture and will generate
|
||||
a parser and printer for a comma-separated list of key-value pairs. If an
|
||||
optional parameter is included in the `struct`, it can be elided. The variables
|
||||
The `struct` directive accepts a list of variables or directives to capture and
|
||||
will generate a parser and printer for a comma-separated list of key-value pairs.
|
||||
If an optional parameter is included in the `struct`, it can be elided. The variables
|
||||
are printed in the order they are specified in the argument list **but can be
|
||||
parsed in any order**. For example:
|
||||
|
||||
@@ -876,6 +876,13 @@ assembly format of `` `<` struct(params) `>` `` will result in:
|
||||
The order in which the parameters are printed is the order in which they are
|
||||
declared in the attribute's or type's `parameter` list.
|
||||
|
||||
Passing `custom<Foo>($variable)` allows providing a custom printer and parser
|
||||
for the encapsulated variable. Check the
|
||||
[custom and ref directive](#custom-and-ref-directive) section for more
|
||||
information about how to define the printer and parser functions. Note that a
|
||||
custom directive within a struct directive can only encapsulate a single
|
||||
variable.
|
||||
|
||||
###### `custom` and `ref` directive
|
||||
|
||||
The `custom` directive is used to dispatch calls to user-defined printer and
|
||||
|
||||
68
mlir/test/IR/custom-struct-attr-roundtrip.mlir
Normal file
68
mlir/test/IR/custom-struct-attr-roundtrip.mlir
Normal file
@@ -0,0 +1,68 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_struct_attr_roundtrip
|
||||
func.func @test_struct_attr_roundtrip() -> () {
|
||||
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
|
||||
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
|
||||
// CHECK: attr = #test.custom_struct<type_str = "struct", value = ?, opt_value = 1>
|
||||
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = ?, opt_value = 1>} : () -> ()
|
||||
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
|
||||
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
|
||||
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
|
||||
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
|
||||
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
|
||||
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Verify all required parameters must be provided. `value` is missing.
|
||||
|
||||
// expected-error @below {{struct is missing required parameter: value}}
|
||||
"test.op"() {attr = #test.custom_struct<type_str = "struct">} : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// Verify all keywords must be provided. All missing.
|
||||
|
||||
// expected-error @below {{expected valid keyword}}
|
||||
// expected-error @below {{expected a parameter name in struct}}
|
||||
"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// Verify all keywords must be provided. `type_str` missing.
|
||||
|
||||
// expected-error @below {{expected valid keyword}}
|
||||
// expected-error @below {{expected a parameter name in struct}}
|
||||
"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// Verify all keywords must be provided. `value` missing.
|
||||
|
||||
// expected-error @below {{expected valid keyword}}
|
||||
// expected-error @below {{expected a parameter name in struct}}
|
||||
"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// Verify invalid keyword provided.
|
||||
|
||||
// expected-error @below {{duplicate or unknown struct parameter name: type_str2}}
|
||||
"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// Verify duplicated keyword provided.
|
||||
|
||||
// expected-error @below {{duplicate or unknown struct parameter name: type_str}}
|
||||
"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// Verify equals missing.
|
||||
|
||||
// expected-error @below {{expected '='}}
|
||||
"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()
|
||||
@@ -369,6 +369,16 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
|
||||
}];
|
||||
}
|
||||
|
||||
// Test `struct` with nested `custom` assembly format.
|
||||
def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
|
||||
let mnemonic = "custom_struct";
|
||||
let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
|
||||
OptionalParameter<"mlir::ArrayAttr">:$opt_value);
|
||||
let assemblyFormat = [{
|
||||
`<` struct($type_str, custom<CustomStructAttr>($value), custom<CustomOptStructFieldAttr>($opt_value)) `>`
|
||||
}];
|
||||
}
|
||||
|
||||
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
|
||||
let mnemonic = "nested_polynomial";
|
||||
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
|
||||
|
||||
@@ -316,6 +316,49 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestCustomStructAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printCustomStructAttr(AsmPrinter &p, int64_t value) {
|
||||
if (ShapedType::isDynamic(value)) {
|
||||
p << "?";
|
||||
} else {
|
||||
p.printStrippedAttrOrType(value);
|
||||
}
|
||||
}
|
||||
|
||||
static ParseResult parseCustomStructAttr(AsmParser &p, int64_t &value) {
|
||||
if (succeeded(p.parseOptionalQuestion())) {
|
||||
value = ShapedType::kDynamic;
|
||||
return success();
|
||||
}
|
||||
return p.parseInteger(value);
|
||||
}
|
||||
|
||||
static void printCustomOptStructFieldAttr(AsmPrinter &p, ArrayAttr attr) {
|
||||
if (attr && attr.size() == 1 && isa<IntegerAttr>(attr[0])) {
|
||||
p << cast<IntegerAttr>(attr[0]).getInt();
|
||||
} else {
|
||||
p.printStrippedAttrOrType(attr);
|
||||
}
|
||||
}
|
||||
|
||||
static ParseResult parseCustomOptStructFieldAttr(AsmParser &p,
|
||||
ArrayAttr &attr) {
|
||||
int64_t value;
|
||||
OptionalParseResult result = p.parseOptionalInteger(value);
|
||||
if (result.has_value()) {
|
||||
if (failed(result.value()))
|
||||
return failure();
|
||||
attr = ArrayAttr::get(
|
||||
p.getContext(),
|
||||
{IntegerAttr::get(IntegerType::get(p.getContext(), 64), value)});
|
||||
return success();
|
||||
}
|
||||
return p.parseAttribute(attr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestOpAsmAttrInterfaceAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -37,14 +37,14 @@ def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
|
||||
def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
|
||||
let parameters = (ins "int":$v0);
|
||||
// CHECK: literals may only be used in the top-level section of the format
|
||||
// CHECK: expected a variable in `struct` argument list
|
||||
// CHECK: expected a parameter or `custom` directive in `struct` argument list
|
||||
let assemblyFormat = "`<` struct($v0, `,`) `>`";
|
||||
}
|
||||
|
||||
// Test struct directive cannot capture zero parameters.
|
||||
def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
|
||||
let parameters = (ins "int":$v0);
|
||||
// CHECK: `struct` argument list expected a variable or directive
|
||||
// CHECK: `struct` argument list expected a parameter or directive
|
||||
let assemblyFormat = "`<` struct() $v0 `>`";
|
||||
}
|
||||
|
||||
@@ -144,3 +144,24 @@ def InvalidTypeT : InvalidType<"InvalidTypeT", "invalid_t"> {
|
||||
// CHECK: `custom` directive with no bound parameters cannot be used as optional group anchor
|
||||
let assemblyFormat = "$a (`(` custom<Foo>(ref($a))^ `)`)?";
|
||||
}
|
||||
|
||||
// Test `struct` with nested `custom` directive with multiple fields.
|
||||
def InvalidTypeU : InvalidType<"InvalidTypeU", "invalid_u"> {
|
||||
let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
|
||||
// CHECK: `struct` can only contain `custom` directives with a single argument
|
||||
let assemblyFormat = "struct(custom<Foo>($a, $b))";
|
||||
}
|
||||
|
||||
// Test `struct` with nested `custom` directive invalid parameter.
|
||||
def InvalidTypeV : InvalidType<"InvalidTypeV", "invalid_v"> {
|
||||
let parameters = (ins OptionalParameter<"int">:$a);
|
||||
// CHECK: a `custom` directive nested within a `struct` must be passed a parameter
|
||||
let assemblyFormat = "struct($a, custom<Foo>(ref($a)))";
|
||||
}
|
||||
|
||||
// Test `custom` with nested `custom` directive invalid parameter.
|
||||
def InvalidTypeW : InvalidType<"InvalidTypeV", "invalid_v"> {
|
||||
let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
|
||||
// CHECK: `custom` can only be used at the top-level context or within a `struct` directive
|
||||
let assemblyFormat = "custom<Foo>($a, custom<Bar>($b))";
|
||||
}
|
||||
|
||||
@@ -736,6 +736,84 @@ def TypeS : TestType<"TestS"> {
|
||||
let assemblyFormat = "$a";
|
||||
}
|
||||
|
||||
/// Test that a `struct` with nested `custom` parser and printer are generated correctly.
|
||||
|
||||
// ATTR: ::mlir::Attribute TestTAttr::parse(::mlir::AsmParser &odsParser,
|
||||
// ATTR: ::mlir::Type odsType) {
|
||||
// ATTR: bool _seen_v0 = false;
|
||||
// ATTR: bool _seen_v1 = false;
|
||||
// ATTR: bool _seen_v2 = false;
|
||||
// ATTR: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
|
||||
// ATTR: if (odsParser.parseEqual())
|
||||
// ATTR: return {};
|
||||
// ATTR: if (!_seen_v0 && _paramKey == "v0") {
|
||||
// ATTR: _seen_v0 = true;
|
||||
// ATTR: _result_v0 = ::parseAttrParamA(odsParser, odsType);
|
||||
// ATTR: if (::mlir::failed(_result_v0))
|
||||
// ATTR: return {};
|
||||
// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
|
||||
// ATTR: _seen_v1 = true;
|
||||
// ATTR: {
|
||||
// ATTR: auto odsCustomResult = parseNestedCustom(odsParser,
|
||||
// ATTR-NEXT: ::mlir::detail::unwrapForCustomParse(_result_v1));
|
||||
// ATTR: if (::mlir::failed(odsCustomResult)) return {};
|
||||
// ATTR: if (::mlir::failed(_result_v1)) {
|
||||
// ATTR: odsParser.emitError(odsCustomLoc, "custom parser failed to parse parameter 'v1'");
|
||||
// ATTR: return {};
|
||||
// ATTR: }
|
||||
// ATTR: }
|
||||
// ATTR: } else if (!_seen_v2 && _paramKey == "v2") {
|
||||
// ATTR: _seen_v2 = true;
|
||||
// ATTR: _result_v2 = ::mlir::FieldParser<AttrParamB>::parse(odsParser);
|
||||
// ATTR: if (::mlir::failed(_result_v2)) {
|
||||
// ATTR: odsParser.emitError(odsParser.getCurrentLocation(), "failed to parse AttrT parameter 'v2' which is to be a `AttrParamB`");
|
||||
// ATTR: return {};
|
||||
// ATTR: }
|
||||
// ATTR: } else {
|
||||
// ATTR: return {};
|
||||
// ATTR: }
|
||||
// ATTR: return true;
|
||||
// ATTR: }
|
||||
// ATTR: do {
|
||||
// ATTR: ::llvm::StringRef _paramKey;
|
||||
// ATTR: if (odsParser.parseKeyword(&_paramKey)) {
|
||||
// ATTR: odsParser.emitError(odsParser.getCurrentLocation(),
|
||||
// ATTR-NEXT: "expected a parameter name in struct");
|
||||
// ATTR: return {};
|
||||
// ATTR: }
|
||||
// ATTR: if (!_loop_body(_paramKey)) return {};
|
||||
// ATTR: } while(!odsParser.parseOptionalComma());
|
||||
// ATTR: if (!_seen_v0)
|
||||
// ATTR: if (!_seen_v1)
|
||||
// ATTR: return TestTAttr::get(odsParser.getContext(),
|
||||
// ATTR: TestParamA((*_result_v0)),
|
||||
// ATTR: TestParamB((*_result_v1)),
|
||||
// ATTR: AttrParamB((_result_v2.value_or(AttrParamB()))));
|
||||
// ATTR: }
|
||||
|
||||
// ATTR: void TestTAttr::print(::mlir::AsmPrinter &odsPrinter) const {
|
||||
// ATTR: odsPrinter << "v0 = ";
|
||||
// ATTR: ::printAttrParamA(odsPrinter, getV0());
|
||||
// ATTR: odsPrinter << ", ";
|
||||
// ATTR: odsPrinter << "v1 = ";
|
||||
// ATTR: printNestedCustom(odsPrinter,
|
||||
// ATTR-NEXT: getV1());
|
||||
// ATTR: if (!(getV2() == AttrParamB())) {
|
||||
// ATTR: odsPrinter << "v2 = ";
|
||||
// ATTR: odsPrinter.printStrippedAttrOrType(getV2());
|
||||
// ATTR: }
|
||||
|
||||
def AttrT : TestAttr<"TestT"> {
|
||||
let parameters = (ins
|
||||
AttrParamA:$v0,
|
||||
AttrParamB:$v1,
|
||||
OptionalParameter<"AttrParamB">:$v2
|
||||
);
|
||||
|
||||
let mnemonic = "attr_t";
|
||||
let assemblyFormat = "`{` struct($v0, custom<NestedCustom>($v1), $v2) `}`";
|
||||
}
|
||||
|
||||
// DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
|
||||
// DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
|
||||
// DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "llvm/ADT/SmallVectorExtras.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
@@ -75,37 +76,35 @@ private:
|
||||
AttrOrTypeParameter param;
|
||||
};
|
||||
|
||||
/// Utility to return the encapsulated parameter element for the provided format
|
||||
/// element. This parameter can originate from either a `ParameterElement`,
|
||||
/// `CustomDirective` with a single parameter argument or `RefDirective`.
|
||||
static ParameterElement *getEncapsulatedParameterElement(FormatElement *el) {
|
||||
return TypeSwitch<FormatElement *, ParameterElement *>(el)
|
||||
.Case<CustomDirective>([&](auto custom) {
|
||||
FailureOr<ParameterElement *> maybeParam =
|
||||
custom->template getFrontAs<ParameterElement>();
|
||||
return *maybeParam;
|
||||
})
|
||||
.Case<ParameterElement>([&](auto param) { return param; })
|
||||
.Case<RefDirective>(
|
||||
[&](auto ref) { return cast<ParameterElement>(ref->getArg()); })
|
||||
.Default([&](auto el) {
|
||||
assert(false && "unexpected struct element type");
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
/// Shorthand functions that can be used with ranged-based conditions.
|
||||
static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
|
||||
static bool formatIsOptional(FormatElement *el) {
|
||||
ParameterElement *param = getEncapsulatedParameterElement(el);
|
||||
return param != nullptr && param->isOptional();
|
||||
}
|
||||
static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
|
||||
|
||||
/// Base class for a directive that contains references to multiple variables.
|
||||
template <DirectiveElement::Kind DirectiveKind>
|
||||
class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
|
||||
public:
|
||||
using Base = ParamsDirectiveBase<DirectiveKind>;
|
||||
|
||||
ParamsDirectiveBase(std::vector<ParameterElement *> &¶ms)
|
||||
: params(std::move(params)) {}
|
||||
|
||||
/// Get the parameters contained in this directive.
|
||||
ArrayRef<ParameterElement *> getParams() const { return params; }
|
||||
|
||||
/// Get the number of parameters.
|
||||
unsigned getNumParams() const { return params.size(); }
|
||||
|
||||
/// Take all of the parameters from this directive.
|
||||
std::vector<ParameterElement *> takeParams() { return std::move(params); }
|
||||
|
||||
/// Returns true if there are optional parameters present.
|
||||
bool hasOptionalParams() const {
|
||||
return llvm::any_of(getParams(), paramIsOptional);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The parameters captured by this directive.
|
||||
std::vector<ParameterElement *> params;
|
||||
};
|
||||
static bool formatNotOptional(FormatElement *el) {
|
||||
return !formatIsOptional(el);
|
||||
}
|
||||
|
||||
/// This class represents a `params` directive that refers to all parameters
|
||||
/// of an attribute or type. When used as a top-level directive, it generates
|
||||
@@ -116,9 +115,15 @@ private:
|
||||
/// When used as an argument to another directive that accepts variables,
|
||||
/// `params` can be used in place of manually listing all parameters of an
|
||||
/// attribute or type.
|
||||
class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
|
||||
class ParamsDirective
|
||||
: public VectorDirectiveBase<DirectiveElement::Params, ParameterElement *> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Returns true if there are optional parameters present.
|
||||
bool hasOptionalElements() const {
|
||||
return llvm::any_of(getElements(), paramIsOptional);
|
||||
}
|
||||
};
|
||||
|
||||
/// This class represents a `struct` directive that generates a struct format
|
||||
@@ -126,9 +131,15 @@ public:
|
||||
///
|
||||
/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
|
||||
///
|
||||
class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
|
||||
class StructDirective
|
||||
: public VectorDirectiveBase<DirectiveElement::Struct, FormatElement *> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Returns true if there are optional format elements present.
|
||||
bool hasOptionalElements() const {
|
||||
return llvm::any_of(getElements(), formatIsOptional);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -214,10 +225,10 @@ private:
|
||||
/// Generate the printer code for a variable.
|
||||
void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
|
||||
bool skipGuard = false);
|
||||
/// Generate a printer for comma-separated parameters.
|
||||
void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
|
||||
/// Generate a printer for comma-separated format elements.
|
||||
void genCommaSeparatedPrinter(ArrayRef<FormatElement *> params,
|
||||
FmtContext &ctx, MethodBody &os,
|
||||
function_ref<void(ParameterElement *)> extra);
|
||||
function_ref<void(FormatElement *)> extra);
|
||||
/// Generate the printer code for a `params` directive.
|
||||
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
|
||||
/// Generate the printer code for a `struct` directive.
|
||||
@@ -443,14 +454,14 @@ void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
|
||||
|
||||
// If there are optional parameters, we need to switch to `parseOptionalComma`
|
||||
// if there are no more required parameters after a certain point.
|
||||
bool hasOptional = el->hasOptionalParams();
|
||||
bool hasOptional = el->hasOptionalElements();
|
||||
if (hasOptional) {
|
||||
// Wrap everything in a do-while so that we can `break`.
|
||||
os << "do {\n";
|
||||
os.indent();
|
||||
}
|
||||
|
||||
ArrayRef<ParameterElement *> params = el->getParams();
|
||||
ArrayRef<ParameterElement *> params = el->getElements();
|
||||
using IteratorT = ParameterElement *const *;
|
||||
IteratorT it = params.begin();
|
||||
|
||||
@@ -551,22 +562,31 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
|
||||
while (!$_parser.parseOptionalComma()) {
|
||||
)";
|
||||
|
||||
const char *const checkParamKey = R"(
|
||||
if (!_seen_{0} && _paramKey == "{0}") {
|
||||
_seen_{0} = true;
|
||||
)";
|
||||
|
||||
os << "// Parse parameter struct\n";
|
||||
|
||||
// Declare a "seen" variable for each key.
|
||||
for (ParameterElement *param : el->getParams())
|
||||
for (FormatElement *arg : el->getElements()) {
|
||||
ParameterElement *param = getEncapsulatedParameterElement(arg);
|
||||
os << formatv("bool _seen_{0} = false;\n", param->getName());
|
||||
}
|
||||
|
||||
// Generate the body of the parsing loop inside a lambda.
|
||||
os << "{\n";
|
||||
os.indent()
|
||||
<< "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
|
||||
genLiteralParser("=", ctx, os.indent());
|
||||
for (ParameterElement *param : el->getParams()) {
|
||||
os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
|
||||
" _seen_{0} = true;\n",
|
||||
param->getName());
|
||||
genVariableParser(param, ctx, os.indent());
|
||||
for (FormatElement *arg : el->getElements()) {
|
||||
ParameterElement *param = getEncapsulatedParameterElement(arg);
|
||||
os.getStream().printReindented(strfmt(checkParamKey, param->getName()));
|
||||
if (auto realParam = dyn_cast<ParameterElement>(arg))
|
||||
genVariableParser(param, ctx, os.indent());
|
||||
else if (auto custom = dyn_cast<CustomDirective>(arg))
|
||||
genCustomParser(custom, ctx, os.indent());
|
||||
os.unindent() << "} else ";
|
||||
// Print the check for duplicate or unknown parameter.
|
||||
}
|
||||
@@ -576,10 +596,10 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
|
||||
|
||||
// Generate the parsing loop. If optional parameters are present, then the
|
||||
// parse loop is guarded by commas.
|
||||
unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
|
||||
unsigned numOptional = llvm::count_if(el->getElements(), formatIsOptional);
|
||||
if (numOptional) {
|
||||
// If the struct itself is optional, pull out the first iteration.
|
||||
if (numOptional == el->getNumParams()) {
|
||||
if (numOptional == el->getNumElements()) {
|
||||
os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
|
||||
os.indent();
|
||||
} else {
|
||||
@@ -587,7 +607,7 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
|
||||
}
|
||||
} else {
|
||||
os.getStream().printReindented(
|
||||
tgfmt(loopHeader, &ctx, el->getNumParams()).str());
|
||||
tgfmt(loopHeader, &ctx, el->getNumElements()).str());
|
||||
}
|
||||
os.indent();
|
||||
os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
|
||||
@@ -597,12 +617,13 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
|
||||
// all mandatory parameters have been parsed.
|
||||
// The whole struct is optional if all its parameters are optional.
|
||||
if (numOptional) {
|
||||
if (numOptional == el->getNumParams()) {
|
||||
if (numOptional == el->getNumElements()) {
|
||||
os << "}\n";
|
||||
os.unindent() << "}\n";
|
||||
} else {
|
||||
os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
|
||||
for (ParameterElement *param : el->getParams()) {
|
||||
for (FormatElement *arg : el->getElements()) {
|
||||
ParameterElement *param = getEncapsulatedParameterElement(arg);
|
||||
if (param->isOptional())
|
||||
continue;
|
||||
os.getStream().printReindented(
|
||||
@@ -614,7 +635,8 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
|
||||
// N flags, successfully exiting the loop means that all parameters have
|
||||
// been seen. `parseOptionalComma` would cause issues with any formats that
|
||||
// use "struct(...) `,`" beacuse structs aren't sounded by braces.
|
||||
os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
|
||||
os.getStream().printReindented(
|
||||
strfmt(loopTerminator, el->getNumElements()));
|
||||
}
|
||||
os.unindent() << "}\n";
|
||||
}
|
||||
@@ -631,7 +653,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
|
||||
os << "(void)odsCustomLoc;\n";
|
||||
os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
|
||||
os.indent();
|
||||
for (FormatElement *arg : el->getArguments()) {
|
||||
for (FormatElement *arg : el->getElements()) {
|
||||
os << ",\n";
|
||||
if (auto *param = dyn_cast<ParameterElement>(arg))
|
||||
os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
|
||||
@@ -648,7 +670,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
|
||||
} else {
|
||||
os << "if (::mlir::failed(odsCustomResult)) return {};\n";
|
||||
}
|
||||
for (FormatElement *arg : el->getArguments()) {
|
||||
for (FormatElement *arg : el->getElements()) {
|
||||
if (auto *param = dyn_cast<ParameterElement>(arg)) {
|
||||
if (param->isOptional())
|
||||
continue;
|
||||
@@ -689,7 +711,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
|
||||
guardOn(llvm::ArrayRef(param));
|
||||
} else if (auto *params = dyn_cast<ParamsDirective>(first)) {
|
||||
genParamsParser(params, ctx, os);
|
||||
guardOn(params->getParams());
|
||||
guardOn(params->getElements());
|
||||
} else if (auto *custom = dyn_cast<CustomDirective>(first)) {
|
||||
os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
|
||||
os.indent();
|
||||
@@ -704,7 +726,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
|
||||
} else {
|
||||
auto *strct = cast<StructDirective>(first);
|
||||
genStructParser(strct, ctx, os);
|
||||
guardOn(params->getParams());
|
||||
guardOn(params->getElements());
|
||||
}
|
||||
os.indent();
|
||||
|
||||
@@ -816,14 +838,26 @@ static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &¶ms,
|
||||
os.indent();
|
||||
}
|
||||
|
||||
/// Generate code to guard printing on the presence of any optional format
|
||||
/// elements.
|
||||
template <typename FormatElemRange>
|
||||
static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os,
|
||||
FormatElemRange &&args, bool inverted = false) {
|
||||
guardOnAny(ctx, os,
|
||||
llvm::make_filter_range(
|
||||
llvm::map_range(args, getEncapsulatedParameterElement),
|
||||
[](ParameterElement *param) { return param->isOptional(); }),
|
||||
inverted);
|
||||
}
|
||||
|
||||
void DefFormat::genCommaSeparatedPrinter(
|
||||
ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
|
||||
function_ref<void(ParameterElement *)> extra) {
|
||||
ArrayRef<FormatElement *> args, FmtContext &ctx, MethodBody &os,
|
||||
function_ref<void(FormatElement *)> extra) {
|
||||
// Emit a space if necessary, but only if the struct is present.
|
||||
if (shouldEmitSpace || !lastWasPunctuation) {
|
||||
bool allOptional = llvm::all_of(params, paramIsOptional);
|
||||
bool allOptional = llvm::all_of(args, formatIsOptional);
|
||||
if (allOptional)
|
||||
guardOnAny(ctx, os, params);
|
||||
guardOnAnyOptional(ctx, os, args);
|
||||
os << tgfmt("$_printer << ' ';\n", &ctx);
|
||||
if (allOptional)
|
||||
os.unindent() << "}\n";
|
||||
@@ -832,17 +866,21 @@ void DefFormat::genCommaSeparatedPrinter(
|
||||
// The first printed element does not need to emit a comma.
|
||||
os << "{\n";
|
||||
os.indent() << "bool _firstPrinted = true;\n";
|
||||
for (ParameterElement *param : params) {
|
||||
for (FormatElement *arg : args) {
|
||||
ParameterElement *param = getEncapsulatedParameterElement(arg);
|
||||
if (param->isOptional()) {
|
||||
param->genPrintGuard(ctx, os << "if (") << ") {\n";
|
||||
os.indent();
|
||||
}
|
||||
os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
|
||||
os << "_firstPrinted = false;\n";
|
||||
extra(param);
|
||||
extra(arg);
|
||||
shouldEmitSpace = false;
|
||||
lastWasPunctuation = true;
|
||||
genVariablePrinter(param, ctx, os);
|
||||
if (auto realParam = dyn_cast<ParameterElement>(arg))
|
||||
genVariablePrinter(realParam, ctx, os);
|
||||
else if (auto custom = dyn_cast<CustomDirective>(arg))
|
||||
genCustomPrinter(custom, ctx, os);
|
||||
if (param->isOptional())
|
||||
os.unindent() << "}\n";
|
||||
}
|
||||
@@ -851,16 +889,19 @@ void DefFormat::genCommaSeparatedPrinter(
|
||||
|
||||
void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
|
||||
MethodBody &os) {
|
||||
genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
|
||||
[&](ParameterElement *param) {});
|
||||
SmallVector<FormatElement *> args = llvm::map_to_vector(
|
||||
el->getElements(), [](ParameterElement *param) -> FormatElement * {
|
||||
return static_cast<FormatElement *>(param);
|
||||
});
|
||||
genCommaSeparatedPrinter(args, ctx, os, [&](FormatElement *param) {});
|
||||
}
|
||||
|
||||
void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
|
||||
MethodBody &os) {
|
||||
genCommaSeparatedPrinter(
|
||||
llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
|
||||
os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
|
||||
});
|
||||
genCommaSeparatedPrinter(el->getElements(), ctx, os, [&](FormatElement *arg) {
|
||||
ParameterElement *param = getEncapsulatedParameterElement(arg);
|
||||
os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
|
||||
});
|
||||
}
|
||||
|
||||
void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
|
||||
@@ -873,7 +914,7 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
|
||||
|
||||
os << tgfmt("print$0($_printer", &ctx, el->getName());
|
||||
os.indent();
|
||||
for (FormatElement *arg : el->getArguments()) {
|
||||
for (FormatElement *arg : el->getElements()) {
|
||||
os << ",\n";
|
||||
if (auto *param = dyn_cast<ParameterElement>(arg)) {
|
||||
os << param->getParam().getAccessorName() << "()";
|
||||
@@ -893,19 +934,12 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
|
||||
if (auto *param = dyn_cast<ParameterElement>(anchor)) {
|
||||
guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted());
|
||||
} else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
|
||||
guardOnAny(ctx, os, params->getParams(), el->isInverted());
|
||||
guardOnAny(ctx, os, params->getElements(), el->isInverted());
|
||||
} else if (auto *strct = dyn_cast<StructDirective>(anchor)) {
|
||||
guardOnAny(ctx, os, strct->getParams(), el->isInverted());
|
||||
guardOnAnyOptional(ctx, os, strct->getElements(), el->isInverted());
|
||||
} else {
|
||||
auto *custom = cast<CustomDirective>(anchor);
|
||||
guardOnAny(ctx, os,
|
||||
llvm::make_filter_range(
|
||||
llvm::map_range(custom->getArguments(),
|
||||
[](FormatElement *el) {
|
||||
return dyn_cast<ParameterElement>(el);
|
||||
}),
|
||||
[](ParameterElement *param) { return !!param; }),
|
||||
el->isInverted());
|
||||
guardOnAnyOptional(ctx, os, custom->getElements(), el->isInverted());
|
||||
}
|
||||
// Generate the printer for the contained elements.
|
||||
{
|
||||
@@ -960,6 +994,9 @@ protected:
|
||||
LogicalResult verifyOptionalGroupElements(SMLoc loc,
|
||||
ArrayRef<FormatElement *> elements,
|
||||
FormatElement *anchor) override;
|
||||
/// Verify the arguments to a struct directive.
|
||||
LogicalResult verifyStructArguments(SMLoc loc,
|
||||
ArrayRef<FormatElement *> arguments);
|
||||
|
||||
LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
|
||||
|
||||
@@ -1010,7 +1047,7 @@ LogicalResult DefFormatParser::verify(SMLoc loc,
|
||||
auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
|
||||
if (!structEl || !literalEl)
|
||||
continue;
|
||||
if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
|
||||
if (literalEl->getSpelling() == "," && structEl->hasOptionalElements()) {
|
||||
return emitError(loc, "`struct` directive with optional parameters "
|
||||
"cannot be followed by a comma literal");
|
||||
}
|
||||
@@ -1037,17 +1074,17 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
|
||||
"parameters in an optional group must be optional");
|
||||
}
|
||||
} else if (auto *params = dyn_cast<ParamsDirective>(el)) {
|
||||
if (llvm::any_of(params->getParams(), paramNotOptional)) {
|
||||
if (llvm::any_of(params->getElements(), paramNotOptional)) {
|
||||
return emitError(loc, "`params` directive allowed in optional group "
|
||||
"only if all parameters are optional");
|
||||
}
|
||||
} else if (auto *strct = dyn_cast<StructDirective>(el)) {
|
||||
if (llvm::any_of(strct->getParams(), paramNotOptional)) {
|
||||
if (llvm::any_of(strct->getElements(), formatNotOptional)) {
|
||||
return emitError(loc, "`struct` is only allowed in an optional group "
|
||||
"if all captured parameters are optional");
|
||||
}
|
||||
} else if (auto *custom = dyn_cast<CustomDirective>(el)) {
|
||||
for (FormatElement *el : custom->getArguments()) {
|
||||
for (FormatElement *el : custom->getElements()) {
|
||||
// If the custom argument is a variable, then it must be optional.
|
||||
if (auto *param = dyn_cast<ParameterElement>(el))
|
||||
if (!param->isOptional())
|
||||
@@ -1068,10 +1105,10 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
|
||||
// arguments is a bound parameter.
|
||||
if (auto *custom = dyn_cast<CustomDirective>(anchor)) {
|
||||
const auto *bound =
|
||||
llvm::find_if(custom->getArguments(), [](FormatElement *el) {
|
||||
llvm::find_if(custom->getElements(), [](FormatElement *el) {
|
||||
return isa<ParameterElement>(el);
|
||||
});
|
||||
if (bound == custom->getArguments().end())
|
||||
if (bound == custom->getElements().end())
|
||||
return emitError(loc, "`custom` directive with no bound parameters "
|
||||
"cannot be used as optional group anchor");
|
||||
}
|
||||
@@ -1079,6 +1116,28 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
DefFormatParser::verifyStructArguments(SMLoc loc,
|
||||
ArrayRef<FormatElement *> arguments) {
|
||||
for (FormatElement *el : arguments) {
|
||||
if (!isa<ParameterElement, CustomDirective, ParamsDirective>(el)) {
|
||||
return emitError(loc, "expected a parameter, custom directive or params "
|
||||
"directive in `struct` arguments list");
|
||||
}
|
||||
if (auto custom = dyn_cast<CustomDirective>(el)) {
|
||||
if (custom->getNumElements() != 1) {
|
||||
return emitError(loc, "`struct` can only contain `custom` directives "
|
||||
"with a single argument");
|
||||
}
|
||||
if (failed(custom->getFrontAs<ParameterElement>())) {
|
||||
return emitError(loc, "a `custom` directive nested within a `struct` "
|
||||
"must be passed a parameter");
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DefFormatParser::markQualified(SMLoc loc,
|
||||
FormatElement *element) {
|
||||
if (!isa<ParameterElement>(element))
|
||||
@@ -1172,37 +1231,45 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
|
||||
return emitError(loc, "`struct` can only be used at the top-level context");
|
||||
|
||||
if (failed(parseToken(FormatToken::l_paren,
|
||||
"expected '(' before `struct` argument list")))
|
||||
"expected '(' before `struct` argument list"))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Parse variables captured by `struct`.
|
||||
std::vector<ParameterElement *> vars;
|
||||
std::vector<FormatElement *> vars;
|
||||
|
||||
// Parse first captured parameter or a `params` directive.
|
||||
FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
|
||||
if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
|
||||
return emitError(loc,
|
||||
"`struct` argument list expected a variable or directive");
|
||||
if (failed(var) ||
|
||||
!isa<ParameterElement, ParamsDirective, CustomDirective>(*var)) {
|
||||
return emitError(
|
||||
loc, "`struct` argument list expected a parameter or directive");
|
||||
}
|
||||
if (isa<VariableElement>(*var)) {
|
||||
if (isa<ParameterElement, CustomDirective>(*var)) {
|
||||
// Parse any other parameters.
|
||||
vars.push_back(cast<ParameterElement>(*var));
|
||||
vars.push_back(*var);
|
||||
while (peekToken().is(FormatToken::comma)) {
|
||||
consumeToken();
|
||||
var = parseElement(StructDirectiveContext);
|
||||
if (failed(var) || !isa<VariableElement>(*var))
|
||||
return emitError(loc, "expected a variable in `struct` argument list");
|
||||
vars.push_back(cast<ParameterElement>(*var));
|
||||
if (failed(var) || !isa<ParameterElement, CustomDirective>(*var))
|
||||
return emitError(loc, "expected a parameter or `custom` directive in "
|
||||
"`struct` argument list");
|
||||
vars.push_back(*var);
|
||||
}
|
||||
} else {
|
||||
// `struct(params)` captures all parameters in the attribute or type.
|
||||
vars = cast<ParamsDirective>(*var)->takeParams();
|
||||
ParamsDirective *params = cast<ParamsDirective>(*var);
|
||||
vars.reserve(params->getNumElements());
|
||||
for (ParameterElement *el : params->takeElements())
|
||||
vars.push_back(cast<FormatElement>(el));
|
||||
}
|
||||
|
||||
if (failed(parseToken(FormatToken::r_paren,
|
||||
"expected ')' at the end of an argument list")))
|
||||
"expected ')' at the end of an argument list"))) {
|
||||
return failure();
|
||||
}
|
||||
if (failed(verifyStructArguments(loc, vars)))
|
||||
return failure();
|
||||
|
||||
return create<StructDirective>(std::move(vars));
|
||||
}
|
||||
|
||||
|
||||
@@ -400,8 +400,10 @@ FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
|
||||
|
||||
FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
|
||||
Context ctx) {
|
||||
if (ctx != TopLevelContext)
|
||||
return emitError(loc, "'custom' is only valid as a top-level directive");
|
||||
if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
|
||||
return emitError(loc, "`custom` can only be used at the top-level context "
|
||||
"or within a `struct` directive");
|
||||
}
|
||||
|
||||
FailureOr<FormatToken> nameTok;
|
||||
if (failed(parseToken(FormatToken::less,
|
||||
|
||||
@@ -338,29 +338,56 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Base class for a directive that contains references to elements of type `T`
|
||||
/// in a vector.
|
||||
template <DirectiveElement::Kind DirectiveKind, typename T>
|
||||
class VectorDirectiveBase : public DirectiveElementBase<DirectiveKind> {
|
||||
public:
|
||||
using Base = VectorDirectiveBase<DirectiveKind, T>;
|
||||
|
||||
VectorDirectiveBase(std::vector<T> &&elems) : elems(std::move(elems)) {}
|
||||
|
||||
/// Get the elements contained in this directive.
|
||||
ArrayRef<T> getElements() const { return elems; }
|
||||
|
||||
/// Get the number of elements.
|
||||
unsigned getNumElements() const { return elems.size(); }
|
||||
|
||||
/// Take all of the elements from this directive.
|
||||
std::vector<T> takeElements() { return std::move(elems); }
|
||||
|
||||
protected:
|
||||
/// The elements captured by this directive.
|
||||
std::vector<T> elems;
|
||||
};
|
||||
|
||||
/// This class represents a custom format directive that is implemented by the
|
||||
/// user in C++. The directive accepts a list of arguments that is passed to the
|
||||
/// C++ function.
|
||||
class CustomDirective : public DirectiveElementBase<DirectiveElement::Custom> {
|
||||
class CustomDirective
|
||||
: public VectorDirectiveBase<DirectiveElement::Custom, FormatElement *> {
|
||||
public:
|
||||
using Base::Base;
|
||||
/// Create a custom directive with a name and list of arguments.
|
||||
CustomDirective(StringRef name, std::vector<FormatElement *> &&arguments)
|
||||
: name(name), arguments(std::move(arguments)) {}
|
||||
: Base(std::move(arguments)), name(name) {}
|
||||
|
||||
/// Get the custom directive name.
|
||||
StringRef getName() const { return name; }
|
||||
|
||||
/// Get the arguments to the custom directive.
|
||||
ArrayRef<FormatElement *> getArguments() const { return arguments; }
|
||||
template <typename T>
|
||||
FailureOr<T *> getFrontAs() const {
|
||||
if (getNumElements() != 1)
|
||||
return failure();
|
||||
if (T *elem = dyn_cast<T>(getElements()[0]))
|
||||
return elem;
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
/// The name of the custom directive. The name is used to call two C++
|
||||
/// methods: `parse{name}` and `print{name}` with the given arguments.
|
||||
StringRef name;
|
||||
/// The arguments with which to call the custom functions. These are either
|
||||
/// variables (for which the functions are responsible for populating) or
|
||||
/// references to variables.
|
||||
std::vector<FormatElement *> arguments;
|
||||
};
|
||||
|
||||
/// This class represents a reference directive. This directive can be used to
|
||||
|
||||
@@ -882,7 +882,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
|
||||
}
|
||||
|
||||
} else if (auto *custom = dyn_cast<CustomDirective>(element)) {
|
||||
for (FormatElement *paramElement : custom->getArguments())
|
||||
for (FormatElement *paramElement : custom->getElements())
|
||||
genElementParserStorage(paramElement, op, body);
|
||||
|
||||
} else if (isa<OperandsDirective>(element)) {
|
||||
@@ -1037,7 +1037,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
||||
// * Add a local variable for optional operands and types. This provides a
|
||||
// better API to the user defined parser methods.
|
||||
// * Set the location of operand variables.
|
||||
for (FormatElement *param : dir->getArguments()) {
|
||||
for (FormatElement *param : dir->getElements()) {
|
||||
if (auto *operand = dyn_cast<OperandVariable>(param)) {
|
||||
auto *var = operand->getVar();
|
||||
body << " " << var->name
|
||||
@@ -1089,7 +1089,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
||||
}
|
||||
|
||||
body << " auto odsResult = parse" << dir->getName() << "(parser";
|
||||
for (FormatElement *param : dir->getArguments()) {
|
||||
for (FormatElement *param : dir->getElements()) {
|
||||
body << ", ";
|
||||
genCustomParameterParser(param, body);
|
||||
}
|
||||
@@ -1103,7 +1103,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
||||
}
|
||||
|
||||
// After parsing, add handling for any of the optional constructs.
|
||||
for (FormatElement *param : dir->getArguments()) {
|
||||
for (FormatElement *param : dir->getElements()) {
|
||||
if (auto *attr = dyn_cast<AttributeVariable>(param)) {
|
||||
const NamedAttribute *var = attr->getVar();
|
||||
if (var->attr.isOptional() || var->attr.hasDefaultValue())
|
||||
@@ -2215,7 +2215,7 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
|
||||
static void genCustomDirectivePrinter(CustomDirective *customDir,
|
||||
const Operator &op, MethodBody &body) {
|
||||
body << " print" << customDir->getName() << "(_odsPrinter, *this";
|
||||
for (FormatElement *param : customDir->getArguments()) {
|
||||
for (FormatElement *param : customDir->getElements()) {
|
||||
body << ", ";
|
||||
genCustomDirectiveParameterPrinter(param, op, body);
|
||||
}
|
||||
@@ -2359,7 +2359,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
|
||||
.Case([&](CustomDirective *ele) {
|
||||
body << '(';
|
||||
llvm::interleave(
|
||||
ele->getArguments(), body,
|
||||
ele->getElements(), body,
|
||||
[&](FormatElement *child) {
|
||||
body << '(';
|
||||
genOptionalGroupPrinterAnchor(child, op, body);
|
||||
@@ -2375,7 +2375,7 @@ void collect(FormatElement *element,
|
||||
TypeSwitch<FormatElement *>(element)
|
||||
.Case([&](VariableElement *var) { variables.emplace_back(var); })
|
||||
.Case([&](CustomDirective *ele) {
|
||||
for (FormatElement *arg : ele->getArguments())
|
||||
for (FormatElement *arg : ele->getElements())
|
||||
collect(arg, variables);
|
||||
})
|
||||
.Case([&](OptionalElement *ele) {
|
||||
@@ -3774,7 +3774,7 @@ LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
|
||||
return success();
|
||||
// Verify each child as being valid in an optional group. They are all
|
||||
// potential anchors if the custom directive was marked as one.
|
||||
for (FormatElement *child : ele->getArguments()) {
|
||||
for (FormatElement *child : ele->getElements()) {
|
||||
if (isa<RefDirective>(child))
|
||||
continue;
|
||||
if (failed(verifyOptionalGroupElement(loc, child, /*isAnchor=*/true)))
|
||||
|
||||
Reference in New Issue
Block a user