Files
clang-p2996/mlir/test/python/ir/operation.py
Nikhil Kalra 0ad1f8369c [mlir] Python: Extend print large elements limit to resources (#125738)
If the large element limit is specified, large elements are hidden from
the asm but large resources are not. This change extends the large
elements limit to apply to printed resources as well.
2025-02-05 11:48:11 -08:00

1112 lines
33 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
import gc
import io
import itertools
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
from mlir.dialects._ods_common import _cext
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
return f
def expect_index_error(callback):
try:
_ = callback()
raise RuntimeError("Expected IndexError")
except IndexError:
pass
# Verify iterator based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
@run
def testTraverseOpRegionBlockIterators():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""",
ctx,
)
op = module.operation
assert op.context is ctx
# Get the block using iterators off of the named collections.
regions = list(op.regions)
blocks = list(regions[0].blocks)
# CHECK: MODULE REGIONS=1 BLOCKS=1
print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
# Should verify.
# CHECK: .verify = True
print(f".verify = {module.operation.verify()}")
# Get the blocks from the default collection.
default_blocks = list(regions[0])
# They should compare equal regardless of how obtained.
assert default_blocks == blocks
# Should be able to get the operations from either the named collection
# or the block.
operations = list(blocks[0].operations)
default_operations = list(blocks[0])
assert default_operations == operations
def walk_operations(indent, op):
for i, region in enumerate(op.regions):
print(f"{indent}REGION {i}:")
for j, block in enumerate(region):
print(f"{indent} BLOCK {j}:")
for k, child_op in enumerate(block):
print(f"{indent} OP {k}: {child_op}")
walk_operations(indent + " ", child_op)
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: func
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: %0 = "custom.addi"
# CHECK: OP 1: func.return
walk_operations("", op)
# CHECK: Region iter: <mlir.{{.+}}.RegionIterator
# CHECK: Block iter: <mlir.{{.+}}.BlockIterator
# CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
print(" Region iter:", iter(op.regions))
print(" Block iter:", iter(op.regions[0]))
print("Operation iter:", iter(op.regions[0].blocks[0]))
# Verify index based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
@run
def testTraverseOpRegionBlockIndices():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""",
ctx,
)
def walk_operations(indent, op):
for i in range(len(op.regions)):
region = op.regions[i]
print(f"{indent}REGION {i}:")
for j in range(len(region.blocks)):
block = region.blocks[j]
print(f"{indent} BLOCK {j}:")
for k in range(len(block.operations)):
child_op = block.operations[k]
print(f"{indent} OP {k}: {child_op}")
print(
f"{indent} OP {k}: parent {child_op.operation.parent.name}"
)
walk_operations(indent + " ", child_op)
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: func
# CHECK: OP 0: parent builtin.module
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: %0 = "custom.addi"
# CHECK: OP 0: parent func.func
# CHECK: OP 1: func.return
# CHECK: OP 1: parent func.func
walk_operations("", module.operation)
# CHECK-LABEL: TEST: testBlockAndRegionOwners
@run
def testBlockAndRegionOwners():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
builtin.module {
func.func @f() {
func.return
}
}
""",
ctx,
)
assert module.operation.regions[0].owner == module.operation
assert module.operation.regions[0].blocks[0].owner == module.operation
func = module.body.operations[0]
assert func.operation.regions[0].owner == func
assert func.operation.regions[0].blocks[0].owner == func
# CHECK-LABEL: TEST: testBlockArgumentList
@run
def testBlockArgumentList():
with Context() as ctx:
module = Module.parse(
r"""
func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
return
}
""",
ctx,
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
assert len(entry_block.arguments) == 3
# CHECK: Argument 0, type i32
# CHECK: Argument 1, type f64
# CHECK: Argument 2, type index
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
arg.set_type(new_type)
# CHECK: Argument 0, type i8
# CHECK: Argument 1, type i16
# CHECK: Argument 2, type i24
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
# Check that slicing works for block argument lists.
# CHECK: Argument 1, type i16
# CHECK: Argument 2, type i24
for arg in entry_block.arguments[1:]:
print(f"Argument {arg.arg_number}, type {arg.type}")
# Check that we can concatenate slices of argument lists.
# CHECK: Length: 4
print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:]))
# CHECK: Type: i8
# CHECK: Type: i16
# CHECK: Type: i24
for t in entry_block.arguments.types:
print("Type: ", t)
# Check that slicing and type access compose.
# CHECK: Sliced type: i16
# CHECK: Sliced type: i24
for t in entry_block.arguments[1:].types:
print("Sliced type: ", t)
# Check that slice addition works as expected.
# CHECK: Argument 2, type i24
# CHECK: Argument 0, type i8
restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
for arg in restructured:
print(f"Argument {arg.arg_number}, type {arg.type}")
# CHECK-LABEL: TEST: testOperationOperands
@run
def testOperationOperands():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) {
%0 = "test.producer"() : () -> i64
"test.consumer"(%arg0, %0) : (i32, i64) -> ()
return
}"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
consumer = entry_block.operations[1]
assert len(consumer.operands) == 2
# CHECK: Operand 0, type i32
# CHECK: Operand 1, type i64
for i, operand in enumerate(consumer.operands):
print(f"Operand {i}, type {operand.type}")
# CHECK-LABEL: TEST: testOperationOperandsSlice
@run
def testOperationOperandsSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1() {
%0 = "test.producer0"() : () -> i64
%1 = "test.producer1"() : () -> i64
%2 = "test.producer2"() : () -> i64
%3 = "test.producer3"() : () -> i64
%4 = "test.producer4"() : () -> i64
"test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
return
}"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
consumer = entry_block.operations[5]
assert len(consumer.operands) == 5
for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
assert left == right
# CHECK: test.producer0
# CHECK: test.producer1
# CHECK: test.producer2
# CHECK: test.producer3
# CHECK: test.producer4
full_slice = consumer.operands[:]
for operand in full_slice:
print(operand)
# CHECK: test.producer0
# CHECK: test.producer1
first_two = consumer.operands[0:2]
for operand in first_two:
print(operand)
# CHECK: test.producer3
# CHECK: test.producer4
last_two = consumer.operands[3:]
for operand in last_two:
print(operand)
# CHECK: test.producer0
# CHECK: test.producer2
# CHECK: test.producer4
even = consumer.operands[::2]
for operand in even:
print(operand)
# CHECK: test.producer2
fourth = consumer.operands[::2][1::2]
for operand in fourth:
print(operand)
# CHECK-LABEL: TEST: testOperationOperandsSet
@run
def testOperationOperandsSet():
with Context() as ctx, Location.unknown(ctx):
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1() {
%0 = "test.producer0"() : () -> i64
%1 = "test.producer1"() : () -> i64
%2 = "test.producer2"() : () -> i64
"test.consumer"(%0) : (i64) -> ()
return
}"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
producer1 = entry_block.operations[1]
producer2 = entry_block.operations[2]
consumer = entry_block.operations[3]
assert len(consumer.operands) == 1
type = consumer.operands[0].type
# CHECK: test.producer1
consumer.operands[0] = producer1.result
print(consumer.operands[0])
# CHECK: test.producer2
consumer.operands[-1] = producer2.result
print(consumer.operands[0])
# CHECK-LABEL: TEST: testDetachedOperation
@run
def testDetachedOperation():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signed(32)
op1 = Operation.create(
"custom.op1",
results=[i32, i32],
regions=1,
attributes={
"foo": StringAttr.get("foo_value"),
"bar": StringAttr.get("bar_value"),
},
)
# CHECK: %0:2 = "custom.op1"() ({
# CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
print(op1)
# TODO: Check successors once enough infra exists to do it properly.
# CHECK-LABEL: TEST: testOperationInsertionPoint
@run
def testOperationInsertionPoint():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""",
ctx,
)
# Create test op.
with Location.unknown(ctx):
op1 = Operation.create("custom.op1")
op2 = Operation.create("custom.op2")
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(op1)
ip.insert(op2)
# CHECK: func @f1
# CHECK: "custom.op1"()
# CHECK: "custom.op2"()
# CHECK: %0 = "custom.addi"
print(module)
# Trying to add a previously added op should raise.
try:
ip.insert(op1)
except ValueError:
pass
else:
assert False, "expected insert of attached op to raise"
# CHECK-LABEL: TEST: testOperationWithRegion
@run
def testOperationWithRegion():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signed(32)
op1 = Operation.create("custom.op1", regions=1)
block = op1.regions[0].blocks.append(i32, i32)
# CHECK: "custom.op1"() ({
# CHECK: ^bb0(%arg0: si32, %arg1: si32):
# CHECK: "custom.terminator"() : () -> ()
# CHECK: }) : () -> ()
terminator = Operation.create("custom.terminator")
ip = InsertionPoint(block)
ip.insert(terminator)
print(op1)
# Now add the whole operation to another op.
# TODO: Verify lifetime hazard by nulling out the new owning module and
# accessing op1.
# TODO: Also verify accessing the terminator once both parents are nulled
# out.
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(op1)
# CHECK: func @f1
# CHECK: "custom.op1"()
# CHECK: "custom.terminator"
# CHECK: %0 = "custom.addi"
print(module)
# CHECK-LABEL: TEST: testOperationResultList
@run
def testOperationResultList():
ctx = Context()
module = Module.parse(
r"""
func.func @f1() {
%0:3 = call @f2() : () -> (i32, f64, index)
call @f3() : () -> ()
return
}
func.func private @f2() -> (i32, f64, index)
func.func private @f3() -> ()
""",
ctx,
)
caller = module.body.operations[0]
call = caller.regions[0].blocks[0].operations[0]
assert len(call.results) == 3
# CHECK: Result 0, type i32
# CHECK: Result 1, type f64
# CHECK: Result 2, type index
for res in call.results:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result type i32
# CHECK: Result type f64
# CHECK: Result type index
for t in call.results.types:
print(f"Result type {t}")
# Out of range
expect_index_error(lambda: call.results[3])
expect_index_error(lambda: call.results[-4])
no_results_call = caller.regions[0].blocks[0].operations[1]
assert len(no_results_call.results) == 0
assert no_results_call.results.owner == no_results_call
# CHECK-LABEL: TEST: testOperationResultListSlice
@run
def testOperationResultListSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1() {
"some.op"() : () -> (i1, i2, i3, i4, i5)
return
}
"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
producer = entry_block.operations[0]
assert len(producer.results) == 5
for left, right in zip(producer.results, producer.results[::-1][::-1]):
assert left == right
assert left.result_number == right.result_number
# CHECK: Result 0, type i1
# CHECK: Result 1, type i2
# CHECK: Result 2, type i3
# CHECK: Result 3, type i4
# CHECK: Result 4, type i5
full_slice = producer.results[:]
for res in full_slice:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result 1, type i2
# CHECK: Result 2, type i3
# CHECK: Result 3, type i4
middle = producer.results[1:4]
for res in middle:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result 1, type i2
# CHECK: Result 3, type i4
odd = producer.results[1::2]
for res in odd:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result 3, type i4
# CHECK: Result 1, type i2
inverted_middle = producer.results[-2:0:-2]
for res in inverted_middle:
print(f"Result {res.result_number}, type {res.type}")
# CHECK-LABEL: TEST: testOperationAttributes
@run
def testOperationAttributes():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
"some.op"() { some.attribute = 1 : i8,
other.attribute = 3.0,
dependent = "text" } : () -> ()
""",
ctx,
)
op = module.body.operations[0]
assert len(op.attributes) == 3
iattr = op.attributes["some.attribute"]
fattr = op.attributes["other.attribute"]
sattr = op.attributes["dependent"]
# CHECK: Attribute type i8, value 1
print(f"Attribute type {iattr.type}, value {iattr.value}")
# CHECK: Attribute type f64, value 3.0
print(f"Attribute type {fattr.type}, value {fattr.value}")
# CHECK: Attribute value text
print(f"Attribute value {sattr.value}")
# CHECK: Attribute value b'text'
print(f"Attribute value {sattr.value_bytes}")
# We don't know in which order the attributes are stored.
# CHECK-DAG: NamedAttribute(dependent="text")
# CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
# CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
for attr in op.attributes:
print(str(attr))
# Check that exceptions are raised as expected.
try:
op.attributes["does_not_exist"]
except KeyError:
pass
else:
assert False, "expected KeyError on accessing a non-existent attribute"
try:
op.attributes[42]
except IndexError:
pass
else:
assert False, "expected IndexError on accessing an out-of-bounds attribute"
# CHECK-LABEL: TEST: testOperationPrint
@run
def testOperationPrint():
ctx = Context()
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
%1 = arith.constant dense_resource<resource1> : tensor<3xi64>
return %arg0 : i32
}
{-#
dialect_resources: {
builtin: {
resource1: "0x08000000010000000000000002000000000000000300000000000000"
}
}
#-}
""",
ctx,
)
# Test print to stdout.
# CHECK: return %arg0 : i32
# CHECK: resource1: "0x08
module.operation.print()
# Test print to text file.
f = io.StringIO()
# CHECK: <class 'str'>
# CHECK: return %arg0 : i32
module.operation.print(file=f)
str_value = f.getvalue()
print(str_value.__class__)
print(f.getvalue())
# Test roundtrip to bytecode.
bytecode_stream = io.BytesIO()
module.operation.write_bytecode(bytecode_stream, desired_version=1)
bytecode = bytecode_stream.getvalue()
assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
ctx2 = Context()
module_roundtrip = Module.parse(bytecode, ctx2)
f = io.StringIO()
module_roundtrip.operation.print(file=f)
roundtrip_value = f.getvalue()
assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
# Test print to binary file.
f = io.BytesIO()
# CHECK: <class 'bytes'>
# CHECK: return %arg0 : i32
module.operation.print(file=f, binary=True)
bytes_value = f.getvalue()
print(bytes_value.__class__)
print(bytes_value)
# Test print local_scope.
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
module.operation.print(enable_debug_info=True, use_local_scope=True)
# Test printing using state.
state = AsmState(module.operation)
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
module.operation.print(state)
# Test print with options.
# CHECK: value = dense_resource<__elided__> : tensor<4xi32>
# CHECK: "func.return"(%arg0) : (i32) -> () -:5:7
# CHECK-NOT: resource1: "0x08
module.operation.print(
large_elements_limit=2,
enable_debug_info=True,
pretty_debug_info=True,
print_generic_op_form=True,
use_local_scope=True,
)
# Test print with skip_regions option
# CHECK: func.func @f1(%arg0: i32) -> i32
# CHECK-NOT: func.return
module.body.operations[0].print(
skip_regions=True,
)
# CHECK-LABEL: TEST: testKnownOpView
@run
def testKnownOpView():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
module = Module.parse(
r"""
%1 = "custom.f32"() : () -> f32
%2 = "custom.f32"() : () -> f32
%3 = arith.addf %1, %2 : f32
%4 = arith.constant 0 : i32
"""
)
print(module)
# addf should map to a known OpView class in the arithmetic dialect.
# We know the OpView for it defines an 'lhs' attribute.
addf = module.body.operations[2]
# CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
print(repr(addf))
# CHECK: "custom.f32"()
print(addf.lhs)
# One of the custom ops should resolve to the default OpView.
custom = module.body.operations[0]
# CHECK: OpView object
print(repr(custom))
# Check again to make sure negative caching works.
custom = module.body.operations[0]
# CHECK: OpView object
print(repr(custom))
# constant should map to an extension OpView class in the arithmetic dialect.
constant = module.body.operations[3]
# CHECK: <mlir.dialects.arith.ConstantOp object
print(repr(constant))
# Checks that the arith extension is being registered successfully
# (literal_value is a property on the extension class but not on the default OpView).
# CHECK: literal value 0
print("literal value", constant.literal_value)
# Checks that "late" registration/replacement (i.e., post all module loading/initialization)
# is working correctly.
@_cext.register_operation(arith._Dialect, replace=True)
class ConstantOp(arith.ConstantOp):
def __init__(self, result, value, *, loc=None, ip=None):
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
super().__init__(value, loc=loc, ip=ip)
constant = module.body.operations[3]
# CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object
print(repr(constant))
# CHECK-LABEL: TEST: testSingleResultProperty
@run
def testSingleResultProperty():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
module = Module.parse(
r"""
"custom.no_result"() : () -> ()
%0:2 = "custom.two_result"() : () -> (f32, f32)
%1 = "custom.one_result"() : () -> f32
"""
)
print(module)
try:
module.body.operations[0].result
except ValueError as e:
# CHECK: Cannot call .result on operation custom.no_result which has 0 results
print(e)
else:
assert False, "Expected exception"
try:
module.body.operations[1].result
except ValueError as e:
# CHECK: Cannot call .result on operation custom.two_result which has 2 results
print(e)
else:
assert False, "Expected exception"
# CHECK: %1 = "custom.one_result"() : () -> f32
print(module.body.operations[2])
def create_invalid_operation():
# This module has two region and is invalid verify that we fallback
# to the generic printer for safety.
op = Operation.create("builtin.module", regions=2)
op.regions[0].blocks.append()
return op
# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
@run
def testInvalidOperationStrSoftFails():
ctx = Context()
with Location.unknown(ctx):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
# CHECK: "builtin.module"() ({
# CHECK: }) : () -> ()
print(invalid_op)
try:
invalid_op.verify()
except MLIRError as e:
# CHECK: Exception: <
# CHECK: Verification failed:
# CHECK: error: unknown: 'builtin.module' op requires one region
# CHECK: note: unknown: see current operation:
# CHECK: "builtin.module"() ({
# CHECK: ^bb0:
# CHECK: }, {
# CHECK: }) : () -> ()
# CHECK: >
print(f"Exception: <{e}>")
# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
@run
def testInvalidModuleStrSoftFails():
ctx = Context()
with Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
# CHECK: "builtin.module"() ({
# CHECK: }) : () -> ()
print(module)
# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
@run
def testInvalidOperationGetAsmBinarySoftFails():
ctx = Context()
with Location.unknown(ctx):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
# CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
print(invalid_op.get_asm(binary=True))
# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
@run
def testCreateWithInvalidAttributes():
ctx = Context()
with Location.unknown(ctx):
try:
Operation.create(
"builtin.module", attributes={None: StringAttr.get("name")}
)
except Exception as e:
# CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
print(e)
try:
Operation.create("builtin.module", attributes={42: StringAttr.get("name")})
except Exception as e:
# CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
print(e)
try:
Operation.create("builtin.module", attributes={"some_key": ctx})
except Exception as e:
# CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
print(e)
try:
Operation.create("builtin.module", attributes={"some_key": None})
except Exception as e:
# CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
print(e)
# CHECK-LABEL: TEST: testOperationName
@run
def testOperationName():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
%0 = "custom.op1"() : () -> f32
%1 = "custom.op2"() : () -> i32
%2 = "custom.op1"() : () -> f32
""",
ctx,
)
# CHECK: custom.op1
# CHECK: custom.op2
# CHECK: custom.op1
for op in module.body.operations:
print(op.operation.name)
# CHECK-LABEL: TEST: testCapsuleConversions
@run
def testCapsuleConversions():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Operation.create("custom.op1").operation
m_capsule = m._CAPIPtr
assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
m2 = Operation._CAPICreate(m_capsule)
assert m2 is m
# CHECK-LABEL: TEST: testOperationErase
@run
def testOperationErase():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Module.create()
with InsertionPoint(m.body):
op = Operation.create("custom.op1")
# CHECK: "custom.op1"
print(m)
op.operation.erase()
# CHECK-NOT: "custom.op1"
print(m)
# Ensure we can create another operation
Operation.create("custom.op2")
# CHECK-LABEL: TEST: testOperationClone
@run
def testOperationClone():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Module.create()
with InsertionPoint(m.body):
op = Operation.create("custom.op1")
# CHECK: "custom.op1"
print(m)
clone = op.operation.clone()
op.operation.erase()
# CHECK: "custom.op1"
print(m)
# CHECK-LABEL: TEST: testOperationLoc
@run
def testOperationLoc():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx:
loc = Location.name("loc")
op = Operation.create("custom.op", loc=loc)
assert op.location == loc
assert op.operation.location == loc
# CHECK-LABEL: TEST: testModuleMerge
@run
def testModuleMerge():
with Context():
m1 = Module.parse("func.func private @foo()")
m2 = Module.parse(
"""
func.func private @bar()
func.func private @qux()
"""
)
foo = m1.body.operations[0]
bar = m2.body.operations[0]
qux = m2.body.operations[1]
bar.move_before(foo)
qux.move_after(foo)
# CHECK: module
# CHECK: func private @bar
# CHECK: func private @foo
# CHECK: func private @qux
print(m1)
# CHECK: module {
# CHECK-NEXT: }
print(m2)
# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
@run
def testAppendMoveFromAnotherBlock():
with Context():
m1 = Module.parse("func.func private @foo()")
m2 = Module.parse("func.func private @bar()")
func = m1.body.operations[0]
m2.body.append(func)
# CHECK: module
# CHECK: func private @bar
# CHECK: func private @foo
print(m2)
# CHECK: module {
# CHECK-NEXT: }
print(m1)
# CHECK-LABEL: TEST: testDetachFromParent
@run
def testDetachFromParent():
with Context():
m1 = Module.parse("func.func private @foo()")
func = m1.body.operations[0].detach_from_parent()
try:
func.detach_from_parent()
except ValueError as e:
if "has no parent" not in str(e):
raise
else:
assert False, "expected ValueError when detaching a detached operation"
print(m1)
# CHECK-NOT: func private @foo
# CHECK-LABEL: TEST: testOperationHash
@run
def testOperationHash():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx, Location.unknown():
op = Operation.create("custom.op1")
assert hash(op) == hash(op.operation)
# CHECK-LABEL: TEST: testOperationParse
@run
def testOperationParse():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
# Generic operation parsing.
m = Operation.parse("module {}")
o = Operation.parse('"test.foo"() : () -> ()')
assert isinstance(m, ModuleOp)
assert type(o) is OpView
# Parsing specific operation.
m = ModuleOp.parse("module {}")
assert isinstance(m, ModuleOp)
try:
ModuleOp.parse('"test.foo"() : () -> ()')
except MLIRError as e:
# CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
print(f"error: {e}")
else:
assert False, "expected error"
o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
# CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
print(
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)
# CHECK-LABEL: TEST: testOpWalk
@run
def testOpWalk():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
builtin.module {
func.func @f() {
func.return
}
}
""",
ctx,
)
def callback(op):
print(op.name)
return WalkResult.ADVANCE
# Test post-order walk (default).
# CHECK-NEXT: Post-order
# CHECK-NEXT: func.return
# CHECK-NEXT: func.func
# CHECK-NEXT: builtin.module
print("Post-order")
module.operation.walk(callback)
# Test pre-order walk.
# CHECK-NEXT: Pre-order
# CHECK-NEXT: builtin.module
# CHECK-NEXT: func.fun
# CHECK-NEXT: func.return
print("Pre-order")
module.operation.walk(callback, WalkOrder.PRE_ORDER)
# Test interrput.
# CHECK-NEXT: Interrupt post-order
# CHECK-NEXT: func.return
print("Interrupt post-order")
def callback(op):
print(op.name)
return WalkResult.INTERRUPT
module.operation.walk(callback)
# Test skip.
# CHECK-NEXT: Skip pre-order
# CHECK-NEXT: builtin.module
print("Skip pre-order")
def callback(op):
print(op.name)
return WalkResult.SKIP
module.operation.walk(callback, WalkOrder.PRE_ORDER)
# Test exception.
# CHECK: Exception
# CHECK-NEXT: func.return
# CHECK-NEXT: Exception raised
print("Exception")
def callback(op):
print(op.name)
raise ValueError
return WalkResult.ADVANCE
try:
module.operation.walk(callback)
except RuntimeError:
print("Exception raised")