[mlir][sparse] add a cursor to sparse storage scheme

This prepare a subsequent revision that will generalize
the insertion code generation. Similar to the support lib,
insertions become much easier to perform with some "cursor"
bookkeeping. Note that we, in the long run, could perhaps
avoid storing the "cursor" permanently and use some
retricted-scope solution (alloca?) instead. However,
that puts harder restrictions on insertion-chain operations,
so for now we follow the more straightforward approach.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D136800
This commit is contained in:
Aart Bik
2022-10-26 15:07:18 -07:00
parent 846904195b
commit 80b08b68f2
4 changed files with 283 additions and 222 deletions

View File

@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
namespace {
static constexpr uint64_t DimSizesIdx = 0;
static constexpr uint64_t DimCursorIdx = 1;
static constexpr uint64_t MemSizesIdx = 2;
static constexpr uint64_t FieldsIdx = 3;
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
@@ -90,11 +95,17 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
.getResult();
}
/// Translates field index to memSizes index.
static unsigned getMemSizesIndex(unsigned field) {
assert(FieldsIdx <= field);
return field - FieldsIdx;
}
/// Returns field index of sparse tensor type for pointers/indices, when set.
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
assert(getSparseTensorEncoding(type));
RankedTensorType rType = type.cast<RankedTensorType>();
unsigned field = 2; // start past sizes
unsigned field = FieldsIdx; // start past header
unsigned ptr = 0;
unsigned idx = 0;
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
@@ -140,6 +151,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
//
// struct {
// memref<rank x index> dimSizes ; size in each dimension
// memref<rank x index> dimCursor ; cursor in each dimension
// memref<n x index> memSizes ; sizes of ptrs/inds/values
// ; per-dimension d:
// ; if dense:
@@ -153,11 +165,11 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
// };
//
unsigned rank = rType.getShape().size();
// The dimSizes array.
fields.push_back(MemRefType::get({rank}, indexType));
// The memSizes array.
unsigned lastField = getFieldIndex(type, -1u, -1u);
fields.push_back(MemRefType::get({lastField - 2}, indexType));
// The dimSizes array, dimCursor array, and memSizes array.
fields.push_back(MemRefType::get({rank}, indexType));
fields.push_back(MemRefType::get({rank}, indexType));
fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType));
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
// Dimension level types apply in order to the reordered dimension.
@@ -179,7 +191,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
return success();
}
/// Create allocation operation.
/// Creates allocation operation.
static Value createAllocation(OpBuilder &builder, Location loc, Type type,
Value sz) {
auto memType = MemRefType::get({ShapedType::kDynamicSize}, type);
@@ -220,14 +232,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
else
sizes.push_back(constantIndex(builder, loc, shape[r]));
}
// The dimSizes array.
// The dimSizes array, dimCursor array, and memSizes array.
unsigned lastField = getFieldIndex(type, -1u, -1u);
Value dimSizes =
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
fields.push_back(dimSizes);
// The sizes array.
unsigned lastField = getFieldIndex(type, -1u, -1u);
Value dimCursor =
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
Value memSizes = builder.create<memref::AllocOp>(
loc, MemRefType::get({lastField - 2}, indexType));
loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType));
fields.push_back(dimSizes);
fields.push_back(dimCursor);
fields.push_back(memSizes);
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
@@ -277,23 +291,17 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
return forOp;
}
/// Translates field index to memSizes index.
static unsigned getMemSizesIndex(unsigned field) {
assert(2 <= field);
return field - 2;
}
/// Creates a pushback op for given field and updates the fields array
/// accordingly.
static void createPushback(OpBuilder &builder, Location loc,
SmallVectorImpl<Value> &fields, unsigned field,
Value value) {
assert(2 <= field && field < fields.size());
assert(FieldsIdx <= field && field < fields.size());
Type etp = fields[field].getType().cast<ShapedType>().getElementType();
if (value.getType() != etp)
value = builder.create<arith::IndexCastOp>(loc, etp, value);
fields[field] = builder.create<PushBackOp>(
loc, fields[field].getType(), fields[1], fields[field], value,
loc, fields[field].getType(), fields[MemSizesIdx], fields[field], value,
APInt(64, getMemSizesIndex(field)));
}
@@ -312,8 +320,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
return; // TODO: add codegen
// push_back memSizes indices-0 index
// push_back memSizes values value
createPushback(builder, loc, fields, 3, indices[0]);
createPushback(builder, loc, fields, 4, value);
createPushback(builder, loc, fields, FieldsIdx + 1, indices[0]);
createPushback(builder, loc, fields, FieldsIdx + 2, value);
}
/// Generations insertion finalization code.
@@ -329,9 +337,9 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
// push_back memSizes pointers-0 memSizes[2]
Value zero = constantIndex(builder, loc, 0);
Value two = constantIndex(builder, loc, 2);
Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
createPushback(builder, loc, fields, 2, zero);
createPushback(builder, loc, fields, 2, size);
Value size = builder.create<memref::LoadOp>(loc, fields[MemSizesIdx], two);
createPushback(builder, loc, fields, FieldsIdx, zero);
createPushback(builder, loc, fields, FieldsIdx, size);
}
//===----------------------------------------------------------------------===//
@@ -759,7 +767,7 @@ public:
unsigned lastField = fields.size() - 1;
Value field =
constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[1], field);
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[MemSizesIdx], field);
return success();
}
};