[mlir][ods] Fix generation of optional custom parsers (#84821)
We need to generate `.has_value` for `OptionalParseResult`, also ensure that `auto result` doesn't conflict with `result` which is the variable name for `OperationState`.
This commit is contained in:
@@ -14,4 +14,9 @@ module @dimension_list {
|
||||
test.custom_dimension_list_attr dimension_list = ?
|
||||
// CHECK: test.custom_dimension_list_attr dimension_list = ?x?
|
||||
test.custom_dimension_list_attr dimension_list = ?x?
|
||||
|
||||
// CHECK: test.optional_custom_attr
|
||||
test.optional_custom_attr bar
|
||||
// CHECK: test.optional_custom_attr foo false
|
||||
test.optional_custom_attr foo false
|
||||
}
|
||||
|
||||
@@ -14,3 +14,8 @@ test.custom_dimension_list_attr dimension_list = -1
|
||||
// expected-error@+2 {{expected ']'}}
|
||||
// expected-error@+1 {{custom op 'test.custom_dimension_list_attr' Failed parsing dimension list.}}
|
||||
test.custom_dimension_list_attr dimension_list = [2x3]
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @below {{expected attribute value}}
|
||||
test.optional_custom_attr foo
|
||||
|
||||
@@ -499,6 +499,23 @@ void AffineScopeOp::print(OpAsmPrinter &p) {
|
||||
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test OptionalCustomAttrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static OptionalParseResult parseOptionalCustomParser(AsmParser &p,
|
||||
IntegerAttr &result) {
|
||||
if (succeeded(p.parseOptionalKeyword("foo")))
|
||||
return p.parseAttribute(result);
|
||||
return {};
|
||||
}
|
||||
|
||||
static void printOptionalCustomParser(AsmPrinter &p, Operation *,
|
||||
IntegerAttr result) {
|
||||
p << "foo ";
|
||||
p.printAttribute(result);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test removing op with inner ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -2048,6 +2048,17 @@ def CustomDimensionListAttrOp : TEST_Op<"custom_dimension_list_attr"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def OptionalCustomAttrOp : TEST_Op<"optional_custom_attr"> {
|
||||
let description = [{
|
||||
Test using a custom directive as the optional group anchor and the first
|
||||
element to parse. It is expected to return an `OptionalParseResult`.
|
||||
}];
|
||||
let arguments = (ins OptionalAttr<I1Attr>:$attr);
|
||||
let assemblyFormat = [{
|
||||
attr-dict (custom<OptionalCustomParser>($attr)^) : (`bar`)?
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test OpAsmInterface.
|
||||
|
||||
|
||||
@@ -648,7 +648,7 @@ def TypeN : TestType<"TestP"> {
|
||||
// TYPE-LABEL: TestQType::parse
|
||||
// TYPE: if (auto result = [&]() -> ::mlir::OptionalParseResult {
|
||||
// TYPE: auto odsCustomResult = parseAB(odsParser
|
||||
// TYPE: if (!odsCustomResult) return {};
|
||||
// TYPE: if (!odsCustomResult.has_value()) return {};
|
||||
// TYPE: if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();
|
||||
// TYPE: return ::mlir::success();
|
||||
// TYPE: }(); result.has_value() && ::mlir::failed(*result)) {
|
||||
|
||||
@@ -93,14 +93,14 @@ def OptionalGroupC : TestFormat_Op<[{
|
||||
}]>, Arguments<(ins DefaultValuedStrAttr<StrAttr, "default">:$a)>;
|
||||
|
||||
// CHECK-LABEL: OptionalGroupD::parse
|
||||
// CHECK: if (auto result = [&]() -> ::mlir::OptionalParseResult {
|
||||
// CHECK: if (auto optResult = [&]() -> ::mlir::OptionalParseResult {
|
||||
// CHECK: auto odsResult = parseCustom(parser, aOperand, bOperand);
|
||||
// CHECK: if (!odsResult) return {};
|
||||
// CHECK: if (!odsResult.has_value()) return {};
|
||||
// CHECK: if (::mlir::failed(*odsResult)) return ::mlir::failure();
|
||||
// CHECK: return ::mlir::success();
|
||||
// CHECK: }(); result.has_value() && ::mlir::failed(*result)) {
|
||||
// CHECK: }(); optResult.has_value() && ::mlir::failed(*optResult)) {
|
||||
// CHECK: return ::mlir::failure();
|
||||
// CHECK: } else if (result.has_value()) {
|
||||
// CHECK: } else if (optResult.has_value()) {
|
||||
|
||||
// CHECK-LABEL: OptionalGroupD::print
|
||||
// CHECK-NEXT: if (((getA()) || (getB()))) {
|
||||
|
||||
@@ -622,7 +622,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
|
||||
}
|
||||
os.unindent() << ");\n";
|
||||
if (isOptional) {
|
||||
os << "if (!odsCustomResult) return {};\n";
|
||||
os << "if (!odsCustomResult.has_value()) return {};\n";
|
||||
os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n";
|
||||
} else {
|
||||
os << "if (::mlir::failed(odsCustomResult)) return {};\n";
|
||||
|
||||
@@ -1025,7 +1025,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
||||
body << ");\n";
|
||||
|
||||
if (isOptional) {
|
||||
body << " if (!odsResult) return {};\n"
|
||||
body << " if (!odsResult.has_value()) return {};\n"
|
||||
<< " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n";
|
||||
} else {
|
||||
body << " if (odsResult) return ::mlir::failure();\n";
|
||||
@@ -1285,13 +1285,13 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
|
||||
region->name);
|
||||
}
|
||||
} else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) {
|
||||
body << " if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
|
||||
body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
|
||||
genCustomDirectiveParser(custom, body, useProperties, opCppClassName,
|
||||
/*isOptional=*/true);
|
||||
body << " return ::mlir::success();\n"
|
||||
<< " }(); result.has_value() && ::mlir::failed(*result)) {\n"
|
||||
<< " }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n"
|
||||
<< " return ::mlir::failure();\n"
|
||||
<< " } else if (result.has_value()) {\n";
|
||||
<< " } else if (optResult.has_value()) {\n";
|
||||
}
|
||||
|
||||
genElementParsers(firstElement, thenElements.drop_front(),
|
||||
|
||||
Reference in New Issue
Block a user