Prefer to get the path to libmlir_runner_utils and libmlir_c_runner_utils via %mlir_runner_utils and %mlir_c_runner_utils. Fallback to the previous paths only if they aren't defined. This ensures the test will pass regardless of the build configuration used downstream.
803 lines
26 KiB
Python
803 lines
26 KiB
Python
# RUN: env MLIR_RUNNER_UTILS=%mlir_runner_utils MLIR_C_RUNNER_UTILS=%mlir_c_runner_utils %PYTHON %s 2>&1 | FileCheck %s
|
|
# REQUIRES: host-supports-jit
|
|
import gc, sys, os, tempfile
|
|
from mlir.ir import *
|
|
from mlir.passmanager import *
|
|
from mlir.execution_engine import *
|
|
from mlir.runtime import *
|
|
from ml_dtypes import bfloat16, float8_e5m2
|
|
|
|
MLIR_RUNNER_UTILS = os.getenv(
|
|
"MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so"
|
|
)
|
|
MLIR_C_RUNNER_UTILS = os.getenv(
|
|
"MLIR_C_RUNNER_UTILS", "../../../../lib/libmlir_c_runner_utils.so"
|
|
)
|
|
|
|
# Log everything to stderr and flush so that we have a unified stream to match
|
|
# errors/info emitted by MLIR to stderr.
|
|
def log(*args):
|
|
print(*args, file=sys.stderr)
|
|
sys.stderr.flush()
|
|
|
|
|
|
def run(f):
|
|
log("\nTEST:", f.__name__)
|
|
f()
|
|
gc.collect()
|
|
assert Context._get_live_count() == 0
|
|
|
|
|
|
# Verify capsule interop.
|
|
# CHECK-LABEL: TEST: testCapsule
|
|
def testCapsule():
|
|
with Context():
|
|
module = Module.parse(
|
|
r"""
|
|
llvm.func @none() {
|
|
llvm.return
|
|
}
|
|
"""
|
|
)
|
|
execution_engine = ExecutionEngine(module)
|
|
execution_engine_capsule = execution_engine._CAPIPtr
|
|
# CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
|
|
log(repr(execution_engine_capsule))
|
|
execution_engine._testing_release()
|
|
execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
|
|
# CHECK: _mlirExecutionEngine.ExecutionEngine
|
|
log(repr(execution_engine1))
|
|
|
|
|
|
run(testCapsule)
|
|
|
|
|
|
# Test invalid ExecutionEngine creation
|
|
# CHECK-LABEL: TEST: testInvalidModule
|
|
def testInvalidModule():
|
|
with Context():
|
|
# Builtin function
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @foo() { return }
|
|
"""
|
|
)
|
|
# CHECK: Got RuntimeError: Failure while creating the ExecutionEngine.
|
|
try:
|
|
execution_engine = ExecutionEngine(module)
|
|
except RuntimeError as e:
|
|
log("Got RuntimeError: ", e)
|
|
|
|
|
|
run(testInvalidModule)
|
|
|
|
|
|
def lowerToLLVM(module):
|
|
pm = PassManager.parse(
|
|
"builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)"
|
|
)
|
|
pm.run(module.operation)
|
|
return module
|
|
|
|
|
|
# Test simple ExecutionEngine execution
|
|
# CHECK-LABEL: TEST: testInvokeVoid
|
|
def testInvokeVoid():
|
|
with Context():
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @void() attributes { llvm.emit_c_interface } {
|
|
return
|
|
}
|
|
"""
|
|
)
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
# Nothing to check other than no exception thrown here.
|
|
execution_engine.invoke("void")
|
|
|
|
|
|
run(testInvokeVoid)
|
|
|
|
|
|
# Test argument passing and result with a simple float addition.
|
|
# CHECK-LABEL: TEST: testInvokeFloatAdd
|
|
def testInvokeFloatAdd():
|
|
with Context():
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
|
|
%add = arith.addf %arg0, %arg1 : f32
|
|
return %add : f32
|
|
}
|
|
"""
|
|
)
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
# Prepare arguments: two input floats and one result.
|
|
# Arguments must be passed as pointers.
|
|
c_float_p = ctypes.c_float * 1
|
|
arg0 = c_float_p(42.0)
|
|
arg1 = c_float_p(2.0)
|
|
res = c_float_p(-1.0)
|
|
execution_engine.invoke("add", arg0, arg1, res)
|
|
# CHECK: 42.0 + 2.0 = 44.0
|
|
log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
|
|
|
|
|
|
run(testInvokeFloatAdd)
|
|
|
|
|
|
# Test callback
|
|
# CHECK-LABEL: TEST: testBasicCallback
|
|
def testBasicCallback():
|
|
# Define a callback function that takes a float and an integer and returns a float.
|
|
@ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
|
|
def callback(a, b):
|
|
return a / 2 + b / 2
|
|
|
|
with Context():
|
|
# The module just forwards to a runtime function known as "some_callback_into_python".
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
|
|
%resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
|
|
return %resf : f32
|
|
}
|
|
func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
|
|
"""
|
|
)
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.register_runtime("some_callback_into_python", callback)
|
|
|
|
# Prepare arguments: two input floats and one result.
|
|
# Arguments must be passed as pointers.
|
|
c_float_p = ctypes.c_float * 1
|
|
c_int_p = ctypes.c_int * 1
|
|
arg0 = c_float_p(42.0)
|
|
arg1 = c_int_p(2)
|
|
res = c_float_p(-1.0)
|
|
execution_engine.invoke("add", arg0, arg1, res)
|
|
# CHECK: 42.0 + 2 = 44.0
|
|
log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
|
|
|
|
|
|
run(testBasicCallback)
|
|
|
|
|
|
# Test callback with an unranked memref
|
|
# CHECK-LABEL: TEST: testUnrankedMemRefCallback
|
|
def testUnrankedMemRefCallback():
|
|
# Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
|
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
|
def callback(a):
|
|
arr = unranked_memref_to_numpy(a, np.float32)
|
|
log("Inside callback: ")
|
|
log(arr)
|
|
|
|
with Context():
|
|
# The module just forwards to a runtime function known as "some_callback_into_python".
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
|
|
call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
|
|
return
|
|
}
|
|
func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
|
|
"""
|
|
)
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.register_runtime("some_callback_into_python", callback)
|
|
inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
|
|
# CHECK: Inside callback:
|
|
# CHECK{LITERAL}: [[1. 2.]
|
|
# CHECK{LITERAL}: [3. 4.]]
|
|
execution_engine.invoke(
|
|
"callback_memref",
|
|
ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
|
|
)
|
|
inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
|
|
strided_arr = np.lib.stride_tricks.as_strided(
|
|
inp_arr_1, strides=(4, 0), shape=(3, 4)
|
|
)
|
|
# CHECK: Inside callback:
|
|
# CHECK{LITERAL}: [[5. 5. 5. 5.]
|
|
# CHECK{LITERAL}: [6. 6. 6. 6.]
|
|
# CHECK{LITERAL}: [7. 7. 7. 7.]]
|
|
execution_engine.invoke(
|
|
"callback_memref",
|
|
ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
|
|
)
|
|
|
|
|
|
run(testUnrankedMemRefCallback)
|
|
|
|
|
|
# Test callback with a ranked memref.
|
|
# CHECK-LABEL: TEST: testRankedMemRefCallback
|
|
def testRankedMemRefCallback():
|
|
# Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
|
|
@ctypes.CFUNCTYPE(
|
|
None,
|
|
ctypes.POINTER(
|
|
make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
|
|
),
|
|
)
|
|
def callback(a):
|
|
arr = ranked_memref_to_numpy(a)
|
|
log("Inside Callback: ")
|
|
log(arr)
|
|
|
|
with Context():
|
|
# The module just forwards to a runtime function known as "some_callback_into_python".
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
|
|
call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
|
|
return
|
|
}
|
|
func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
|
|
"""
|
|
)
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.register_runtime("some_callback_into_python", callback)
|
|
inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
|
|
# CHECK: Inside Callback:
|
|
# CHECK{LITERAL}: [[1. 5.]
|
|
# CHECK{LITERAL}: [6. 7.]]
|
|
execution_engine.invoke(
|
|
"callback_memref",
|
|
ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
|
|
)
|
|
|
|
|
|
run(testRankedMemRefCallback)
|
|
|
|
|
|
# Test callback with a ranked memref with non-zero offset.
|
|
# CHECK-LABEL: TEST: testRankedMemRefWithOffsetCallback
|
|
def testRankedMemRefWithOffsetCallback():
|
|
# Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
|
|
@ctypes.CFUNCTYPE(
|
|
None,
|
|
ctypes.POINTER(
|
|
make_nd_memref_descriptor(1, np.ctypeslib.as_ctypes_type(np.float32))
|
|
),
|
|
)
|
|
def callback(a):
|
|
arr = ranked_memref_to_numpy(a)
|
|
log("Inside Callback: ")
|
|
log(arr)
|
|
|
|
with Context():
|
|
# The module takes a subview of the argument memref and calls the callback with it
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
|
|
%base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
|
|
%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
|
|
%cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[?], offset: ?>>
|
|
call @some_callback_into_python(%cast) : (memref<?xf32, strided<[?], offset: ?>>) -> ()
|
|
return
|
|
}
|
|
func.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
|
|
"""
|
|
)
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.register_runtime("some_callback_into_python", callback)
|
|
inp_arr = np.array([0, 0, 0, 1, 2], np.float32)
|
|
# CHECK: Inside Callback:
|
|
# CHECK{LITERAL}: [1. 2.]
|
|
execution_engine.invoke(
|
|
"callback_memref",
|
|
ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
|
|
)
|
|
|
|
|
|
run(testRankedMemRefWithOffsetCallback)
|
|
|
|
|
|
# Test callback with an unranked memref with non-zero offset
|
|
# CHECK-LABEL: TEST: testUnrankedMemRefWithOffsetCallback
|
|
def testUnrankedMemRefWithOffsetCallback():
|
|
# Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
|
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
|
def callback(a):
|
|
arr = unranked_memref_to_numpy(a, np.float32)
|
|
log("Inside callback: ")
|
|
log(arr)
|
|
|
|
with Context():
|
|
# The module takes a subview of the argument memref, casts it to an unranked memref and
|
|
# calls the callback with it.
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
|
|
%base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
|
|
%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
|
|
%cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<*xf32>
|
|
call @some_callback_into_python(%cast) : (memref<*xf32>) -> ()
|
|
return
|
|
}
|
|
func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface}
|
|
"""
|
|
)
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.register_runtime("some_callback_into_python", callback)
|
|
inp_arr = np.array([1, 2, 3, 4, 5], np.float32)
|
|
# CHECK: Inside callback:
|
|
# CHECK{LITERAL}: [4. 5.]
|
|
execution_engine.invoke(
|
|
"callback_memref",
|
|
ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
|
|
)
|
|
|
|
run(testUnrankedMemRefWithOffsetCallback)
|
|
|
|
|
|
# Test addition of two memrefs.
|
|
# CHECK-LABEL: TEST: testMemrefAdd
|
|
def testMemrefAdd():
|
|
with Context():
|
|
module = Module.parse(
|
|
"""
|
|
module {
|
|
func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
|
|
%0 = arith.constant 0 : index
|
|
%1 = memref.load %arg0[%0] : memref<1xf32>
|
|
%2 = memref.load %arg1[] : memref<f32>
|
|
%3 = arith.addf %1, %2 : f32
|
|
memref.store %3, %arg2[%0] : memref<1xf32>
|
|
return
|
|
}
|
|
} """
|
|
)
|
|
arg1 = np.array([32.5]).astype(np.float32)
|
|
arg2 = np.array(6).astype(np.float32)
|
|
res = np.array([0]).astype(np.float32)
|
|
|
|
arg1_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg1))
|
|
)
|
|
arg2_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg2))
|
|
)
|
|
res_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(res))
|
|
)
|
|
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.invoke(
|
|
"main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
|
|
)
|
|
# CHECK: [32.5] + 6.0 = [38.5]
|
|
log("{0} + {1} = {2}".format(arg1, arg2, res))
|
|
|
|
|
|
run(testMemrefAdd)
|
|
|
|
|
|
# Test addition of two f16 memrefs
|
|
# CHECK-LABEL: TEST: testF16MemrefAdd
|
|
def testF16MemrefAdd():
|
|
with Context():
|
|
module = Module.parse(
|
|
"""
|
|
module {
|
|
func.func @main(%arg0: memref<1xf16>,
|
|
%arg1: memref<1xf16>,
|
|
%arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
|
|
%0 = arith.constant 0 : index
|
|
%1 = memref.load %arg0[%0] : memref<1xf16>
|
|
%2 = memref.load %arg1[%0] : memref<1xf16>
|
|
%3 = arith.addf %1, %2 : f16
|
|
memref.store %3, %arg2[%0] : memref<1xf16>
|
|
return
|
|
}
|
|
} """
|
|
)
|
|
|
|
arg1 = np.array([11.0]).astype(np.float16)
|
|
arg2 = np.array([22.0]).astype(np.float16)
|
|
arg3 = np.array([0.0]).astype(np.float16)
|
|
|
|
arg1_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg1))
|
|
)
|
|
arg2_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg2))
|
|
)
|
|
arg3_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg3))
|
|
)
|
|
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.invoke(
|
|
"main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
|
|
)
|
|
# CHECK: [11.] + [22.] = [33.]
|
|
log("{0} + {1} = {2}".format(arg1, arg2, arg3))
|
|
|
|
# test to-numpy utility
|
|
# CHECK: [33.]
|
|
npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
|
|
log(npout)
|
|
|
|
|
|
run(testF16MemrefAdd)
|
|
|
|
|
|
# Test addition of two complex memrefs
|
|
# CHECK-LABEL: TEST: testComplexMemrefAdd
|
|
def testComplexMemrefAdd():
|
|
with Context():
|
|
module = Module.parse(
|
|
"""
|
|
module {
|
|
func.func @main(%arg0: memref<1xcomplex<f64>>,
|
|
%arg1: memref<1xcomplex<f64>>,
|
|
%arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
|
|
%0 = arith.constant 0 : index
|
|
%1 = memref.load %arg0[%0] : memref<1xcomplex<f64>>
|
|
%2 = memref.load %arg1[%0] : memref<1xcomplex<f64>>
|
|
%3 = complex.add %1, %2 : complex<f64>
|
|
memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
|
|
return
|
|
}
|
|
} """
|
|
)
|
|
|
|
arg1 = np.array([1.0 + 2.0j]).astype(np.complex128)
|
|
arg2 = np.array([3.0 + 4.0j]).astype(np.complex128)
|
|
arg3 = np.array([0.0 + 0.0j]).astype(np.complex128)
|
|
|
|
arg1_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg1))
|
|
)
|
|
arg2_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg2))
|
|
)
|
|
arg3_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg3))
|
|
)
|
|
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.invoke(
|
|
"main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
|
|
)
|
|
# CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
|
|
log("{0} + {1} = {2}".format(arg1, arg2, arg3))
|
|
|
|
# test to-numpy utility
|
|
# CHECK: [4.+6.j]
|
|
npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
|
|
log(npout)
|
|
|
|
|
|
run(testComplexMemrefAdd)
|
|
|
|
|
|
# Test addition of two complex unranked memrefs
|
|
# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
|
|
def testComplexUnrankedMemrefAdd():
|
|
with Context():
|
|
module = Module.parse(
|
|
"""
|
|
module {
|
|
func.func @main(%arg0: memref<*xcomplex<f32>>,
|
|
%arg1: memref<*xcomplex<f32>>,
|
|
%arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
|
|
%A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
|
|
%B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
|
|
%C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
|
|
%0 = arith.constant 0 : index
|
|
%1 = memref.load %A[%0] : memref<1xcomplex<f32>>
|
|
%2 = memref.load %B[%0] : memref<1xcomplex<f32>>
|
|
%3 = complex.add %1, %2 : complex<f32>
|
|
memref.store %3, %C[%0] : memref<1xcomplex<f32>>
|
|
return
|
|
}
|
|
} """
|
|
)
|
|
|
|
arg1 = np.array([5.0 + 6.0j]).astype(np.complex64)
|
|
arg2 = np.array([7.0 + 8.0j]).astype(np.complex64)
|
|
arg3 = np.array([0.0 + 0.0j]).astype(np.complex64)
|
|
|
|
arg1_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_unranked_memref_descriptor(arg1))
|
|
)
|
|
arg2_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_unranked_memref_descriptor(arg2))
|
|
)
|
|
arg3_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_unranked_memref_descriptor(arg3))
|
|
)
|
|
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.invoke(
|
|
"main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
|
|
)
|
|
# CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
|
|
log("{0} + {1} = {2}".format(arg1, arg2, arg3))
|
|
|
|
# test to-numpy utility
|
|
# CHECK: [12.+14.j]
|
|
npout = unranked_memref_to_numpy(arg3_memref_ptr[0], np.dtype(np.complex64))
|
|
log(npout)
|
|
|
|
|
|
run(testComplexUnrankedMemrefAdd)
|
|
|
|
|
|
# Test bf16 memrefs
|
|
# CHECK-LABEL: TEST: testBF16Memref
|
|
def testBF16Memref():
|
|
with Context():
|
|
module = Module.parse(
|
|
"""
|
|
module {
|
|
func.func @main(%arg0: memref<1xbf16>,
|
|
%arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
|
|
%0 = arith.constant 0 : index
|
|
%1 = memref.load %arg0[%0] : memref<1xbf16>
|
|
memref.store %1, %arg1[%0] : memref<1xbf16>
|
|
return
|
|
}
|
|
} """
|
|
)
|
|
|
|
arg1 = np.array([0.5]).astype(bfloat16)
|
|
arg2 = np.array([0.0]).astype(bfloat16)
|
|
|
|
arg1_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg1))
|
|
)
|
|
arg2_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg2))
|
|
)
|
|
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
|
|
|
|
# test to-numpy utility
|
|
# CHECK: [0.5]
|
|
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
|
|
log(npout)
|
|
|
|
|
|
run(testBF16Memref)
|
|
|
|
|
|
# Test f8E5M2 memrefs
|
|
# CHECK-LABEL: TEST: testF8E5M2Memref
|
|
def testF8E5M2Memref():
|
|
with Context():
|
|
module = Module.parse(
|
|
"""
|
|
module {
|
|
func.func @main(%arg0: memref<1xf8E5M2>,
|
|
%arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } {
|
|
%0 = arith.constant 0 : index
|
|
%1 = memref.load %arg0[%0] : memref<1xf8E5M2>
|
|
memref.store %1, %arg1[%0] : memref<1xf8E5M2>
|
|
return
|
|
}
|
|
} """
|
|
)
|
|
|
|
arg1 = np.array([0.5]).astype(float8_e5m2)
|
|
arg2 = np.array([0.0]).astype(float8_e5m2)
|
|
|
|
arg1_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg1))
|
|
)
|
|
arg2_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg2))
|
|
)
|
|
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
|
|
|
|
# test to-numpy utility
|
|
# CHECK: [0.5]
|
|
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
|
|
log(npout)
|
|
|
|
|
|
run(testF8E5M2Memref)
|
|
|
|
|
|
# Test addition of two 2d_memref
|
|
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
|
|
def testDynamicMemrefAdd2D():
|
|
with Context():
|
|
module = Module.parse(
|
|
"""
|
|
module {
|
|
func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
|
|
%c0 = arith.constant 0 : index
|
|
%c2 = arith.constant 2 : index
|
|
%c1 = arith.constant 1 : index
|
|
cf.br ^bb1(%c0 : index)
|
|
^bb1(%0: index): // 2 preds: ^bb0, ^bb5
|
|
%1 = arith.cmpi slt, %0, %c2 : index
|
|
cf.cond_br %1, ^bb2, ^bb6
|
|
^bb2: // pred: ^bb1
|
|
%c0_0 = arith.constant 0 : index
|
|
%c2_1 = arith.constant 2 : index
|
|
%c1_2 = arith.constant 1 : index
|
|
cf.br ^bb3(%c0_0 : index)
|
|
^bb3(%2: index): // 2 preds: ^bb2, ^bb4
|
|
%3 = arith.cmpi slt, %2, %c2_1 : index
|
|
cf.cond_br %3, ^bb4, ^bb5
|
|
^bb4: // pred: ^bb3
|
|
%4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
|
|
%5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
|
|
%6 = arith.addf %4, %5 : f32
|
|
memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
|
|
%7 = arith.addi %2, %c1_2 : index
|
|
cf.br ^bb3(%7 : index)
|
|
^bb5: // pred: ^bb3
|
|
%8 = arith.addi %0, %c1 : index
|
|
cf.br ^bb1(%8 : index)
|
|
^bb6: // pred: ^bb1
|
|
return
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
arg1 = np.random.randn(2, 2).astype(np.float32)
|
|
arg2 = np.random.randn(2, 2).astype(np.float32)
|
|
res = np.random.randn(2, 2).astype(np.float32)
|
|
|
|
arg1_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg1))
|
|
)
|
|
arg2_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg2))
|
|
)
|
|
res_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(res))
|
|
)
|
|
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
|
execution_engine.invoke(
|
|
"memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
|
|
)
|
|
# CHECK: True
|
|
log(np.allclose(arg1 + arg2, res))
|
|
|
|
|
|
run(testDynamicMemrefAdd2D)
|
|
|
|
|
|
# Test loading of shared libraries.
|
|
# CHECK-LABEL: TEST: testSharedLibLoad
|
|
def testSharedLibLoad():
|
|
with Context():
|
|
module = Module.parse(
|
|
"""
|
|
module {
|
|
func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
|
|
%c0 = arith.constant 0 : index
|
|
%cst42 = arith.constant 42.0 : f32
|
|
memref.store %cst42, %arg0[%c0] : memref<1xf32>
|
|
%u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>
|
|
call @printMemrefF32(%u_memref) : (memref<*xf32>) -> ()
|
|
return
|
|
}
|
|
func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
|
|
} """
|
|
)
|
|
arg0 = np.array([0.0]).astype(np.float32)
|
|
|
|
arg0_memref_ptr = ctypes.pointer(
|
|
ctypes.pointer(get_ranked_memref_descriptor(arg0))
|
|
)
|
|
|
|
if sys.platform == "win32":
|
|
shared_libs = [
|
|
"../../../../bin/mlir_runner_utils.dll",
|
|
"../../../../bin/mlir_c_runner_utils.dll",
|
|
]
|
|
elif sys.platform == "darwin":
|
|
shared_libs = [
|
|
"../../../../lib/libmlir_runner_utils.dylib",
|
|
"../../../../lib/libmlir_c_runner_utils.dylib",
|
|
]
|
|
else:
|
|
shared_libs = [
|
|
MLIR_RUNNER_UTILS,
|
|
MLIR_C_RUNNER_UTILS,
|
|
]
|
|
|
|
execution_engine = ExecutionEngine(
|
|
lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
|
|
)
|
|
execution_engine.invoke("main", arg0_memref_ptr)
|
|
# CHECK: Unranked Memref
|
|
# CHECK-NEXT: [42]
|
|
|
|
|
|
run(testSharedLibLoad)
|
|
|
|
|
|
# Test that nano time clock is available.
|
|
# CHECK-LABEL: TEST: testNanoTime
|
|
def testNanoTime():
|
|
with Context():
|
|
module = Module.parse(
|
|
"""
|
|
module {
|
|
func.func @main() attributes { llvm.emit_c_interface } {
|
|
%now = call @nanoTime() : () -> i64
|
|
%memref = memref.alloca() : memref<1xi64>
|
|
%c0 = arith.constant 0 : index
|
|
memref.store %now, %memref[%c0] : memref<1xi64>
|
|
%u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
|
|
call @printMemrefI64(%u_memref) : (memref<*xi64>) -> ()
|
|
return
|
|
}
|
|
func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
|
|
func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
|
|
}"""
|
|
)
|
|
|
|
if sys.platform == "win32":
|
|
shared_libs = [
|
|
"../../../../bin/mlir_runner_utils.dll",
|
|
"../../../../bin/mlir_c_runner_utils.dll",
|
|
]
|
|
else:
|
|
shared_libs = [
|
|
MLIR_RUNNER_UTILS,
|
|
MLIR_C_RUNNER_UTILS,
|
|
]
|
|
|
|
execution_engine = ExecutionEngine(
|
|
lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
|
|
)
|
|
execution_engine.invoke("main")
|
|
# CHECK: Unranked Memref
|
|
# CHECK: [{{.*}}]
|
|
|
|
|
|
run(testNanoTime)
|
|
|
|
|
|
# Test that nano time clock is available.
|
|
# CHECK-LABEL: TEST: testDumpToObjectFile
|
|
def testDumpToObjectFile():
|
|
fd, object_path = tempfile.mkstemp(suffix=".o")
|
|
|
|
try:
|
|
with Context():
|
|
module = Module.parse(
|
|
"""
|
|
module {
|
|
func.func @main() attributes { llvm.emit_c_interface } {
|
|
return
|
|
}
|
|
}"""
|
|
)
|
|
|
|
execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3)
|
|
|
|
# CHECK: Object file exists: True
|
|
print(f"Object file exists: {os.path.exists(object_path)}")
|
|
# CHECK: Object file is empty: True
|
|
print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
|
|
|
|
execution_engine.dump_to_object_file(object_path)
|
|
|
|
# CHECK: Object file exists: True
|
|
print(f"Object file exists: {os.path.exists(object_path)}")
|
|
# CHECK: Object file is empty: False
|
|
print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
|
|
|
|
finally:
|
|
os.close(fd)
|
|
os.remove(object_path)
|
|
|
|
|
|
run(testDumpToObjectFile)
|