# RUN: %PYTHON %s | FileCheck %s import gc import io import itertools from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 return f # CHECK-LABEL: TEST: testSymbolTableInsert @run def testSymbolTableInsert(): with Context() as ctx: ctx.allow_unregistered_dialects = True m1 = Module.parse(""" func.func private @foo() func.func private @bar()""") m2 = Module.parse(""" func.func private @qux() func.func private @foo() "foo.bar"() : () -> ()""") symbol_table = SymbolTable(m1.operation) # CHECK: func private @foo # CHECK: func private @bar assert "foo" in symbol_table print(symbol_table["foo"]) assert "bar" in symbol_table bar = symbol_table["bar"] print(symbol_table["bar"]) assert "qux" not in symbol_table del symbol_table["bar"] try: symbol_table.erase(symbol_table["bar"]) except KeyError: pass else: assert False, "expected KeyError" # CHECK: module # CHECK: func private @foo() print(m1) assert "bar" not in symbol_table try: print(bar) except RuntimeError as e: if "the operation has been invalidated" not in str(e): raise else: assert False, "expected RuntimeError due to invalidated operation" qux = m2.body.operations[0] m1.body.append(qux) symbol_table.insert(qux) assert "qux" in symbol_table # Check that insertion actually renames this symbol in the symbol table. foo2 = m2.body.operations[0] m1.body.append(foo2) updated_name = symbol_table.insert(foo2) assert foo2.name.value != "foo" assert foo2.name == updated_name # CHECK: module # CHECK: func private @foo() # CHECK: func private @qux() # CHECK: func private @foo{{.*}} print(m1) try: symbol_table.insert(m2.body.operations[0]) except ValueError as e: if "Expected operation to have a symbol name" not in str(e): raise else: assert False, "exepcted ValueError when adding a non-symbol" # CHECK-LABEL: testSymbolTableRAUW @run def testSymbolTableRAUW(): with Context() as ctx: m = Module.parse(""" func.func private @foo() { call @bar() : () -> () return } func.func private @bar() """) foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2] SymbolTable.set_symbol_name(bar, "bam") # Note that module.operation counts as a "nested symbol table" which won't # be traversed into, so it is necessary to traverse its children. SymbolTable.replace_all_symbol_uses("bar", "bam", foo) # CHECK: call @bam() # CHECK: func private @bam print(m) # CHECK: Foo symbol: "foo" # CHECK: Bar symbol: "bam" print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}") print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}") # CHECK-LABEL: testSymbolTableVisibility @run def testSymbolTableVisibility(): with Context() as ctx: m = Module.parse(""" func.func private @foo() { return } """) foo = m.operation.regions[0].blocks[0].operations[0] # CHECK: Existing visibility: "private" print(f"Existing visibility: {SymbolTable.get_visibility(foo)}") SymbolTable.set_visibility(foo, "public") # CHECK: func public @foo print(m) # CHECK: testWalkSymbolTables @run def testWalkSymbolTables(): with Context() as ctx: m = Module.parse(""" module @outer { module @inner{ } } """) def callback(symbol_table_op, uses_visible): print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}") # CHECK: SYMBOL TABLE: True: module @inner # CHECK: SYMBOL TABLE: True: module @outer SymbolTable.walk_symbol_tables(m.operation, True, callback) # Make sure exceptions in the callback are handled. def error_callback(symbol_table_op, uses_visible): assert False, "Raised from python" try: SymbolTable.walk_symbol_tables(m.operation, True, error_callback) except RuntimeError as e: # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python print(f"GOT EXCEPTION: {e}")