This function has several overloads that allow to specify the symbol
that should be renamed and the scope for that renaming in different
ways. The overloads were inconsistent in the following way (quoted
strings are `StringAttr`s, other variables are `Operation *`):
* `replaceAllSymbolUses(symbolOp, "new_symbol", scopeOp)` would traverse
into the nested regions of `scopeOp` and hence rename the symbol inside
of `scopeOp`.
* `replaceAllSymbolUses("symbol", "new_symbol", scopeOp)` would *not*
traverse into the nested regions of `scopeOp` and hence *not* rename the
symbol.
The underlying behavior was spread over different places and is somewhat
hard to understand. The two overloads above mainly differed by what
`collectSymbolScopes` computed, which is itself overloaded. If `scopeOp`
is a top-level module, then the overload on `(Operation *, Operation
*)`, which is used in the first of the above cases, computes a scope
where the body region of the module is the `limit`; however, the
overload on `(StringAttr, Operation *)` computed the module op itself as
the `limit`. Later, `walkSymbolTable` would walk the body of the module
if it was given as a region but it would *not* enter the regions of the
module op because that op has a symbol table (which was assumed to be a
*different* scope).
The fix in this commit is change the behavior of `collectSymbolScopes`
such that the `(StringAttr, Operation *)` overload returns a scope for
each region in the `limit` argument.
181 lines
5.1 KiB
Python
181 lines
5.1 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
import gc
|
|
import io
|
|
import itertools
|
|
from mlir.ir import *
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
gc.collect()
|
|
assert Context._get_live_count() == 0
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: testSymbolTableInsert
|
|
@run
|
|
def testSymbolTableInsert():
|
|
with Context() as ctx:
|
|
ctx.allow_unregistered_dialects = True
|
|
m1 = Module.parse(
|
|
"""
|
|
func.func private @foo()
|
|
func.func private @bar()"""
|
|
)
|
|
m2 = Module.parse(
|
|
"""
|
|
func.func private @qux()
|
|
func.func private @foo()
|
|
"foo.bar"() : () -> ()"""
|
|
)
|
|
|
|
symbol_table = SymbolTable(m1.operation)
|
|
|
|
# CHECK: func private @foo
|
|
# CHECK: func private @bar
|
|
assert "foo" in symbol_table
|
|
print(symbol_table["foo"])
|
|
assert "bar" in symbol_table
|
|
bar = symbol_table["bar"]
|
|
print(symbol_table["bar"])
|
|
|
|
assert "qux" not in symbol_table
|
|
|
|
del symbol_table["bar"]
|
|
try:
|
|
symbol_table.erase(symbol_table["bar"])
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
assert False, "expected KeyError"
|
|
|
|
# CHECK: module
|
|
# CHECK: func private @foo()
|
|
print(m1)
|
|
assert "bar" not in symbol_table
|
|
|
|
try:
|
|
print(bar)
|
|
except RuntimeError as e:
|
|
if "the operation has been invalidated" not in str(e):
|
|
raise
|
|
else:
|
|
assert False, "expected RuntimeError due to invalidated operation"
|
|
|
|
qux = m2.body.operations[0]
|
|
m1.body.append(qux)
|
|
symbol_table.insert(qux)
|
|
assert "qux" in symbol_table
|
|
|
|
# Check that insertion actually renames this symbol in the symbol table.
|
|
foo2 = m2.body.operations[0]
|
|
m1.body.append(foo2)
|
|
updated_name = symbol_table.insert(foo2)
|
|
assert foo2.name.value != "foo"
|
|
assert foo2.name == updated_name
|
|
assert isinstance(updated_name, StringAttr)
|
|
|
|
# CHECK: module
|
|
# CHECK: func private @foo()
|
|
# CHECK: func private @qux()
|
|
# CHECK: func private @foo{{.*}}
|
|
print(m1)
|
|
|
|
try:
|
|
symbol_table.insert(m2.body.operations[0])
|
|
except ValueError as e:
|
|
if "Expected operation to have a symbol name" not in str(e):
|
|
raise
|
|
else:
|
|
assert False, "exepcted ValueError when adding a non-symbol"
|
|
|
|
|
|
# CHECK-LABEL: testSymbolTableRAUW
|
|
@run
|
|
def testSymbolTableRAUW():
|
|
with Context() as ctx:
|
|
m = Module.parse(
|
|
"""
|
|
func.func private @foo() {
|
|
call @bar() : () -> ()
|
|
return
|
|
}
|
|
func.func private @bar()
|
|
"""
|
|
)
|
|
foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
|
|
|
|
# Do renaming just within `foo`.
|
|
SymbolTable.set_symbol_name(bar, "bam")
|
|
SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
|
|
# CHECK: call @bam()
|
|
# CHECK: func private @bam
|
|
print(m)
|
|
# CHECK: Foo symbol: StringAttr("foo")
|
|
# CHECK: Bar symbol: StringAttr("bam")
|
|
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
|
|
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
|
|
|
|
# Do renaming within the module.
|
|
SymbolTable.set_symbol_name(bar, "baz")
|
|
SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
|
|
# CHECK: call @baz()
|
|
# CHECK: func private @baz
|
|
print(m)
|
|
# CHECK: Foo symbol: StringAttr("foo")
|
|
# CHECK: Bar symbol: StringAttr("baz")
|
|
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
|
|
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
|
|
|
|
|
|
# CHECK-LABEL: testSymbolTableVisibility
|
|
@run
|
|
def testSymbolTableVisibility():
|
|
with Context() as ctx:
|
|
m = Module.parse(
|
|
"""
|
|
func.func private @foo() {
|
|
return
|
|
}
|
|
"""
|
|
)
|
|
foo = m.operation.regions[0].blocks[0].operations[0]
|
|
# CHECK: Existing visibility: StringAttr("private")
|
|
print(f"Existing visibility: {repr(SymbolTable.get_visibility(foo))}")
|
|
SymbolTable.set_visibility(foo, "public")
|
|
# CHECK: func public @foo
|
|
print(m)
|
|
|
|
|
|
# CHECK: testWalkSymbolTables
|
|
@run
|
|
def testWalkSymbolTables():
|
|
with Context() as ctx:
|
|
m = Module.parse(
|
|
"""
|
|
module @outer {
|
|
module @inner{
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
def callback(symbol_table_op, uses_visible):
|
|
print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
|
|
|
|
# CHECK: SYMBOL TABLE: True: module @inner
|
|
# CHECK: SYMBOL TABLE: True: module @outer
|
|
SymbolTable.walk_symbol_tables(m.operation, True, callback)
|
|
|
|
# Make sure exceptions in the callback are handled.
|
|
def error_callback(symbol_table_op, uses_visible):
|
|
assert False, "Raised from python"
|
|
|
|
try:
|
|
SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
|
|
except RuntimeError as e:
|
|
# CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
|
|
print(f"GOT EXCEPTION: {e}")
|