While working on an integration, I found a lot of inconsistencies on IR printing and verification. It turns out that we were: * Only doing "soft fail" verification on IR printing of Operation, not of a Module. * Failed verification was interacting badly with binary=True IR printing (causing a TypeError trying to pass an `str` to a `bytes` based handle). * For systematic integrations, it is often desirable to control verification yourself so that you can explicitly handle errors. This patch: * Trues up the "soft fail" semantics by having `Module.__str__` delegate to `Operation.__str__` vs having a shortcut implementation. * Fixes soft fail in the presence of binary=True (and adds an additional happy path test case to make sure the binary functionality works). * Adds an `assume_verified` boolean flag to the `print`/`get_asm` methods which disables internal verification, presupposing that the caller has taken care of it. It turns out that we had a number of tests which were generating illegal IR but it wasn't being caught because they were doing a print on the `Module` vs operation. All except two were trivially fixed: * linalg/ops.py : Had two tests for direct constructing a Matmul incorrectly. Fixing them made them just like the next two tests so just deleted (no need to test the verifier only at this level). * linalg/opdsl/emit_structured_generic.py : Hand coded conv and pooling tests appear to be using illegal shaped inputs/outputs, causing a verification failure. I just used the `assume_verified=` flag to restore the original behavior and left a TODO. Will get someone who owns that to fix it properly in a followup (would also be nice to break this file up into multiple test modules as it is hard to tell exactly what is failing). Notes to downstreams: * If, like some of our tests, you get verification failures after this patch, it is likely that your IR was always invalid and you will need to fix the root cause. To temporarily revert to prior (broken) behavior, replace calls like `print(module)` with `print(module.operation.get_asm(assume_verified=True))`. Differential Revision: https://reviews.llvm.org/D114680
224 lines
7.3 KiB
Python
224 lines
7.3 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from mlir.ir import *
|
|
import mlir.dialects.builtin as builtin
|
|
import mlir.dialects.std as std
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: testFromPyFunc
|
|
@run
|
|
def testFromPyFunc():
|
|
with Context() as ctx, Location.unknown() as loc:
|
|
ctx.allow_unregistered_dialects = True
|
|
m = builtin.ModuleOp()
|
|
f32 = F32Type.get()
|
|
f64 = F64Type.get()
|
|
with InsertionPoint(m.body):
|
|
# CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
|
|
# CHECK: return %arg0 : f64
|
|
@builtin.FuncOp.from_py_func(f64)
|
|
def unary_return(a):
|
|
return a
|
|
|
|
# CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
|
|
# CHECK: return %arg0, %arg1 : f32, f64
|
|
@builtin.FuncOp.from_py_func(f32, f64)
|
|
def binary_return(a, b):
|
|
return a, b
|
|
|
|
# CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
|
|
# CHECK: return
|
|
@builtin.FuncOp.from_py_func(f32, f64)
|
|
def none_return(a, b):
|
|
pass
|
|
|
|
# CHECK-LABEL: func @call_unary
|
|
# CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
|
|
# CHECK: return %0 : f64
|
|
@builtin.FuncOp.from_py_func(f64)
|
|
def call_unary(a):
|
|
return unary_return(a)
|
|
|
|
# CHECK-LABEL: func @call_binary
|
|
# CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
|
|
# CHECK: return %0#0, %0#1 : f32, f64
|
|
@builtin.FuncOp.from_py_func(f32, f64)
|
|
def call_binary(a, b):
|
|
return binary_return(a, b)
|
|
|
|
# We expect coercion of a single result operation to a returned value.
|
|
# CHECK-LABEL: func @single_result_op
|
|
# CHECK: %0 = "custom.op1"() : () -> f32
|
|
# CHECK: return %0 : f32
|
|
@builtin.FuncOp.from_py_func()
|
|
def single_result_op():
|
|
return Operation.create("custom.op1", results=[f32])
|
|
|
|
# CHECK-LABEL: func @call_none
|
|
# CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
|
|
# CHECK: return
|
|
@builtin.FuncOp.from_py_func(f32, f64)
|
|
def call_none(a, b):
|
|
return none_return(a, b)
|
|
|
|
## Variants and optional feature tests.
|
|
# CHECK-LABEL: func @from_name_arg
|
|
@builtin.FuncOp.from_py_func(f32, f64, name="from_name_arg")
|
|
def explicit_name(a, b):
|
|
return b
|
|
|
|
@builtin.FuncOp.from_py_func(f32, f64)
|
|
def positional_func_op(a, b, func_op):
|
|
assert isinstance(func_op, builtin.FuncOp)
|
|
return b
|
|
|
|
@builtin.FuncOp.from_py_func(f32, f64)
|
|
def kw_func_op(a, b=None, func_op=None):
|
|
assert isinstance(func_op, builtin.FuncOp)
|
|
return b
|
|
|
|
@builtin.FuncOp.from_py_func(f32, f64)
|
|
def kwargs_func_op(a, b=None, **kwargs):
|
|
assert isinstance(kwargs["func_op"], builtin.FuncOp)
|
|
return b
|
|
|
|
# CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
|
|
# CHECK: return %arg1 : f64
|
|
@builtin.FuncOp.from_py_func(f32, f64, results=[f64])
|
|
def explicit_results(a, b):
|
|
std.ReturnOp([b])
|
|
|
|
print(m)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testFromPyFuncErrors
|
|
@run
|
|
def testFromPyFuncErrors():
|
|
with Context() as ctx, Location.unknown() as loc:
|
|
m = builtin.ModuleOp()
|
|
f32 = F32Type.get()
|
|
f64 = F64Type.get()
|
|
with InsertionPoint(m.body):
|
|
try:
|
|
|
|
@builtin.FuncOp.from_py_func(f64, results=[f64])
|
|
def unary_return(a):
|
|
return a
|
|
except AssertionError as e:
|
|
# CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
|
|
print(e)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testBuildFuncOp
|
|
@run
|
|
def testBuildFuncOp():
|
|
ctx = Context()
|
|
with Location.unknown(ctx) as loc:
|
|
m = builtin.ModuleOp()
|
|
|
|
f32 = F32Type.get()
|
|
tensor_type = RankedTensorType.get((2, 3, 4), f32)
|
|
with InsertionPoint.at_block_begin(m.body):
|
|
func = builtin.FuncOp(name="some_func",
|
|
type=FunctionType.get(
|
|
inputs=[tensor_type, tensor_type],
|
|
results=[tensor_type]),
|
|
visibility="nested")
|
|
# CHECK: Name is: "some_func"
|
|
print("Name is: ", func.name)
|
|
|
|
# CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
|
print("Type is: ", func.type)
|
|
|
|
# CHECK: Visibility is: "nested"
|
|
print("Visibility is: ", func.visibility)
|
|
|
|
try:
|
|
entry_block = func.entry_block
|
|
except IndexError as e:
|
|
# CHECK: External function does not have a body
|
|
print(e)
|
|
|
|
with InsertionPoint(func.add_entry_block()):
|
|
std.ReturnOp([func.entry_block.arguments[0]])
|
|
pass
|
|
|
|
try:
|
|
func.add_entry_block()
|
|
except IndexError as e:
|
|
# CHECK: The function already has an entry block!
|
|
print(e)
|
|
|
|
# Try the callback builder and passing type as tuple.
|
|
func = builtin.FuncOp(name="some_other_func",
|
|
type=([tensor_type, tensor_type], [tensor_type]),
|
|
visibility="nested",
|
|
body_builder=lambda func: std.ReturnOp(
|
|
[func.entry_block.arguments[0]]))
|
|
|
|
# CHECK: module {
|
|
# CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
|
# CHECK: return %arg0 : tensor<2x3x4xf32>
|
|
# CHECK: }
|
|
# CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
|
# CHECK: return %arg0 : tensor<2x3x4xf32>
|
|
# CHECK: }
|
|
print(m)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testFuncArgumentAccess
|
|
@run
|
|
def testFuncArgumentAccess():
|
|
with Context() as ctx, Location.unknown():
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.create()
|
|
f32 = F32Type.get()
|
|
f64 = F64Type.get()
|
|
with InsertionPoint(module.body):
|
|
func = builtin.FuncOp("some_func", ([f32, f32], [f32, f32]))
|
|
with InsertionPoint(func.add_entry_block()):
|
|
std.ReturnOp(func.arguments)
|
|
func.arg_attrs = ArrayAttr.get([
|
|
DictAttr.get({
|
|
"custom_dialect.foo": StringAttr.get("bar"),
|
|
"custom_dialect.baz": UnitAttr.get()
|
|
}),
|
|
DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
|
|
])
|
|
func.result_attrs = ArrayAttr.get([
|
|
DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
|
|
DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
|
|
])
|
|
|
|
other = builtin.FuncOp("other_func", ([f32, f32], []))
|
|
with InsertionPoint(other.add_entry_block()):
|
|
std.ReturnOp([])
|
|
other.arg_attrs = [
|
|
DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
|
|
DictAttr.get()
|
|
]
|
|
|
|
# CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
|
|
print(func.arg_attrs)
|
|
|
|
# CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
|
|
print(func.result_attrs)
|
|
|
|
# CHECK: func @some_func(
|
|
# CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
|
|
# CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
|
|
# CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
|
|
# CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
|
|
# CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
|
|
#
|
|
# CHECK: func @other_func(
|
|
# CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
|
|
# CHECK: %{{.*}}: f32)
|
|
print(module)
|