[mlir][tensor][transform][python] Add mix-in class.
This patch adds a mix-in class for the only transform op of the tensor dialect that can benefit from one: the MakeLoopIndependentOp. It adds an overload that makes providing the return type optional. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D156918
This commit is contained in:
@@ -207,6 +207,7 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/TensorTransformOps.td
|
||||
SOURCES
|
||||
dialects/_tensor_transform_ops_ext.py
|
||||
dialects/transform/tensor.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME tensor_transform)
|
||||
|
||||
64
mlir/python/mlir/dialects/_tensor_transform_ops_ext.py
Normal file
64
mlir/python/mlir/dialects/_tensor_transform_ops_ext.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import transform
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, overload, Union
|
||||
|
||||
|
||||
class MakeLoopIndependentOp:
|
||||
"""Specialization for MakeLoopIndependentOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
num_loops: Union[int, IntegerAttr],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
num_loops: Union[int, IntegerAttr],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type_or_target: Type,
|
||||
target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None,
|
||||
num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
if isinstance(transformed_type_or_target, Type):
|
||||
transformed_type = transformed_type_or_target
|
||||
target = target_or_num_loops
|
||||
num_loops = num_loops_or_none
|
||||
else:
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
target = transformed_type_or_target
|
||||
num_loops = target_or_num_loops
|
||||
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
num_loops,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
40
mlir/test/python/dialects/transform_tensor_ext.py
Normal file
40
mlir/test/python/dialects/transform_tensor_ext.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import transform
|
||||
from mlir.dialects.transform import tensor
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
f(sequence.bodyTarget)
|
||||
transform.YieldOp()
|
||||
print(module)
|
||||
return f
|
||||
|
||||
|
||||
@run
|
||||
def testMakeLoopIndependentOpCompact(target):
|
||||
tensor.MakeLoopIndependentOp(target, 4)
|
||||
# CHECK-LABEL: TEST: testMakeLoopIndependentOpCompact
|
||||
# CHECK: = transform.tensor.make_loop_independent
|
||||
# CHECK-SAME: num_loops = 4 : i64
|
||||
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
|
||||
|
||||
|
||||
@run
|
||||
def testMakeLoopIndependentOpTyped(target):
|
||||
tensor.MakeLoopIndependentOp(transform.OperationType.get("test.dummy"), target, 4)
|
||||
# CHECK-LABEL: TEST: testMakeLoopIndependentOpTyped
|
||||
# CHECK: = transform.tensor.make_loop_independent
|
||||
# CHECK-SAME: num_loops = 4 : i64
|
||||
# CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy">
|
||||
@@ -975,6 +975,7 @@ filegroup(
|
||||
"mlir/dialects/_loop_transform_ops_ext.py",
|
||||
"mlir/dialects/_memref_transform_ops_ext.py",
|
||||
"mlir/dialects/_structured_transform_ops_ext.py",
|
||||
"mlir/dialects/_tensor_ops_ext.py",
|
||||
"mlir/dialects/_transform_ops_ext.py",
|
||||
"mlir/dialects/_transform_pdl_extension_ops_ext.py",
|
||||
":BufferizationTransformOpsPyGen",
|
||||
|
||||
Reference in New Issue
Block a user