[mlir][sparse] Refactoring: abstract sparse tensor memory scheme into a SparseTensorDescriptor class.
This patch abstracts sparse tensor memory scheme into a SparseTensorDescriptor class. Previously, the field accesses are performed in a relatively error-prone way, this patch hides the hairy details behind a SparseTensorDescriptor class to allow users access sparse tensor fields in a more cohesive way. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D138627
This commit is contained in:
@@ -36,10 +36,6 @@ using FuncGeneratorType =
|
||||
|
||||
static constexpr const char kInsertFuncNamePrefix[] = "_insert_";
|
||||
|
||||
static constexpr uint64_t dimSizesIdx = 0;
|
||||
static constexpr uint64_t memSizesIdx = 1;
|
||||
static constexpr uint64_t fieldsIdx = 2;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -49,6 +45,18 @@ static UnrealizedConversionCastOp getTuple(Value tensor) {
|
||||
return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
|
||||
}
|
||||
|
||||
static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
|
||||
auto tuple = getTuple(tensor);
|
||||
return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs());
|
||||
}
|
||||
|
||||
static MutSparseTensorDescriptor
|
||||
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
|
||||
auto tuple = getTuple(tensor);
|
||||
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
|
||||
return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields);
|
||||
}
|
||||
|
||||
/// Packs the given values as a "tuple" value.
|
||||
static Value genTuple(OpBuilder &builder, Location loc, Type tp,
|
||||
ValueRange values) {
|
||||
@@ -56,6 +64,14 @@ static Value genTuple(OpBuilder &builder, Location loc, Type tp,
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
static Value genTuple(OpBuilder &builder, Location loc,
|
||||
SparseTensorDescriptor desc) {
|
||||
return builder
|
||||
.create<UnrealizedConversionCastOp>(loc, desc.getTensorType(),
|
||||
desc.getFields())
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
/// Flatten a list of operands that may contain sparse tensors.
|
||||
static void flattenOperands(ValueRange operands,
|
||||
SmallVectorImpl<Value> &flattened) {
|
||||
@@ -101,7 +117,7 @@ static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
|
||||
|
||||
/// Creates a straightforward counting for-loop.
|
||||
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
|
||||
SmallVectorImpl<Value> &fields,
|
||||
MutableArrayRef<Value> fields,
|
||||
Value lower = Value()) {
|
||||
Type indexType = builder.getIndexType();
|
||||
if (!lower)
|
||||
@@ -118,81 +134,46 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
|
||||
/// original dimension 'dim'. Returns std::nullopt if no sparse encoding is
|
||||
/// attached to the given tensor type.
|
||||
static Optional<Value> sizeFromTensorAtDim(OpBuilder &builder, Location loc,
|
||||
RankedTensorType tensorTp,
|
||||
Value adaptedValue, unsigned dim) {
|
||||
auto enc = getSparseTensorEncoding(tensorTp);
|
||||
if (!enc)
|
||||
return std::nullopt;
|
||||
|
||||
SparseTensorDescriptor desc,
|
||||
unsigned dim) {
|
||||
RankedTensorType rtp = desc.getTensorType();
|
||||
// Access into static dimension can query original type directly.
|
||||
// Note that this is typically already done by DimOp's folding.
|
||||
auto shape = tensorTp.getShape();
|
||||
auto shape = rtp.getShape();
|
||||
if (!ShapedType::isDynamic(shape[dim]))
|
||||
return constantIndex(builder, loc, shape[dim]);
|
||||
|
||||
// Any other query can consult the dimSizes array at field DimSizesIdx,
|
||||
// accounting for the reordering applied to the sparse storage.
|
||||
auto tuple = getTuple(adaptedValue);
|
||||
Value idx = constantIndex(builder, loc, toStoredDim(tensorTp, dim));
|
||||
return builder
|
||||
.create<memref::LoadOp>(loc, tuple.getInputs()[dimSizesIdx], idx)
|
||||
Value idx = constantIndex(builder, loc, toStoredDim(rtp, dim));
|
||||
return builder.create<memref::LoadOp>(loc, desc.getDimSizesMemRef(), idx)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// Gets the dimension size at the given stored dimension 'd', either as a
|
||||
// constant for a static size, or otherwise dynamically through memSizes.
|
||||
Value sizeAtStoredDim(OpBuilder &builder, Location loc, RankedTensorType rtp,
|
||||
SmallVectorImpl<Value> &fields, unsigned d) {
|
||||
Value sizeAtStoredDim(OpBuilder &builder, Location loc,
|
||||
SparseTensorDescriptor desc, unsigned d) {
|
||||
RankedTensorType rtp = desc.getTensorType();
|
||||
unsigned dim = toOrigDim(rtp, d);
|
||||
auto shape = rtp.getShape();
|
||||
if (!ShapedType::isDynamic(shape[dim]))
|
||||
return constantIndex(builder, loc, shape[dim]);
|
||||
return genLoad(builder, loc, fields[dimSizesIdx],
|
||||
|
||||
return genLoad(builder, loc, desc.getDimSizesMemRef(),
|
||||
constantIndex(builder, loc, d));
|
||||
}
|
||||
|
||||
/// Translates field index to memSizes index.
|
||||
static unsigned getMemSizesIndex(unsigned field) {
|
||||
assert(fieldsIdx <= field);
|
||||
return field - fieldsIdx;
|
||||
}
|
||||
|
||||
/// Creates a pushback op for given field and updates the fields array
|
||||
/// accordingly. This operation also updates the memSizes contents.
|
||||
static void createPushback(OpBuilder &builder, Location loc,
|
||||
SmallVectorImpl<Value> &fields, unsigned field,
|
||||
MutSparseTensorDescriptor desc, unsigned fidx,
|
||||
Value value, Value repeat = Value()) {
|
||||
assert(fieldsIdx <= field && field < fields.size());
|
||||
Type etp = fields[field].getType().cast<ShapedType>().getElementType();
|
||||
fields[field] = builder.create<PushBackOp>(
|
||||
loc, fields[field].getType(), fields[memSizesIdx], fields[field],
|
||||
toType(builder, loc, value, etp), APInt(64, getMemSizesIndex(field)),
|
||||
Type etp = desc.getElementType(fidx);
|
||||
Value field = desc.getField(fidx);
|
||||
Value newField = builder.create<PushBackOp>(
|
||||
loc, field.getType(), desc.getMemSizesMemRef(), field,
|
||||
toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(fidx)),
|
||||
repeat);
|
||||
}
|
||||
|
||||
/// Returns field index of sparse tensor type for pointers/indices, when set.
|
||||
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
|
||||
assert(getSparseTensorEncoding(type));
|
||||
RankedTensorType rType = type.cast<RankedTensorType>();
|
||||
unsigned field = fieldsIdx; // start past header
|
||||
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
|
||||
if (isCompressedDim(rType, r)) {
|
||||
if (r == ptrDim)
|
||||
return field;
|
||||
field++;
|
||||
if (r == idxDim)
|
||||
return field;
|
||||
field++;
|
||||
} else if (isSingletonDim(rType, r)) {
|
||||
if (r == idxDim)
|
||||
return field;
|
||||
field++;
|
||||
} else {
|
||||
assert(isDenseDim(rType, r)); // no fields
|
||||
}
|
||||
}
|
||||
assert(ptrDim == -1u && idxDim == -1u);
|
||||
return field + 1; // return values field index
|
||||
desc.setField(fidx, newField);
|
||||
}
|
||||
|
||||
/// Maps a sparse tensor type to the appropriate compounded buffers.
|
||||
@@ -201,64 +182,24 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
|
||||
auto enc = getSparseTensorEncoding(type);
|
||||
if (!enc)
|
||||
return std::nullopt;
|
||||
// Construct the basic types.
|
||||
auto *context = type.getContext();
|
||||
|
||||
RankedTensorType rType = type.cast<RankedTensorType>();
|
||||
Type indexType = IndexType::get(context);
|
||||
Type idxType = enc.getIndexType();
|
||||
Type ptrType = enc.getPointerType();
|
||||
Type eltType = rType.getElementType();
|
||||
//
|
||||
// Sparse tensor storage scheme for rank-dimensional tensor is organized
|
||||
// as a single compound type with the following fields. Note that every
|
||||
// memref with ? size actually behaves as a "vector", i.e. the stored
|
||||
// size is the capacity and the used size resides in the memSizes array.
|
||||
//
|
||||
// struct {
|
||||
// memref<rank x index> dimSizes ; size in each dimension
|
||||
// memref<n x index> memSizes ; sizes of ptrs/inds/values
|
||||
// ; per-dimension d:
|
||||
// ; if dense:
|
||||
// <nothing>
|
||||
// ; if compresed:
|
||||
// memref<? x ptr> pointers-d ; pointers for sparse dim d
|
||||
// memref<? x idx> indices-d ; indices for sparse dim d
|
||||
// ; if singleton:
|
||||
// memref<? x idx> indices-d ; indices for singleton dim d
|
||||
// memref<? x eltType> values ; values
|
||||
// };
|
||||
//
|
||||
unsigned rank = rType.getShape().size();
|
||||
unsigned lastField = getFieldIndex(type, -1u, -1u);
|
||||
// The dimSizes array and memSizes array.
|
||||
fields.push_back(MemRefType::get({rank}, indexType));
|
||||
fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType));
|
||||
// Per-dimension storage.
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
// Dimension level types apply in order to the reordered dimension.
|
||||
// As a result, the compound type can be constructed directly in the given
|
||||
// order. Clients of this type know what field is what from the sparse
|
||||
// tensor type.
|
||||
if (isCompressedDim(rType, r)) {
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamic}, ptrType));
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType));
|
||||
} else if (isSingletonDim(rType, r)) {
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType));
|
||||
} else {
|
||||
assert(isDenseDim(rType, r)); // no fields
|
||||
}
|
||||
}
|
||||
// The values array.
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamic}, eltType));
|
||||
assert(fields.size() == lastField);
|
||||
foreachFieldAndTypeInSparseTensor(
|
||||
rType,
|
||||
[&fields](Type fieldType, unsigned fieldIdx,
|
||||
SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/,
|
||||
DimLevelType /*dlt*/) -> bool {
|
||||
assert(fieldIdx == fields.size());
|
||||
fields.push_back(fieldType);
|
||||
return true;
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Generates code that allocates a sparse storage scheme for given rank.
|
||||
static void allocSchemeForRank(OpBuilder &builder, Location loc,
|
||||
RankedTensorType rtp,
|
||||
SmallVectorImpl<Value> &fields, unsigned field,
|
||||
unsigned r0) {
|
||||
MutSparseTensorDescriptor desc, unsigned r0) {
|
||||
RankedTensorType rtp = desc.getTensorType();
|
||||
unsigned rank = rtp.getShape().size();
|
||||
Value linear = constantIndex(builder, loc, 1);
|
||||
for (unsigned r = r0; r < rank; r++) {
|
||||
@@ -268,7 +209,8 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
|
||||
// the desired "linear + 1" length property at all times.
|
||||
Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
|
||||
Value ptrZero = constantZero(builder, loc, ptrType);
|
||||
createPushback(builder, loc, fields, field, ptrZero, linear);
|
||||
createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero,
|
||||
linear);
|
||||
return;
|
||||
}
|
||||
if (isSingletonDim(rtp, r)) {
|
||||
@@ -278,23 +220,23 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
|
||||
// at this level. We will eventually reach a compressed level or
|
||||
// otherwise the values array for the from-here "all-dense" case.
|
||||
assert(isDenseDim(rtp, r));
|
||||
Value size = sizeAtStoredDim(builder, loc, rtp, fields, r);
|
||||
Value size = sizeAtStoredDim(builder, loc, desc, r);
|
||||
linear = builder.create<arith::MulIOp>(loc, linear, size);
|
||||
}
|
||||
// Reached values array so prepare for an insertion.
|
||||
Value valZero = constantZero(builder, loc, rtp.getElementType());
|
||||
createPushback(builder, loc, fields, field, valZero, linear);
|
||||
assert(fields.size() == ++field);
|
||||
createPushback(builder, loc, desc, desc.getValMemRefIndex(), valZero, linear);
|
||||
}
|
||||
|
||||
/// Creates allocation operation.
|
||||
static Value createAllocation(OpBuilder &builder, Location loc, Type type,
|
||||
Value sz, bool enableInit) {
|
||||
auto memType = MemRefType::get({ShapedType::kDynamic}, type);
|
||||
Value buffer = builder.create<memref::AllocOp>(loc, memType, sz);
|
||||
static Value createAllocation(OpBuilder &builder, Location loc,
|
||||
MemRefType memRefType, Value sz,
|
||||
bool enableInit) {
|
||||
Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
|
||||
Type elemType = memRefType.getElementType();
|
||||
if (enableInit) {
|
||||
Value fillValue =
|
||||
builder.create<arith::ConstantOp>(loc, type, builder.getZeroAttr(type));
|
||||
Value fillValue = builder.create<arith::ConstantOp>(
|
||||
loc, elemType, builder.getZeroAttr(elemType));
|
||||
builder.create<linalg::FillOp>(loc, fillValue, buffer);
|
||||
}
|
||||
return buffer;
|
||||
@@ -310,69 +252,68 @@ static Value createAllocation(OpBuilder &builder, Location loc, Type type,
|
||||
static void createAllocFields(OpBuilder &builder, Location loc, Type type,
|
||||
ValueRange dynSizes, bool enableInit,
|
||||
SmallVectorImpl<Value> &fields) {
|
||||
auto enc = getSparseTensorEncoding(type);
|
||||
assert(enc);
|
||||
RankedTensorType rtp = type.cast<RankedTensorType>();
|
||||
Type indexType = builder.getIndexType();
|
||||
Type idxType = enc.getIndexType();
|
||||
Type ptrType = enc.getPointerType();
|
||||
Type eltType = rtp.getElementType();
|
||||
auto shape = rtp.getShape();
|
||||
unsigned rank = shape.size();
|
||||
Value heuristic = constantIndex(builder, loc, 16);
|
||||
|
||||
foreachFieldAndTypeInSparseTensor(
|
||||
rtp,
|
||||
[&builder, &fields, loc, heuristic,
|
||||
enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind,
|
||||
unsigned /*dim*/, DimLevelType /*dlt*/) -> bool {
|
||||
assert(fields.size() == fIdx);
|
||||
auto memRefTp = fType.cast<MemRefType>();
|
||||
Value field;
|
||||
switch (fKind) {
|
||||
case SparseTensorFieldKind::DimSizes:
|
||||
case SparseTensorFieldKind::MemSizes:
|
||||
field = builder.create<memref::AllocOp>(loc, memRefTp);
|
||||
break;
|
||||
case SparseTensorFieldKind::PtrMemRef:
|
||||
case SparseTensorFieldKind::IdxMemRef:
|
||||
case SparseTensorFieldKind::ValMemRef:
|
||||
field =
|
||||
createAllocation(builder, loc, memRefTp, heuristic, enableInit);
|
||||
break;
|
||||
}
|
||||
assert(field);
|
||||
fields.push_back(field);
|
||||
// Returns true to continue the iteration.
|
||||
return true;
|
||||
});
|
||||
|
||||
MutSparseTensorDescriptor desc(rtp, fields);
|
||||
|
||||
// Build original sizes.
|
||||
SmallVector<Value> sizes;
|
||||
auto shape = rtp.getShape();
|
||||
unsigned rank = shape.size();
|
||||
for (unsigned r = 0, o = 0; r < rank; r++) {
|
||||
if (ShapedType::isDynamic(shape[r]))
|
||||
sizes.push_back(dynSizes[o++]);
|
||||
else
|
||||
sizes.push_back(constantIndex(builder, loc, shape[r]));
|
||||
}
|
||||
// The dimSizes array and memSizes array.
|
||||
unsigned lastField = getFieldIndex(type, -1u, -1u);
|
||||
Value dimSizes =
|
||||
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
|
||||
Value memSizes = builder.create<memref::AllocOp>(
|
||||
loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType));
|
||||
fields.push_back(dimSizes);
|
||||
fields.push_back(memSizes);
|
||||
// Per-dimension storage.
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
if (isCompressedDim(rtp, r)) {
|
||||
fields.push_back(
|
||||
createAllocation(builder, loc, ptrType, heuristic, enableInit));
|
||||
fields.push_back(
|
||||
createAllocation(builder, loc, idxType, heuristic, enableInit));
|
||||
} else if (isSingletonDim(rtp, r)) {
|
||||
fields.push_back(
|
||||
createAllocation(builder, loc, idxType, heuristic, enableInit));
|
||||
} else {
|
||||
assert(isDenseDim(rtp, r)); // no fields
|
||||
}
|
||||
}
|
||||
// The values array.
|
||||
fields.push_back(
|
||||
createAllocation(builder, loc, eltType, heuristic, enableInit));
|
||||
assert(fields.size() == lastField);
|
||||
// Initialize the storage scheme to an empty tensor. Initialized memSizes
|
||||
// to all zeros, sets the dimSizes to known values and gives all pointer
|
||||
// fields an initial zero entry, so that it is easier to maintain the
|
||||
// "linear + 1" length property.
|
||||
builder.create<linalg::FillOp>(
|
||||
loc, ValueRange{constantZero(builder, loc, indexType)},
|
||||
ValueRange{memSizes}); // zero memSizes
|
||||
Value ptrZero = constantZero(builder, loc, ptrType);
|
||||
for (unsigned r = 0, field = fieldsIdx; r < rank; r++) {
|
||||
loc, constantZero(builder, loc, builder.getIndexType()),
|
||||
desc.getMemSizesMemRef()); // zero memSizes
|
||||
|
||||
Value ptrZero =
|
||||
constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType());
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
unsigned ro = toOrigDim(rtp, r);
|
||||
genStore(builder, loc, sizes[ro], dimSizes, constantIndex(builder, loc, r));
|
||||
if (isCompressedDim(rtp, r)) {
|
||||
createPushback(builder, loc, fields, field, ptrZero);
|
||||
field += 2;
|
||||
} else if (isSingletonDim(rtp, r)) {
|
||||
field += 1;
|
||||
}
|
||||
// Fills dim sizes array.
|
||||
genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(),
|
||||
constantIndex(builder, loc, r));
|
||||
|
||||
// Pushes a leading zero to pointers memref.
|
||||
if (isCompressedDim(rtp, r))
|
||||
createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero);
|
||||
}
|
||||
allocSchemeForRank(builder, loc, rtp, fields, fieldsIdx, /*rank=*/0);
|
||||
allocSchemeForRank(builder, loc, desc, /*rank=*/0);
|
||||
}
|
||||
|
||||
/// Helper method that generates block specific to compressed case:
|
||||
@@ -396,19 +337,22 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
|
||||
/// }
|
||||
/// pos[d] = next
|
||||
static Value genCompressed(OpBuilder &builder, Location loc,
|
||||
RankedTensorType rtp, SmallVectorImpl<Value> &fields,
|
||||
MutSparseTensorDescriptor desc,
|
||||
SmallVectorImpl<Value> &indices, Value value,
|
||||
Value pos, unsigned field, unsigned d) {
|
||||
Value pos, unsigned d) {
|
||||
RankedTensorType rtp = desc.getTensorType();
|
||||
unsigned rank = rtp.getShape().size();
|
||||
SmallVector<Type> types;
|
||||
Type indexType = builder.getIndexType();
|
||||
Type boolType = builder.getIntegerType(1);
|
||||
unsigned idxIndex = desc.getIdxMemRefIndex(d);
|
||||
unsigned ptrIndex = desc.getPtrMemRefIndex(d);
|
||||
Value one = constantIndex(builder, loc, 1);
|
||||
Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
|
||||
Value plo = genLoad(builder, loc, fields[field], pos);
|
||||
Value phi = genLoad(builder, loc, fields[field], pp1);
|
||||
Value psz = constantIndex(builder, loc, getMemSizesIndex(field + 1));
|
||||
Value msz = genLoad(builder, loc, fields[memSizesIdx], psz);
|
||||
Value plo = genLoad(builder, loc, desc.getField(ptrIndex), pos);
|
||||
Value phi = genLoad(builder, loc, desc.getField(ptrIndex), pp1);
|
||||
Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(idxIndex));
|
||||
Value msz = genLoad(builder, loc, desc.getMemSizesMemRef(), psz);
|
||||
Value phim1 = builder.create<arith::SubIOp>(
|
||||
loc, toType(builder, loc, phi, indexType), one);
|
||||
// Conditional expression.
|
||||
@@ -418,49 +362,55 @@ static Value genCompressed(OpBuilder &builder, Location loc,
|
||||
scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
|
||||
types.pop_back();
|
||||
builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
|
||||
Value crd = genLoad(builder, loc, fields[field + 1], phim1);
|
||||
Value crd = genLoad(builder, loc, desc.getField(idxIndex), phim1);
|
||||
Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
toType(builder, loc, crd, indexType),
|
||||
indices[d]);
|
||||
builder.create<scf::YieldOp>(loc, eq);
|
||||
builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
|
||||
if (d > 0)
|
||||
genStore(builder, loc, msz, fields[field], pos);
|
||||
genStore(builder, loc, msz, desc.getField(ptrIndex), pos);
|
||||
builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
|
||||
builder.setInsertionPointAfter(ifOp1);
|
||||
Value p = ifOp1.getResult(0);
|
||||
// If present construct. Note that for a non-unique dimension level, we simply
|
||||
// set the condition to false and rely on CSE/DCE to clean up the IR.
|
||||
// If present construct. Note that for a non-unique dimension level, we
|
||||
// simply set the condition to false and rely on CSE/DCE to clean up the IR.
|
||||
//
|
||||
// TODO: generate less temporary IR?
|
||||
//
|
||||
for (unsigned i = 0, e = fields.size(); i < e; i++)
|
||||
types.push_back(fields[i].getType());
|
||||
for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
|
||||
types.push_back(desc.getField(i).getType());
|
||||
types.push_back(indexType);
|
||||
if (!isUniqueDim(rtp, d))
|
||||
p = constantI1(builder, loc, false);
|
||||
scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
|
||||
// If present (fields unaffected, update next to phim1).
|
||||
builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
|
||||
fields.push_back(phim1);
|
||||
builder.create<scf::YieldOp>(loc, fields);
|
||||
fields.pop_back();
|
||||
|
||||
// FIXME: This does not looks like a clean way, but probably the most
|
||||
// efficient way.
|
||||
desc.getFields().push_back(phim1);
|
||||
builder.create<scf::YieldOp>(loc, desc.getFields());
|
||||
desc.getFields().pop_back();
|
||||
|
||||
// If !present (changes fields, update next).
|
||||
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
|
||||
Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
|
||||
genStore(builder, loc, mszp1, fields[field], pp1);
|
||||
createPushback(builder, loc, fields, field + 1, indices[d]);
|
||||
genStore(builder, loc, mszp1, desc.getField(ptrIndex), pp1);
|
||||
createPushback(builder, loc, desc, idxIndex, indices[d]);
|
||||
// Prepare the next dimension "as needed".
|
||||
if ((d + 1) < rank)
|
||||
allocSchemeForRank(builder, loc, rtp, fields, field + 2, d + 1);
|
||||
fields.push_back(msz);
|
||||
builder.create<scf::YieldOp>(loc, fields);
|
||||
fields.pop_back();
|
||||
allocSchemeForRank(builder, loc, desc, d + 1);
|
||||
|
||||
desc.getFields().push_back(msz);
|
||||
builder.create<scf::YieldOp>(loc, desc.getFields());
|
||||
desc.getFields().pop_back();
|
||||
|
||||
// Update fields and return next pos.
|
||||
builder.setInsertionPointAfter(ifOp2);
|
||||
unsigned o = 0;
|
||||
for (unsigned i = 0, e = fields.size(); i < e; i++)
|
||||
fields[i] = ifOp2.getResult(o++);
|
||||
for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
|
||||
desc.setField(i, ifOp2.getResult(o++));
|
||||
return ifOp2.getResult(o);
|
||||
}
|
||||
|
||||
@@ -488,11 +438,10 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
|
||||
// Construct fields and indices arrays from parameters.
|
||||
ValueRange tmp = args.drop_back(rank + 1);
|
||||
SmallVector<Value> fields(tmp.begin(), tmp.end());
|
||||
MutSparseTensorDescriptor desc(rtp, fields);
|
||||
tmp = args.take_back(rank + 1).drop_back();
|
||||
SmallVector<Value> indices(tmp.begin(), tmp.end());
|
||||
Value value = args.back();
|
||||
|
||||
unsigned field = fieldsIdx; // Start past header.
|
||||
Value pos = constantZero(builder, loc, builder.getIndexType());
|
||||
// Generate code for every dimension.
|
||||
for (unsigned d = 0; d < rank; d++) {
|
||||
@@ -504,39 +453,35 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
|
||||
// }
|
||||
// pos[d] = indices.size() - 1
|
||||
// <insert @ pos[d] at next dimension d + 1>
|
||||
pos = genCompressed(builder, loc, rtp, fields, indices, value, pos, field,
|
||||
d);
|
||||
field += 2;
|
||||
pos = genCompressed(builder, loc, desc, indices, value, pos, d);
|
||||
} else if (isSingletonDim(rtp, d)) {
|
||||
// Create:
|
||||
// indices[d].push_back(i[d])
|
||||
// pos[d] = pos[d-1]
|
||||
// <insert @ pos[d] at next dimension d + 1>
|
||||
createPushback(builder, loc, fields, field, indices[d]);
|
||||
field += 1;
|
||||
createPushback(builder, loc, desc, desc.getIdxMemRefIndex(d), indices[d]);
|
||||
} else {
|
||||
assert(isDenseDim(rtp, d));
|
||||
// Construct the new position as:
|
||||
// pos[d] = size * pos[d-1] + i[d]
|
||||
// <insert @ pos[d] at next dimension d + 1>
|
||||
Value size = sizeAtStoredDim(builder, loc, rtp, fields, d);
|
||||
Value size = sizeAtStoredDim(builder, loc, desc, d);
|
||||
Value mult = builder.create<arith::MulIOp>(loc, size, pos);
|
||||
pos = builder.create<arith::AddIOp>(loc, mult, indices[d]);
|
||||
}
|
||||
}
|
||||
// Reached the actual value append/insert.
|
||||
if (!isDenseDim(rtp, rank - 1))
|
||||
createPushback(builder, loc, fields, field++, value);
|
||||
createPushback(builder, loc, desc, desc.getValMemRefIndex(), value);
|
||||
else
|
||||
genStore(builder, loc, value, fields[field++], pos);
|
||||
assert(fields.size() == field);
|
||||
genStore(builder, loc, value, desc.getValMemRef(), pos);
|
||||
builder.create<func::ReturnOp>(loc, fields);
|
||||
}
|
||||
|
||||
/// Generates a call to a function to perform an insertion operation. If the
|
||||
/// function doesn't exist yet, call `createFunc` to generate the function.
|
||||
static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp,
|
||||
SmallVectorImpl<Value> &fields,
|
||||
static void genInsertionCallHelper(OpBuilder &builder,
|
||||
MutSparseTensorDescriptor desc,
|
||||
SmallVectorImpl<Value> &indices, Value value,
|
||||
func::FuncOp insertPoint,
|
||||
StringRef namePrefix,
|
||||
@@ -544,6 +489,7 @@ static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp,
|
||||
// The mangled name of the function has this format:
|
||||
// <namePrefix>_[C|S|D]_<shape>_<ordering>_<eltType>
|
||||
// _<indexBitWidth>_<pointerBitWidth>
|
||||
RankedTensorType rtp = desc.getTensorType();
|
||||
SmallString<32> nameBuffer;
|
||||
llvm::raw_svector_ostream nameOstream(nameBuffer);
|
||||
nameOstream << namePrefix;
|
||||
@@ -577,7 +523,7 @@ static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp,
|
||||
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
|
||||
|
||||
// Construct parameters for fields and indices.
|
||||
SmallVector<Value> operands(fields.begin(), fields.end());
|
||||
SmallVector<Value> operands(desc.getFields().begin(), desc.getFields().end());
|
||||
operands.append(indices.begin(), indices.end());
|
||||
operands.push_back(value);
|
||||
Location loc = insertPoint.getLoc();
|
||||
@@ -590,7 +536,7 @@ static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp,
|
||||
func = builder.create<func::FuncOp>(
|
||||
loc, nameOstream.str(),
|
||||
FunctionType::get(context, ValueRange(operands).getTypes(),
|
||||
ValueRange(fields).getTypes()));
|
||||
ValueRange(desc.getFields()).getTypes()));
|
||||
func.setPrivate();
|
||||
createFunc(builder, module, func, rtp);
|
||||
}
|
||||
@@ -598,42 +544,44 @@ static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp,
|
||||
// Generate a call to perform the insertion and update `fields` with values
|
||||
// returned from the call.
|
||||
func::CallOp call = builder.create<func::CallOp>(loc, func, operands);
|
||||
for (size_t i = 0; i < fields.size(); i++) {
|
||||
fields[i] = call.getResult(i);
|
||||
for (size_t i = 0, e = desc.getNumFields(); i < e; i++) {
|
||||
desc.getFields()[i] = call.getResult(i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generations insertion finalization code.
|
||||
static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
|
||||
SmallVectorImpl<Value> &fields) {
|
||||
static void genEndInsert(OpBuilder &builder, Location loc,
|
||||
MutSparseTensorDescriptor desc) {
|
||||
RankedTensorType rtp = desc.getTensorType();
|
||||
unsigned rank = rtp.getShape().size();
|
||||
unsigned field = fieldsIdx; // start past header
|
||||
for (unsigned d = 0; d < rank; d++) {
|
||||
if (isCompressedDim(rtp, d)) {
|
||||
// Compressed dimensions need a pointer cleanup for all entries
|
||||
// that were not visited during the insertion pass.
|
||||
//
|
||||
// TODO: avoid cleanup and keep compressed scheme consistent at all times?
|
||||
// TODO: avoid cleanup and keep compressed scheme consistent at all
|
||||
// times?
|
||||
//
|
||||
if (d > 0) {
|
||||
Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
|
||||
Value mz = constantIndex(builder, loc, getMemSizesIndex(field));
|
||||
Value hi = genLoad(builder, loc, fields[memSizesIdx], mz);
|
||||
Value ptrMemRef = desc.getPtrMemRef(d);
|
||||
Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d));
|
||||
Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz);
|
||||
Value zero = constantIndex(builder, loc, 0);
|
||||
Value one = constantIndex(builder, loc, 1);
|
||||
// Vector of only one, but needed by createFor's prototype.
|
||||
SmallVector<Value, 1> inits{genLoad(builder, loc, fields[field], zero)};
|
||||
SmallVector<Value, 1> inits{genLoad(builder, loc, ptrMemRef, zero)};
|
||||
scf::ForOp loop = createFor(builder, loc, hi, inits, one);
|
||||
Value i = loop.getInductionVar();
|
||||
Value oldv = loop.getRegionIterArg(0);
|
||||
Value newv = genLoad(builder, loc, fields[field], i);
|
||||
Value newv = genLoad(builder, loc, ptrMemRef, i);
|
||||
Value ptrZero = constantZero(builder, loc, ptrType);
|
||||
Value cond = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, newv, ptrZero);
|
||||
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(ptrType),
|
||||
cond, /*else*/ true);
|
||||
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
||||
genStore(builder, loc, oldv, fields[field], i);
|
||||
genStore(builder, loc, oldv, ptrMemRef, i);
|
||||
builder.create<scf::YieldOp>(loc, oldv);
|
||||
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
|
||||
builder.create<scf::YieldOp>(loc, newv);
|
||||
@@ -641,14 +589,10 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
|
||||
builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
|
||||
builder.setInsertionPointAfter(loop);
|
||||
}
|
||||
field += 2;
|
||||
} else if (isSingletonDim(rtp, d)) {
|
||||
field++;
|
||||
} else {
|
||||
assert(isDenseDim(rtp, d));
|
||||
assert(isDenseDim(rtp, d) || isSingletonDim(rtp, d));
|
||||
}
|
||||
}
|
||||
assert(fields.size() == ++field);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -739,12 +683,12 @@ public:
|
||||
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Optional<int64_t> index = op.getConstantIndex();
|
||||
if (!index)
|
||||
if (!index || !getSparseTensorEncoding(adaptor.getSource().getType()))
|
||||
return failure();
|
||||
auto sz =
|
||||
sizeFromTensorAtDim(rewriter, op.getLoc(),
|
||||
op.getSource().getType().cast<RankedTensorType>(),
|
||||
adaptor.getSource(), *index);
|
||||
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
|
||||
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
|
||||
|
||||
if (!sz)
|
||||
return failure();
|
||||
|
||||
@@ -834,16 +778,14 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RankedTensorType srcType =
|
||||
op.getTensor().getType().cast<RankedTensorType>();
|
||||
auto tuple = getTuple(adaptor.getTensor());
|
||||
// Prepare fields.
|
||||
SmallVector<Value> fields(tuple.getInputs());
|
||||
// Prepare descriptor.
|
||||
SmallVector<Value> fields;
|
||||
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
|
||||
// Generate optional insertion finalization code.
|
||||
if (op.getHasInserts())
|
||||
genEndInsert(rewriter, op.getLoc(), srcType, fields);
|
||||
genEndInsert(rewriter, op.getLoc(), desc);
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields));
|
||||
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -855,7 +797,10 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!getSparseTensorEncoding(op.getTensor().getType()))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
RankedTensorType srcType =
|
||||
op.getTensor().getType().cast<RankedTensorType>();
|
||||
Type eltType = srcType.getElementType();
|
||||
@@ -867,8 +812,7 @@ public:
|
||||
// dimension size, translated back to original dimension). Note that we
|
||||
// recursively rewrite the new DimOp on the **original** tensor.
|
||||
unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
|
||||
auto sz = sizeFromTensorAtDim(rewriter, loc, srcType, adaptor.getTensor(),
|
||||
innerDim);
|
||||
auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
|
||||
assert(sz); // This for sure is a sparse tensor
|
||||
// Generate a memref for `sz` elements of type `t`.
|
||||
auto genAlloc = [&](Type t) {
|
||||
@@ -908,16 +852,15 @@ public:
|
||||
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
RankedTensorType dstType =
|
||||
op.getTensor().getType().cast<RankedTensorType>();
|
||||
Type eltType = dstType.getElementType();
|
||||
auto tuple = getTuple(adaptor.getTensor());
|
||||
SmallVector<Value> fields;
|
||||
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
|
||||
Value values = adaptor.getValues();
|
||||
Value filled = adaptor.getFilled();
|
||||
Value added = adaptor.getAdded();
|
||||
Value count = adaptor.getCount();
|
||||
// Prepare fields and indices.
|
||||
SmallVector<Value> fields(tuple.getInputs());
|
||||
RankedTensorType dstType = desc.getTensorType();
|
||||
Type eltType = dstType.getElementType();
|
||||
// Prepare indices.
|
||||
SmallVector<Value> indices(adaptor.getIndices());
|
||||
// If the innermost dimension is ordered, we need to sort the indices
|
||||
// in the "added" array prior to applying the compression.
|
||||
@@ -939,19 +882,19 @@ public:
|
||||
// filled[index] = false;
|
||||
// yield new_memrefs
|
||||
// }
|
||||
scf::ForOp loop = createFor(rewriter, loc, count, fields);
|
||||
scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
|
||||
Value i = loop.getInductionVar();
|
||||
Value index = genLoad(rewriter, loc, added, i);
|
||||
Value value = genLoad(rewriter, loc, values, index);
|
||||
indices.push_back(index);
|
||||
// TODO: faster for subsequent insertions?
|
||||
auto insertPoint = op->template getParentOfType<func::FuncOp>();
|
||||
genInsertionCallHelper(rewriter, dstType, fields, indices, value,
|
||||
insertPoint, kInsertFuncNamePrefix, genInsertBody);
|
||||
genInsertionCallHelper(rewriter, desc, indices, value, insertPoint,
|
||||
kInsertFuncNamePrefix, genInsertBody);
|
||||
genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values,
|
||||
index);
|
||||
genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index);
|
||||
rewriter.create<scf::YieldOp>(loc, fields);
|
||||
rewriter.create<scf::YieldOp>(loc, desc.getFields());
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
Value result = genTuple(rewriter, loc, dstType, loop->getResults());
|
||||
// Deallocate the buffers on exit of the full loop nest.
|
||||
@@ -973,20 +916,18 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RankedTensorType dstType =
|
||||
op.getTensor().getType().cast<RankedTensorType>();
|
||||
auto tuple = getTuple(adaptor.getTensor());
|
||||
// Prepare fields and indices.
|
||||
SmallVector<Value> fields(tuple.getInputs());
|
||||
SmallVector<Value> fields;
|
||||
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
|
||||
// Prepare and indices.
|
||||
SmallVector<Value> indices(adaptor.getIndices());
|
||||
// Generate insertion.
|
||||
Value value = adaptor.getValue();
|
||||
auto insertPoint = op->template getParentOfType<func::FuncOp>();
|
||||
genInsertionCallHelper(rewriter, dstType, fields, indices, value,
|
||||
insertPoint, kInsertFuncNamePrefix, genInsertBody);
|
||||
genInsertionCallHelper(rewriter, desc, indices, value, insertPoint,
|
||||
kInsertFuncNamePrefix, genInsertBody);
|
||||
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
|
||||
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1003,11 +944,9 @@ public:
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto tuple = getTuple(adaptor.getTensor());
|
||||
unsigned idx = Base::getIndexForOp(tuple, op);
|
||||
auto fields = tuple.getInputs();
|
||||
assert(idx < fields.size());
|
||||
rewriter.replaceOp(op, fields[idx]);
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
Value field = Base::getFieldForOp(desc, op);
|
||||
rewriter.replaceOp(op, field);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1018,10 +957,10 @@ class SparseToPointersConverter
|
||||
public:
|
||||
using SparseGetterOpConverter::SparseGetterOpConverter;
|
||||
// Callback for SparseGetterOpConverter.
|
||||
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
|
||||
ToPointersOp op) {
|
||||
static Value getFieldForOp(const SparseTensorDescriptor &desc,
|
||||
ToPointersOp op) {
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u);
|
||||
return desc.getPtrMemRef(dim);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1031,10 +970,10 @@ class SparseToIndicesConverter
|
||||
public:
|
||||
using SparseGetterOpConverter::SparseGetterOpConverter;
|
||||
// Callback for SparseGetterOpConverter.
|
||||
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
|
||||
ToIndicesOp op) {
|
||||
static Value getFieldForOp(const SparseTensorDescriptor &desc,
|
||||
ToIndicesOp op) {
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim);
|
||||
return desc.getIdxMemRef(dim);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1044,10 +983,9 @@ class SparseToValuesConverter
|
||||
public:
|
||||
using SparseGetterOpConverter::SparseGetterOpConverter;
|
||||
// Callback for SparseGetterOpConverter.
|
||||
static unsigned getIndexForOp(UnrealizedConversionCastOp tuple,
|
||||
ToValuesOp /*op*/) {
|
||||
// The last field holds the value buffer.
|
||||
return tuple.getInputs().size() - 1;
|
||||
static Value getFieldForOp(const SparseTensorDescriptor &desc,
|
||||
ToValuesOp /*op*/) {
|
||||
return desc.getValMemRef();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1079,12 +1017,11 @@ public:
|
||||
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Query memSizes for the actually stored values size.
|
||||
auto tuple = getTuple(adaptor.getTensor());
|
||||
auto fields = tuple.getInputs();
|
||||
unsigned lastField = fields.size() - 1;
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
Value field =
|
||||
constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[memSizesIdx], field);
|
||||
constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex());
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, desc.getMemSizesMemRef(),
|
||||
field);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user