diff --git a/mlir/include/mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td b/mlir/include/mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td index 0275f241fda3..4a6898e36d34 100644 --- a/mlir/include/mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td +++ b/mlir/include/mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td @@ -20,7 +20,7 @@ include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.td" include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" -def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at", +def EmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at", [MatchOpInterface, DeclareOpInterfaceMethods, MemoryEffectsOpInterface, NavigationTransformOpTrait]> { @@ -39,7 +39,7 @@ def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at", let assemblyFormat = "$at `,` $message attr-dict `:` type($at)"; } -def DebugEmitParamAsRemarkOp +def EmitParamAsRemarkOp : TransformDialectOp<"debug.emit_param_as_remark", [MatchOpInterface, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp index 7a9f8f4b1b52..12257da878a4 100644 --- a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp @@ -19,9 +19,9 @@ using namespace mlir; #include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc" DiagnosedSilenceableFailure -transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter, - transform::TransformResults &results, - transform::TransformState &state) { +transform::EmitRemarkAtOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { if (isa(getAt().getType())) { auto payload = state.getPayloadOps(getAt()); for (Operation *op : payload) @@ -52,9 +52,10 @@ transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter, return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure transform::DebugEmitParamAsRemarkOp::apply( - transform::TransformRewriter &rewriter, - transform::TransformResults &results, transform::TransformState &state) { +DiagnosedSilenceableFailure +transform::EmitParamAsRemarkOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { std::string str; llvm::raw_string_ostream os(str); if (getMessage()) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index ee07081246fc..b2daabb2a595 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" DIALECT_NAME transform EXTENSION_NAME transform_pdl_extension) +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformDebugExtensionOps.td + SOURCES + dialects/transform/debug.py + DIALECT_NAME transform + EXTENSION_NAME transform_debug_extension) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/TransformDebugExtensionOps.td b/mlir/python/mlir/dialects/TransformDebugExtensionOps.td new file mode 100644 index 000000000000..22a85d236699 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformDebugExtensionOps.td @@ -0,0 +1,19 @@ +//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the Debug extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS + +include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/transform/debug.py b/mlir/python/mlir/dialects/transform/debug.py new file mode 100644 index 000000000000..f7c04268dc03 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/debug.py @@ -0,0 +1,81 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +from ...ir import Attribute, Operation, Value, StringAttr +from .._transform_debug_extension_ops_gen import * +from .._transform_pdl_extension_ops_gen import _Dialect + +try: + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmitParamAsRemarkOp(EmitParamAsRemarkOp): + def __init__( + self, + param: Attribute, + *, + anchor: Optional[Operation] = None, + message: Optional[Union[StringAttr, str]] = None, + loc=None, + ip=None, + ): + if isinstance(message, str): + message = StringAttr.get(message) + + super().__init__( + param, + anchor=anchor, + message=message, + loc=loc, + ip=ip, + ) + + +def emit_param_as_remark( + param: Attribute, + *, + anchor: Optional[Operation] = None, + message: Optional[Union[StringAttr, str]] = None, + loc=None, + ip=None, +): + return EmitParamAsRemarkOp(param, anchor=anchor, message=message, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmitRemarkAtOp(EmitRemarkAtOp): + def __init__( + self, + at: Union[Operation, Value], + message: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(message, str): + message = StringAttr.get(message) + + super().__init__( + at, + message, + loc=loc, + ip=ip, + ) + + +def emit_remark_at( + at: Union[Operation, Value], + message: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, +): + return EmitRemarkAtOp(at, message, loc=loc, ip=ip) diff --git a/mlir/test/python/dialects/transform_debug_ext.py b/mlir/test/python/dialects/transform_debug_ext.py new file mode 100644 index 000000000000..2dfdaed34386 --- /dev/null +++ b/mlir/test/python/dialects/transform_debug_ext.py @@ -0,0 +1,45 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects.transform import debug + + +def run(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + f(sequence.bodyTarget) + transform.YieldOp() + print(module) + return f + + +@run +def testDebugEmitParamAsRemark(target): + i0 = IntegerAttr.get(IntegerType.get_signless(32), 0) + i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0) + debug.emit_param_as_remark(i0_param) + debug.emit_param_as_remark(i0_param, anchor=target, message="some text") + # CHECK-LABEL: TEST: testDebugEmitParamAsRemark + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: %[[PARAM:.*]] = transform.param.constant + # CHECK: transform.debug.emit_param_as_remark %[[PARAM]] + # CHECK: transform.debug.emit_param_as_remark %[[PARAM]] + # CHECK-SAME: "some text" + # CHECK-SAME: at %[[ARG0]] + + +@run +def testDebugEmitRemarkAtOp(target): + debug.emit_remark_at(target, "some text") + # CHECK-LABEL: TEST: testDebugEmitRemarkAtOp + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: transform.debug.emit_remark_at %[[ARG0]], "some text"