[mlir][tblgen] Add custom parsing and printing within struct (#133939)

This PR extends the `struct` directive in tablegen to support nested
`custom` directives. Note that this assumes/verifies that that `custom`
directive has a single parameter.

This enables defining custom field parsing and printing functions if the
`struct` directive doesn't suffice. There is some existing potential
downstream usage for it:
a3c7de9242/stablehlo/dialect/StablehloOps.cpp (L3102)
This commit is contained in:
Jorn Tuyls
2025-04-30 14:43:03 +02:00
committed by GitHub
parent 3b12bac6d1
commit de6d010f4e
10 changed files with 440 additions and 117 deletions

View File

@@ -842,9 +842,9 @@ if they are not present.
###### `struct` Directive
The `struct` directive accepts a list of variables to capture and will generate
a parser and printer for a comma-separated list of key-value pairs. If an
optional parameter is included in the `struct`, it can be elided. The variables
The `struct` directive accepts a list of variables or directives to capture and
will generate a parser and printer for a comma-separated list of key-value pairs.
If an optional parameter is included in the `struct`, it can be elided. The variables
are printed in the order they are specified in the argument list **but can be
parsed in any order**. For example:
@@ -876,6 +876,13 @@ assembly format of `` `<` struct(params) `>` `` will result in:
The order in which the parameters are printed is the order in which they are
declared in the attribute's or type's `parameter` list.
Passing `custom<Foo>($variable)` allows providing a custom printer and parser
for the encapsulated variable. Check the
[custom and ref directive](#custom-and-ref-directive) section for more
information about how to define the printer and parser functions. Note that a
custom directive within a struct directive can only encapsulate a single
variable.
###### `custom` and `ref` directive
The `custom` directive is used to dispatch calls to user-defined printer and

View File

@@ -0,0 +1,68 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
// CHECK-LABEL: @test_struct_attr_roundtrip
func.func @test_struct_attr_roundtrip() -> () {
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
// CHECK: attr = #test.custom_struct<type_str = "struct", value = ?, opt_value = 1>
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = ?, opt_value = 1>} : () -> ()
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
return
}
// -----
// Verify all required parameters must be provided. `value` is missing.
// expected-error @below {{struct is missing required parameter: value}}
"test.op"() {attr = #test.custom_struct<type_str = "struct">} : () -> ()
// -----
// Verify all keywords must be provided. All missing.
// expected-error @below {{expected valid keyword}}
// expected-error @below {{expected a parameter name in struct}}
"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
// -----
// Verify all keywords must be provided. `type_str` missing.
// expected-error @below {{expected valid keyword}}
// expected-error @below {{expected a parameter name in struct}}
"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
// -----
// Verify all keywords must be provided. `value` missing.
// expected-error @below {{expected valid keyword}}
// expected-error @below {{expected a parameter name in struct}}
"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
// -----
// Verify invalid keyword provided.
// expected-error @below {{duplicate or unknown struct parameter name: type_str2}}
"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
// -----
// Verify duplicated keyword provided.
// expected-error @below {{duplicate or unknown struct parameter name: type_str}}
"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
// -----
// Verify equals missing.
// expected-error @below {{expected '='}}
"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()

View File

@@ -369,6 +369,16 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
}];
}
// Test `struct` with nested `custom` assembly format.
def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
let mnemonic = "custom_struct";
let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
OptionalParameter<"mlir::ArrayAttr">:$opt_value);
let assemblyFormat = [{
`<` struct($type_str, custom<CustomStructAttr>($value), custom<CustomOptStructFieldAttr>($opt_value)) `>`
}];
}
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
let mnemonic = "nested_polynomial";
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);

View File

@@ -316,6 +316,49 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
return success();
}
//===----------------------------------------------------------------------===//
// TestCustomStructAttr
//===----------------------------------------------------------------------===//
static void printCustomStructAttr(AsmPrinter &p, int64_t value) {
if (ShapedType::isDynamic(value)) {
p << "?";
} else {
p.printStrippedAttrOrType(value);
}
}
static ParseResult parseCustomStructAttr(AsmParser &p, int64_t &value) {
if (succeeded(p.parseOptionalQuestion())) {
value = ShapedType::kDynamic;
return success();
}
return p.parseInteger(value);
}
static void printCustomOptStructFieldAttr(AsmPrinter &p, ArrayAttr attr) {
if (attr && attr.size() == 1 && isa<IntegerAttr>(attr[0])) {
p << cast<IntegerAttr>(attr[0]).getInt();
} else {
p.printStrippedAttrOrType(attr);
}
}
static ParseResult parseCustomOptStructFieldAttr(AsmParser &p,
ArrayAttr &attr) {
int64_t value;
OptionalParseResult result = p.parseOptionalInteger(value);
if (result.has_value()) {
if (failed(result.value()))
return failure();
attr = ArrayAttr::get(
p.getContext(),
{IntegerAttr::get(IntegerType::get(p.getContext(), 64), value)});
return success();
}
return p.parseAttribute(attr);
}
//===----------------------------------------------------------------------===//
// TestOpAsmAttrInterfaceAttr
//===----------------------------------------------------------------------===//

View File

@@ -37,14 +37,14 @@ def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
let parameters = (ins "int":$v0);
// CHECK: literals may only be used in the top-level section of the format
// CHECK: expected a variable in `struct` argument list
// CHECK: expected a parameter or `custom` directive in `struct` argument list
let assemblyFormat = "`<` struct($v0, `,`) `>`";
}
// Test struct directive cannot capture zero parameters.
def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
let parameters = (ins "int":$v0);
// CHECK: `struct` argument list expected a variable or directive
// CHECK: `struct` argument list expected a parameter or directive
let assemblyFormat = "`<` struct() $v0 `>`";
}
@@ -144,3 +144,24 @@ def InvalidTypeT : InvalidType<"InvalidTypeT", "invalid_t"> {
// CHECK: `custom` directive with no bound parameters cannot be used as optional group anchor
let assemblyFormat = "$a (`(` custom<Foo>(ref($a))^ `)`)?";
}
// Test `struct` with nested `custom` directive with multiple fields.
def InvalidTypeU : InvalidType<"InvalidTypeU", "invalid_u"> {
let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
// CHECK: `struct` can only contain `custom` directives with a single argument
let assemblyFormat = "struct(custom<Foo>($a, $b))";
}
// Test `struct` with nested `custom` directive invalid parameter.
def InvalidTypeV : InvalidType<"InvalidTypeV", "invalid_v"> {
let parameters = (ins OptionalParameter<"int">:$a);
// CHECK: a `custom` directive nested within a `struct` must be passed a parameter
let assemblyFormat = "struct($a, custom<Foo>(ref($a)))";
}
// Test `custom` with nested `custom` directive invalid parameter.
def InvalidTypeW : InvalidType<"InvalidTypeV", "invalid_v"> {
let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
// CHECK: `custom` can only be used at the top-level context or within a `struct` directive
let assemblyFormat = "custom<Foo>($a, custom<Bar>($b))";
}

View File

@@ -736,6 +736,84 @@ def TypeS : TestType<"TestS"> {
let assemblyFormat = "$a";
}
/// Test that a `struct` with nested `custom` parser and printer are generated correctly.
// ATTR: ::mlir::Attribute TestTAttr::parse(::mlir::AsmParser &odsParser,
// ATTR: ::mlir::Type odsType) {
// ATTR: bool _seen_v0 = false;
// ATTR: bool _seen_v1 = false;
// ATTR: bool _seen_v2 = false;
// ATTR: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
// ATTR: if (odsParser.parseEqual())
// ATTR: return {};
// ATTR: if (!_seen_v0 && _paramKey == "v0") {
// ATTR: _seen_v0 = true;
// ATTR: _result_v0 = ::parseAttrParamA(odsParser, odsType);
// ATTR: if (::mlir::failed(_result_v0))
// ATTR: return {};
// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
// ATTR: _seen_v1 = true;
// ATTR: {
// ATTR: auto odsCustomResult = parseNestedCustom(odsParser,
// ATTR-NEXT: ::mlir::detail::unwrapForCustomParse(_result_v1));
// ATTR: if (::mlir::failed(odsCustomResult)) return {};
// ATTR: if (::mlir::failed(_result_v1)) {
// ATTR: odsParser.emitError(odsCustomLoc, "custom parser failed to parse parameter 'v1'");
// ATTR: return {};
// ATTR: }
// ATTR: }
// ATTR: } else if (!_seen_v2 && _paramKey == "v2") {
// ATTR: _seen_v2 = true;
// ATTR: _result_v2 = ::mlir::FieldParser<AttrParamB>::parse(odsParser);
// ATTR: if (::mlir::failed(_result_v2)) {
// ATTR: odsParser.emitError(odsParser.getCurrentLocation(), "failed to parse AttrT parameter 'v2' which is to be a `AttrParamB`");
// ATTR: return {};
// ATTR: }
// ATTR: } else {
// ATTR: return {};
// ATTR: }
// ATTR: return true;
// ATTR: }
// ATTR: do {
// ATTR: ::llvm::StringRef _paramKey;
// ATTR: if (odsParser.parseKeyword(&_paramKey)) {
// ATTR: odsParser.emitError(odsParser.getCurrentLocation(),
// ATTR-NEXT: "expected a parameter name in struct");
// ATTR: return {};
// ATTR: }
// ATTR: if (!_loop_body(_paramKey)) return {};
// ATTR: } while(!odsParser.parseOptionalComma());
// ATTR: if (!_seen_v0)
// ATTR: if (!_seen_v1)
// ATTR: return TestTAttr::get(odsParser.getContext(),
// ATTR: TestParamA((*_result_v0)),
// ATTR: TestParamB((*_result_v1)),
// ATTR: AttrParamB((_result_v2.value_or(AttrParamB()))));
// ATTR: }
// ATTR: void TestTAttr::print(::mlir::AsmPrinter &odsPrinter) const {
// ATTR: odsPrinter << "v0 = ";
// ATTR: ::printAttrParamA(odsPrinter, getV0());
// ATTR: odsPrinter << ", ";
// ATTR: odsPrinter << "v1 = ";
// ATTR: printNestedCustom(odsPrinter,
// ATTR-NEXT: getV1());
// ATTR: if (!(getV2() == AttrParamB())) {
// ATTR: odsPrinter << "v2 = ";
// ATTR: odsPrinter.printStrippedAttrOrType(getV2());
// ATTR: }
def AttrT : TestAttr<"TestT"> {
let parameters = (ins
AttrParamA:$v0,
AttrParamB:$v1,
OptionalParameter<"AttrParamB">:$v2
);
let mnemonic = "attr_t";
let assemblyFormat = "`{` struct($v0, custom<NestedCustom>($v1), $v2) `}`";
}
// DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
// DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
// DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {

View File

@@ -13,6 +13,7 @@
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -75,37 +76,35 @@ private:
AttrOrTypeParameter param;
};
/// Utility to return the encapsulated parameter element for the provided format
/// element. This parameter can originate from either a `ParameterElement`,
/// `CustomDirective` with a single parameter argument or `RefDirective`.
static ParameterElement *getEncapsulatedParameterElement(FormatElement *el) {
return TypeSwitch<FormatElement *, ParameterElement *>(el)
.Case<CustomDirective>([&](auto custom) {
FailureOr<ParameterElement *> maybeParam =
custom->template getFrontAs<ParameterElement>();
return *maybeParam;
})
.Case<ParameterElement>([&](auto param) { return param; })
.Case<RefDirective>(
[&](auto ref) { return cast<ParameterElement>(ref->getArg()); })
.Default([&](auto el) {
assert(false && "unexpected struct element type");
return nullptr;
});
}
/// Shorthand functions that can be used with ranged-based conditions.
static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
static bool formatIsOptional(FormatElement *el) {
ParameterElement *param = getEncapsulatedParameterElement(el);
return param != nullptr && param->isOptional();
}
static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
/// Base class for a directive that contains references to multiple variables.
template <DirectiveElement::Kind DirectiveKind>
class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
public:
using Base = ParamsDirectiveBase<DirectiveKind>;
ParamsDirectiveBase(std::vector<ParameterElement *> &&params)
: params(std::move(params)) {}
/// Get the parameters contained in this directive.
ArrayRef<ParameterElement *> getParams() const { return params; }
/// Get the number of parameters.
unsigned getNumParams() const { return params.size(); }
/// Take all of the parameters from this directive.
std::vector<ParameterElement *> takeParams() { return std::move(params); }
/// Returns true if there are optional parameters present.
bool hasOptionalParams() const {
return llvm::any_of(getParams(), paramIsOptional);
}
private:
/// The parameters captured by this directive.
std::vector<ParameterElement *> params;
};
static bool formatNotOptional(FormatElement *el) {
return !formatIsOptional(el);
}
/// This class represents a `params` directive that refers to all parameters
/// of an attribute or type. When used as a top-level directive, it generates
@@ -116,9 +115,15 @@ private:
/// When used as an argument to another directive that accepts variables,
/// `params` can be used in place of manually listing all parameters of an
/// attribute or type.
class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
class ParamsDirective
: public VectorDirectiveBase<DirectiveElement::Params, ParameterElement *> {
public:
using Base::Base;
/// Returns true if there are optional parameters present.
bool hasOptionalElements() const {
return llvm::any_of(getElements(), paramIsOptional);
}
};
/// This class represents a `struct` directive that generates a struct format
@@ -126,9 +131,15 @@ public:
///
/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
///
class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
class StructDirective
: public VectorDirectiveBase<DirectiveElement::Struct, FormatElement *> {
public:
using Base::Base;
/// Returns true if there are optional format elements present.
bool hasOptionalElements() const {
return llvm::any_of(getElements(), formatIsOptional);
}
};
} // namespace
@@ -214,10 +225,10 @@ private:
/// Generate the printer code for a variable.
void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
bool skipGuard = false);
/// Generate a printer for comma-separated parameters.
void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
/// Generate a printer for comma-separated format elements.
void genCommaSeparatedPrinter(ArrayRef<FormatElement *> params,
FmtContext &ctx, MethodBody &os,
function_ref<void(ParameterElement *)> extra);
function_ref<void(FormatElement *)> extra);
/// Generate the printer code for a `params` directive.
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `struct` directive.
@@ -443,14 +454,14 @@ void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
// If there are optional parameters, we need to switch to `parseOptionalComma`
// if there are no more required parameters after a certain point.
bool hasOptional = el->hasOptionalParams();
bool hasOptional = el->hasOptionalElements();
if (hasOptional) {
// Wrap everything in a do-while so that we can `break`.
os << "do {\n";
os.indent();
}
ArrayRef<ParameterElement *> params = el->getParams();
ArrayRef<ParameterElement *> params = el->getElements();
using IteratorT = ParameterElement *const *;
IteratorT it = params.begin();
@@ -551,22 +562,31 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
while (!$_parser.parseOptionalComma()) {
)";
const char *const checkParamKey = R"(
if (!_seen_{0} && _paramKey == "{0}") {
_seen_{0} = true;
)";
os << "// Parse parameter struct\n";
// Declare a "seen" variable for each key.
for (ParameterElement *param : el->getParams())
for (FormatElement *arg : el->getElements()) {
ParameterElement *param = getEncapsulatedParameterElement(arg);
os << formatv("bool _seen_{0} = false;\n", param->getName());
}
// Generate the body of the parsing loop inside a lambda.
os << "{\n";
os.indent()
<< "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
genLiteralParser("=", ctx, os.indent());
for (ParameterElement *param : el->getParams()) {
os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
" _seen_{0} = true;\n",
param->getName());
genVariableParser(param, ctx, os.indent());
for (FormatElement *arg : el->getElements()) {
ParameterElement *param = getEncapsulatedParameterElement(arg);
os.getStream().printReindented(strfmt(checkParamKey, param->getName()));
if (auto realParam = dyn_cast<ParameterElement>(arg))
genVariableParser(param, ctx, os.indent());
else if (auto custom = dyn_cast<CustomDirective>(arg))
genCustomParser(custom, ctx, os.indent());
os.unindent() << "} else ";
// Print the check for duplicate or unknown parameter.
}
@@ -576,10 +596,10 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
// Generate the parsing loop. If optional parameters are present, then the
// parse loop is guarded by commas.
unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
unsigned numOptional = llvm::count_if(el->getElements(), formatIsOptional);
if (numOptional) {
// If the struct itself is optional, pull out the first iteration.
if (numOptional == el->getNumParams()) {
if (numOptional == el->getNumElements()) {
os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
os.indent();
} else {
@@ -587,7 +607,7 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
}
} else {
os.getStream().printReindented(
tgfmt(loopHeader, &ctx, el->getNumParams()).str());
tgfmt(loopHeader, &ctx, el->getNumElements()).str());
}
os.indent();
os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
@@ -597,12 +617,13 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
// all mandatory parameters have been parsed.
// The whole struct is optional if all its parameters are optional.
if (numOptional) {
if (numOptional == el->getNumParams()) {
if (numOptional == el->getNumElements()) {
os << "}\n";
os.unindent() << "}\n";
} else {
os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
for (ParameterElement *param : el->getParams()) {
for (FormatElement *arg : el->getElements()) {
ParameterElement *param = getEncapsulatedParameterElement(arg);
if (param->isOptional())
continue;
os.getStream().printReindented(
@@ -614,7 +635,8 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
// N flags, successfully exiting the loop means that all parameters have
// been seen. `parseOptionalComma` would cause issues with any formats that
// use "struct(...) `,`" beacuse structs aren't sounded by braces.
os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
os.getStream().printReindented(
strfmt(loopTerminator, el->getNumElements()));
}
os.unindent() << "}\n";
}
@@ -631,7 +653,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
os << "(void)odsCustomLoc;\n";
os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
os.indent();
for (FormatElement *arg : el->getArguments()) {
for (FormatElement *arg : el->getElements()) {
os << ",\n";
if (auto *param = dyn_cast<ParameterElement>(arg))
os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
@@ -648,7 +670,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
} else {
os << "if (::mlir::failed(odsCustomResult)) return {};\n";
}
for (FormatElement *arg : el->getArguments()) {
for (FormatElement *arg : el->getElements()) {
if (auto *param = dyn_cast<ParameterElement>(arg)) {
if (param->isOptional())
continue;
@@ -689,7 +711,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
guardOn(llvm::ArrayRef(param));
} else if (auto *params = dyn_cast<ParamsDirective>(first)) {
genParamsParser(params, ctx, os);
guardOn(params->getParams());
guardOn(params->getElements());
} else if (auto *custom = dyn_cast<CustomDirective>(first)) {
os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
os.indent();
@@ -704,7 +726,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
} else {
auto *strct = cast<StructDirective>(first);
genStructParser(strct, ctx, os);
guardOn(params->getParams());
guardOn(params->getElements());
}
os.indent();
@@ -816,14 +838,26 @@ static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
os.indent();
}
/// Generate code to guard printing on the presence of any optional format
/// elements.
template <typename FormatElemRange>
static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os,
FormatElemRange &&args, bool inverted = false) {
guardOnAny(ctx, os,
llvm::make_filter_range(
llvm::map_range(args, getEncapsulatedParameterElement),
[](ParameterElement *param) { return param->isOptional(); }),
inverted);
}
void DefFormat::genCommaSeparatedPrinter(
ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
function_ref<void(ParameterElement *)> extra) {
ArrayRef<FormatElement *> args, FmtContext &ctx, MethodBody &os,
function_ref<void(FormatElement *)> extra) {
// Emit a space if necessary, but only if the struct is present.
if (shouldEmitSpace || !lastWasPunctuation) {
bool allOptional = llvm::all_of(params, paramIsOptional);
bool allOptional = llvm::all_of(args, formatIsOptional);
if (allOptional)
guardOnAny(ctx, os, params);
guardOnAnyOptional(ctx, os, args);
os << tgfmt("$_printer << ' ';\n", &ctx);
if (allOptional)
os.unindent() << "}\n";
@@ -832,17 +866,21 @@ void DefFormat::genCommaSeparatedPrinter(
// The first printed element does not need to emit a comma.
os << "{\n";
os.indent() << "bool _firstPrinted = true;\n";
for (ParameterElement *param : params) {
for (FormatElement *arg : args) {
ParameterElement *param = getEncapsulatedParameterElement(arg);
if (param->isOptional()) {
param->genPrintGuard(ctx, os << "if (") << ") {\n";
os.indent();
}
os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
os << "_firstPrinted = false;\n";
extra(param);
extra(arg);
shouldEmitSpace = false;
lastWasPunctuation = true;
genVariablePrinter(param, ctx, os);
if (auto realParam = dyn_cast<ParameterElement>(arg))
genVariablePrinter(realParam, ctx, os);
else if (auto custom = dyn_cast<CustomDirective>(arg))
genCustomPrinter(custom, ctx, os);
if (param->isOptional())
os.unindent() << "}\n";
}
@@ -851,16 +889,19 @@ void DefFormat::genCommaSeparatedPrinter(
void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
MethodBody &os) {
genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
[&](ParameterElement *param) {});
SmallVector<FormatElement *> args = llvm::map_to_vector(
el->getElements(), [](ParameterElement *param) -> FormatElement * {
return static_cast<FormatElement *>(param);
});
genCommaSeparatedPrinter(args, ctx, os, [&](FormatElement *param) {});
}
void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
MethodBody &os) {
genCommaSeparatedPrinter(
llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
});
genCommaSeparatedPrinter(el->getElements(), ctx, os, [&](FormatElement *arg) {
ParameterElement *param = getEncapsulatedParameterElement(arg);
os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
});
}
void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
@@ -873,7 +914,7 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
os << tgfmt("print$0($_printer", &ctx, el->getName());
os.indent();
for (FormatElement *arg : el->getArguments()) {
for (FormatElement *arg : el->getElements()) {
os << ",\n";
if (auto *param = dyn_cast<ParameterElement>(arg)) {
os << param->getParam().getAccessorName() << "()";
@@ -893,19 +934,12 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
if (auto *param = dyn_cast<ParameterElement>(anchor)) {
guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted());
} else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
guardOnAny(ctx, os, params->getParams(), el->isInverted());
guardOnAny(ctx, os, params->getElements(), el->isInverted());
} else if (auto *strct = dyn_cast<StructDirective>(anchor)) {
guardOnAny(ctx, os, strct->getParams(), el->isInverted());
guardOnAnyOptional(ctx, os, strct->getElements(), el->isInverted());
} else {
auto *custom = cast<CustomDirective>(anchor);
guardOnAny(ctx, os,
llvm::make_filter_range(
llvm::map_range(custom->getArguments(),
[](FormatElement *el) {
return dyn_cast<ParameterElement>(el);
}),
[](ParameterElement *param) { return !!param; }),
el->isInverted());
guardOnAnyOptional(ctx, os, custom->getElements(), el->isInverted());
}
// Generate the printer for the contained elements.
{
@@ -960,6 +994,9 @@ protected:
LogicalResult verifyOptionalGroupElements(SMLoc loc,
ArrayRef<FormatElement *> elements,
FormatElement *anchor) override;
/// Verify the arguments to a struct directive.
LogicalResult verifyStructArguments(SMLoc loc,
ArrayRef<FormatElement *> arguments);
LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
@@ -1010,7 +1047,7 @@ LogicalResult DefFormatParser::verify(SMLoc loc,
auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
if (!structEl || !literalEl)
continue;
if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
if (literalEl->getSpelling() == "," && structEl->hasOptionalElements()) {
return emitError(loc, "`struct` directive with optional parameters "
"cannot be followed by a comma literal");
}
@@ -1037,17 +1074,17 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
"parameters in an optional group must be optional");
}
} else if (auto *params = dyn_cast<ParamsDirective>(el)) {
if (llvm::any_of(params->getParams(), paramNotOptional)) {
if (llvm::any_of(params->getElements(), paramNotOptional)) {
return emitError(loc, "`params` directive allowed in optional group "
"only if all parameters are optional");
}
} else if (auto *strct = dyn_cast<StructDirective>(el)) {
if (llvm::any_of(strct->getParams(), paramNotOptional)) {
if (llvm::any_of(strct->getElements(), formatNotOptional)) {
return emitError(loc, "`struct` is only allowed in an optional group "
"if all captured parameters are optional");
}
} else if (auto *custom = dyn_cast<CustomDirective>(el)) {
for (FormatElement *el : custom->getArguments()) {
for (FormatElement *el : custom->getElements()) {
// If the custom argument is a variable, then it must be optional.
if (auto *param = dyn_cast<ParameterElement>(el))
if (!param->isOptional())
@@ -1068,10 +1105,10 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
// arguments is a bound parameter.
if (auto *custom = dyn_cast<CustomDirective>(anchor)) {
const auto *bound =
llvm::find_if(custom->getArguments(), [](FormatElement *el) {
llvm::find_if(custom->getElements(), [](FormatElement *el) {
return isa<ParameterElement>(el);
});
if (bound == custom->getArguments().end())
if (bound == custom->getElements().end())
return emitError(loc, "`custom` directive with no bound parameters "
"cannot be used as optional group anchor");
}
@@ -1079,6 +1116,28 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
return success();
}
LogicalResult
DefFormatParser::verifyStructArguments(SMLoc loc,
ArrayRef<FormatElement *> arguments) {
for (FormatElement *el : arguments) {
if (!isa<ParameterElement, CustomDirective, ParamsDirective>(el)) {
return emitError(loc, "expected a parameter, custom directive or params "
"directive in `struct` arguments list");
}
if (auto custom = dyn_cast<CustomDirective>(el)) {
if (custom->getNumElements() != 1) {
return emitError(loc, "`struct` can only contain `custom` directives "
"with a single argument");
}
if (failed(custom->getFrontAs<ParameterElement>())) {
return emitError(loc, "a `custom` directive nested within a `struct` "
"must be passed a parameter");
}
}
}
return success();
}
LogicalResult DefFormatParser::markQualified(SMLoc loc,
FormatElement *element) {
if (!isa<ParameterElement>(element))
@@ -1172,37 +1231,45 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
return emitError(loc, "`struct` can only be used at the top-level context");
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before `struct` argument list")))
"expected '(' before `struct` argument list"))) {
return failure();
}
// Parse variables captured by `struct`.
std::vector<ParameterElement *> vars;
std::vector<FormatElement *> vars;
// Parse first captured parameter or a `params` directive.
FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
return emitError(loc,
"`struct` argument list expected a variable or directive");
if (failed(var) ||
!isa<ParameterElement, ParamsDirective, CustomDirective>(*var)) {
return emitError(
loc, "`struct` argument list expected a parameter or directive");
}
if (isa<VariableElement>(*var)) {
if (isa<ParameterElement, CustomDirective>(*var)) {
// Parse any other parameters.
vars.push_back(cast<ParameterElement>(*var));
vars.push_back(*var);
while (peekToken().is(FormatToken::comma)) {
consumeToken();
var = parseElement(StructDirectiveContext);
if (failed(var) || !isa<VariableElement>(*var))
return emitError(loc, "expected a variable in `struct` argument list");
vars.push_back(cast<ParameterElement>(*var));
if (failed(var) || !isa<ParameterElement, CustomDirective>(*var))
return emitError(loc, "expected a parameter or `custom` directive in "
"`struct` argument list");
vars.push_back(*var);
}
} else {
// `struct(params)` captures all parameters in the attribute or type.
vars = cast<ParamsDirective>(*var)->takeParams();
ParamsDirective *params = cast<ParamsDirective>(*var);
vars.reserve(params->getNumElements());
for (ParameterElement *el : params->takeElements())
vars.push_back(cast<FormatElement>(el));
}
if (failed(parseToken(FormatToken::r_paren,
"expected ')' at the end of an argument list")))
"expected ')' at the end of an argument list"))) {
return failure();
}
if (failed(verifyStructArguments(loc, vars)))
return failure();
return create<StructDirective>(std::move(vars));
}

View File

@@ -400,8 +400,10 @@ FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
Context ctx) {
if (ctx != TopLevelContext)
return emitError(loc, "'custom' is only valid as a top-level directive");
if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
return emitError(loc, "`custom` can only be used at the top-level context "
"or within a `struct` directive");
}
FailureOr<FormatToken> nameTok;
if (failed(parseToken(FormatToken::less,

View File

@@ -338,29 +338,56 @@ public:
}
};
/// Base class for a directive that contains references to elements of type `T`
/// in a vector.
template <DirectiveElement::Kind DirectiveKind, typename T>
class VectorDirectiveBase : public DirectiveElementBase<DirectiveKind> {
public:
using Base = VectorDirectiveBase<DirectiveKind, T>;
VectorDirectiveBase(std::vector<T> &&elems) : elems(std::move(elems)) {}
/// Get the elements contained in this directive.
ArrayRef<T> getElements() const { return elems; }
/// Get the number of elements.
unsigned getNumElements() const { return elems.size(); }
/// Take all of the elements from this directive.
std::vector<T> takeElements() { return std::move(elems); }
protected:
/// The elements captured by this directive.
std::vector<T> elems;
};
/// This class represents a custom format directive that is implemented by the
/// user in C++. The directive accepts a list of arguments that is passed to the
/// C++ function.
class CustomDirective : public DirectiveElementBase<DirectiveElement::Custom> {
class CustomDirective
: public VectorDirectiveBase<DirectiveElement::Custom, FormatElement *> {
public:
using Base::Base;
/// Create a custom directive with a name and list of arguments.
CustomDirective(StringRef name, std::vector<FormatElement *> &&arguments)
: name(name), arguments(std::move(arguments)) {}
: Base(std::move(arguments)), name(name) {}
/// Get the custom directive name.
StringRef getName() const { return name; }
/// Get the arguments to the custom directive.
ArrayRef<FormatElement *> getArguments() const { return arguments; }
template <typename T>
FailureOr<T *> getFrontAs() const {
if (getNumElements() != 1)
return failure();
if (T *elem = dyn_cast<T>(getElements()[0]))
return elem;
return failure();
}
private:
/// The name of the custom directive. The name is used to call two C++
/// methods: `parse{name}` and `print{name}` with the given arguments.
StringRef name;
/// The arguments with which to call the custom functions. These are either
/// variables (for which the functions are responsible for populating) or
/// references to variables.
std::vector<FormatElement *> arguments;
};
/// This class represents a reference directive. This directive can be used to

View File

@@ -882,7 +882,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
}
} else if (auto *custom = dyn_cast<CustomDirective>(element)) {
for (FormatElement *paramElement : custom->getArguments())
for (FormatElement *paramElement : custom->getElements())
genElementParserStorage(paramElement, op, body);
} else if (isa<OperandsDirective>(element)) {
@@ -1037,7 +1037,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
// * Add a local variable for optional operands and types. This provides a
// better API to the user defined parser methods.
// * Set the location of operand variables.
for (FormatElement *param : dir->getArguments()) {
for (FormatElement *param : dir->getElements()) {
if (auto *operand = dyn_cast<OperandVariable>(param)) {
auto *var = operand->getVar();
body << " " << var->name
@@ -1089,7 +1089,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
}
body << " auto odsResult = parse" << dir->getName() << "(parser";
for (FormatElement *param : dir->getArguments()) {
for (FormatElement *param : dir->getElements()) {
body << ", ";
genCustomParameterParser(param, body);
}
@@ -1103,7 +1103,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
}
// After parsing, add handling for any of the optional constructs.
for (FormatElement *param : dir->getArguments()) {
for (FormatElement *param : dir->getElements()) {
if (auto *attr = dyn_cast<AttributeVariable>(param)) {
const NamedAttribute *var = attr->getVar();
if (var->attr.isOptional() || var->attr.hasDefaultValue())
@@ -2215,7 +2215,7 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
static void genCustomDirectivePrinter(CustomDirective *customDir,
const Operator &op, MethodBody &body) {
body << " print" << customDir->getName() << "(_odsPrinter, *this";
for (FormatElement *param : customDir->getArguments()) {
for (FormatElement *param : customDir->getElements()) {
body << ", ";
genCustomDirectiveParameterPrinter(param, op, body);
}
@@ -2359,7 +2359,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
.Case([&](CustomDirective *ele) {
body << '(';
llvm::interleave(
ele->getArguments(), body,
ele->getElements(), body,
[&](FormatElement *child) {
body << '(';
genOptionalGroupPrinterAnchor(child, op, body);
@@ -2375,7 +2375,7 @@ void collect(FormatElement *element,
TypeSwitch<FormatElement *>(element)
.Case([&](VariableElement *var) { variables.emplace_back(var); })
.Case([&](CustomDirective *ele) {
for (FormatElement *arg : ele->getArguments())
for (FormatElement *arg : ele->getElements())
collect(arg, variables);
})
.Case([&](OptionalElement *ele) {
@@ -3774,7 +3774,7 @@ LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
return success();
// Verify each child as being valid in an optional group. They are all
// potential anchors if the custom directive was marked as one.
for (FormatElement *child : ele->getArguments()) {
for (FormatElement *child : ele->getElements()) {
if (isa<RefDirective>(child))
continue;
if (failed(verifyOptionalGroupElement(loc, child, /*isAnchor=*/true)))