[mlir] Add OpAsmTypeInterface for pretty-print (#121187)

See
https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792
for detailed introduction.

This PR acts as the first part of it
* Add `OpAsmTypeInterface` and `getAsmName` API for deducing ASM name
from type
* Add default impl in `OpAsmOpInterface` to respect this API when
available.

The `OpAsmAttrInterface` / hooking into Alias system part should be
another PR, using a `getAlias` API.

### Discussion

* Instead of using `StringRef getAsmName()` as the API, I use `void
getAsmName(OpAsmSetNameFn)`, as returning StringRef might be unsafe
(std::string constructed inside then returned a _ref_; and this aligns
with the design of `getAsmResultNames`.
* On the result packing of an op, the current approach is that when not
all of the result types are `OpAsmTypeInterface`, then do nothing (old
default impl)

### Review 

Cc @j2kun and @Alexanderviand-intel for downstream; Cc @River707 and
@joker-eph for relevent commit history; Cc @ftynse for discourse.
This commit is contained in:
Hongren Zheng
2025-01-28 13:31:41 +08:00
committed by GitHub
parent 7f37b34d31
commit 3c64f86314
9 changed files with 120 additions and 5 deletions

View File

@@ -1,7 +1,14 @@
add_mlir_interface(OpAsmInterface)
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)

View File

@@ -109,6 +109,27 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
];
}
//===----------------------------------------------------------------------===//
// OpAsmTypeInterface
//===----------------------------------------------------------------------===//
def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
let description = [{
This interface provides hooks to interact with the AsmPrinter and AsmParser
classes.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
Get a name to use when printing a value of this type.
}],
"void", "getAsmName",
(ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
>,
];
}
//===----------------------------------------------------------------------===//
// ResourceHandleParameter
//===----------------------------------------------------------------------===//

View File

@@ -734,7 +734,7 @@ public:
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
private:
private:
template <typename IntT, typename ParseFn>
OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
ParseFn &&parseFn) {
@@ -756,7 +756,7 @@ public:
return success();
}
public:
public:
template <typename IntT>
OptionalParseResult parseOptionalInteger(IntT &result) {
return parseOptionalIntegerAndCheck(
@@ -1727,6 +1727,10 @@ public:
// Dialect OpAsm interface.
//===--------------------------------------------------------------------===//
/// A functor used to set the name of the result. See 'getAsmResultNames' below
/// for more details.
using OpAsmSetNameFn = function_ref<void(StringRef)>;
/// A functor used to set the name of the start of a result group of an
/// operation. See 'getAsmResultNames' below for more details.
using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
@@ -1820,7 +1824,8 @@ ParseResult parseDimensionList(OpAsmParser &parser,
//===--------------------------------------------------------------------===//
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
#include "mlir/IR/OpAsmInterface.h.inc"
#include "mlir/IR/OpAsmOpInterface.h.inc"
#include "mlir/IR/OpAsmTypeInterface.h.inc"
namespace llvm {
template <>

View File

@@ -125,7 +125,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
//===----------------------------------------------------------------------===//
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
#include "mlir/IR/OpAsmInterface.cpp.inc"
#include "mlir/IR/OpAsmOpInterface.cpp.inc"
#include "mlir/IR/OpAsmTypeInterface.cpp.inc"
LogicalResult
OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {

View File

@@ -0,0 +1,24 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
//===----------------------------------------------------------------------===//
// Test OpAsmOpInterface
//===----------------------------------------------------------------------===//
func.func @result_name_from_op_asm_type_interface() {
// CHECK-LABEL: @result_name_from_op_asm_type_interface
// CHECK: %op_asm_type_interface
%0 = "test.result_name_from_type"() : () -> !test.op_asm_type_interface
return
}
// -----
func.func @block_argument_name_from_op_asm_type_interface() {
// CHECK-LABEL: @block_argument_name_from_op_asm_type_interface
// CHECK: ^bb0(%op_asm_type_interface
test.block_argument_name_from_type {
^bb0(%arg0: !test.op_asm_type_interface):
"test.terminator"() : ()->()
}
return
}

View File

@@ -506,6 +506,38 @@ void CustomResultsNameOp::getAsmResultNames(
setNameFn(getResult(i), str.getValue());
}
//===----------------------------------------------------------------------===//
// ResultNameFromTypeOp
//===----------------------------------------------------------------------===//
void ResultNameFromTypeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
auto result = getResult();
auto setResultNameFn = [&](::llvm::StringRef name) {
setNameFn(result, name);
};
auto opAsmTypeInterface =
::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
opAsmTypeInterface.getAsmName(setResultNameFn);
}
//===----------------------------------------------------------------------===//
// BlockArgumentNameFromTypeOp
//===----------------------------------------------------------------------===//
void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
::mlir::Region &region, ::mlir::OpAsmSetValueNameFn setNameFn) {
for (auto &block : region) {
for (auto arg : block.getArguments()) {
if (auto opAsmTypeInterface =
::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
opAsmTypeInterface.getAsmName(setArgNameFn);
}
}
}
}
//===----------------------------------------------------------------------===//
// ResultTypeWithTraitOp
//===----------------------------------------------------------------------===//

View File

@@ -924,6 +924,21 @@ def CustomResultsNameOp
let results = (outs Variadic<AnyInteger>:$r);
}
// This is used to test OpAsmTypeInterface::getAsmName for op result name,
def ResultNameFromTypeOp
: TEST_Op<"result_name_from_type",
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let results = (outs AnyType:$r);
}
// This is used to test OpAsmTypeInterface::getAsmName for block argument,
def BlockArgumentNameFromTypeOp
: TEST_Op<"block_argument_name_from_type",
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let regions = (region AnyRegion:$body);
let assemblyFormat = "regions attr-dict-with-keyword";
}
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
// operations nested in a region under this op will drop the "test." dialect
// prefix.

View File

@@ -398,4 +398,9 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let assemblyFormat = "`<` $param `>`";
}
def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
let mnemonic = "op_asm_type_interface";
}
#endif // TEST_TYPEDEFS

View File

@@ -532,3 +532,8 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
}
printer << ">";
}
void TestTypeOpAsmTypeInterfaceType::getAsmName(
OpAsmSetNameFn setNameFn) const {
setNameFn("op_asm_type_interface");
}