[mlir][sparse] Macros to clean up StridedMemRefType in the SparseTensorRuntime
In particular, this silences warnings from [-Wsign-compare]. Depends On D137681 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D137735
This commit is contained in:
@@ -57,6 +57,7 @@
|
||||
#include "mlir/ExecutionEngine/SparseTensor/Storage.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir::sparse_tensor;
|
||||
@@ -213,6 +214,37 @@ fromMLIRSparseTensor(const SparseTensorStorage<uint64_t, uint64_t, V> *tensor,
|
||||
*pIndices = indices;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Utilities for manipulating `StridedMemRefType`.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define ASSERT_NO_STRIDE(MEMREF) \
|
||||
do { \
|
||||
assert((MEMREF) && "Memref is nullptr"); \
|
||||
assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride"); \
|
||||
} while (false)
|
||||
|
||||
// All our functions use `uint64_t` for ranks, but `StridedMemRefType::sizes`
|
||||
// uses `int64_t`. And we must make the cast explicit for the sake of
|
||||
// `operator==`, or else it will generate a [-Wsign-compare] warning.
|
||||
#define MEMREF_GET_USIZE(MEMREF) static_cast<uint64_t>((MEMREF)->sizes[0])
|
||||
|
||||
#define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset)
|
||||
|
||||
// We make this a function rather than a macro mainly for type safety
|
||||
// reasons. This function does not modify the vector, but it cannot
|
||||
// be marked `const` because it is stored into the non-`const` memref.
|
||||
template <typename T>
|
||||
static void vectorToMemref(std::vector<T> &v, StridedMemRefType<T, 1> &ref) {
|
||||
ref.basePtr = ref.data = v.data();
|
||||
ref.offset = 0;
|
||||
assert(v.size() <= std::numeric_limits<int64_t>::max() && "Size overflow");
|
||||
ref.sizes[0] = static_cast<int64_t>(v.size());
|
||||
ref.strides[0] = 1;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
extern "C" {
|
||||
@@ -286,20 +318,21 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
|
||||
StridedMemRefType<index_type, 1> *lvl2dimRef,
|
||||
StridedMemRefType<index_type, 1> *dim2lvlRef, OverheadType ptrTp,
|
||||
OverheadType indTp, PrimaryType valTp, Action action, void *ptr) {
|
||||
assert(dimSizesRef && dimSizesRef->strides[0] == 1);
|
||||
assert(lvlSizesRef && lvlSizesRef->strides[0] == 1);
|
||||
assert(lvlTypesRef && lvlTypesRef->strides[0] == 1);
|
||||
assert(lvl2dimRef && lvl2dimRef->strides[0] == 1);
|
||||
assert(dim2lvlRef && dim2lvlRef->strides[0] == 1);
|
||||
const uint64_t dimRank = dimSizesRef->sizes[0];
|
||||
const uint64_t lvlRank = lvlSizesRef->sizes[0];
|
||||
assert(dim2lvlRef->sizes[0] == dimRank);
|
||||
assert(lvlTypesRef->sizes[0] == lvlRank && lvl2dimRef->sizes[0] == lvlRank);
|
||||
const index_type *dimSizes = dimSizesRef->data + dimSizesRef->offset;
|
||||
const index_type *lvlSizes = lvlSizesRef->data + lvlSizesRef->offset;
|
||||
const DimLevelType *lvlTypes = lvlTypesRef->data + lvlTypesRef->offset;
|
||||
const index_type *lvl2dim = lvl2dimRef->data + lvl2dimRef->offset;
|
||||
const index_type *dim2lvl = dim2lvlRef->data + dim2lvlRef->offset;
|
||||
ASSERT_NO_STRIDE(dimSizesRef);
|
||||
ASSERT_NO_STRIDE(lvlSizesRef);
|
||||
ASSERT_NO_STRIDE(lvlTypesRef);
|
||||
ASSERT_NO_STRIDE(lvl2dimRef);
|
||||
ASSERT_NO_STRIDE(dim2lvlRef);
|
||||
const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef);
|
||||
const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
|
||||
assert(MEMREF_GET_USIZE(dim2lvlRef) == dimRank);
|
||||
assert(MEMREF_GET_USIZE(lvlTypesRef) == lvlRank);
|
||||
assert(MEMREF_GET_USIZE(lvl2dimRef) == lvlRank);
|
||||
const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
|
||||
const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
|
||||
const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
|
||||
const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
|
||||
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
|
||||
|
||||
// Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
|
||||
// This is safe because of the static_assert above.
|
||||
@@ -424,10 +457,8 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
|
||||
assert(ref &&tensor); \
|
||||
std::vector<V> *v; \
|
||||
static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \
|
||||
ref->basePtr = ref->data = v->data(); \
|
||||
ref->offset = 0; \
|
||||
ref->sizes[0] = v->size(); \
|
||||
ref->strides[0] = 1; \
|
||||
assert(v); \
|
||||
vectorToMemref(*v, *ref); \
|
||||
}
|
||||
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
|
||||
#undef IMPL_SPARSEVALUES
|
||||
@@ -438,10 +469,8 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
|
||||
assert(ref &&tensor); \
|
||||
std::vector<TYPE> *v; \
|
||||
static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \
|
||||
ref->basePtr = ref->data = v->data(); \
|
||||
ref->offset = 0; \
|
||||
ref->sizes[0] = v->size(); \
|
||||
ref->strides[0] = 1; \
|
||||
assert(v); \
|
||||
vectorToMemref(*v, *ref); \
|
||||
}
|
||||
#define IMPL_SPARSEPOINTERS(PNAME, P) \
|
||||
IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
|
||||
@@ -462,16 +491,17 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEINDICES)
|
||||
void *lvlCOO, StridedMemRefType<V, 0> *vref, \
|
||||
StridedMemRefType<index_type, 1> *dimIndRef, \
|
||||
StridedMemRefType<index_type, 1> *dim2lvlRef) { \
|
||||
assert(lvlCOO &&vref &&dimIndRef &&dim2lvlRef); \
|
||||
assert(dimIndRef->strides[0] == 1 && dim2lvlRef->strides[0] == 1); \
|
||||
const uint64_t rank = dimIndRef->sizes[0]; \
|
||||
assert(dim2lvlRef->sizes[0] == rank); \
|
||||
const index_type *dimInd = dimIndRef->data + dimIndRef->offset; \
|
||||
const index_type *dim2lvl = dim2lvlRef->data + dim2lvlRef->offset; \
|
||||
assert(lvlCOO &&vref); \
|
||||
ASSERT_NO_STRIDE(dimIndRef); \
|
||||
ASSERT_NO_STRIDE(dim2lvlRef); \
|
||||
const uint64_t rank = MEMREF_GET_USIZE(dimIndRef); \
|
||||
assert(MEMREF_GET_USIZE(dim2lvlRef) == rank); \
|
||||
const index_type *dimInd = MEMREF_GET_PAYLOAD(dimIndRef); \
|
||||
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \
|
||||
std::vector<index_type> lvlInd(rank); \
|
||||
for (uint64_t d = 0; d < rank; ++d) \
|
||||
lvlInd[dim2lvl[d]] = dimInd[d]; \
|
||||
V *value = vref->data + vref->offset; \
|
||||
V *value = MEMREF_GET_PAYLOAD(vref); \
|
||||
static_cast<SparseTensorCOO<V> *>(lvlCOO)->add(lvlInd, *value); \
|
||||
return lvlCOO; \
|
||||
}
|
||||
@@ -482,11 +512,11 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_ADDELT)
|
||||
bool _mlir_ciface_getNext##VNAME(void *iter, \
|
||||
StridedMemRefType<index_type, 1> *iref, \
|
||||
StridedMemRefType<V, 0> *vref) { \
|
||||
assert(iter &&iref &&vref); \
|
||||
assert(iref->strides[0] == 1); \
|
||||
index_type *indx = iref->data + iref->offset; \
|
||||
V *value = vref->data + vref->offset; \
|
||||
const uint64_t isize = iref->sizes[0]; \
|
||||
assert(iter &&vref); \
|
||||
ASSERT_NO_STRIDE(iref); \
|
||||
index_type *indx = MEMREF_GET_PAYLOAD(iref); \
|
||||
V *value = MEMREF_GET_PAYLOAD(vref); \
|
||||
const uint64_t isize = MEMREF_GET_USIZE(iref); \
|
||||
const Element<V> *elem = \
|
||||
static_cast<SparseTensorIterator<V> *>(iter)->getNext(); \
|
||||
if (elem == nullptr) \
|
||||
@@ -503,11 +533,11 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
|
||||
void _mlir_ciface_lexInsert##VNAME(void *tensor, \
|
||||
StridedMemRefType<index_type, 1> *cref, \
|
||||
StridedMemRefType<V, 0> *vref) { \
|
||||
assert(tensor &&cref &&vref); \
|
||||
assert(cref->strides[0] == 1); \
|
||||
index_type *cursor = cref->data + cref->offset; \
|
||||
assert(tensor &&vref); \
|
||||
ASSERT_NO_STRIDE(cref); \
|
||||
index_type *cursor = MEMREF_GET_PAYLOAD(cref); \
|
||||
assert(cursor); \
|
||||
V *value = vref->data + vref->offset; \
|
||||
V *value = MEMREF_GET_PAYLOAD(vref); \
|
||||
static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, *value); \
|
||||
}
|
||||
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
|
||||
@@ -518,16 +548,16 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
|
||||
void *tensor, StridedMemRefType<index_type, 1> *cref, \
|
||||
StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \
|
||||
StridedMemRefType<index_type, 1> *aref, index_type count) { \
|
||||
assert(tensor &&cref &&vref &&fref &&aref); \
|
||||
assert(cref->strides[0] == 1); \
|
||||
assert(vref->strides[0] == 1); \
|
||||
assert(fref->strides[0] == 1); \
|
||||
assert(aref->strides[0] == 1); \
|
||||
assert(vref->sizes[0] == fref->sizes[0]); \
|
||||
index_type *cursor = cref->data + cref->offset; \
|
||||
V *values = vref->data + vref->offset; \
|
||||
bool *filled = fref->data + fref->offset; \
|
||||
index_type *added = aref->data + aref->offset; \
|
||||
assert(tensor); \
|
||||
ASSERT_NO_STRIDE(cref); \
|
||||
ASSERT_NO_STRIDE(vref); \
|
||||
ASSERT_NO_STRIDE(fref); \
|
||||
ASSERT_NO_STRIDE(aref); \
|
||||
assert(MEMREF_GET_USIZE(vref) == MEMREF_GET_USIZE(fref)); \
|
||||
index_type *cursor = MEMREF_GET_PAYLOAD(cref); \
|
||||
V *values = MEMREF_GET_PAYLOAD(vref); \
|
||||
bool *filled = MEMREF_GET_PAYLOAD(fref); \
|
||||
index_type *added = MEMREF_GET_PAYLOAD(aref); \
|
||||
static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \
|
||||
cursor, values, filled, added, count); \
|
||||
}
|
||||
@@ -646,9 +676,9 @@ index_type getSparseTensorReaderDimSize(void *p, index_type d) {
|
||||
|
||||
void _mlir_ciface_getSparseTensorReaderDimSizes(
|
||||
void *p, StridedMemRefType<index_type, 1> *dref) {
|
||||
assert(p && dref);
|
||||
assert(dref->strides[0] == 1);
|
||||
index_type *dimSizes = dref->data + dref->offset;
|
||||
assert(p);
|
||||
ASSERT_NO_STRIDE(dref);
|
||||
index_type *dimSizes = MEMREF_GET_PAYLOAD(dref);
|
||||
SparseTensorReader &file = *static_cast<SparseTensorReader *>(p);
|
||||
const index_type *sizes = file.getDimSizes();
|
||||
index_type rank = file.getRank();
|
||||
@@ -664,12 +694,12 @@ void delSparseTensorReader(void *p) {
|
||||
void _mlir_ciface_getSparseTensorReaderNext##VNAME( \
|
||||
void *p, StridedMemRefType<index_type, 1> *iref, \
|
||||
StridedMemRefType<V, 0> *vref) { \
|
||||
assert(p &&iref &&vref); \
|
||||
assert(iref->strides[0] == 1); \
|
||||
index_type *indices = iref->data + iref->offset; \
|
||||
assert(p &&vref); \
|
||||
ASSERT_NO_STRIDE(iref); \
|
||||
index_type *indices = MEMREF_GET_PAYLOAD(iref); \
|
||||
SparseTensorReader *stfile = static_cast<SparseTensorReader *>(p); \
|
||||
index_type rank = stfile->getRank(); \
|
||||
V *value = vref->data + vref->offset; \
|
||||
V *value = MEMREF_GET_PAYLOAD(vref); \
|
||||
*value = stfile->readCOOElement<V>(rank, indices); \
|
||||
}
|
||||
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
|
||||
@@ -693,10 +723,10 @@ void delSparseTensorWriter(void *p) {
|
||||
void _mlir_ciface_outSparseTensorWriterMetaData(
|
||||
void *p, index_type rank, index_type nnz,
|
||||
StridedMemRefType<index_type, 1> *dref) {
|
||||
assert(p && dref);
|
||||
assert(dref->strides[0] == 1);
|
||||
assert(p);
|
||||
ASSERT_NO_STRIDE(dref);
|
||||
assert(rank != 0);
|
||||
index_type *dimSizes = dref->data + dref->offset;
|
||||
index_type *dimSizes = MEMREF_GET_PAYLOAD(dref);
|
||||
SparseTensorWriter &file = *static_cast<SparseTensorWriter *>(p);
|
||||
file << rank << " " << nnz << std::endl;
|
||||
for (index_type r = 0; r < rank - 1; ++r)
|
||||
@@ -708,13 +738,13 @@ void _mlir_ciface_outSparseTensorWriterMetaData(
|
||||
void _mlir_ciface_outSparseTensorWriterNext##VNAME( \
|
||||
void *p, index_type rank, StridedMemRefType<index_type, 1> *iref, \
|
||||
StridedMemRefType<V, 0> *vref) { \
|
||||
assert(p &&iref &&vref); \
|
||||
assert(iref->strides[0] == 1); \
|
||||
index_type *indices = iref->data + iref->offset; \
|
||||
assert(p &&vref); \
|
||||
ASSERT_NO_STRIDE(iref); \
|
||||
index_type *indices = MEMREF_GET_PAYLOAD(iref); \
|
||||
SparseTensorWriter &file = *static_cast<SparseTensorWriter *>(p); \
|
||||
for (uint64_t r = 0; r < rank; ++r) \
|
||||
file << (indices[r] + 1) << " "; \
|
||||
V *value = vref->data + vref->offset; \
|
||||
V *value = MEMREF_GET_PAYLOAD(vref); \
|
||||
file << *value << std::endl; \
|
||||
}
|
||||
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)
|
||||
@@ -722,4 +752,8 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)
|
||||
|
||||
} // extern "C"
|
||||
|
||||
#undef MEMREF_GET_PAYLOAD
|
||||
#undef MEMREF_GET_USIZE
|
||||
#undef ASSERT_NO_STRIDE
|
||||
|
||||
#endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
|
||||
|
||||
Reference in New Issue
Block a user