[mlir][sparse] added codegen for dimop, pointers, indices, values

Demonstrates how sparse tensor type -> tuple -> getter
will eventually yield actual code on the memrefs directly

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D133143
This commit is contained in:
Aart Bik
2022-09-01 12:34:58 -07:00
parent 54c47ff939
commit 3ae98fd259
4 changed files with 208 additions and 36 deletions

View File

@@ -33,14 +33,24 @@ namespace {
// Helper methods.
//===----------------------------------------------------------------------===//
/// Reorders stored dimension to logical dimension.
static unsigned reorder(const SparseTensorEncodingAttr &enc, unsigned d) {
/// Reorders stored dimension to original dimension.
static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) {
auto order = enc.getDimOrdering();
if (order) {
assert(order.isPermutation());
return order.getDimPosition(d);
return order.getDimPosition(i);
}
return d;
return i;
}
/// Reorders original dimension to stored dimension.
static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) {
auto order = enc.getDimOrdering();
if (order) {
assert(order.isPermutation());
return order.getPermutedPosition(i);
}
return i;
}
/// Maps a sparse tensor type to the appropriate compounded buffers.
@@ -63,14 +73,13 @@ static Optional<Type> convertSparseTensorType(Type type) {
// single compound type with the following fields:
//
// struct {
// ; if dynamic shape:
// memref<rank x index> dimSize ; size in each dimension
// memref<rank x index> dimSizes ; size in each dimension
// ; per-dimension d:
// ; if dense:
// <nothing>
// ; if compresed:
// memref<? x idx> indices-d ; indices for sparse dim d
// memref<? x ptr> pointers-d ; pointers for sparse dim d
// memref<? x idx> indices-d ; indices for sparse dim d
// ; if singleton:
// memref<? x idx> indices-d ; indices for singleton dim d
// memref<? x eltType> values ; values
@@ -81,12 +90,11 @@ static Optional<Type> convertSparseTensorType(Type type) {
unsigned rank = rType.getShape().size();
SmallVector<Type, 8> fields;
// The dimSizes array.
if (!rType.hasStaticShape())
fields.push_back(MemRefType::get({rank}, indexType));
fields.push_back(MemRefType::get({rank}, indexType));
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
// Get the original dimension (ro) for the current stored dimension (r).
unsigned ro = reorder(enc, r);
unsigned ro = toOrig(enc, r);
// Dimension level types apply in order to the reordered dimension.
// As a result, the compound type can be constructed directly in the given
// order. Clients of this type know what field is what from the sparse
@@ -103,8 +111,8 @@ static Optional<Type> convertSparseTensorType(Type type) {
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
allDense = false;
linear = 1;
break;
@@ -128,6 +136,63 @@ static Optional<Type> convertSparseTensorType(Type type) {
return TupleType::get(context, fields);
}
// Returns field index for pointers (d), indices (d) for set field.
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
auto enc = getSparseTensorEncoding(type);
assert(enc);
RankedTensorType rType = type.cast<RankedTensorType>();
unsigned field = 1; // start at DimSizes;
unsigned ptr = 0;
unsigned idx = 0;
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
switch (enc.getDimLevelType()[r]) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
break; // no fields
case SparseTensorEncodingAttr::DimLevelType::Compressed:
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
if (ptr++ == ptrDim)
return field;
field++;
if (idx++ == idxDim)
return field;
field++;
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
if (idx++ == idxDim)
return field;
field++;
break;
}
}
llvm_unreachable("failed to find ptr/idx field index");
return -1;
}
/// Returns field type in tuple at given index.
static Type getFieldType(Value tuple, unsigned field) {
return tuple.getType().cast<TupleType>().getType(field);
}
/// Creates tuple get operation at given index.
static Value createTupleGet(OpBuilder &builder, Location loc, Value tuple,
unsigned field) {
Type indexType = builder.getIndexType();
return builder.create<StorageGetOp>(loc, getFieldType(tuple, field), tuple,
builder.getIntegerAttr(indexType, field));
}
/// Returns integral constant, if defined.
static Optional<int64_t> getConstantInt(Value val) {
if (auto constantOp = val.getDefiningOp<arith::ConstantOp>())
return constantOp.getValue().cast<IntegerAttr>().getInt();
return {};
}
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -151,26 +216,82 @@ public:
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Type type = op.getSource().getType();
// Only rewrite annotated DimOp with constant index.
auto enc = getSparseTensorEncoding(type);
auto enc = getSparseTensorEncoding(op.getSource().getType());
if (!enc)
return failure();
Optional<int64_t> index = op.getConstantIndex();
Optional<int64_t> index = getConstantInt(adaptor.getIndex());
if (!index)
return failure();
// Access into static shape can query original type directly.
// Access into static dimension can query original type directly.
// Note that this is typically already done by DimOp's folding.
RankedTensorType rType = type.cast<RankedTensorType>();
if (rType.hasStaticShape()) {
rewriter.replaceOp(
op, constantIndex(rewriter, loc, rType.getShape()[*index]));
Location loc = op->getLoc();
auto shape = op.getSource().getType().cast<RankedTensorType>().getShape();
if (!ShapedType::isDynamic(shape[*index])) {
rewriter.replaceOp(op, constantIndex(rewriter, loc, shape[*index]));
return success();
}
// Any other query can consult the dimSize array.
// TODO: this needs tuple access
return failure();
// Any other query can consult the dimSizes array at field 0 using,
// accounting for the reordering applied to the sparse storage.
Value tuple = adaptor.getSource();
Value dimSizes = createTupleGet(rewriter, loc, tuple, 0);
rewriter.replaceOpWithNewOp<memref::LoadOp>(
op, dimSizes, constantIndex(rewriter, loc, toStored(enc, *index)));
return success();
}
};
/// Sparse conversion rule for pointer accesses.
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
if (!index)
return failure();
// Replace the requested pointer access with corresponding field.
Location loc = op->getLoc();
Value tuple = adaptor.getTensor();
unsigned i = getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*index, -1);
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
return success();
}
};
/// Sparse conversion rule for index accesses.
class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
if (!index)
return failure();
// Replace the requested indices access with corresponding field.
Location loc = op->getLoc();
Value tuple = adaptor.getTensor();
unsigned i = getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*index);
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
return success();
}
};
/// Sparse conversion rule for value accesses.
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested values access with corresponding field.
Location loc = op->getLoc();
Value tuple = adaptor.getTensor();
unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
return success();
}
};
@@ -193,6 +314,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
/// the sparsification of linear algebra operations.
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseDimOpConverter>(
typeConverter, patterns.getContext());
patterns.add<SparseReturnConverter, SparseDimOpConverter,
SparseToPointersConverter, SparseToIndicesConverter,
SparseToValuesConverter>(typeConverter, patterns.getContext());
}