[MLIR][Transform] expose transform.debug extension in Python (#145550)
Removes the Debug... prefix on the ops in tablegen, in line with pretty much all other Transform-dialect extension ops. This means that the ops in Python look like `debug.EmitParamAsRemarkOp`/`debug.emit_param_as_remark` instead of `debug.DebugEmitParamAsRemarkOp`/`debug.debug_emit_param_as_remark`.
This commit is contained in:
@@ -20,7 +20,7 @@ include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.td"
|
||||
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||
|
||||
def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
|
||||
def EmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
|
||||
[MatchOpInterface,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
MemoryEffectsOpInterface, NavigationTransformOpTrait]> {
|
||||
@@ -39,7 +39,7 @@ def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
|
||||
let assemblyFormat = "$at `,` $message attr-dict `:` type($at)";
|
||||
}
|
||||
|
||||
def DebugEmitParamAsRemarkOp
|
||||
def EmitParamAsRemarkOp
|
||||
: TransformDialectOp<"debug.emit_param_as_remark",
|
||||
[MatchOpInterface,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
|
||||
@@ -19,9 +19,9 @@ using namespace mlir;
|
||||
#include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc"
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
transform::EmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
if (isa<TransformHandleTypeInterface>(getAt().getType())) {
|
||||
auto payload = state.getPayloadOps(getAt());
|
||||
for (Operation *op : payload)
|
||||
@@ -52,9 +52,10 @@ transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure transform::DebugEmitParamAsRemarkOp::apply(
|
||||
transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results, transform::TransformState &state) {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::EmitParamAsRemarkOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
if (getMessage())
|
||||
|
||||
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME transform_pdl_extension)
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/TransformDebugExtensionOps.td
|
||||
SOURCES
|
||||
dialects/transform/debug.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME transform_debug_extension)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
|
||||
19
mlir/python/mlir/dialects/TransformDebugExtensionOps.td
Normal file
19
mlir/python/mlir/dialects/TransformDebugExtensionOps.td
Normal file
@@ -0,0 +1,19 @@
|
||||
//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Entry point of the generated Python bindings for the Debug extension of the
|
||||
// Transform dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
|
||||
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
|
||||
|
||||
include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td"
|
||||
|
||||
#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
|
||||
81
mlir/python/mlir/dialects/transform/debug.py
Normal file
81
mlir/python/mlir/dialects/transform/debug.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ...ir import Attribute, Operation, Value, StringAttr
|
||||
from .._transform_debug_extension_ops_gen import *
|
||||
from .._transform_pdl_extension_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class EmitParamAsRemarkOp(EmitParamAsRemarkOp):
|
||||
def __init__(
|
||||
self,
|
||||
param: Attribute,
|
||||
*,
|
||||
anchor: Optional[Operation] = None,
|
||||
message: Optional[Union[StringAttr, str]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(message, str):
|
||||
message = StringAttr.get(message)
|
||||
|
||||
super().__init__(
|
||||
param,
|
||||
anchor=anchor,
|
||||
message=message,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
def emit_param_as_remark(
|
||||
param: Attribute,
|
||||
*,
|
||||
anchor: Optional[Operation] = None,
|
||||
message: Optional[Union[StringAttr, str]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
return EmitParamAsRemarkOp(param, anchor=anchor, message=message, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class EmitRemarkAtOp(EmitRemarkAtOp):
|
||||
def __init__(
|
||||
self,
|
||||
at: Union[Operation, Value],
|
||||
message: Optional[Union[StringAttr, str]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(message, str):
|
||||
message = StringAttr.get(message)
|
||||
|
||||
super().__init__(
|
||||
at,
|
||||
message,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
def emit_remark_at(
|
||||
at: Union[Operation, Value],
|
||||
message: Optional[Union[StringAttr, str]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
return EmitRemarkAtOp(at, message, loc=loc, ip=ip)
|
||||
45
mlir/test/python/dialects/transform_debug_ext.py
Normal file
45
mlir/test/python/dialects/transform_debug_ext.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import transform
|
||||
from mlir.dialects.transform import debug
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
f(sequence.bodyTarget)
|
||||
transform.YieldOp()
|
||||
print(module)
|
||||
return f
|
||||
|
||||
|
||||
@run
|
||||
def testDebugEmitParamAsRemark(target):
|
||||
i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
|
||||
i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
|
||||
debug.emit_param_as_remark(i0_param)
|
||||
debug.emit_param_as_remark(i0_param, anchor=target, message="some text")
|
||||
# CHECK-LABEL: TEST: testDebugEmitParamAsRemark
|
||||
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
|
||||
# CHECK: %[[PARAM:.*]] = transform.param.constant
|
||||
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
|
||||
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
|
||||
# CHECK-SAME: "some text"
|
||||
# CHECK-SAME: at %[[ARG0]]
|
||||
|
||||
|
||||
@run
|
||||
def testDebugEmitRemarkAtOp(target):
|
||||
debug.emit_remark_at(target, "some text")
|
||||
# CHECK-LABEL: TEST: testDebugEmitRemarkAtOp
|
||||
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
|
||||
# CHECK: transform.debug.emit_remark_at %[[ARG0]], "some text"
|
||||
Reference in New Issue
Block a user