[MLIR][Transform] expose transform.debug extension in Python (#145550)
Removes the Debug... prefix on the ops in tablegen, in line with pretty much all other Transform-dialect extension ops. This means that the ops in Python look like `debug.EmitParamAsRemarkOp`/`debug.emit_param_as_remark` instead of `debug.DebugEmitParamAsRemarkOp`/`debug.debug_emit_param_as_remark`.
This commit is contained in:
@@ -20,7 +20,7 @@ include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.td"
|
|||||||
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
|
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
|
||||||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||||
|
|
||||||
def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
|
def EmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
|
||||||
[MatchOpInterface,
|
[MatchOpInterface,
|
||||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||||
MemoryEffectsOpInterface, NavigationTransformOpTrait]> {
|
MemoryEffectsOpInterface, NavigationTransformOpTrait]> {
|
||||||
@@ -39,7 +39,7 @@ def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
|
|||||||
let assemblyFormat = "$at `,` $message attr-dict `:` type($at)";
|
let assemblyFormat = "$at `,` $message attr-dict `:` type($at)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def DebugEmitParamAsRemarkOp
|
def EmitParamAsRemarkOp
|
||||||
: TransformDialectOp<"debug.emit_param_as_remark",
|
: TransformDialectOp<"debug.emit_param_as_remark",
|
||||||
[MatchOpInterface,
|
[MatchOpInterface,
|
||||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||||
|
|||||||
@@ -19,9 +19,9 @@ using namespace mlir;
|
|||||||
#include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc"
|
#include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc"
|
||||||
|
|
||||||
DiagnosedSilenceableFailure
|
DiagnosedSilenceableFailure
|
||||||
transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
|
transform::EmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
|
||||||
transform::TransformResults &results,
|
transform::TransformResults &results,
|
||||||
transform::TransformState &state) {
|
transform::TransformState &state) {
|
||||||
if (isa<TransformHandleTypeInterface>(getAt().getType())) {
|
if (isa<TransformHandleTypeInterface>(getAt().getType())) {
|
||||||
auto payload = state.getPayloadOps(getAt());
|
auto payload = state.getPayloadOps(getAt());
|
||||||
for (Operation *op : payload)
|
for (Operation *op : payload)
|
||||||
@@ -52,9 +52,10 @@ transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
|
|||||||
return DiagnosedSilenceableFailure::success();
|
return DiagnosedSilenceableFailure::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
DiagnosedSilenceableFailure transform::DebugEmitParamAsRemarkOp::apply(
|
DiagnosedSilenceableFailure
|
||||||
transform::TransformRewriter &rewriter,
|
transform::EmitParamAsRemarkOp::apply(transform::TransformRewriter &rewriter,
|
||||||
transform::TransformResults &results, transform::TransformState &state) {
|
transform::TransformResults &results,
|
||||||
|
transform::TransformState &state) {
|
||||||
std::string str;
|
std::string str;
|
||||||
llvm::raw_string_ostream os(str);
|
llvm::raw_string_ostream os(str);
|
||||||
if (getMessage())
|
if (getMessage())
|
||||||
|
|||||||
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
|||||||
DIALECT_NAME transform
|
DIALECT_NAME transform
|
||||||
EXTENSION_NAME transform_pdl_extension)
|
EXTENSION_NAME transform_pdl_extension)
|
||||||
|
|
||||||
|
declare_mlir_dialect_extension_python_bindings(
|
||||||
|
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||||
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||||
|
TD_FILE dialects/TransformDebugExtensionOps.td
|
||||||
|
SOURCES
|
||||||
|
dialects/transform/debug.py
|
||||||
|
DIALECT_NAME transform
|
||||||
|
EXTENSION_NAME transform_debug_extension)
|
||||||
|
|
||||||
declare_mlir_dialect_python_bindings(
|
declare_mlir_dialect_python_bindings(
|
||||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||||
|
|||||||
19
mlir/python/mlir/dialects/TransformDebugExtensionOps.td
Normal file
19
mlir/python/mlir/dialects/TransformDebugExtensionOps.td
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Entry point of the generated Python bindings for the Debug extension of the
|
||||||
|
// Transform dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
|
||||||
|
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
|
||||||
|
|
||||||
|
include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td"
|
||||||
|
|
||||||
|
#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
|
||||||
81
mlir/python/mlir/dialects/transform/debug.py
Normal file
81
mlir/python/mlir/dialects/transform/debug.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ...ir import Attribute, Operation, Value, StringAttr
|
||||||
|
from .._transform_debug_extension_ops_gen import *
|
||||||
|
from .._transform_pdl_extension_ops_gen import _Dialect
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .._ods_common import _cext as _ods_cext
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError("Error loading imports from extension module") from e
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||||
|
class EmitParamAsRemarkOp(EmitParamAsRemarkOp):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
param: Attribute,
|
||||||
|
*,
|
||||||
|
anchor: Optional[Operation] = None,
|
||||||
|
message: Optional[Union[StringAttr, str]] = None,
|
||||||
|
loc=None,
|
||||||
|
ip=None,
|
||||||
|
):
|
||||||
|
if isinstance(message, str):
|
||||||
|
message = StringAttr.get(message)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
param,
|
||||||
|
anchor=anchor,
|
||||||
|
message=message,
|
||||||
|
loc=loc,
|
||||||
|
ip=ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def emit_param_as_remark(
|
||||||
|
param: Attribute,
|
||||||
|
*,
|
||||||
|
anchor: Optional[Operation] = None,
|
||||||
|
message: Optional[Union[StringAttr, str]] = None,
|
||||||
|
loc=None,
|
||||||
|
ip=None,
|
||||||
|
):
|
||||||
|
return EmitParamAsRemarkOp(param, anchor=anchor, message=message, loc=loc, ip=ip)
|
||||||
|
|
||||||
|
|
||||||
|
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||||
|
class EmitRemarkAtOp(EmitRemarkAtOp):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
at: Union[Operation, Value],
|
||||||
|
message: Optional[Union[StringAttr, str]] = None,
|
||||||
|
*,
|
||||||
|
loc=None,
|
||||||
|
ip=None,
|
||||||
|
):
|
||||||
|
if isinstance(message, str):
|
||||||
|
message = StringAttr.get(message)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
at,
|
||||||
|
message,
|
||||||
|
loc=loc,
|
||||||
|
ip=ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def emit_remark_at(
|
||||||
|
at: Union[Operation, Value],
|
||||||
|
message: Optional[Union[StringAttr, str]] = None,
|
||||||
|
*,
|
||||||
|
loc=None,
|
||||||
|
ip=None,
|
||||||
|
):
|
||||||
|
return EmitRemarkAtOp(at, message, loc=loc, ip=ip)
|
||||||
45
mlir/test/python/dialects/transform_debug_ext.py
Normal file
45
mlir/test/python/dialects/transform_debug_ext.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
from mlir.ir import *
|
||||||
|
from mlir.dialects import transform
|
||||||
|
from mlir.dialects.transform import debug
|
||||||
|
|
||||||
|
|
||||||
|
def run(f):
|
||||||
|
print("\nTEST:", f.__name__)
|
||||||
|
with Context(), Location.unknown():
|
||||||
|
module = Module.create()
|
||||||
|
with InsertionPoint(module.body):
|
||||||
|
sequence = transform.SequenceOp(
|
||||||
|
transform.FailurePropagationMode.Propagate,
|
||||||
|
[],
|
||||||
|
transform.AnyOpType.get(),
|
||||||
|
)
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
f(sequence.bodyTarget)
|
||||||
|
transform.YieldOp()
|
||||||
|
print(module)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testDebugEmitParamAsRemark(target):
|
||||||
|
i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
|
||||||
|
i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
|
||||||
|
debug.emit_param_as_remark(i0_param)
|
||||||
|
debug.emit_param_as_remark(i0_param, anchor=target, message="some text")
|
||||||
|
# CHECK-LABEL: TEST: testDebugEmitParamAsRemark
|
||||||
|
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
|
||||||
|
# CHECK: %[[PARAM:.*]] = transform.param.constant
|
||||||
|
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
|
||||||
|
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
|
||||||
|
# CHECK-SAME: "some text"
|
||||||
|
# CHECK-SAME: at %[[ARG0]]
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testDebugEmitRemarkAtOp(target):
|
||||||
|
debug.emit_remark_at(target, "some text")
|
||||||
|
# CHECK-LABEL: TEST: testDebugEmitRemarkAtOp
|
||||||
|
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
|
||||||
|
# CHECK: transform.debug.emit_remark_at %[[ARG0]], "some text"
|
||||||
Reference in New Issue
Block a user