[MLIR][Transform] expose transform.debug extension in Python (#145550)
Removes the Debug... prefix on the ops in tablegen, in line with pretty much all other Transform-dialect extension ops. This means that the ops in Python look like `debug.EmitParamAsRemarkOp`/`debug.emit_param_as_remark` instead of `debug.DebugEmitParamAsRemarkOp`/`debug.debug_emit_param_as_remark`.
This commit is contained in:
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME transform_pdl_extension)
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/TransformDebugExtensionOps.td
|
||||
SOURCES
|
||||
dialects/transform/debug.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME transform_debug_extension)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
|
||||
19
mlir/python/mlir/dialects/TransformDebugExtensionOps.td
Normal file
19
mlir/python/mlir/dialects/TransformDebugExtensionOps.td
Normal file
@@ -0,0 +1,19 @@
|
||||
//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Entry point of the generated Python bindings for the Debug extension of the
|
||||
// Transform dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
|
||||
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
|
||||
|
||||
include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td"
|
||||
|
||||
#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
|
||||
81
mlir/python/mlir/dialects/transform/debug.py
Normal file
81
mlir/python/mlir/dialects/transform/debug.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ...ir import Attribute, Operation, Value, StringAttr
|
||||
from .._transform_debug_extension_ops_gen import *
|
||||
from .._transform_pdl_extension_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class EmitParamAsRemarkOp(EmitParamAsRemarkOp):
|
||||
def __init__(
|
||||
self,
|
||||
param: Attribute,
|
||||
*,
|
||||
anchor: Optional[Operation] = None,
|
||||
message: Optional[Union[StringAttr, str]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(message, str):
|
||||
message = StringAttr.get(message)
|
||||
|
||||
super().__init__(
|
||||
param,
|
||||
anchor=anchor,
|
||||
message=message,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
def emit_param_as_remark(
|
||||
param: Attribute,
|
||||
*,
|
||||
anchor: Optional[Operation] = None,
|
||||
message: Optional[Union[StringAttr, str]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
return EmitParamAsRemarkOp(param, anchor=anchor, message=message, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class EmitRemarkAtOp(EmitRemarkAtOp):
|
||||
def __init__(
|
||||
self,
|
||||
at: Union[Operation, Value],
|
||||
message: Optional[Union[StringAttr, str]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(message, str):
|
||||
message = StringAttr.get(message)
|
||||
|
||||
super().__init__(
|
||||
at,
|
||||
message,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
def emit_remark_at(
|
||||
at: Union[Operation, Value],
|
||||
message: Optional[Union[StringAttr, str]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
return EmitRemarkAtOp(at, message, loc=loc, ip=ip)
|
||||
Reference in New Issue
Block a user