40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from mlir.ir import *
|
|
import mlir.dialects.arith as arith
|
|
import mlir.dialects.builtin as builtin
|
|
import mlir.dialects.tensor as tensor
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: testDimOp
|
|
@run
|
|
def testDimOp():
|
|
with Context() as ctx, Location.unknown():
|
|
module = Module.create()
|
|
f32Type = F32Type.get()
|
|
indexType = IndexType.get()
|
|
with InsertionPoint(module.body):
|
|
|
|
@builtin.FuncOp.from_py_func(RankedTensorType.get((-1, -1), f32Type))
|
|
# CHECK: func @tensor_static_dim
|
|
# CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
|
|
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
|
# CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
|
# CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
|
|
# CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
|
|
# CHECK: return %[[D0]], %[[D1]]
|
|
def tensor_static_dim(t):
|
|
c0 = arith.ConstantOp(indexType, 0)
|
|
c1 = arith.ConstantOp(indexType, 1)
|
|
d0 = tensor.DimOp(t, c0)
|
|
d1 = tensor.DimOp(t, c1)
|
|
return [d0.result, d1.result]
|
|
|
|
print(module)
|