[mlir][sparse] codegen for trivial tensor cast

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133176
This commit is contained in:
Aart Bik
2022-09-01 18:44:48 -07:00
parent 1b726f0a4c
commit f27b806df5
2 changed files with 38 additions and 31 deletions

View File

@@ -33,16 +33,6 @@ namespace {
// Helper methods.
//===----------------------------------------------------------------------===//
/// Reorders stored dimension to original dimension.
static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) {
auto order = enc.getDimOrdering();
if (order) {
assert(order.isPermutation());
return order.getDimPosition(i);
}
return i;
}
/// Reorders original dimension to stored dimension.
static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) {
auto order = enc.getDimOrdering();
@@ -67,7 +57,6 @@ static Optional<Type> convertSparseTensorType(Type type) {
Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
Type eltType = rType.getElementType();
ArrayRef<int64_t> shape = rType.getShape();
//
// Sparse tensor storage for rank-dimensional tensor is organized as a
// single compound type with the following fields:
@@ -85,27 +74,18 @@ static Optional<Type> convertSparseTensorType(Type type) {
// memref<? x eltType> values ; values
// };
//
int64_t linear = 1;
bool allDense = true;
unsigned rank = rType.getShape().size();
SmallVector<Type, 8> fields;
// The dimSizes array.
fields.push_back(MemRefType::get({rank}, indexType));
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
// Get the original dimension (ro) for the current stored dimension (r).
unsigned ro = toOrig(enc, r);
// Dimension level types apply in order to the reordered dimension.
// As a result, the compound type can be constructed directly in the given
// order. Clients of this type know what field is what from the sparse
// tensor type.
switch (enc.getDimLevelType()[r]) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
// Linearize the size of consecutive dense dimensions.
if (ShapedType::isDynamic(shape[ro]) || ShapedType::isDynamic(linear))
linear = ShapedType::kDynamicSize;
else
linear *= shape[ro];
break;
case SparseTensorEncodingAttr::DimLevelType::Compressed:
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
@@ -113,23 +93,17 @@ static Optional<Type> convertSparseTensorType(Type type) {
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
allDense = false;
linear = 1;
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
allDense = false;
linear = 1;
break;
}
}
// The values array.
int64_t nnz =
(rType.hasStaticShape() && allDense) ? linear : ShapedType::kDynamicSize;
fields.push_back(MemRefType::get({nnz}, eltType));
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
// Sparse tensor storage (temporarily) lives in a tuple. This allows a
// simple 1:1 type conversion during codegen. A subsequent pass uses
// a 1:N type conversion to expand the tuple into its fields.
@@ -241,6 +215,23 @@ public:
}
};
/// Sparse codegen rule for trivial tensor casts.
class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only rewrite identically annotated source/dest.
auto encDst = getSparseTensorEncoding(op.getType());
auto encSrc = getSparseTensorEncoding(op.getSource().getType());
if (!encDst || encDst != encSrc)
return failure();
rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
};
/// Sparse conversion rule for pointer accesses.
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
public:
@@ -314,7 +305,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
/// the sparsification of linear algebra operations.
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseDimOpConverter,
patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
SparseToPointersConverter, SparseToIndicesConverter,
SparseToValuesConverter>(typeConverter, patterns.getContext());
}