[MLIR][python bindings] TypeCasters for Attributes

Differential Revision: https://reviews.llvm.org/D151840
This commit is contained in:
max
2023-05-31 15:52:46 -05:00
parent 31fbfa57e7
commit 9566ee2806
9 changed files with 288 additions and 36 deletions

View File

@@ -45,6 +45,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map);
/// Returns the affine map wrapped in the given affine map attribute.
MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr);
/// Returns the typeID of an AffineMap attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Array attribute.
//===----------------------------------------------------------------------===//
@@ -64,6 +67,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr,
intptr_t pos);
/// Returns the typeID of an Array attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Dictionary attribute.
//===----------------------------------------------------------------------===//
@@ -89,6 +95,9 @@ mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED MlirAttribute
mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name);
/// Returns the typeID of a Dictionary attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Floating point attribute.
//===----------------------------------------------------------------------===//
@@ -115,6 +124,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc,
/// the value as double.
MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr);
/// Returns the typeID of a Float attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Integer attribute.
//===----------------------------------------------------------------------===//
@@ -142,6 +154,9 @@ MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr);
/// is of unsigned type and fits into an unsigned 64-bit integer.
MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
/// Returns the typeID of an Integer attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Bool attribute.
//===----------------------------------------------------------------------===//
@@ -162,6 +177,9 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr);
/// Checks whether the given attribute is an integer set attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
/// Returns the typeID of an IntegerSet attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Opaque attribute.
//===----------------------------------------------------------------------===//
@@ -185,6 +203,9 @@ mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr);
/// the context in which the attribute lives.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr);
/// Returns the typeID of an Opaque attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// String attribute.
//===----------------------------------------------------------------------===//
@@ -206,6 +227,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type,
/// long as the context in which the attribute lives.
MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr);
/// Returns the typeID of a String attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// SymbolRef attribute.
//===----------------------------------------------------------------------===//
@@ -239,6 +263,9 @@ mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute
mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos);
/// Returns the typeID of an SymbolRef attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Flat SymbolRef attribute.
//===----------------------------------------------------------------------===//
@@ -256,6 +283,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx,
MLIR_CAPI_EXPORTED MlirStringRef
mlirFlatSymbolRefAttrGetValue(MlirAttribute attr);
/// Returns the typeID of an FlatSymbolRef attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Type attribute.
//===----------------------------------------------------------------------===//
@@ -270,6 +300,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirTypeAttrGet(MlirType type);
/// Returns the type stored in the given type attribute.
MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr);
/// Returns the typeID of a Type attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Unit attribute.
//===----------------------------------------------------------------------===//
@@ -280,6 +313,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAUnit(MlirAttribute attr);
/// Creates a unit attribute in the given context.
MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx);
/// Returns the typeID of a Unit attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Elements attributes.
//===----------------------------------------------------------------------===//
@@ -306,6 +342,8 @@ MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr);
// Dense array attribute.
//===----------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void);
/// Checks whether the given attribute is a dense array attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr);
@@ -370,6 +408,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr);
/// Returns the typeID of an DenseIntOrFPElements attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void);
/// Creates a dense elements attribute with the given Shaped type and elements
/// in the same context as the type.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet(
@@ -612,6 +653,9 @@ mlirSparseElementsAttrGetIndices(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute
mlirSparseElementsAttrGetValues(MlirAttribute attr);
/// Returns the typeID of a SparseElements attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirSparseElementsAttrGetTypeID(void);
//===----------------------------------------------------------------------===//
// Strided layout attribute.
//===----------------------------------------------------------------------===//
@@ -635,6 +679,9 @@ mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr);
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr,
intptr_t pos);
/// Returns the typeID of a StridedLayout attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void);
#ifdef __cplusplus
}
#endif

View File

@@ -860,6 +860,9 @@ MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute);
/// Gets the type id of the attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute);
/// Gets the dialect of the attribute.
MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute);
/// Checks whether an attribute is null.
static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }

View File

@@ -97,6 +97,7 @@ struct type_caster<MlirAttribute> {
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Attribute")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
.release();
}
};

View File

@@ -80,6 +80,8 @@ public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
static constexpr const char *pyClassName = "AffineMapAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirAffineMapAttrGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
@@ -259,6 +261,8 @@ public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
static constexpr const char *pyClassName = "ArrayAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirArrayAttrGetTypeID;
class PyArrayAttributeIterator {
public:
@@ -339,6 +343,8 @@ public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
static constexpr const char *pyClassName = "FloatAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloatAttrGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
@@ -406,6 +412,10 @@ public:
return mlirIntegerAttrGetValueUInt(self);
},
"Returns the value of the integer attribute");
c.def_property_readonly_static("static_typeid",
[](py::object & /*class*/) -> MlirTypeID {
return mlirIntegerAttrGetTypeID();
});
}
};
@@ -438,6 +448,8 @@ public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
static constexpr const char *pyClassName = "FlatSymbolRefAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFlatSymbolRefAttrGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
@@ -464,6 +476,8 @@ public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
static constexpr const char *pyClassName = "OpaqueAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirOpaqueAttrGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
@@ -501,6 +515,8 @@ public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
static constexpr const char *pyClassName = "StringAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirStringAttrGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
@@ -921,6 +937,8 @@ public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
static constexpr const char *pyClassName = "DictAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirDictionaryAttrGetTypeID;
intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
@@ -1013,6 +1031,8 @@ public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
static constexpr const char *pyClassName = "TypeAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTypeAttrGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
@@ -1035,6 +1055,8 @@ public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
static constexpr const char *pyClassName = "UnitAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUnitAttrGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
@@ -1054,6 +1076,8 @@ public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
static constexpr const char *pyClassName = "StridedLayoutAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirStridedLayoutAttrGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
@@ -1099,6 +1123,50 @@ public:
}
};
py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
std::string msg =
std::string("Can't cast unknown element type DenseArrayAttr (") +
std::string(py::repr(py::cast(pyAttribute))) + ")";
throw py::cast_error(msg);
}
py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
return py::cast(PyDenseFPElementsAttribute(pyAttribute));
if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
return py::cast(PyDenseIntElementsAttribute(pyAttribute));
std::string msg =
std::string(
"Can't cast unknown element type DenseIntOrFPElementsAttr (") +
std::string(py::repr(py::cast(pyAttribute))) + ")";
throw py::cast_error(msg);
}
py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
if (PyBoolAttribute::isaFunction(pyAttribute))
return py::cast(PyBoolAttribute(pyAttribute));
if (PyIntegerAttribute::isaFunction(pyAttribute))
return py::cast(PyIntegerAttribute(pyAttribute));
std::string msg =
std::string("Can't cast unknown element type DenseArrayAttr (") +
std::string(py::repr(py::cast(pyAttribute))) + ")";
throw py::cast_error(msg);
}
} // namespace
void mlir::python::populateIRAttributes(py::module &m) {
@@ -1118,6 +1186,9 @@ void mlir::python::populateIRAttributes(py::module &m) {
PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseF64ArrayAttribute::bind(m);
PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
PyGlobals::get().registerTypeCaster(
mlirDenseArrayAttrGetTypeID(),
pybind11::cpp_function(denseArrayAttributeCaster));
PyArrayAttribute::bind(m);
PyArrayAttribute::PyArrayAttributeIterator::bind(m);
@@ -1125,6 +1196,10 @@ void mlir::python::populateIRAttributes(py::module &m) {
PyDenseElementsAttribute::bind(m);
PyDenseFPElementsAttribute::bind(m);
PyDenseIntElementsAttribute::bind(m);
PyGlobals::get().registerTypeCaster(
mlirDenseIntOrFPElementsAttrGetTypeID(),
pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
PyDictAttribute::bind(m);
PyFlatSymbolRefAttribute::bind(m);
PyOpaqueAttribute::bind(m);
@@ -1132,6 +1207,9 @@ void mlir::python::populateIRAttributes(py::module &m) {
PyIntegerAttribute::bind(m);
PyStringAttribute::bind(m);
PyTypeAttribute::bind(m);
PyGlobals::get().registerTypeCaster(
mlirIntegerAttrGetTypeID(),
pybind11::cpp_function(integerOrBoolAttributeCaster));
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);

View File

@@ -2640,10 +2640,7 @@ void mlir::python::populateIRCore(py::module &m) {
"Context that owns the Location")
.def_property_readonly(
"attr",
[](PyLocation &self) {
return PyAttribute(self.getContext(),
mlirLocationGetAttribute(self));
},
[](PyLocation &self) { return mlirLocationGetAttribute(self); },
"Get the underlying LocationAttr")
.def(
"emit_error",
@@ -3139,7 +3136,7 @@ void mlir::python::populateIRCore(py::module &m) {
context->get(), toMlirStringRef(attrSpec));
if (mlirAttributeIsNull(type))
throw MLIRError("Unable to parse attribute", errors.take());
return PyAttribute(context->getRef(), type);
return type;
},
py::arg("asm"), py::arg("context") = py::none(),
"Parses an attribute from an assembly form. Raises an MLIRError on "
@@ -3175,18 +3172,38 @@ void mlir::python::populateIRCore(py::module &m) {
return printAccum.join();
},
"Returns the assembly form of the Attribute.")
.def("__repr__", [](PyAttribute &self) {
// Generally, assembly formats are not printed for __repr__ because
// this can cause exceptionally long debug output and exceptions.
// However, attribute values are generally considered useful and are
// printed. This may need to be re-evaluated if debug dumps end up
// being excessive.
PyPrintAccumulator printAccum;
printAccum.parts.append("Attribute(");
mlirAttributePrint(self, printAccum.getCallback(),
printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
.def("__repr__",
[](PyAttribute &self) {
// Generally, assembly formats are not printed for __repr__ because
// this can cause exceptionally long debug output and exceptions.
// However, attribute values are generally considered useful and
// are printed. This may need to be re-evaluated if debug dumps end
// up being excessive.
PyPrintAccumulator printAccum;
printAccum.parts.append("Attribute(");
mlirAttributePrint(self, printAccum.getCallback(),
printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
})
.def_property_readonly(
"typeid",
[](PyAttribute &self) -> MlirTypeID {
MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
return mlirTypeID;
})
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
std::optional<pybind11::function> typeCaster =
PyGlobals::get().lookupTypeCaster(mlirTypeID,
mlirAttributeGetDialect(self));
if (!typeCaster)
return py::cast(self);
return typeCaster.value()(self);
});
//----------------------------------------------------------------------------
@@ -3216,13 +3233,7 @@ void mlir::python::populateIRCore(py::module &m) {
"The name of the NamedAttribute binding")
.def_property_readonly(
"attr",
[](PyNamedAttribute &self) {
// TODO: When named attribute is removed/refactored, also remove
// this constructor (it does an inefficient table lookup).
auto contextRef = PyMlirContext::forContext(
mlirAttributeGetContext(self.namedAttr.attribute));
return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
},
[](PyNamedAttribute &self) { return self.namedAttr.attribute; },
py::keep_alive<0, 1>(),
"The underlying generic attribute of the NamedAttribute binding");

View File

@@ -986,6 +986,8 @@ public:
// const char *pyClassName
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirAttribute);
using GetTypeIDFunctionTy = MlirTypeID (*)();
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
PyConcreteAttribute() = default;
PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
@@ -1017,6 +1019,34 @@ public:
pybind11::arg("other"));
cls.def_property_readonly(
"type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
cls.def_property_readonly_static(
"static_typeid", [](py::object & /*class*/) -> MlirTypeID {
if (DerivedTy::getTypeIdFunction)
return DerivedTy::getTypeIdFunction();
throw py::attribute_error(
(DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
});
cls.def_property_readonly("typeid", [](PyAttribute &self) {
return py::cast(self).attr("typeid").cast<MlirTypeID>();
});
cls.def("__repr__", [](DerivedTy &self) {
PyPrintAccumulator printAccum;
printAccum.parts.append(DerivedTy::pyClassName);
printAccum.parts.append("(");
mlirAttributePrint(self, printAccum.getCallback(),
printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
});
if (DerivedTy::getTypeIdFunction) {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
return pyAttribute;
}));
}
DerivedTy::bindDerived(cls);
}

View File

@@ -44,6 +44,10 @@ MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue());
}
MlirTypeID mlirAffineMapAttrGetTypeID(void) {
return wrap(AffineMapAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// Array attribute.
//===----------------------------------------------------------------------===//
@@ -68,6 +72,8 @@ MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]);
}
MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); }
//===----------------------------------------------------------------------===//
// Dictionary attribute.
//===----------------------------------------------------------------------===//
@@ -102,6 +108,10 @@ MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
return wrap(llvm::cast<DictionaryAttr>(unwrap(attr)).get(unwrap(name)));
}
MlirTypeID mlirDictionaryAttrGetTypeID(void) {
return wrap(DictionaryAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// Floating point attribute.
//===----------------------------------------------------------------------===//
@@ -124,6 +134,8 @@ double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble();
}
MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); }
//===----------------------------------------------------------------------===//
// Integer attribute.
//===----------------------------------------------------------------------===//
@@ -148,6 +160,10 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
}
MlirTypeID mlirIntegerAttrGetTypeID(void) {
return wrap(IntegerAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// Bool attribute.
//===----------------------------------------------------------------------===//
@@ -172,6 +188,10 @@ bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
return llvm::isa<IntegerSetAttr>(unwrap(attr));
}
MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
return wrap(IntegerSetAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// Opaque attribute.
//===----------------------------------------------------------------------===//
@@ -197,6 +217,10 @@ MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData());
}
MlirTypeID mlirOpaqueAttrGetTypeID(void) {
return wrap(OpaqueAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// String attribute.
//===----------------------------------------------------------------------===//
@@ -217,6 +241,10 @@ MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
return wrap(llvm::cast<StringAttr>(unwrap(attr)).getValue());
}
MlirTypeID mlirStringAttrGetTypeID(void) {
return wrap(StringAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// SymbolRef attribute.
//===----------------------------------------------------------------------===//
@@ -257,6 +285,10 @@ MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]);
}
MlirTypeID mlirSymbolRefAttrGetTypeID(void) {
return wrap(SymbolRefAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// Flat SymbolRef attribute.
//===----------------------------------------------------------------------===//
@@ -273,6 +305,10 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue());
}
MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) {
return wrap(FlatSymbolRefAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// Type attribute.
//===----------------------------------------------------------------------===//
@@ -289,6 +325,8 @@ MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue());
}
MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); }
//===----------------------------------------------------------------------===//
// Unit attribute.
//===----------------------------------------------------------------------===//
@@ -301,6 +339,8 @@ MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
return wrap(UnitAttr::get(unwrap(ctx)));
}
MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); }
//===----------------------------------------------------------------------===//
// Elements attributes.
//===----------------------------------------------------------------------===//
@@ -329,8 +369,13 @@ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
// Dense array attribute.
//===----------------------------------------------------------------------===//
MlirTypeID mlirDenseArrayAttrGetTypeID() {
return wrap(DenseArrayAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// IsA support.
//===----------------------------------------------------------------------===//
bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) {
return llvm::isa<DenseBoolArrayAttr>(unwrap(attr));
@@ -356,6 +401,7 @@ bool mlirAttributeIsADenseF64Array(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
// Constructors.
//===----------------------------------------------------------------------===//
MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size,
int const *values) {
@@ -395,6 +441,7 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
//===----------------------------------------------------------------------===//
// Accessors.
//===----------------------------------------------------------------------===//
intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
return llvm::cast<DenseArrayAttr>(unwrap(attr)).size();
@@ -402,6 +449,7 @@ intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
// Indexed accessors.
//===----------------------------------------------------------------------===//
bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) {
return llvm::cast<DenseBoolArrayAttr>(unwrap(attr))[pos];
@@ -431,19 +479,27 @@ double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
//===----------------------------------------------------------------------===//
// IsA support.
//===----------------------------------------------------------------------===//
bool mlirAttributeIsADenseElements(MlirAttribute attr) {
return llvm::isa<DenseElementsAttr>(unwrap(attr));
}
bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
return llvm::isa<DenseIntElementsAttr>(unwrap(attr));
}
bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
return llvm::isa<DenseFPElementsAttr>(unwrap(attr));
}
MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) {
return wrap(DenseIntOrFPElementsAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// Constructors.
//===----------------------------------------------------------------------===//
MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
intptr_t numElements,
@@ -620,6 +676,7 @@ MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
//===----------------------------------------------------------------------===//
// Splat accessors.
//===----------------------------------------------------------------------===//
bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat();
@@ -663,6 +720,7 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
// Indexed accessors.
//===----------------------------------------------------------------------===//
bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<bool>()[pos];
@@ -705,6 +763,7 @@ MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
//===----------------------------------------------------------------------===//
// Raw data accessors.
//===----------------------------------------------------------------------===//
const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
return static_cast<const void *>(
@@ -876,6 +935,10 @@ MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues());
}
MlirTypeID mlirSparseElementsAttrGetTypeID(void) {
return wrap(SparseElementsAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
// Strided layout attribute.
//===----------------------------------------------------------------------===//
@@ -903,3 +966,7 @@ intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) {
int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) {
return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos];
}
MlirTypeID mlirStridedLayoutAttrGetTypeID(void) {
return wrap(StridedLayoutAttr::getTypeID());
}

View File

@@ -870,6 +870,10 @@ MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
return wrap(unwrap(attr).getTypeID());
}
MlirDialect mlirAttributeGetDialect(MlirAttribute attr) {
return wrap(&unwrap(attr).getDialect());
}
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
return unwrap(a1) == unwrap(a2);
}

View File

@@ -23,7 +23,7 @@ def testParsePrint():
gc.collect()
# CHECK: "hello"
print(str(t))
# CHECK: Attribute("hello")
# CHECK: StringAttr("hello")
print(repr(t))
@@ -134,7 +134,7 @@ def testStandardAttrCasts():
a1 = Attribute.parse('"attr1"')
astr = StringAttr(a1)
aself = StringAttr(astr)
# CHECK: Attribute("attr1")
# CHECK: StringAttr("attr1")
print(repr(astr))
try:
tillegal = StringAttr(Attribute.parse("1.0"))
@@ -324,32 +324,32 @@ def testDenseIntAttr():
@run
def testDenseArrayGetItem():
def print_item(AttrClass, attr_asm):
attr = AttrClass(Attribute.parse(attr_asm))
def print_item(attr_asm):
attr = Attribute.parse(attr_asm)
print(f"{len(attr)}: {attr[0]}, {attr[1]}")
with Context():
# CHECK: 2: 0, 1
print_item(DenseBoolArrayAttr, "array<i1: false, true>")
print_item("array<i1: false, true>")
# CHECK: 2: 2, 3
print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
print_item("array<i8: 2, 3>")
# CHECK: 2: 4, 5
print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
print_item("array<i16: 4, 5>")
# CHECK: 2: 6, 7
print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
print_item("array<i32: 6, 7>")
# CHECK: 2: 8, 9
print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
print_item("array<i64: 8, 9>")
# CHECK: 2: 1.{{0+}}, 2.{{0+}}
print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
print_item("array<f32: 1.0, 2.0>")
# CHECK: 2: 3.{{0+}}, 4.{{0+}}
print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
print_item("array<f64: 3.0, 4.0>")
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
@run
def testDenseIntAttrGetItem():
def print_item(attr_asm):
attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
attr = Attribute.parse(attr_asm)
dtype = ShapedType(attr.type).element_type
try:
item = attr[0]
@@ -592,3 +592,14 @@ def testConcreteTypesRoundTrip():
print(repr(type_attr.value))
# CHECK: F32Type(f32)
print(repr(type_attr.value.element_type))
# CHECK-LABEL: TEST: testConcreteAttributesRoundTrip
@run
def testConcreteAttributesRoundTrip():
with Context(), Location.unknown():
# CHECK: FloatAttr(4.200000e+01 : f32)
print(repr(Attribute.parse("42.0 : f32")))
assert IntegerAttr.static_typeid is not None