[mlir] Introduce OpAsmAttrInterface for pretty-print (#124721)
See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction. This PR adds * Definition of `OpAsmAttrInterface` * Integration of `OpAsmAttrInterface` with `AsmPrinter` In https://github.com/llvm/llvm-project/pull/121187#discussion_r1931472250 I mentioned splitting them into two PRs, but I realized that a PR with only definition of `OpAsmAttrInterface` is hard to test as it requires a custom Dialect with `OpAsmDialectInterface` to hook with `AsmPrinter`, so I just put them together to have a e2e test. Cc @River707 @jpienaar @ftynse for review.
This commit is contained in:
@@ -2,6 +2,8 @@ add_mlir_interface(SymbolInterfaces)
|
||||
add_mlir_interface(RegionKindInterface)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
|
||||
mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
|
||||
mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
|
||||
mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
|
||||
mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
|
||||
|
||||
@@ -130,6 +130,28 @@ def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpAsmAttrInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def OpAsmAttrInterface : AttrInterface<"OpAsmAttrInterface"> {
|
||||
let description = [{
|
||||
This interface provides hooks to interact with the AsmPrinter and AsmParser
|
||||
classes.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<[{
|
||||
Get a name to use when generating an alias for this attribute.
|
||||
}],
|
||||
"::mlir::OpAsmDialectInterface::AliasResult", "getAlias",
|
||||
(ins "::llvm::raw_ostream&":$os), "",
|
||||
"return ::mlir::OpAsmDialectInterface::AliasResult::NoAlias;"
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResourceHandleParameter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1825,6 +1825,7 @@ ParseResult parseDimensionList(OpAsmParser &parser,
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
|
||||
#include "mlir/IR/OpAsmAttrInterface.h.inc"
|
||||
#include "mlir/IR/OpAsmOpInterface.h.inc"
|
||||
#include "mlir/IR/OpAsmTypeInterface.h.inc"
|
||||
|
||||
|
||||
@@ -125,6 +125,7 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
|
||||
#include "mlir/IR/OpAsmAttrInterface.cpp.inc"
|
||||
#include "mlir/IR/OpAsmOpInterface.cpp.inc"
|
||||
#include "mlir/IR/OpAsmTypeInterface.cpp.inc"
|
||||
|
||||
@@ -1159,15 +1160,31 @@ template <typename T>
|
||||
void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
|
||||
bool canBeDeferred) {
|
||||
SmallString<32> nameBuffer;
|
||||
for (const auto &interface : interfaces) {
|
||||
OpAsmDialectInterface::AliasResult result =
|
||||
interface.getAlias(symbol, aliasOS);
|
||||
if (result == OpAsmDialectInterface::AliasResult::NoAlias)
|
||||
continue;
|
||||
nameBuffer = std::move(aliasBuffer);
|
||||
assert(!nameBuffer.empty() && "expected valid alias name");
|
||||
if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
|
||||
break;
|
||||
|
||||
OpAsmDialectInterface::AliasResult symbolInterfaceResult =
|
||||
OpAsmDialectInterface::AliasResult::NoAlias;
|
||||
if constexpr (std::is_base_of_v<Attribute, T>) {
|
||||
if (auto symbolInterface = dyn_cast<OpAsmAttrInterface>(symbol)) {
|
||||
symbolInterfaceResult = symbolInterface.getAlias(aliasOS);
|
||||
if (symbolInterfaceResult !=
|
||||
OpAsmDialectInterface::AliasResult::NoAlias) {
|
||||
nameBuffer = std::move(aliasBuffer);
|
||||
assert(!nameBuffer.empty() && "expected valid alias name");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::FinalAlias) {
|
||||
for (const auto &interface : interfaces) {
|
||||
OpAsmDialectInterface::AliasResult result =
|
||||
interface.getAlias(symbol, aliasOS);
|
||||
if (result == OpAsmDialectInterface::AliasResult::NoAlias)
|
||||
continue;
|
||||
nameBuffer = std::move(aliasBuffer);
|
||||
assert(!nameBuffer.empty() && "expected valid alias name");
|
||||
if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (nameBuffer.empty())
|
||||
|
||||
@@ -58,3 +58,17 @@ func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test OpAsmAttrInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: #op_asm_attr_interface_test
|
||||
#attr = #test.op_asm_attr_interface<value = "test">
|
||||
|
||||
func.func @test_op_asm_attr_interface() {
|
||||
%1 = "test.result_name_from_type"() {attr = #attr} : () -> !test.op_asm_type_interface
|
||||
return
|
||||
}
|
||||
|
||||
@@ -395,4 +395,14 @@ def TestCustomLocationAttr : Test_LocAttr<"TestCustomLocation"> {
|
||||
let assemblyFormat = "`<` $file `*` $line `>`";
|
||||
}
|
||||
|
||||
// Test OpAsmAttrInterface.
|
||||
def TestOpAsmAttrInterfaceAttr : Test_Attr<"TestOpAsmAttrInterface",
|
||||
[DeclareAttrInterfaceMethods<OpAsmAttrInterface, ["getAlias"]>]> {
|
||||
let mnemonic = "op_asm_attr_interface";
|
||||
let parameters = (ins "mlir::StringAttr":$value);
|
||||
let assemblyFormat = [{
|
||||
`<` struct(params) `>`
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TEST_ATTRDEFS
|
||||
|
||||
@@ -67,7 +67,7 @@ void CompoundAAttr::print(AsmPrinter &printer) const {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess()){
|
||||
if (parser.parseLess()) {
|
||||
return Attribute();
|
||||
}
|
||||
SmallVector<int64_t> shape;
|
||||
@@ -316,6 +316,17 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestOpAsmAttrInterfaceAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
::mlir::OpAsmDialectInterface::AliasResult
|
||||
TestOpAsmAttrInterfaceAttr::getAlias(::llvm::raw_ostream &os) const {
|
||||
os << "op_asm_attr_interface_";
|
||||
os << getValue().getValue();
|
||||
return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tablegen Generated Definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Reference in New Issue
Block a user