Files
clang-p2996/mlir/test/python/ir/context_lifecycle.py
vfdev ab18cc246c [MLIR][py] Add PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (#130109)
In some projects like JAX ir.Context are used with disabled multi-threading to avoid
caching multiple threading pools:

623865fe95/jax/_src/interpreters/mlir.py (L606-L611)

However, when context has enabled multithreading it also uses locks on
the StorageUniquers and this can be helpful to avoid data races in the
multi-threaded execution (for example with free-threaded cpython,
https://github.com/jax-ml/jax/issues/26272).
With this PR user can enable the multi-threading: 1) enables additional
locking and 2) set a shared threading pool such that cached contexts can
have one global pool.
2025-03-10 11:19:23 +01:00

73 lines
1.8 KiB
Python

# RUN: %PYTHON %s
# Standalone sanity check of context life-cycle.
import gc
import mlir.ir
assert mlir.ir.Context._get_live_count() == 0
# Create first context.
print("CREATE C1")
c1 = mlir.ir.Context()
assert mlir.ir.Context._get_live_count() == 1
c1_repr = repr(c1)
print("C1 = ", c1_repr)
print("GETTING AGAIN...")
c2 = c1._get_context_again()
c2_repr = repr(c2)
assert mlir.ir.Context._get_live_count() == 1
assert c1_repr == c2_repr
print("C2 =", c2)
# Make sure new contexts on constructor.
print("CREATE C3")
c3 = mlir.ir.Context()
assert mlir.ir.Context._get_live_count() == 2
c3_repr = repr(c3)
print("C3 =", c3)
assert c3_repr != c1_repr
print("FREE C3")
c3 = None
gc.collect()
assert mlir.ir.Context._get_live_count() == 1
print("Free C1")
c1 = None
gc.collect()
assert mlir.ir.Context._get_live_count() == 1
print("Free C2")
c2 = None
gc.collect()
assert mlir.ir.Context._get_live_count() == 0
# Create a context, get its capsule and create from capsule.
c4 = mlir.ir.Context()
c4_capsule = c4._CAPIPtr
assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
c5 = mlir.ir.Context._CAPICreate(c4_capsule)
assert c4 is c5
c4 = None
c5 = None
gc.collect()
# Create a global threadpool and use it in two contexts
tp = mlir.ir.ThreadPool()
assert tp.get_max_concurrency() > 0
c5 = mlir.ir.Context()
c5.set_thread_pool(tp)
assert c5.get_num_threads() == tp.get_max_concurrency()
assert c5._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
c6 = mlir.ir.Context()
c6.set_thread_pool(tp)
assert c6.get_num_threads() == tp.get_max_concurrency()
assert c6._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
c7 = mlir.ir.Context(thread_pool=tp)
assert c7.get_num_threads() == tp.get_max_concurrency()
assert c7._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
assert mlir.ir.Context._get_live_count() == 3
c5 = None
c6 = None
c7 = None
gc.collect()