[MLIR][Transform] expose transform.debug extension in Python (#145550)

Removes the Debug... prefix on the ops in tablegen, in line with pretty
much all other Transform-dialect extension ops. This means that the ops
in Python look like
`debug.EmitParamAsRemarkOp`/`debug.emit_param_as_remark` instead of
`debug.DebugEmitParamAsRemarkOp`/`debug.debug_emit_param_as_remark`.
This commit is contained in:
Rolf Morel
2025-06-25 17:39:01 +02:00
committed by GitHub
parent 46ee7f1908
commit c08502defe
6 changed files with 163 additions and 8 deletions

View File

@@ -20,7 +20,7 @@ include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformDialect.td"
def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at", def EmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
[MatchOpInterface, [MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<TransformOpInterface>,
MemoryEffectsOpInterface, NavigationTransformOpTrait]> { MemoryEffectsOpInterface, NavigationTransformOpTrait]> {
@@ -39,7 +39,7 @@ def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
let assemblyFormat = "$at `,` $message attr-dict `:` type($at)"; let assemblyFormat = "$at `,` $message attr-dict `:` type($at)";
} }
def DebugEmitParamAsRemarkOp def EmitParamAsRemarkOp
: TransformDialectOp<"debug.emit_param_as_remark", : TransformDialectOp<"debug.emit_param_as_remark",
[MatchOpInterface, [MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<TransformOpInterface>,

View File

@@ -19,9 +19,9 @@ using namespace mlir;
#include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc" #include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc"
DiagnosedSilenceableFailure DiagnosedSilenceableFailure
transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter, transform::EmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformResults &results,
transform::TransformState &state) { transform::TransformState &state) {
if (isa<TransformHandleTypeInterface>(getAt().getType())) { if (isa<TransformHandleTypeInterface>(getAt().getType())) {
auto payload = state.getPayloadOps(getAt()); auto payload = state.getPayloadOps(getAt());
for (Operation *op : payload) for (Operation *op : payload)
@@ -52,9 +52,10 @@ transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success(); return DiagnosedSilenceableFailure::success();
} }
DiagnosedSilenceableFailure transform::DebugEmitParamAsRemarkOp::apply( DiagnosedSilenceableFailure
transform::TransformRewriter &rewriter, transform::EmitParamAsRemarkOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) { transform::TransformResults &results,
transform::TransformState &state) {
std::string str; std::string str;
llvm::raw_string_ostream os(str); llvm::raw_string_ostream os(str);
if (getMessage()) if (getMessage())

View File

@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
DIALECT_NAME transform DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension) EXTENSION_NAME transform_pdl_extension)
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformDebugExtensionOps.td
SOURCES
dialects/transform/debug.py
DIALECT_NAME transform
EXTENSION_NAME transform_debug_extension)
declare_mlir_dialect_python_bindings( declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

View File

@@ -0,0 +1,19 @@
//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Entry point of the generated Python bindings for the Debug extension of the
// Transform dialect.
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td"
#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS

View File

@@ -0,0 +1,81 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Optional
from ...ir import Attribute, Operation, Value, StringAttr
from .._transform_debug_extension_ops_gen import *
from .._transform_pdl_extension_ops_gen import _Dialect
try:
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Union
@_ods_cext.register_operation(_Dialect, replace=True)
class EmitParamAsRemarkOp(EmitParamAsRemarkOp):
def __init__(
self,
param: Attribute,
*,
anchor: Optional[Operation] = None,
message: Optional[Union[StringAttr, str]] = None,
loc=None,
ip=None,
):
if isinstance(message, str):
message = StringAttr.get(message)
super().__init__(
param,
anchor=anchor,
message=message,
loc=loc,
ip=ip,
)
def emit_param_as_remark(
param: Attribute,
*,
anchor: Optional[Operation] = None,
message: Optional[Union[StringAttr, str]] = None,
loc=None,
ip=None,
):
return EmitParamAsRemarkOp(param, anchor=anchor, message=message, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class EmitRemarkAtOp(EmitRemarkAtOp):
def __init__(
self,
at: Union[Operation, Value],
message: Optional[Union[StringAttr, str]] = None,
*,
loc=None,
ip=None,
):
if isinstance(message, str):
message = StringAttr.get(message)
super().__init__(
at,
message,
loc=loc,
ip=ip,
)
def emit_remark_at(
at: Union[Operation, Value],
message: Optional[Union[StringAttr, str]] = None,
*,
loc=None,
ip=None,
):
return EmitRemarkAtOp(at, message, loc=loc, ip=ip)

View File

@@ -0,0 +1,45 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import debug
def run(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
f(sequence.bodyTarget)
transform.YieldOp()
print(module)
return f
@run
def testDebugEmitParamAsRemark(target):
i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
debug.emit_param_as_remark(i0_param)
debug.emit_param_as_remark(i0_param, anchor=target, message="some text")
# CHECK-LABEL: TEST: testDebugEmitParamAsRemark
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: %[[PARAM:.*]] = transform.param.constant
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
# CHECK-SAME: "some text"
# CHECK-SAME: at %[[ARG0]]
@run
def testDebugEmitRemarkAtOp(target):
debug.emit_remark_at(target, "some text")
# CHECK-LABEL: TEST: testDebugEmitRemarkAtOp
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: transform.debug.emit_remark_at %[[ARG0]], "some text"