Following a rather direct approach to expose PDL usage from C and then Python. This doesn't yes plumb through adding support for custom matchers through this interface, so constrained to basics initially. This also exposes greedy rewrite driver. Only way currently to define patterns is via PDL (just to keep small). The creation of the PDL pattern module could be improved to avoid folks potentially accessing the module used to construct it post construction. No ergonomic work done yet. --------- Signed-off-by: Jacques Pienaar <jpienaar@google.com>
68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
# RUN: %PYTHON %s 2>&1 | FileCheck %s
|
|
|
|
from mlir.dialects import arith, func, pdl
|
|
from mlir.dialects.builtin import module
|
|
from mlir.ir import *
|
|
from mlir.rewrite import *
|
|
|
|
|
|
def construct_and_print_in_module(f):
|
|
print("\nTEST:", f.__name__)
|
|
with Context(), Location.unknown():
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
module = f(module)
|
|
if module is not None:
|
|
print(module)
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_add_to_mul
|
|
# CHECK: arith.muli
|
|
@construct_and_print_in_module
|
|
def test_add_to_mul(module_):
|
|
index_type = IndexType.get()
|
|
|
|
# Create a test case.
|
|
@module(sym_name="ir")
|
|
def ir():
|
|
@func.func(index_type, index_type)
|
|
def add_func(a, b):
|
|
return arith.addi(a, b)
|
|
|
|
# Create a rewrite from add to mul. This will match
|
|
# - operation name is arith.addi
|
|
# - operands are index types.
|
|
# - there are two operands.
|
|
with Location.unknown():
|
|
m = Module.create()
|
|
with InsertionPoint(m.body):
|
|
# Change all arith.addi with index types to arith.muli.
|
|
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
|
|
def pat():
|
|
# Match arith.addi with index types.
|
|
index_type = pdl.TypeOp(IndexType.get())
|
|
operand0 = pdl.OperandOp(index_type)
|
|
operand1 = pdl.OperandOp(index_type)
|
|
op0 = pdl.OperationOp(
|
|
name="arith.addi", args=[operand0, operand1], types=[index_type]
|
|
)
|
|
|
|
# Replace the matched op with arith.muli.
|
|
@pdl.rewrite()
|
|
def rew():
|
|
newOp = pdl.OperationOp(
|
|
name="arith.muli", args=[operand0, operand1], types=[index_type]
|
|
)
|
|
pdl.ReplaceOp(op0, with_op=newOp)
|
|
|
|
# Create a PDL module from module and freeze it. At this point the ownership
|
|
# of the module is transferred to the PDL module. This ownership transfer is
|
|
# not yet captured Python side/has sharp edges. So best to construct the
|
|
# module and PDL module in same scope.
|
|
# FIXME: This should be made more robust.
|
|
frozen = PDLModule(m).freeze()
|
|
# Could apply frozen pattern set multiple times.
|
|
apply_patterns_and_fold_greedily(module_, frozen)
|
|
return module_
|