[mlir] Add OpAsmTypeInterface for pretty-print (#121187)
See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction. This PR acts as the first part of it * Add `OpAsmTypeInterface` and `getAsmName` API for deducing ASM name from type * Add default impl in `OpAsmOpInterface` to respect this API when available. The `OpAsmAttrInterface` / hooking into Alias system part should be another PR, using a `getAlias` API. ### Discussion * Instead of using `StringRef getAsmName()` as the API, I use `void getAsmName(OpAsmSetNameFn)`, as returning StringRef might be unsafe (std::string constructed inside then returned a _ref_; and this aligns with the design of `getAsmResultNames`. * On the result packing of an op, the current approach is that when not all of the result types are `OpAsmTypeInterface`, then do nothing (old default impl) ### Review Cc @j2kun and @Alexanderviand-intel for downstream; Cc @River707 and @joker-eph for relevent commit history; Cc @ftynse for discourse.
This commit is contained in:
@@ -1,7 +1,14 @@
|
||||
add_mlir_interface(OpAsmInterface)
|
||||
add_mlir_interface(SymbolInterfaces)
|
||||
add_mlir_interface(RegionKindInterface)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
|
||||
mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
|
||||
mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
|
||||
mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
|
||||
add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
|
||||
add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
|
||||
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
|
||||
|
||||
@@ -109,6 +109,27 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpAsmTypeInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
|
||||
let description = [{
|
||||
This interface provides hooks to interact with the AsmPrinter and AsmParser
|
||||
classes.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<[{
|
||||
Get a name to use when printing a value of this type.
|
||||
}],
|
||||
"void", "getAsmName",
|
||||
(ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResourceHandleParameter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -734,7 +734,7 @@ public:
|
||||
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
|
||||
virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
|
||||
|
||||
private:
|
||||
private:
|
||||
template <typename IntT, typename ParseFn>
|
||||
OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
|
||||
ParseFn &&parseFn) {
|
||||
@@ -756,7 +756,7 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
public:
|
||||
public:
|
||||
template <typename IntT>
|
||||
OptionalParseResult parseOptionalInteger(IntT &result) {
|
||||
return parseOptionalIntegerAndCheck(
|
||||
@@ -1727,6 +1727,10 @@ public:
|
||||
// Dialect OpAsm interface.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// A functor used to set the name of the result. See 'getAsmResultNames' below
|
||||
/// for more details.
|
||||
using OpAsmSetNameFn = function_ref<void(StringRef)>;
|
||||
|
||||
/// A functor used to set the name of the start of a result group of an
|
||||
/// operation. See 'getAsmResultNames' below for more details.
|
||||
using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
|
||||
@@ -1820,7 +1824,8 @@ ParseResult parseDimensionList(OpAsmParser &parser,
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
|
||||
#include "mlir/IR/OpAsmInterface.h.inc"
|
||||
#include "mlir/IR/OpAsmOpInterface.h.inc"
|
||||
#include "mlir/IR/OpAsmTypeInterface.h.inc"
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
|
||||
@@ -125,7 +125,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
|
||||
#include "mlir/IR/OpAsmInterface.cpp.inc"
|
||||
#include "mlir/IR/OpAsmOpInterface.cpp.inc"
|
||||
#include "mlir/IR/OpAsmTypeInterface.cpp.inc"
|
||||
|
||||
LogicalResult
|
||||
OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
|
||||
|
||||
24
mlir/test/IR/op-asm-interface.mlir
Normal file
24
mlir/test/IR/op-asm-interface.mlir
Normal file
@@ -0,0 +1,24 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test OpAsmOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func.func @result_name_from_op_asm_type_interface() {
|
||||
// CHECK-LABEL: @result_name_from_op_asm_type_interface
|
||||
// CHECK: %op_asm_type_interface
|
||||
%0 = "test.result_name_from_type"() : () -> !test.op_asm_type_interface
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @block_argument_name_from_op_asm_type_interface() {
|
||||
// CHECK-LABEL: @block_argument_name_from_op_asm_type_interface
|
||||
// CHECK: ^bb0(%op_asm_type_interface
|
||||
test.block_argument_name_from_type {
|
||||
^bb0(%arg0: !test.op_asm_type_interface):
|
||||
"test.terminator"() : ()->()
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -506,6 +506,38 @@ void CustomResultsNameOp::getAsmResultNames(
|
||||
setNameFn(getResult(i), str.getValue());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResultNameFromTypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ResultNameFromTypeOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
auto result = getResult();
|
||||
auto setResultNameFn = [&](::llvm::StringRef name) {
|
||||
setNameFn(result, name);
|
||||
};
|
||||
auto opAsmTypeInterface =
|
||||
::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
|
||||
opAsmTypeInterface.getAsmName(setResultNameFn);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BlockArgumentNameFromTypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
|
||||
::mlir::Region ®ion, ::mlir::OpAsmSetValueNameFn setNameFn) {
|
||||
for (auto &block : region) {
|
||||
for (auto arg : block.getArguments()) {
|
||||
if (auto opAsmTypeInterface =
|
||||
::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
|
||||
auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
|
||||
opAsmTypeInterface.getAsmName(setArgNameFn);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResultTypeWithTraitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -924,6 +924,21 @@ def CustomResultsNameOp
|
||||
let results = (outs Variadic<AnyInteger>:$r);
|
||||
}
|
||||
|
||||
// This is used to test OpAsmTypeInterface::getAsmName for op result name,
|
||||
def ResultNameFromTypeOp
|
||||
: TEST_Op<"result_name_from_type",
|
||||
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||
let results = (outs AnyType:$r);
|
||||
}
|
||||
|
||||
// This is used to test OpAsmTypeInterface::getAsmName for block argument,
|
||||
def BlockArgumentNameFromTypeOp
|
||||
: TEST_Op<"block_argument_name_from_type",
|
||||
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let regions = (region AnyRegion:$body);
|
||||
let assemblyFormat = "regions attr-dict-with-keyword";
|
||||
}
|
||||
|
||||
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
|
||||
// operations nested in a region under this op will drop the "test." dialect
|
||||
// prefix.
|
||||
|
||||
@@ -398,4 +398,9 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
|
||||
let assemblyFormat = "`<` $param `>`";
|
||||
}
|
||||
|
||||
def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
|
||||
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
|
||||
let mnemonic = "op_asm_type_interface";
|
||||
}
|
||||
|
||||
#endif // TEST_TYPEDEFS
|
||||
|
||||
@@ -532,3 +532,8 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
|
||||
}
|
||||
printer << ">";
|
||||
}
|
||||
|
||||
void TestTypeOpAsmTypeInterfaceType::getAsmName(
|
||||
OpAsmSetNameFn setNameFn) const {
|
||||
setNameFn("op_asm_type_interface");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user