Files
clang-p2996/mlir/test/python/pass_manager.py
Maksim Levental bdc3e6cb45 [MLIR][python bindings] invalidate ops after PassManager run (#69746)
Fixes https://github.com/llvm/llvm-project/issues/69730 (also see
https://reviews.llvm.org/D155543).

There are  two things outstanding (why I didn't land before):

1. add some C API tests for `mlirOperationWalk`;
2. potentially refactor how the invalidation in `run` works; the first
version of the code looked like this:
    ```cpp
    if (invalidateOps) {
      auto *context = op.getOperation().getContext().get();
      MlirOperationWalkCallback invalidatingCallback =
          [](MlirOperation op, void *userData) {
            PyMlirContext *context =
                static_cast<PyMlirContext *>(userData);
            context->setOperationInvalid(op);
          };
      auto numRegions =
          mlirOperationGetNumRegions(op.getOperation().get());
      for (int i = 0; i < numRegions; ++i) {
        MlirRegion region =
            mlirOperationGetRegion(op.getOperation().get(), i);
        for (MlirBlock block = mlirRegionGetFirstBlock(region);
             !mlirBlockIsNull(block);
             block = mlirBlockGetNextInRegion(block))
          for (MlirOperation childOp =
                   mlirBlockGetFirstOperation(block);
               !mlirOperationIsNull(childOp);
               childOp = mlirOperationGetNextInBlock(childOp))
            mlirOperationWalk(childOp, invalidatingCallback, context,
                              MlirWalkPostOrder);
      }
    }
    ```
This is verbose and ugly but it has the important benefit of not
executing `mlirOperationEqual(rootOp->get(), op)` for every op
underneath the root op.

Supposing there's no desire for the slightly more efficient but highly
convoluted approach, I can land this "posthaste".
But, since we have eyes on this now, any suggestions or approaches (or
needs/concerns) are welcome.
2023-10-20 20:28:32 -05:00

263 lines
7.2 KiB
Python

# RUN: %PYTHON %s 2>&1 | FileCheck %s
import gc, sys
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.func import FuncOp
from mlir.dialects.builtin import ModuleOp
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
def log(*args):
print(*args, file=sys.stderr)
sys.stderr.flush()
def run(f):
log("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
# Verify capsule interop.
# CHECK-LABEL: TEST: testCapsule
def testCapsule():
with Context():
pm = PassManager()
pm_capsule = pm._CAPIPtr
assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
pm._testing_release()
pm1 = PassManager._CAPICreate(pm_capsule)
assert pm1 is not None # And does not crash.
run(testCapsule)
# CHECK-LABEL: TEST: testConstruct
@run
def testConstruct():
with Context():
# CHECK: pm1: 'any()'
# CHECK: pm2: 'builtin.module()'
pm1 = PassManager()
pm2 = PassManager("builtin.module")
log(f"pm1: '{pm1}'")
log(f"pm2: '{pm2}'")
# Verify successful round-trip.
# CHECK-LABEL: TEST: testParseSuccess
def testParseSuccess():
with Context():
# An unregistered pass should not parse.
try:
pm = PassManager.parse(
"builtin.module(func.func(not-existing-pass{json=false}))"
)
except ValueError as e:
# CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass
log("ValueError exception:", e)
else:
log("Exception not produced")
# A registered pass should parse successfully.
pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
# CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
log("Roundtrip: ", pm)
run(testParseSuccess)
# Verify successful round-trip.
# CHECK-LABEL: TEST: testParseSpacedPipeline
def testParseSpacedPipeline():
with Context():
# A registered pass should parse successfully even if has extras spaces for readability
pm = PassManager.parse(
"""builtin.module(
func.func( print-op-stats{ json=false } )
)"""
)
# CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
log("Roundtrip: ", pm)
run(testParseSpacedPipeline)
# Verify failure on unregistered pass.
# CHECK-LABEL: TEST: testParseFail
def testParseFail():
with Context():
try:
pm = PassManager.parse("any(unknown-pass)")
except ValueError as e:
# CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error:
# CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline
# CHECK: unknown-pass
# CHECK: ^
log("ValueError exception:", e)
else:
log("Exception not produced")
run(testParseFail)
# Check that adding to a pass manager works
# CHECK-LABEL: TEST: testAdd
@run
def testAdd():
pm = PassManager("any", Context())
# CHECK: pm: 'any()'
log(f"pm: '{pm}'")
# CHECK: pm: 'any(cse)'
pm.add("cse")
log(f"pm: '{pm}'")
# CHECK: pm: 'any(cse,cse)'
pm.add("cse")
log(f"pm: '{pm}'")
# Verify failure on incorrect level of nesting.
# CHECK-LABEL: TEST: testInvalidNesting
def testInvalidNesting():
with Context():
try:
pm = PassManager.parse("func.func(normalize-memrefs)")
except ValueError as e:
# CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
log("ValueError exception:", e)
else:
log("Exception not produced")
run(testInvalidNesting)
# Verify that a pass manager can execute on IR
# CHECK-LABEL: TEST: testRunPipeline
def testRunPipeline():
with Context():
pm = PassManager.parse("any(print-op-stats{json=false})")
func = FuncOp.parse(r"""func.func @successfulParse() { return }""")
pm.run(func)
# CHECK: Operations encountered:
# CHECK: func.func , 1
# CHECK: func.return , 1
run(testRunPipeline)
# CHECK-LABEL: TEST: testRunPipelineError
@run
def testRunPipelineError():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
op = Operation.parse('"test.op"() : () -> ()')
pm = PassManager.parse("any(cse)")
try:
pm.run(op)
except MLIRError as e:
# CHECK: Exception: <
# CHECK: Failure while executing pass pipeline:
# CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
# CHECK: note: "-":1:1: see current operation: "test.op"() : () -> ()
# CHECK: >
log(f"Exception: <{e}>")
# CHECK-LABEL: TEST: testPostPassOpInvalidation
@run
def testPostPassOpInvalidation():
with Context() as ctx:
module = ModuleOp.parse(
"""
module {
arith.constant 10
func.func @foo() {
arith.constant 10
return
}
}
"""
)
# CHECK: invalidate_ops=False
log("invalidate_ops=False")
outer_const_op = module.body.operations[0]
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
log(outer_const_op)
func_op = module.body.operations[1]
# CHECK: func.func @[[FOO:.*]]() {
# CHECK: %[[VAL1:.*]] = arith.constant 10 : i64
# CHECK: return
# CHECK: }
log(func_op)
inner_const_op = func_op.body.blocks[0].operations[0]
# CHECK: %[[VAL1]] = arith.constant 10 : i64
log(inner_const_op)
PassManager.parse("builtin.module(canonicalize)").run(
module, invalidate_ops=False
)
# CHECK: func.func @foo() {
# CHECK: return
# CHECK: }
log(func_op)
# CHECK: func.func @foo() {
# CHECK: return
# CHECK: }
log(module)
# CHECK: invalidate_ops=True
log("invalidate_ops=True")
module = ModuleOp.parse(
"""
module {
arith.constant 10
func.func @foo() {
arith.constant 10
return
}
}
"""
)
outer_const_op = module.body.operations[0]
func_op = module.body.operations[1]
inner_const_op = func_op.body.blocks[0].operations[0]
PassManager.parse("builtin.module(canonicalize)").run(module)
try:
log(func_op)
except RuntimeError as e:
# CHECK: the operation has been invalidated
log(e)
try:
log(outer_const_op)
except RuntimeError as e:
# CHECK: the operation has been invalidated
log(e)
try:
log(inner_const_op)
except RuntimeError as e:
# CHECK: the operation has been invalidated
log(e)
# CHECK: func.func @foo() {
# CHECK: return
# CHECK: }
log(module)