[mlir][sparse] unify sparse_tensor.out rewriting rules (#70518)
This commit is contained in:
@@ -29,14 +29,14 @@ func.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }}
|
||||
"""
|
||||
|
||||
|
||||
def expected():
|
||||
def expected(id_map):
|
||||
"""Returns expected contents of output.
|
||||
|
||||
Regardless of the dimension ordering, compression, and bitwidths that are
|
||||
used in the sparse tensor, the output is always lexicographically sorted
|
||||
by natural index order.
|
||||
Output appears as dimension coordinates but lexicographically
|
||||
sorted by level coordinates.
|
||||
"""
|
||||
return f"""; extended FROSTT format
|
||||
return (
|
||||
f"""# extended FROSTT format
|
||||
2 5
|
||||
10 10
|
||||
1 1 1
|
||||
@@ -45,13 +45,23 @@ def expected():
|
||||
5 5 5
|
||||
10 1 4
|
||||
"""
|
||||
if id_map
|
||||
else f"""# extended FROSTT format
|
||||
2 5
|
||||
10 10
|
||||
1 1 1
|
||||
10 1 4
|
||||
2 2 2
|
||||
5 5 5
|
||||
1 10 3
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
|
||||
# Build and Compile.
|
||||
module = ir.Module.parse(boilerplate(attr))
|
||||
engine = compiler.compile_and_jit(module)
|
||||
|
||||
# Invoke the kernel and compare output.
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
out = os.path.join(test_dir, "out.tns")
|
||||
@@ -83,20 +93,20 @@ def main():
|
||||
[st.DimLevelType.compressed, st.DimLevelType.compressed],
|
||||
]
|
||||
orderings = [
|
||||
ir.AffineMap.get_permutation([0, 1]),
|
||||
ir.AffineMap.get_permutation([1, 0]),
|
||||
(ir.AffineMap.get_permutation([0, 1]), True),
|
||||
(ir.AffineMap.get_permutation([1, 0]), False),
|
||||
]
|
||||
bitwidths = [8, 16, 32, 64]
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options="", opt_level=2, shared_libs=[support_lib]
|
||||
)
|
||||
for level in levels:
|
||||
for ordering in orderings:
|
||||
for ordering, id_map in orderings:
|
||||
for bwidth in bitwidths:
|
||||
attr = st.EncodingAttr.get(
|
||||
level, ordering, ordering, bwidth, bwidth
|
||||
)
|
||||
build_compile_and_run_output(attr, compiler, expected())
|
||||
build_compile_and_run_output(attr, compiler, expected(id_map))
|
||||
count = count + 1
|
||||
|
||||
# Now do the same for BSR.
|
||||
|
||||
Reference in New Issue
Block a user