[mlir][sparse] add memSizes array to sparse storage format
Rationale: For every dynamic memref (memref<?xtype>), the stored size really indicates the capacity and the entry in the memSizes indicates the actual size. This allows us to use memref's as "vectors". Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D133724
This commit is contained in:
@@ -101,76 +101,12 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
|
||||
.getResult();
|
||||
}
|
||||
|
||||
/// Maps a sparse tensor type to the appropriate compounded buffers.
|
||||
static Optional<LogicalResult>
|
||||
convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
|
||||
auto enc = getSparseTensorEncoding(type);
|
||||
if (!enc)
|
||||
return llvm::None;
|
||||
// Construct the basic types.
|
||||
auto context = type.getContext();
|
||||
unsigned idxWidth = enc.getIndexBitWidth();
|
||||
unsigned ptrWidth = enc.getPointerBitWidth();
|
||||
RankedTensorType rType = type.cast<RankedTensorType>();
|
||||
Type indexType = IndexType::get(context);
|
||||
Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
|
||||
Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
|
||||
Type eltType = rType.getElementType();
|
||||
//
|
||||
// Sparse tensor storage for rank-dimensional tensor is organized as a
|
||||
// single compound type with the following fields:
|
||||
//
|
||||
// struct {
|
||||
// memref<rank x index> dimSizes ; size in each dimension
|
||||
// ; per-dimension d:
|
||||
// ; if dense:
|
||||
// <nothing>
|
||||
// ; if compresed:
|
||||
// memref<? x ptr> pointers-d ; pointers for sparse dim d
|
||||
// memref<? x idx> indices-d ; indices for sparse dim d
|
||||
// ; if singleton:
|
||||
// memref<? x idx> indices-d ; indices for singleton dim d
|
||||
// memref<? x eltType> values ; values
|
||||
// };
|
||||
//
|
||||
unsigned rank = rType.getShape().size();
|
||||
// The dimSizes array.
|
||||
fields.push_back(MemRefType::get({rank}, indexType));
|
||||
// Per-dimension storage.
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
// Dimension level types apply in order to the reordered dimension.
|
||||
// As a result, the compound type can be constructed directly in the given
|
||||
// order. Clients of this type know what field is what from the sparse
|
||||
// tensor type.
|
||||
switch (enc.getDimLevelType()[r]) {
|
||||
case SparseTensorEncodingAttr::DimLevelType::Dense:
|
||||
break; // no fields
|
||||
case SparseTensorEncodingAttr::DimLevelType::Compressed:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
|
||||
break;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Singleton:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The values array.
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Returns field index of sparse tensor type for pointers/indices, when set.
|
||||
/// Returns field index of sparse tensor type for pointers/indices, when set.
|
||||
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
|
||||
auto enc = getSparseTensorEncoding(type);
|
||||
assert(enc);
|
||||
RankedTensorType rType = type.cast<RankedTensorType>();
|
||||
unsigned field = 1; // start at DimSizes;
|
||||
unsigned field = 2; // start past sizes
|
||||
unsigned ptr = 0;
|
||||
unsigned idx = 0;
|
||||
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
|
||||
@@ -198,8 +134,78 @@ static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
llvm_unreachable("failed to find ptr/idx field index");
|
||||
return -1;
|
||||
return field + 1; // return values field index
|
||||
}
|
||||
|
||||
/// Maps a sparse tensor type to the appropriate compounded buffers.
|
||||
static Optional<LogicalResult>
|
||||
convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
|
||||
auto enc = getSparseTensorEncoding(type);
|
||||
if (!enc)
|
||||
return llvm::None;
|
||||
// Construct the basic types.
|
||||
auto context = type.getContext();
|
||||
unsigned idxWidth = enc.getIndexBitWidth();
|
||||
unsigned ptrWidth = enc.getPointerBitWidth();
|
||||
RankedTensorType rType = type.cast<RankedTensorType>();
|
||||
Type indexType = IndexType::get(context);
|
||||
Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
|
||||
Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
|
||||
Type eltType = rType.getElementType();
|
||||
//
|
||||
// Sparse tensor storage for rank-dimensional tensor is organized as a
|
||||
// single compound type with the following fields. Note that every
|
||||
// memref with ? size actually behaves as a "vector", i.e. the stored
|
||||
// size is the capacity and the used size resides in the memSizes array.
|
||||
//
|
||||
// struct {
|
||||
// memref<rank x index> dimSizes ; size in each dimension
|
||||
// memref<n x index> memSizes ; sizes of ptrs/inds/values
|
||||
// ; per-dimension d:
|
||||
// ; if dense:
|
||||
// <nothing>
|
||||
// ; if compresed:
|
||||
// memref<? x ptr> pointers-d ; pointers for sparse dim d
|
||||
// memref<? x idx> indices-d ; indices for sparse dim d
|
||||
// ; if singleton:
|
||||
// memref<? x idx> indices-d ; indices for singleton dim d
|
||||
// memref<? x eltType> values ; values
|
||||
// };
|
||||
//
|
||||
unsigned rank = rType.getShape().size();
|
||||
// The dimSizes array.
|
||||
fields.push_back(MemRefType::get({rank}, indexType));
|
||||
// The memSizes array.
|
||||
unsigned lastField = getFieldIndex(type, -1, -1);
|
||||
fields.push_back(MemRefType::get({lastField - 2}, indexType));
|
||||
// Per-dimension storage.
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
// Dimension level types apply in order to the reordered dimension.
|
||||
// As a result, the compound type can be constructed directly in the given
|
||||
// order. Clients of this type know what field is what from the sparse
|
||||
// tensor type.
|
||||
switch (enc.getDimLevelType()[r]) {
|
||||
case SparseTensorEncodingAttr::DimLevelType::Dense:
|
||||
break; // no fields
|
||||
case SparseTensorEncodingAttr::DimLevelType::Compressed:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
|
||||
break;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Singleton:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The values array.
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
|
||||
assert(fields.size() == lastField);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Create allocation operation.
|
||||
@@ -209,11 +215,12 @@ static Value createAllocation(OpBuilder &builder, Location loc, Type type,
|
||||
return builder.create<memref::AllocOp>(loc, memType, sz);
|
||||
}
|
||||
|
||||
/// Creates allocation for each field in sparse tensor type.
|
||||
/// Creates allocation for each field in sparse tensor type. Note that
|
||||
/// for all dynamic memrefs, the memory size is really the capacity of
|
||||
/// the "vector", while the actual size resides in the sizes array.
|
||||
///
|
||||
/// TODO: for efficiency, we will need heuristis to make educated guesses
|
||||
/// on the required final sizes; also, we will need an improved
|
||||
/// memory allocation scheme with capacity and reallocation
|
||||
/// on the required capacities
|
||||
///
|
||||
static void createAllocFields(OpBuilder &builder, Location loc, Type type,
|
||||
ValueRange dynSizes,
|
||||
@@ -246,6 +253,11 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
|
||||
Value dimSizes =
|
||||
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
|
||||
fields.push_back(dimSizes);
|
||||
// The sizes array.
|
||||
unsigned lastField = getFieldIndex(type, -1, -1);
|
||||
Value memSizes = builder.create<memref::AllocOp>(
|
||||
loc, MemRefType::get({lastField - 2}, indexType));
|
||||
fields.push_back(memSizes);
|
||||
// Per-dimension storage.
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
// Get the original dimension (ro) for the current stored dimension.
|
||||
@@ -278,6 +290,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
|
||||
// In all other case, we resort to the heuristical initial value.
|
||||
Value valuesSz = allDense ? linear : heuristic;
|
||||
fields.push_back(createAllocation(builder, loc, eltType, valuesSz));
|
||||
// Set memSizes.
|
||||
if (allDense)
|
||||
builder.create<memref::StoreOp>(
|
||||
loc, valuesSz, memSizes,
|
||||
constantIndex(builder, loc, 0)); // TODO: avoid memSizes in this case?
|
||||
else
|
||||
builder.create<linalg::FillOp>(
|
||||
loc, ValueRange{constantZero(builder, loc, indexType)},
|
||||
ValueRange{memSizes});
|
||||
assert(fields.size() == lastField);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -467,28 +489,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Base class for getter-like operations, e.g., to_indices, to_pointers.
|
||||
template <typename SourceOp, typename Base>
|
||||
class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
using OpConversionPattern<SourceOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto tuple = llvm::cast<UnrealizedConversionCastOp>(
|
||||
adaptor.getTensor().getDefiningOp());
|
||||
unsigned idx = Base::getIndexForOp(tuple, op);
|
||||
auto fields = tuple.getInputs();
|
||||
assert(idx < fields.size());
|
||||
rewriter.replaceOp(op, fields[idx]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for the expand op.
|
||||
class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
|
||||
public:
|
||||
@@ -543,6 +543,28 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Base class for getter-like operations, e.g., to_indices, to_pointers.
|
||||
template <typename SourceOp, typename Base>
|
||||
class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
using OpConversionPattern<SourceOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto tuple = llvm::cast<UnrealizedConversionCastOp>(
|
||||
adaptor.getTensor().getDefiningOp());
|
||||
unsigned idx = Base::getIndexForOp(tuple, op);
|
||||
auto fields = tuple.getInputs();
|
||||
assert(idx < fields.size());
|
||||
rewriter.replaceOp(op, fields[idx]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for pointer accesses.
|
||||
class SparseToPointersConverter
|
||||
: public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
|
||||
@@ -602,9 +624,9 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
|
||||
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
|
||||
SparseCastConverter, SparseExpandConverter,
|
||||
SparseTensorAllocConverter, SparseTensorDeallocConverter,
|
||||
SparseToPointersConverter, SparseToIndicesConverter,
|
||||
SparseToValuesConverter, SparseTensorLoadConverter>(
|
||||
SparseCastConverter, SparseTensorAllocConverter,
|
||||
SparseTensorDeallocConverter, SparseTensorLoadConverter,
|
||||
SparseExpandConverter, SparseToPointersConverter,
|
||||
SparseToIndicesConverter, SparseToValuesConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user