[mlir][tblgen] Fix region and successor references in custom directives (#146242)
Previously, references to regions and successors were incorrectly disallowed outside the top-level assembly form. This change enables the use of bound regions and successors as variables in custom directives.
This commit is contained in:
@@ -106,3 +106,10 @@ func.func @named_region_has_wrong_number_of_blocks() {
|
||||
test.single_no_terminator_custom_asm_op {
|
||||
"important_dont_drop"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: test.dummy_op_with_region_ref
|
||||
test.dummy_op_with_region_ref {
|
||||
^bb0:
|
||||
}
|
||||
|
||||
@@ -381,3 +381,26 @@ void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type,
|
||||
Attribute attr) {
|
||||
printer.printAttributeWithoutType(attr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CustomDirectiveDummyRegionRef
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult test::parseDummyRegionRef(OpAsmParser &parser, Region ®ion) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void test::printDummyRegionRef(OpAsmPrinter &printer, Operation *op,
|
||||
Region ®ion) { /* do nothing */ }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CustomDirectiveDummySuccessorRef
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult test::parseDummySuccessorRef(OpAsmParser &parser,
|
||||
Block *successor) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void test::printDummySuccessorRef(OpAsmPrinter &printer, Operation *op,
|
||||
Block *successor) { /* do nothing */ }
|
||||
|
||||
@@ -207,6 +207,24 @@ mlir::ParseResult parseAttrElideType(mlir::AsmParser &parser,
|
||||
void printAttrElideType(mlir::AsmPrinter &printer, mlir::Operation *op,
|
||||
mlir::TypeAttr type, mlir::Attribute attr);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CustomDirectiveDummyRegionRef
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
mlir::ParseResult parseDummyRegionRef(mlir::OpAsmParser &parser,
|
||||
mlir::Region ®ion);
|
||||
void printDummyRegionRef(mlir::OpAsmPrinter &printer, mlir::Operation *op,
|
||||
mlir::Region ®ion);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CustomDirectiveDummySuccessorRef
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
mlir::ParseResult parseDummySuccessorRef(mlir::OpAsmParser &parser,
|
||||
mlir::Block *successor);
|
||||
void printDummySuccessorRef(mlir::OpAsmPrinter &printer, mlir::Operation *op,
|
||||
mlir::Block *successor);
|
||||
|
||||
} // end namespace test
|
||||
|
||||
#endif // MLIR_TESTFORMATUTILS_H
|
||||
|
||||
@@ -3665,4 +3665,22 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test assembly format references
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TestOpWithRegionRef : TEST_Op<"dummy_op_with_region_ref", [NoTerminator]> {
|
||||
let regions = (region AnyRegion:$body);
|
||||
let assemblyFormat = [{
|
||||
$body attr-dict custom<DummyRegionRef>(ref($body))
|
||||
}];
|
||||
}
|
||||
|
||||
def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
|
||||
let successors = (successor AnySuccessor:$successor);
|
||||
let assemblyFormat = [{
|
||||
$successor attr-dict custom<DummySuccessorRef>(ref($successor))
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TEST_OPS
|
||||
|
||||
@@ -49,6 +49,19 @@ def DirectiveCustomValidD : TestFormat_Op<[{
|
||||
def DirectiveCustomValidE : TestFormat_Op<[{
|
||||
custom<MyDirective>(prop-dict) attr-dict
|
||||
}]>, Arguments<(ins UnitAttr:$flag)>;
|
||||
def DirectiveCustomValidF : TestFormat_Op<[{
|
||||
$operand custom<MyDirective>(ref($operand)) attr-dict
|
||||
}]>, Arguments<(ins Optional<I64>:$operand)>;
|
||||
def DirectiveCustomValidG : TestFormat_Op<[{
|
||||
$body custom<MyDirective>(ref($body)) attr-dict
|
||||
}]> {
|
||||
let regions = (region AnyRegion:$body);
|
||||
}
|
||||
def DirectiveCustomValidH : TestFormat_Op<[{
|
||||
$successor custom<MyDirective>(ref($successor)) attr-dict
|
||||
}]> {
|
||||
let successors = (successor AnySuccessor:$successor);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// functional-type
|
||||
|
||||
@@ -109,3 +109,23 @@ def OptionalGroupC : TestFormat_Op<[{
|
||||
def OptionalGroupD : TestFormat_Op<[{
|
||||
(custom<Custom>($a, $b)^)? attr-dict
|
||||
}], [AttrSizedOperandSegments]>, Arguments<(ins Optional<I64>:$a, Optional<I64>:$b)>;
|
||||
|
||||
// CHECK-LABEL: RegionRef::parse
|
||||
// CHECK: auto odsResult = parseCustom(parser, *bodyRegion);
|
||||
// CHECK-LABEL: RegionRef::print
|
||||
// CHECK: printCustom(_odsPrinter, *this, getBody());
|
||||
def RegionRef : TestFormat_Op<[{
|
||||
$body custom<Custom>(ref($body)) attr-dict
|
||||
}]> {
|
||||
let regions = (region AnyRegion:$body);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: SuccessorRef::parse
|
||||
// CHECK: auto odsResult = parseCustom(parser, successorSuccessor);
|
||||
// CHECK-LABEL: SuccessorRef::print
|
||||
// CHECK: printCustom(_odsPrinter, *this, getSuccessor());
|
||||
def SuccessorRef : TestFormat_Op<[{
|
||||
$successor custom<Custom>(ref($successor)) attr-dict
|
||||
}]> {
|
||||
let successors = (successor AnySuccessor:$successor);
|
||||
}
|
||||
|
||||
@@ -3376,11 +3376,13 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
|
||||
if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
|
||||
if (hasAllRegions || !seenRegions.insert(region).second)
|
||||
return emitError(loc, "region '" + name + "' is already bound");
|
||||
} else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
|
||||
} else if (ctx == RefDirectiveContext) {
|
||||
if (!seenRegions.count(region))
|
||||
return emitError(loc, "region '" + name +
|
||||
"' must be bound before it is referenced");
|
||||
} else {
|
||||
return emitError(loc, "regions can only be used at the top level");
|
||||
return emitError(loc, "regions can only be used at the top level "
|
||||
"or in a ref directive");
|
||||
}
|
||||
return create<RegionVariable>(region);
|
||||
}
|
||||
@@ -3396,11 +3398,13 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
|
||||
if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
|
||||
if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
|
||||
return emitError(loc, "successor '" + name + "' is already bound");
|
||||
} else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
|
||||
} else if (ctx == RefDirectiveContext) {
|
||||
if (!seenSuccessors.count(successor))
|
||||
return emitError(loc, "successor '" + name +
|
||||
"' must be bound before it is referenced");
|
||||
} else {
|
||||
return emitError(loc, "successors can only be used at the top level");
|
||||
return emitError(loc, "successors can only be used at the top level "
|
||||
"or in a ref directive");
|
||||
}
|
||||
|
||||
return create<SuccessorVariable>(successor);
|
||||
|
||||
Reference in New Issue
Block a user