[MLIR][python bindings] TypeCasters for Attributes
Differential Revision: https://reviews.llvm.org/D151840
This commit is contained in:
@@ -45,6 +45,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map);
|
||||
/// Returns the affine map wrapped in the given affine map attribute.
|
||||
MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of an AffineMap attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Array attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -64,6 +67,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr,
|
||||
intptr_t pos);
|
||||
|
||||
/// Returns the typeID of an Array attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dictionary attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -89,6 +95,9 @@ mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name);
|
||||
|
||||
/// Returns the typeID of a Dictionary attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Floating point attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -115,6 +124,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc,
|
||||
/// the value as double.
|
||||
MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of a Float attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Integer attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -142,6 +154,9 @@ MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr);
|
||||
/// is of unsigned type and fits into an unsigned 64-bit integer.
|
||||
MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of an Integer attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bool attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -162,6 +177,9 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr);
|
||||
/// Checks whether the given attribute is an integer set attribute.
|
||||
MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of an IntegerSet attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Opaque attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -185,6 +203,9 @@ mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr);
|
||||
/// the context in which the attribute lives.
|
||||
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of an Opaque attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// String attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -206,6 +227,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type,
|
||||
/// long as the context in which the attribute lives.
|
||||
MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of a String attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolRef attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -239,6 +263,9 @@ mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos);
|
||||
|
||||
/// Returns the typeID of an SymbolRef attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Flat SymbolRef attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -256,6 +283,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx,
|
||||
MLIR_CAPI_EXPORTED MlirStringRef
|
||||
mlirFlatSymbolRefAttrGetValue(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of an FlatSymbolRef attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -270,6 +300,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirTypeAttrGet(MlirType type);
|
||||
/// Returns the type stored in the given type attribute.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of a Type attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unit attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -280,6 +313,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAUnit(MlirAttribute attr);
|
||||
/// Creates a unit attribute in the given context.
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx);
|
||||
|
||||
/// Returns the typeID of a Unit attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Elements attributes.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -306,6 +342,8 @@ MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr);
|
||||
// Dense array attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void);
|
||||
|
||||
/// Checks whether the given attribute is a dense array attribute.
|
||||
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr);
|
||||
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr);
|
||||
@@ -370,6 +408,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr);
|
||||
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr);
|
||||
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of an DenseIntOrFPElements attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void);
|
||||
|
||||
/// Creates a dense elements attribute with the given Shaped type and elements
|
||||
/// in the same context as the type.
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet(
|
||||
@@ -612,6 +653,9 @@ mlirSparseElementsAttrGetIndices(MlirAttribute attr);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
mlirSparseElementsAttrGetValues(MlirAttribute attr);
|
||||
|
||||
/// Returns the typeID of a SparseElements attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirSparseElementsAttrGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Strided layout attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -635,6 +679,9 @@ mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr);
|
||||
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr,
|
||||
intptr_t pos);
|
||||
|
||||
/// Returns the typeID of a StridedLayout attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -860,6 +860,9 @@ MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute);
|
||||
/// Gets the type id of the attribute.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute);
|
||||
|
||||
/// Gets the dialect of the attribute.
|
||||
MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute);
|
||||
|
||||
/// Checks whether an attribute is null.
|
||||
static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
|
||||
|
||||
|
||||
@@ -97,6 +97,7 @@ struct type_caster<MlirAttribute> {
|
||||
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
|
||||
.attr("Attribute")
|
||||
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
|
||||
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
|
||||
.release();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -80,6 +80,8 @@ public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
|
||||
static constexpr const char *pyClassName = "AffineMapAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirAffineMapAttrGetTypeID;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
@@ -259,6 +261,8 @@ public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
|
||||
static constexpr const char *pyClassName = "ArrayAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirArrayAttrGetTypeID;
|
||||
|
||||
class PyArrayAttributeIterator {
|
||||
public:
|
||||
@@ -339,6 +343,8 @@ public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
|
||||
static constexpr const char *pyClassName = "FloatAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFloatAttrGetTypeID;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
@@ -406,6 +412,10 @@ public:
|
||||
return mlirIntegerAttrGetValueUInt(self);
|
||||
},
|
||||
"Returns the value of the integer attribute");
|
||||
c.def_property_readonly_static("static_typeid",
|
||||
[](py::object & /*class*/) -> MlirTypeID {
|
||||
return mlirIntegerAttrGetTypeID();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -438,6 +448,8 @@ public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
|
||||
static constexpr const char *pyClassName = "FlatSymbolRefAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFlatSymbolRefAttrGetTypeID;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
@@ -464,6 +476,8 @@ public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
|
||||
static constexpr const char *pyClassName = "OpaqueAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirOpaqueAttrGetTypeID;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
@@ -501,6 +515,8 @@ public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
|
||||
static constexpr const char *pyClassName = "StringAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirStringAttrGetTypeID;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
@@ -921,6 +937,8 @@ public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
|
||||
static constexpr const char *pyClassName = "DictAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirDictionaryAttrGetTypeID;
|
||||
|
||||
intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
|
||||
|
||||
@@ -1013,6 +1031,8 @@ public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
|
||||
static constexpr const char *pyClassName = "TypeAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirTypeAttrGetTypeID;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
@@ -1035,6 +1055,8 @@ public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
|
||||
static constexpr const char *pyClassName = "UnitAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirUnitAttrGetTypeID;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
@@ -1054,6 +1076,8 @@ public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
|
||||
static constexpr const char *pyClassName = "StridedLayoutAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirStridedLayoutAttrGetTypeID;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
@@ -1099,6 +1123,50 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
|
||||
if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
|
||||
if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
|
||||
if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
|
||||
if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
|
||||
if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
|
||||
if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
|
||||
if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
|
||||
std::string msg =
|
||||
std::string("Can't cast unknown element type DenseArrayAttr (") +
|
||||
std::string(py::repr(py::cast(pyAttribute))) + ")";
|
||||
throw py::cast_error(msg);
|
||||
}
|
||||
|
||||
py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
|
||||
if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyDenseFPElementsAttribute(pyAttribute));
|
||||
if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyDenseIntElementsAttribute(pyAttribute));
|
||||
std::string msg =
|
||||
std::string(
|
||||
"Can't cast unknown element type DenseIntOrFPElementsAttr (") +
|
||||
std::string(py::repr(py::cast(pyAttribute))) + ")";
|
||||
throw py::cast_error(msg);
|
||||
}
|
||||
|
||||
py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
|
||||
if (PyBoolAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyBoolAttribute(pyAttribute));
|
||||
if (PyIntegerAttribute::isaFunction(pyAttribute))
|
||||
return py::cast(PyIntegerAttribute(pyAttribute));
|
||||
std::string msg =
|
||||
std::string("Can't cast unknown element type DenseArrayAttr (") +
|
||||
std::string(py::repr(py::cast(pyAttribute))) + ")";
|
||||
throw py::cast_error(msg);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::python::populateIRAttributes(py::module &m) {
|
||||
@@ -1118,6 +1186,9 @@ void mlir::python::populateIRAttributes(py::module &m) {
|
||||
PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
|
||||
PyDenseF64ArrayAttribute::bind(m);
|
||||
PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
|
||||
PyGlobals::get().registerTypeCaster(
|
||||
mlirDenseArrayAttrGetTypeID(),
|
||||
pybind11::cpp_function(denseArrayAttributeCaster));
|
||||
|
||||
PyArrayAttribute::bind(m);
|
||||
PyArrayAttribute::PyArrayAttributeIterator::bind(m);
|
||||
@@ -1125,6 +1196,10 @@ void mlir::python::populateIRAttributes(py::module &m) {
|
||||
PyDenseElementsAttribute::bind(m);
|
||||
PyDenseFPElementsAttribute::bind(m);
|
||||
PyDenseIntElementsAttribute::bind(m);
|
||||
PyGlobals::get().registerTypeCaster(
|
||||
mlirDenseIntOrFPElementsAttrGetTypeID(),
|
||||
pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
|
||||
|
||||
PyDictAttribute::bind(m);
|
||||
PyFlatSymbolRefAttribute::bind(m);
|
||||
PyOpaqueAttribute::bind(m);
|
||||
@@ -1132,6 +1207,9 @@ void mlir::python::populateIRAttributes(py::module &m) {
|
||||
PyIntegerAttribute::bind(m);
|
||||
PyStringAttribute::bind(m);
|
||||
PyTypeAttribute::bind(m);
|
||||
PyGlobals::get().registerTypeCaster(
|
||||
mlirIntegerAttrGetTypeID(),
|
||||
pybind11::cpp_function(integerOrBoolAttributeCaster));
|
||||
PyUnitAttribute::bind(m);
|
||||
|
||||
PyStridedLayoutAttribute::bind(m);
|
||||
|
||||
@@ -2640,10 +2640,7 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
"Context that owns the Location")
|
||||
.def_property_readonly(
|
||||
"attr",
|
||||
[](PyLocation &self) {
|
||||
return PyAttribute(self.getContext(),
|
||||
mlirLocationGetAttribute(self));
|
||||
},
|
||||
[](PyLocation &self) { return mlirLocationGetAttribute(self); },
|
||||
"Get the underlying LocationAttr")
|
||||
.def(
|
||||
"emit_error",
|
||||
@@ -3139,7 +3136,7 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
context->get(), toMlirStringRef(attrSpec));
|
||||
if (mlirAttributeIsNull(type))
|
||||
throw MLIRError("Unable to parse attribute", errors.take());
|
||||
return PyAttribute(context->getRef(), type);
|
||||
return type;
|
||||
},
|
||||
py::arg("asm"), py::arg("context") = py::none(),
|
||||
"Parses an attribute from an assembly form. Raises an MLIRError on "
|
||||
@@ -3175,18 +3172,38 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
return printAccum.join();
|
||||
},
|
||||
"Returns the assembly form of the Attribute.")
|
||||
.def("__repr__", [](PyAttribute &self) {
|
||||
// Generally, assembly formats are not printed for __repr__ because
|
||||
// this can cause exceptionally long debug output and exceptions.
|
||||
// However, attribute values are generally considered useful and are
|
||||
// printed. This may need to be re-evaluated if debug dumps end up
|
||||
// being excessive.
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append("Attribute(");
|
||||
mlirAttributePrint(self, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
.def("__repr__",
|
||||
[](PyAttribute &self) {
|
||||
// Generally, assembly formats are not printed for __repr__ because
|
||||
// this can cause exceptionally long debug output and exceptions.
|
||||
// However, attribute values are generally considered useful and
|
||||
// are printed. This may need to be re-evaluated if debug dumps end
|
||||
// up being excessive.
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append("Attribute(");
|
||||
mlirAttributePrint(self, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
})
|
||||
.def_property_readonly(
|
||||
"typeid",
|
||||
[](PyAttribute &self) -> MlirTypeID {
|
||||
MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
|
||||
assert(!mlirTypeIDIsNull(mlirTypeID) &&
|
||||
"mlirTypeID was expected to be non-null.");
|
||||
return mlirTypeID;
|
||||
})
|
||||
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
|
||||
MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
|
||||
assert(!mlirTypeIDIsNull(mlirTypeID) &&
|
||||
"mlirTypeID was expected to be non-null.");
|
||||
std::optional<pybind11::function> typeCaster =
|
||||
PyGlobals::get().lookupTypeCaster(mlirTypeID,
|
||||
mlirAttributeGetDialect(self));
|
||||
if (!typeCaster)
|
||||
return py::cast(self);
|
||||
return typeCaster.value()(self);
|
||||
});
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
@@ -3216,13 +3233,7 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
"The name of the NamedAttribute binding")
|
||||
.def_property_readonly(
|
||||
"attr",
|
||||
[](PyNamedAttribute &self) {
|
||||
// TODO: When named attribute is removed/refactored, also remove
|
||||
// this constructor (it does an inefficient table lookup).
|
||||
auto contextRef = PyMlirContext::forContext(
|
||||
mlirAttributeGetContext(self.namedAttr.attribute));
|
||||
return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
|
||||
},
|
||||
[](PyNamedAttribute &self) { return self.namedAttr.attribute; },
|
||||
py::keep_alive<0, 1>(),
|
||||
"The underlying generic attribute of the NamedAttribute binding");
|
||||
|
||||
|
||||
@@ -986,6 +986,8 @@ public:
|
||||
// const char *pyClassName
|
||||
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
|
||||
using IsAFunctionTy = bool (*)(MlirAttribute);
|
||||
using GetTypeIDFunctionTy = MlirTypeID (*)();
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
|
||||
|
||||
PyConcreteAttribute() = default;
|
||||
PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
|
||||
@@ -1017,6 +1019,34 @@ public:
|
||||
pybind11::arg("other"));
|
||||
cls.def_property_readonly(
|
||||
"type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
|
||||
cls.def_property_readonly_static(
|
||||
"static_typeid", [](py::object & /*class*/) -> MlirTypeID {
|
||||
if (DerivedTy::getTypeIdFunction)
|
||||
return DerivedTy::getTypeIdFunction();
|
||||
throw py::attribute_error(
|
||||
(DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
|
||||
});
|
||||
cls.def_property_readonly("typeid", [](PyAttribute &self) {
|
||||
return py::cast(self).attr("typeid").cast<MlirTypeID>();
|
||||
});
|
||||
cls.def("__repr__", [](DerivedTy &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append(DerivedTy::pyClassName);
|
||||
printAccum.parts.append("(");
|
||||
mlirAttributePrint(self, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
});
|
||||
|
||||
if (DerivedTy::getTypeIdFunction) {
|
||||
PyGlobals::get().registerTypeCaster(
|
||||
DerivedTy::getTypeIdFunction(),
|
||||
pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
|
||||
return pyAttribute;
|
||||
}));
|
||||
}
|
||||
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,10 @@ MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
|
||||
return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue());
|
||||
}
|
||||
|
||||
MlirTypeID mlirAffineMapAttrGetTypeID(void) {
|
||||
return wrap(AffineMapAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Array attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -68,6 +72,8 @@ MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
|
||||
return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]);
|
||||
}
|
||||
|
||||
MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dictionary attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -102,6 +108,10 @@ MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
|
||||
return wrap(llvm::cast<DictionaryAttr>(unwrap(attr)).get(unwrap(name)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirDictionaryAttrGetTypeID(void) {
|
||||
return wrap(DictionaryAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Floating point attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -124,6 +134,8 @@ double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
|
||||
return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble();
|
||||
}
|
||||
|
||||
MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Integer attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -148,6 +160,10 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
|
||||
return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
|
||||
}
|
||||
|
||||
MlirTypeID mlirIntegerAttrGetTypeID(void) {
|
||||
return wrap(IntegerAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bool attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -172,6 +188,10 @@ bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
|
||||
return llvm::isa<IntegerSetAttr>(unwrap(attr));
|
||||
}
|
||||
|
||||
MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
|
||||
return wrap(IntegerSetAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Opaque attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -197,6 +217,10 @@ MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
|
||||
return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData());
|
||||
}
|
||||
|
||||
MlirTypeID mlirOpaqueAttrGetTypeID(void) {
|
||||
return wrap(OpaqueAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// String attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -217,6 +241,10 @@ MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
|
||||
return wrap(llvm::cast<StringAttr>(unwrap(attr)).getValue());
|
||||
}
|
||||
|
||||
MlirTypeID mlirStringAttrGetTypeID(void) {
|
||||
return wrap(StringAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolRef attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -257,6 +285,10 @@ MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
|
||||
llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]);
|
||||
}
|
||||
|
||||
MlirTypeID mlirSymbolRefAttrGetTypeID(void) {
|
||||
return wrap(SymbolRefAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Flat SymbolRef attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -273,6 +305,10 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
|
||||
return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue());
|
||||
}
|
||||
|
||||
MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) {
|
||||
return wrap(FlatSymbolRefAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -289,6 +325,8 @@ MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
|
||||
return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue());
|
||||
}
|
||||
|
||||
MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unit attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -301,6 +339,8 @@ MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
|
||||
return wrap(UnitAttr::get(unwrap(ctx)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Elements attributes.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -329,8 +369,13 @@ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
|
||||
// Dense array attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirDenseArrayAttrGetTypeID() {
|
||||
return wrap(DenseArrayAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IsA support.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) {
|
||||
return llvm::isa<DenseBoolArrayAttr>(unwrap(attr));
|
||||
@@ -356,6 +401,7 @@ bool mlirAttributeIsADenseF64Array(MlirAttribute attr) {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Constructors.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size,
|
||||
int const *values) {
|
||||
@@ -395,6 +441,7 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Accessors.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
|
||||
return llvm::cast<DenseArrayAttr>(unwrap(attr)).size();
|
||||
@@ -402,6 +449,7 @@ intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Indexed accessors.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) {
|
||||
return llvm::cast<DenseBoolArrayAttr>(unwrap(attr))[pos];
|
||||
@@ -431,19 +479,27 @@ double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IsA support.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool mlirAttributeIsADenseElements(MlirAttribute attr) {
|
||||
return llvm::isa<DenseElementsAttr>(unwrap(attr));
|
||||
}
|
||||
|
||||
bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
|
||||
return llvm::isa<DenseIntElementsAttr>(unwrap(attr));
|
||||
}
|
||||
|
||||
bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
|
||||
return llvm::isa<DenseFPElementsAttr>(unwrap(attr));
|
||||
}
|
||||
|
||||
MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) {
|
||||
return wrap(DenseIntOrFPElementsAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Constructors.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
|
||||
intptr_t numElements,
|
||||
@@ -620,6 +676,7 @@ MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Splat accessors.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
|
||||
return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat();
|
||||
@@ -663,6 +720,7 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Indexed accessors.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
|
||||
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<bool>()[pos];
|
||||
@@ -705,6 +763,7 @@ MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Raw data accessors.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
|
||||
return static_cast<const void *>(
|
||||
@@ -876,6 +935,10 @@ MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
|
||||
return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues());
|
||||
}
|
||||
|
||||
MlirTypeID mlirSparseElementsAttrGetTypeID(void) {
|
||||
return wrap(SparseElementsAttr::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Strided layout attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -903,3 +966,7 @@ intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) {
|
||||
int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) {
|
||||
return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos];
|
||||
}
|
||||
|
||||
MlirTypeID mlirStridedLayoutAttrGetTypeID(void) {
|
||||
return wrap(StridedLayoutAttr::getTypeID());
|
||||
}
|
||||
|
||||
@@ -870,6 +870,10 @@ MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
|
||||
return wrap(unwrap(attr).getTypeID());
|
||||
}
|
||||
|
||||
MlirDialect mlirAttributeGetDialect(MlirAttribute attr) {
|
||||
return wrap(&unwrap(attr).getDialect());
|
||||
}
|
||||
|
||||
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
|
||||
return unwrap(a1) == unwrap(a2);
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ def testParsePrint():
|
||||
gc.collect()
|
||||
# CHECK: "hello"
|
||||
print(str(t))
|
||||
# CHECK: Attribute("hello")
|
||||
# CHECK: StringAttr("hello")
|
||||
print(repr(t))
|
||||
|
||||
|
||||
@@ -134,7 +134,7 @@ def testStandardAttrCasts():
|
||||
a1 = Attribute.parse('"attr1"')
|
||||
astr = StringAttr(a1)
|
||||
aself = StringAttr(astr)
|
||||
# CHECK: Attribute("attr1")
|
||||
# CHECK: StringAttr("attr1")
|
||||
print(repr(astr))
|
||||
try:
|
||||
tillegal = StringAttr(Attribute.parse("1.0"))
|
||||
@@ -324,32 +324,32 @@ def testDenseIntAttr():
|
||||
|
||||
@run
|
||||
def testDenseArrayGetItem():
|
||||
def print_item(AttrClass, attr_asm):
|
||||
attr = AttrClass(Attribute.parse(attr_asm))
|
||||
def print_item(attr_asm):
|
||||
attr = Attribute.parse(attr_asm)
|
||||
print(f"{len(attr)}: {attr[0]}, {attr[1]}")
|
||||
|
||||
with Context():
|
||||
# CHECK: 2: 0, 1
|
||||
print_item(DenseBoolArrayAttr, "array<i1: false, true>")
|
||||
print_item("array<i1: false, true>")
|
||||
# CHECK: 2: 2, 3
|
||||
print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
|
||||
print_item("array<i8: 2, 3>")
|
||||
# CHECK: 2: 4, 5
|
||||
print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
|
||||
print_item("array<i16: 4, 5>")
|
||||
# CHECK: 2: 6, 7
|
||||
print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
|
||||
print_item("array<i32: 6, 7>")
|
||||
# CHECK: 2: 8, 9
|
||||
print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
|
||||
print_item("array<i64: 8, 9>")
|
||||
# CHECK: 2: 1.{{0+}}, 2.{{0+}}
|
||||
print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
|
||||
print_item("array<f32: 1.0, 2.0>")
|
||||
# CHECK: 2: 3.{{0+}}, 4.{{0+}}
|
||||
print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
|
||||
print_item("array<f64: 3.0, 4.0>")
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
|
||||
@run
|
||||
def testDenseIntAttrGetItem():
|
||||
def print_item(attr_asm):
|
||||
attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
|
||||
attr = Attribute.parse(attr_asm)
|
||||
dtype = ShapedType(attr.type).element_type
|
||||
try:
|
||||
item = attr[0]
|
||||
@@ -592,3 +592,14 @@ def testConcreteTypesRoundTrip():
|
||||
print(repr(type_attr.value))
|
||||
# CHECK: F32Type(f32)
|
||||
print(repr(type_attr.value.element_type))
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testConcreteAttributesRoundTrip
|
||||
@run
|
||||
def testConcreteAttributesRoundTrip():
|
||||
with Context(), Location.unknown():
|
||||
|
||||
# CHECK: FloatAttr(4.200000e+01 : f32)
|
||||
print(repr(Attribute.parse("42.0 : f32")))
|
||||
|
||||
assert IntegerAttr.static_typeid is not None
|
||||
|
||||
Reference in New Issue
Block a user