148 lines
4.4 KiB
Python
148 lines
4.4 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from mlir import ir
|
|
from mlir.dialects.transform import interpreter as interp
|
|
|
|
|
|
def test_in_context(f):
|
|
with ir.Context(), ir.Location.unknown():
|
|
f()
|
|
return f
|
|
|
|
|
|
print_root_module = """
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%root: !transform.any_op) {
|
|
transform.print %root { name = \"from interpreter\" }: !transform.any_op
|
|
transform.yield
|
|
}
|
|
}"""
|
|
|
|
|
|
@test_in_context
|
|
def print_self():
|
|
m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self"))
|
|
interp.apply_named_sequence(m, m.body.operations[0], m)
|
|
|
|
|
|
# CHECK-LABEL: print_self
|
|
# CHECK: transform.named_sequence @__transform_main
|
|
# CHECK: transform.print
|
|
# CHECK: transform.yield
|
|
|
|
|
|
@test_in_context
|
|
def print_other():
|
|
transform = ir.Module.parse(
|
|
print_root_module.replace("from interpreter", "print_other")
|
|
)
|
|
payload = ir.Module.parse("module attributes { this.is.payload } {}")
|
|
interp.apply_named_sequence(payload, transform.body.operations[0], transform)
|
|
|
|
|
|
# CHECK-LABEL: print_other
|
|
# CHECK-NOT: transform
|
|
# CHECK: this.is.payload
|
|
|
|
|
|
@test_in_context
|
|
def transform_options():
|
|
options = interp.TransformOptions()
|
|
options.expensive_checks = False
|
|
options.enforce_single_top_level_transform_op = True
|
|
m = ir.Module.parse(
|
|
print_root_module.replace("from interpreter", "transform_options")
|
|
)
|
|
payload = ir.Module.parse("module attributes { this.is.payload } {}")
|
|
interp.apply_named_sequence(payload, m.body.operations[0], m, options)
|
|
|
|
|
|
# CHECK-LABEL: transform_options
|
|
|
|
|
|
@test_in_context
|
|
def failed():
|
|
payload = ir.Module.parse("module attributes { this.is.payload } {}")
|
|
try:
|
|
interp.apply_named_sequence(payload, payload, payload)
|
|
except ValueError as e:
|
|
assert (
|
|
"must implement TransformOpInterface to be used as transform root" in str(e)
|
|
)
|
|
|
|
|
|
print_root_via_include_module = """
|
|
module @print_root_via_include_module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
|
|
transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
|
|
transform.named_sequence @__transform_main(%root: !transform.any_op) {
|
|
transform.include @callee2 failures(propagate)
|
|
(%root) : (!transform.any_op) -> ()
|
|
transform.yield
|
|
}
|
|
}"""
|
|
|
|
callee2_definition = """
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
|
|
transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
|
|
transform.include @callee1 failures(propagate)
|
|
(%root) : (!transform.any_op) -> ()
|
|
transform.yield
|
|
}
|
|
}
|
|
"""
|
|
|
|
callee1_definition = """
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
|
|
transform.print %root { name = \"from interpreter\" }: !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|
|
"""
|
|
|
|
|
|
@test_in_context
|
|
def include():
|
|
main = ir.Module.parse(print_root_via_include_module)
|
|
callee1 = ir.Module.parse(callee1_definition)
|
|
callee2 = ir.Module.parse(callee2_definition)
|
|
interp.copy_symbols_and_merge_into(main, callee1)
|
|
interp.copy_symbols_and_merge_into(main, callee2)
|
|
|
|
# CHECK: @print_root_via_include_module
|
|
# CHECK: transform.named_sequence @__transform_main
|
|
# CHECK: transform.include @callee2
|
|
#
|
|
# CHECK: transform.named_sequence @callee1
|
|
# CHECK: transform.print
|
|
#
|
|
# CHECK: transform.named_sequence @callee2
|
|
# CHECK: transform.include @callee1
|
|
interp.apply_named_sequence(main, main.body.operations[0], main)
|
|
|
|
|
|
@test_in_context
|
|
def partial_include():
|
|
main = ir.Module.parse(print_root_via_include_module)
|
|
callee2 = ir.Module.parse(callee2_definition)
|
|
interp.copy_symbols_and_merge_into(main, callee2)
|
|
|
|
try:
|
|
interp.apply_named_sequence(main, main.body.operations[0], main)
|
|
except ValueError as e:
|
|
assert "Failed to apply" in str(e)
|
|
|
|
|
|
@test_in_context
|
|
def repeated_include():
|
|
main = ir.Module.parse(print_root_via_include_module)
|
|
callee2 = ir.Module.parse(callee2_definition)
|
|
interp.copy_symbols_and_merge_into(main, callee2)
|
|
|
|
try:
|
|
interp.copy_symbols_and_merge_into(main, callee2)
|
|
except ValueError as e:
|
|
assert "doubly defined symbol @callee2" in str(e)
|