[mlir][sparse] add memSizes array to sparse storage format

Rationale:
For every dynamic memref (memref<?xtype>), the stored size really
indicates the capacity and the entry in the memSizes indicates
the actual size. This allows us to use memref's as "vectors".

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D133724
This commit is contained in:
Aart Bik
2022-09-12 12:53:07 -07:00
parent 4763200ec9
commit 6607fdf749
2 changed files with 267 additions and 234 deletions

View File

@@ -101,76 +101,12 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
.getResult();
}
/// Maps a sparse tensor type to the appropriate compounded buffers.
static Optional<LogicalResult>
convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
auto enc = getSparseTensorEncoding(type);
if (!enc)
return llvm::None;
// Construct the basic types.
auto context = type.getContext();
unsigned idxWidth = enc.getIndexBitWidth();
unsigned ptrWidth = enc.getPointerBitWidth();
RankedTensorType rType = type.cast<RankedTensorType>();
Type indexType = IndexType::get(context);
Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
Type eltType = rType.getElementType();
//
// Sparse tensor storage for rank-dimensional tensor is organized as a
// single compound type with the following fields:
//
// struct {
// memref<rank x index> dimSizes ; size in each dimension
// ; per-dimension d:
// ; if dense:
// <nothing>
// ; if compresed:
// memref<? x ptr> pointers-d ; pointers for sparse dim d
// memref<? x idx> indices-d ; indices for sparse dim d
// ; if singleton:
// memref<? x idx> indices-d ; indices for singleton dim d
// memref<? x eltType> values ; values
// };
//
unsigned rank = rType.getShape().size();
// The dimSizes array.
fields.push_back(MemRefType::get({rank}, indexType));
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
// Dimension level types apply in order to the reordered dimension.
// As a result, the compound type can be constructed directly in the given
// order. Clients of this type know what field is what from the sparse
// tensor type.
switch (enc.getDimLevelType()[r]) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
break; // no fields
case SparseTensorEncodingAttr::DimLevelType::Compressed:
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
break;
}
}
// The values array.
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
return success();
}
// Returns field index of sparse tensor type for pointers/indices, when set.
/// Returns field index of sparse tensor type for pointers/indices, when set.
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
auto enc = getSparseTensorEncoding(type);
assert(enc);
RankedTensorType rType = type.cast<RankedTensorType>();
unsigned field = 1; // start at DimSizes;
unsigned field = 2; // start past sizes
unsigned ptr = 0;
unsigned idx = 0;
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
@@ -198,8 +134,78 @@ static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
break;
}
}
llvm_unreachable("failed to find ptr/idx field index");
return -1;
return field + 1; // return values field index
}
/// Maps a sparse tensor type to the appropriate compounded buffers.
static Optional<LogicalResult>
convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
auto enc = getSparseTensorEncoding(type);
if (!enc)
return llvm::None;
// Construct the basic types.
auto context = type.getContext();
unsigned idxWidth = enc.getIndexBitWidth();
unsigned ptrWidth = enc.getPointerBitWidth();
RankedTensorType rType = type.cast<RankedTensorType>();
Type indexType = IndexType::get(context);
Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
Type eltType = rType.getElementType();
//
// Sparse tensor storage for rank-dimensional tensor is organized as a
// single compound type with the following fields. Note that every
// memref with ? size actually behaves as a "vector", i.e. the stored
// size is the capacity and the used size resides in the memSizes array.
//
// struct {
// memref<rank x index> dimSizes ; size in each dimension
// memref<n x index> memSizes ; sizes of ptrs/inds/values
// ; per-dimension d:
// ; if dense:
// <nothing>
// ; if compresed:
// memref<? x ptr> pointers-d ; pointers for sparse dim d
// memref<? x idx> indices-d ; indices for sparse dim d
// ; if singleton:
// memref<? x idx> indices-d ; indices for singleton dim d
// memref<? x eltType> values ; values
// };
//
unsigned rank = rType.getShape().size();
// The dimSizes array.
fields.push_back(MemRefType::get({rank}, indexType));
// The memSizes array.
unsigned lastField = getFieldIndex(type, -1, -1);
fields.push_back(MemRefType::get({lastField - 2}, indexType));
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
// Dimension level types apply in order to the reordered dimension.
// As a result, the compound type can be constructed directly in the given
// order. Clients of this type know what field is what from the sparse
// tensor type.
switch (enc.getDimLevelType()[r]) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
break; // no fields
case SparseTensorEncodingAttr::DimLevelType::Compressed:
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
break;
}
}
// The values array.
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
assert(fields.size() == lastField);
return success();
}
/// Create allocation operation.
@@ -209,11 +215,12 @@ static Value createAllocation(OpBuilder &builder, Location loc, Type type,
return builder.create<memref::AllocOp>(loc, memType, sz);
}
/// Creates allocation for each field in sparse tensor type.
/// Creates allocation for each field in sparse tensor type. Note that
/// for all dynamic memrefs, the memory size is really the capacity of
/// the "vector", while the actual size resides in the sizes array.
///
/// TODO: for efficiency, we will need heuristis to make educated guesses
/// on the required final sizes; also, we will need an improved
/// memory allocation scheme with capacity and reallocation
/// on the required capacities
///
static void createAllocFields(OpBuilder &builder, Location loc, Type type,
ValueRange dynSizes,
@@ -246,6 +253,11 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
Value dimSizes =
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
fields.push_back(dimSizes);
// The sizes array.
unsigned lastField = getFieldIndex(type, -1, -1);
Value memSizes = builder.create<memref::AllocOp>(
loc, MemRefType::get({lastField - 2}, indexType));
fields.push_back(memSizes);
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
// Get the original dimension (ro) for the current stored dimension.
@@ -278,6 +290,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
// In all other case, we resort to the heuristical initial value.
Value valuesSz = allDense ? linear : heuristic;
fields.push_back(createAllocation(builder, loc, eltType, valuesSz));
// Set memSizes.
if (allDense)
builder.create<memref::StoreOp>(
loc, valuesSz, memSizes,
constantIndex(builder, loc, 0)); // TODO: avoid memSizes in this case?
else
builder.create<linalg::FillOp>(
loc, ValueRange{constantZero(builder, loc, indexType)},
ValueRange{memSizes});
assert(fields.size() == lastField);
}
//===----------------------------------------------------------------------===//
@@ -467,28 +489,6 @@ public:
}
};
/// Base class for getter-like operations, e.g., to_indices, to_pointers.
template <typename SourceOp, typename Base>
class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto tuple = llvm::cast<UnrealizedConversionCastOp>(
adaptor.getTensor().getDefiningOp());
unsigned idx = Base::getIndexForOp(tuple, op);
auto fields = tuple.getInputs();
assert(idx < fields.size());
rewriter.replaceOp(op, fields[idx]);
return success();
}
};
/// Sparse codegen rule for the expand op.
class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
public:
@@ -543,6 +543,28 @@ public:
}
};
/// Base class for getter-like operations, e.g., to_indices, to_pointers.
template <typename SourceOp, typename Base>
class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto tuple = llvm::cast<UnrealizedConversionCastOp>(
adaptor.getTensor().getDefiningOp());
unsigned idx = Base::getIndexForOp(tuple, op);
auto fields = tuple.getInputs();
assert(idx < fields.size());
rewriter.replaceOp(op, fields[idx]);
return success();
}
};
/// Sparse codegen rule for pointer accesses.
class SparseToPointersConverter
: public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
@@ -602,9 +624,9 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
SparseCastConverter, SparseExpandConverter,
SparseTensorAllocConverter, SparseTensorDeallocConverter,
SparseToPointersConverter, SparseToIndicesConverter,
SparseToValuesConverter, SparseTensorLoadConverter>(
SparseCastConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseToPointersConverter,
SparseToIndicesConverter, SparseToValuesConverter>(
typeConverter, patterns.getContext());
}