So far, only the custom dialect types are exposed. The build and packaging is same as for Linalg and SparseTensor, and in need of refactoring that is beyond the scope of this patch. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D116605
132 lines
4.3 KiB
Python
132 lines
4.3 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from mlir.ir import *
|
|
from mlir.dialects import quant
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_type_hierarchy
|
|
@run
|
|
def test_type_hierarchy():
|
|
with Context():
|
|
i8 = IntegerType.get_signless(8)
|
|
any = Type.parse("!quant.any<i8<-8:7>:f32>")
|
|
uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
|
|
per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
|
|
calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
|
|
|
|
assert not quant.QuantizedType.isinstance(i8)
|
|
assert quant.QuantizedType.isinstance(any)
|
|
assert quant.QuantizedType.isinstance(uniform)
|
|
assert quant.QuantizedType.isinstance(per_axis)
|
|
assert quant.QuantizedType.isinstance(calibrated)
|
|
|
|
assert quant.AnyQuantizedType.isinstance(any)
|
|
assert quant.UniformQuantizedType.isinstance(uniform)
|
|
assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
|
|
assert quant.CalibratedQuantizedType.isinstance(calibrated)
|
|
|
|
assert not quant.AnyQuantizedType.isinstance(uniform)
|
|
assert not quant.UniformQuantizedType.isinstance(per_axis)
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_any_quantized_type
|
|
@run
|
|
def test_any_quantized_type():
|
|
with Context():
|
|
i8 = IntegerType.get_signless(8)
|
|
f32 = F32Type.get()
|
|
any = quant.AnyQuantizedType.get(quant.QuantizedType.FLAG_SIGNED, i8, f32,
|
|
-8, 7)
|
|
|
|
# CHECK: flags: 1
|
|
print(f"flags: {any.flags}")
|
|
# CHECK: signed: True
|
|
print(f"signed: {any.is_signed}")
|
|
# CHECK: storage type: i8
|
|
print(f"storage type: {any.storage_type}")
|
|
# CHECK: expressed type: f32
|
|
print(f"expressed type: {any.expressed_type}")
|
|
# CHECK: storage min: -8
|
|
print(f"storage min: {any.storage_type_min}")
|
|
# CHECK: storage max: 7
|
|
print(f"storage max: {any.storage_type_max}")
|
|
# CHECK: storage width: 8
|
|
print(f"storage width: {any.storage_type_integral_width}")
|
|
# CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
|
|
print(f"quantized element type: {any.quantized_element_type}")
|
|
# CHECK: !quant.any<i8<-8:7>:f32>
|
|
print(any)
|
|
assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_uniform_type
|
|
@run
|
|
def test_uniform_type():
|
|
with Context():
|
|
i8 = IntegerType.get_signless(8)
|
|
f32 = F32Type.get()
|
|
uniform = quant.UniformQuantizedType.get(
|
|
quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7)
|
|
|
|
# CHECK: scale: 0.99872
|
|
print(f"scale: {uniform.scale}")
|
|
# CHECK: zero point: 127
|
|
print(f"zero point: {uniform.zero_point}")
|
|
# CHECK: fixed point: False
|
|
print(f"fixed point: {uniform.is_fixed_point}")
|
|
# CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
|
|
print(uniform)
|
|
assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_uniform_per_axis_type
|
|
@run
|
|
def test_uniform_per_axis_type():
|
|
with Context():
|
|
i8 = IntegerType.get_signless(8)
|
|
f32 = F32Type.get()
|
|
per_axis = quant.UniformQuantizedPerAxisType.get(
|
|
quant.QuantizedType.FLAG_SIGNED,
|
|
i8,
|
|
f32, [200, 0.99872], [0, 120],
|
|
quantized_dimension=1,
|
|
storage_type_min=quant.QuantizedType.default_minimum_for_integer(
|
|
is_signed=True, integral_width=8),
|
|
storage_type_max=quant.QuantizedType.default_maximum_for_integer(
|
|
is_signed=True, integral_width=8))
|
|
|
|
# CHECK: scales: None
|
|
print(f"scales: {per_axis.scales}")
|
|
# CHECK: zero_points: None
|
|
print(f"zero_points: {per_axis.zero_points}")
|
|
# CHECK: quantized dim: 1
|
|
print(f"quantized dim: {per_axis.quantized_dimension}")
|
|
# CHECK: fixed point: False
|
|
print(f"fixed point: {per_axis.is_fixed_point}")
|
|
# CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
|
|
print(per_axis)
|
|
assert per_axis == Type.parse(
|
|
"!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_calibrated_type
|
|
@run
|
|
def test_calibrated_type():
|
|
with Context():
|
|
f32 = F32Type.get()
|
|
calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
|
|
|
|
# CHECK: min: -0.998
|
|
print(f"min: {calibrated.min}")
|
|
# CHECK: max: 1.2321
|
|
print(f"max: {calibrated.max}")
|
|
# CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
|
|
print(calibrated)
|
|
assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
|