[mlir][sparse] add a cursor to sparse storage scheme
This prepare a subsequent revision that will generalize the insertion code generation. Similar to the support lib, insertions become much easier to perform with some "cursor" bookkeeping. Note that we, in the long run, could perhaps avoid storing the "cursor" permanently and use some retricted-scope solution (alloca?) instead. However, that puts harder restrictions on insertion-chain operations, so for now we follow the more straightforward approach. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D136800
This commit is contained in:
@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr uint64_t DimSizesIdx = 0;
|
||||
static constexpr uint64_t DimCursorIdx = 1;
|
||||
static constexpr uint64_t MemSizesIdx = 2;
|
||||
static constexpr uint64_t FieldsIdx = 3;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -90,11 +95,17 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
|
||||
.getResult();
|
||||
}
|
||||
|
||||
/// Translates field index to memSizes index.
|
||||
static unsigned getMemSizesIndex(unsigned field) {
|
||||
assert(FieldsIdx <= field);
|
||||
return field - FieldsIdx;
|
||||
}
|
||||
|
||||
/// Returns field index of sparse tensor type for pointers/indices, when set.
|
||||
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
|
||||
assert(getSparseTensorEncoding(type));
|
||||
RankedTensorType rType = type.cast<RankedTensorType>();
|
||||
unsigned field = 2; // start past sizes
|
||||
unsigned field = FieldsIdx; // start past header
|
||||
unsigned ptr = 0;
|
||||
unsigned idx = 0;
|
||||
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
|
||||
@@ -140,6 +151,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
|
||||
//
|
||||
// struct {
|
||||
// memref<rank x index> dimSizes ; size in each dimension
|
||||
// memref<rank x index> dimCursor ; cursor in each dimension
|
||||
// memref<n x index> memSizes ; sizes of ptrs/inds/values
|
||||
// ; per-dimension d:
|
||||
// ; if dense:
|
||||
@@ -153,11 +165,11 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
|
||||
// };
|
||||
//
|
||||
unsigned rank = rType.getShape().size();
|
||||
// The dimSizes array.
|
||||
fields.push_back(MemRefType::get({rank}, indexType));
|
||||
// The memSizes array.
|
||||
unsigned lastField = getFieldIndex(type, -1u, -1u);
|
||||
fields.push_back(MemRefType::get({lastField - 2}, indexType));
|
||||
// The dimSizes array, dimCursor array, and memSizes array.
|
||||
fields.push_back(MemRefType::get({rank}, indexType));
|
||||
fields.push_back(MemRefType::get({rank}, indexType));
|
||||
fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType));
|
||||
// Per-dimension storage.
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
// Dimension level types apply in order to the reordered dimension.
|
||||
@@ -179,7 +191,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Create allocation operation.
|
||||
/// Creates allocation operation.
|
||||
static Value createAllocation(OpBuilder &builder, Location loc, Type type,
|
||||
Value sz) {
|
||||
auto memType = MemRefType::get({ShapedType::kDynamicSize}, type);
|
||||
@@ -220,14 +232,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
|
||||
else
|
||||
sizes.push_back(constantIndex(builder, loc, shape[r]));
|
||||
}
|
||||
// The dimSizes array.
|
||||
// The dimSizes array, dimCursor array, and memSizes array.
|
||||
unsigned lastField = getFieldIndex(type, -1u, -1u);
|
||||
Value dimSizes =
|
||||
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
|
||||
fields.push_back(dimSizes);
|
||||
// The sizes array.
|
||||
unsigned lastField = getFieldIndex(type, -1u, -1u);
|
||||
Value dimCursor =
|
||||
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
|
||||
Value memSizes = builder.create<memref::AllocOp>(
|
||||
loc, MemRefType::get({lastField - 2}, indexType));
|
||||
loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType));
|
||||
fields.push_back(dimSizes);
|
||||
fields.push_back(dimCursor);
|
||||
fields.push_back(memSizes);
|
||||
// Per-dimension storage.
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
@@ -277,23 +291,17 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
|
||||
return forOp;
|
||||
}
|
||||
|
||||
/// Translates field index to memSizes index.
|
||||
static unsigned getMemSizesIndex(unsigned field) {
|
||||
assert(2 <= field);
|
||||
return field - 2;
|
||||
}
|
||||
|
||||
/// Creates a pushback op for given field and updates the fields array
|
||||
/// accordingly.
|
||||
static void createPushback(OpBuilder &builder, Location loc,
|
||||
SmallVectorImpl<Value> &fields, unsigned field,
|
||||
Value value) {
|
||||
assert(2 <= field && field < fields.size());
|
||||
assert(FieldsIdx <= field && field < fields.size());
|
||||
Type etp = fields[field].getType().cast<ShapedType>().getElementType();
|
||||
if (value.getType() != etp)
|
||||
value = builder.create<arith::IndexCastOp>(loc, etp, value);
|
||||
fields[field] = builder.create<PushBackOp>(
|
||||
loc, fields[field].getType(), fields[1], fields[field], value,
|
||||
loc, fields[field].getType(), fields[MemSizesIdx], fields[field], value,
|
||||
APInt(64, getMemSizesIndex(field)));
|
||||
}
|
||||
|
||||
@@ -312,8 +320,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
|
||||
return; // TODO: add codegen
|
||||
// push_back memSizes indices-0 index
|
||||
// push_back memSizes values value
|
||||
createPushback(builder, loc, fields, 3, indices[0]);
|
||||
createPushback(builder, loc, fields, 4, value);
|
||||
createPushback(builder, loc, fields, FieldsIdx + 1, indices[0]);
|
||||
createPushback(builder, loc, fields, FieldsIdx + 2, value);
|
||||
}
|
||||
|
||||
/// Generations insertion finalization code.
|
||||
@@ -329,9 +337,9 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
|
||||
// push_back memSizes pointers-0 memSizes[2]
|
||||
Value zero = constantIndex(builder, loc, 0);
|
||||
Value two = constantIndex(builder, loc, 2);
|
||||
Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
|
||||
createPushback(builder, loc, fields, 2, zero);
|
||||
createPushback(builder, loc, fields, 2, size);
|
||||
Value size = builder.create<memref::LoadOp>(loc, fields[MemSizesIdx], two);
|
||||
createPushback(builder, loc, fields, FieldsIdx, zero);
|
||||
createPushback(builder, loc, fields, FieldsIdx, size);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -759,7 +767,7 @@ public:
|
||||
unsigned lastField = fields.size() - 1;
|
||||
Value field =
|
||||
constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[1], field);
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[MemSizesIdx], field);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user