[MLIR][python bindings] TypeCasters for Attributes

Differential Revision: https://reviews.llvm.org/D151840
This commit is contained in:
max
2023-05-31 15:52:46 -05:00
parent 31fbfa57e7
commit 9566ee2806
9 changed files with 288 additions and 36 deletions

View File

@@ -23,7 +23,7 @@ def testParsePrint():
gc.collect()
# CHECK: "hello"
print(str(t))
# CHECK: Attribute("hello")
# CHECK: StringAttr("hello")
print(repr(t))
@@ -134,7 +134,7 @@ def testStandardAttrCasts():
a1 = Attribute.parse('"attr1"')
astr = StringAttr(a1)
aself = StringAttr(astr)
# CHECK: Attribute("attr1")
# CHECK: StringAttr("attr1")
print(repr(astr))
try:
tillegal = StringAttr(Attribute.parse("1.0"))
@@ -324,32 +324,32 @@ def testDenseIntAttr():
@run
def testDenseArrayGetItem():
def print_item(AttrClass, attr_asm):
attr = AttrClass(Attribute.parse(attr_asm))
def print_item(attr_asm):
attr = Attribute.parse(attr_asm)
print(f"{len(attr)}: {attr[0]}, {attr[1]}")
with Context():
# CHECK: 2: 0, 1
print_item(DenseBoolArrayAttr, "array<i1: false, true>")
print_item("array<i1: false, true>")
# CHECK: 2: 2, 3
print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
print_item("array<i8: 2, 3>")
# CHECK: 2: 4, 5
print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
print_item("array<i16: 4, 5>")
# CHECK: 2: 6, 7
print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
print_item("array<i32: 6, 7>")
# CHECK: 2: 8, 9
print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
print_item("array<i64: 8, 9>")
# CHECK: 2: 1.{{0+}}, 2.{{0+}}
print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
print_item("array<f32: 1.0, 2.0>")
# CHECK: 2: 3.{{0+}}, 4.{{0+}}
print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
print_item("array<f64: 3.0, 4.0>")
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
@run
def testDenseIntAttrGetItem():
def print_item(attr_asm):
attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
attr = Attribute.parse(attr_asm)
dtype = ShapedType(attr.type).element_type
try:
item = attr[0]
@@ -592,3 +592,14 @@ def testConcreteTypesRoundTrip():
print(repr(type_attr.value))
# CHECK: F32Type(f32)
print(repr(type_attr.value.element_type))
# CHECK-LABEL: TEST: testConcreteAttributesRoundTrip
@run
def testConcreteAttributesRoundTrip():
with Context(), Location.unknown():
# CHECK: FloatAttr(4.200000e+01 : f32)
print(repr(Attribute.parse("42.0 : f32")))
assert IntegerAttr.static_typeid is not None