[mlir][sparse] added codegen for dimop, pointers, indices, values
Demonstrates how sparse tensor type -> tuple -> getter will eventually yield actual code on the memrefs directly Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D133143
This commit is contained in:
@@ -33,14 +33,24 @@ namespace {
|
||||
// Helper methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Reorders stored dimension to logical dimension.
|
||||
static unsigned reorder(const SparseTensorEncodingAttr &enc, unsigned d) {
|
||||
/// Reorders stored dimension to original dimension.
|
||||
static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) {
|
||||
auto order = enc.getDimOrdering();
|
||||
if (order) {
|
||||
assert(order.isPermutation());
|
||||
return order.getDimPosition(d);
|
||||
return order.getDimPosition(i);
|
||||
}
|
||||
return d;
|
||||
return i;
|
||||
}
|
||||
|
||||
/// Reorders original dimension to stored dimension.
|
||||
static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) {
|
||||
auto order = enc.getDimOrdering();
|
||||
if (order) {
|
||||
assert(order.isPermutation());
|
||||
return order.getPermutedPosition(i);
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
/// Maps a sparse tensor type to the appropriate compounded buffers.
|
||||
@@ -63,14 +73,13 @@ static Optional<Type> convertSparseTensorType(Type type) {
|
||||
// single compound type with the following fields:
|
||||
//
|
||||
// struct {
|
||||
// ; if dynamic shape:
|
||||
// memref<rank x index> dimSize ; size in each dimension
|
||||
// memref<rank x index> dimSizes ; size in each dimension
|
||||
// ; per-dimension d:
|
||||
// ; if dense:
|
||||
// <nothing>
|
||||
// ; if compresed:
|
||||
// memref<? x idx> indices-d ; indices for sparse dim d
|
||||
// memref<? x ptr> pointers-d ; pointers for sparse dim d
|
||||
// memref<? x idx> indices-d ; indices for sparse dim d
|
||||
// ; if singleton:
|
||||
// memref<? x idx> indices-d ; indices for singleton dim d
|
||||
// memref<? x eltType> values ; values
|
||||
@@ -81,12 +90,11 @@ static Optional<Type> convertSparseTensorType(Type type) {
|
||||
unsigned rank = rType.getShape().size();
|
||||
SmallVector<Type, 8> fields;
|
||||
// The dimSizes array.
|
||||
if (!rType.hasStaticShape())
|
||||
fields.push_back(MemRefType::get({rank}, indexType));
|
||||
fields.push_back(MemRefType::get({rank}, indexType));
|
||||
// Per-dimension storage.
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
// Get the original dimension (ro) for the current stored dimension (r).
|
||||
unsigned ro = reorder(enc, r);
|
||||
unsigned ro = toOrig(enc, r);
|
||||
// Dimension level types apply in order to the reordered dimension.
|
||||
// As a result, the compound type can be constructed directly in the given
|
||||
// order. Clients of this type know what field is what from the sparse
|
||||
@@ -103,8 +111,8 @@ static Optional<Type> convertSparseTensorType(Type type) {
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
|
||||
allDense = false;
|
||||
linear = 1;
|
||||
break;
|
||||
@@ -128,6 +136,63 @@ static Optional<Type> convertSparseTensorType(Type type) {
|
||||
return TupleType::get(context, fields);
|
||||
}
|
||||
|
||||
// Returns field index for pointers (d), indices (d) for set field.
|
||||
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
|
||||
auto enc = getSparseTensorEncoding(type);
|
||||
assert(enc);
|
||||
RankedTensorType rType = type.cast<RankedTensorType>();
|
||||
unsigned field = 1; // start at DimSizes;
|
||||
unsigned ptr = 0;
|
||||
unsigned idx = 0;
|
||||
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
|
||||
switch (enc.getDimLevelType()[r]) {
|
||||
case SparseTensorEncodingAttr::DimLevelType::Dense:
|
||||
break; // no fields
|
||||
case SparseTensorEncodingAttr::DimLevelType::Compressed:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
|
||||
if (ptr++ == ptrDim)
|
||||
return field;
|
||||
field++;
|
||||
if (idx++ == idxDim)
|
||||
return field;
|
||||
field++;
|
||||
break;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Singleton:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
|
||||
if (idx++ == idxDim)
|
||||
return field;
|
||||
field++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
llvm_unreachable("failed to find ptr/idx field index");
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Returns field type in tuple at given index.
|
||||
static Type getFieldType(Value tuple, unsigned field) {
|
||||
return tuple.getType().cast<TupleType>().getType(field);
|
||||
}
|
||||
|
||||
/// Creates tuple get operation at given index.
|
||||
static Value createTupleGet(OpBuilder &builder, Location loc, Value tuple,
|
||||
unsigned field) {
|
||||
Type indexType = builder.getIndexType();
|
||||
return builder.create<StorageGetOp>(loc, getFieldType(tuple, field), tuple,
|
||||
builder.getIntegerAttr(indexType, field));
|
||||
}
|
||||
|
||||
/// Returns integral constant, if defined.
|
||||
static Optional<int64_t> getConstantInt(Value val) {
|
||||
if (auto constantOp = val.getDefiningOp<arith::ConstantOp>())
|
||||
return constantOp.getValue().cast<IntegerAttr>().getInt();
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Codegen rules.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -151,26 +216,82 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Type type = op.getSource().getType();
|
||||
// Only rewrite annotated DimOp with constant index.
|
||||
auto enc = getSparseTensorEncoding(type);
|
||||
auto enc = getSparseTensorEncoding(op.getSource().getType());
|
||||
if (!enc)
|
||||
return failure();
|
||||
Optional<int64_t> index = op.getConstantIndex();
|
||||
Optional<int64_t> index = getConstantInt(adaptor.getIndex());
|
||||
if (!index)
|
||||
return failure();
|
||||
// Access into static shape can query original type directly.
|
||||
// Access into static dimension can query original type directly.
|
||||
// Note that this is typically already done by DimOp's folding.
|
||||
RankedTensorType rType = type.cast<RankedTensorType>();
|
||||
if (rType.hasStaticShape()) {
|
||||
rewriter.replaceOp(
|
||||
op, constantIndex(rewriter, loc, rType.getShape()[*index]));
|
||||
Location loc = op->getLoc();
|
||||
auto shape = op.getSource().getType().cast<RankedTensorType>().getShape();
|
||||
if (!ShapedType::isDynamic(shape[*index])) {
|
||||
rewriter.replaceOp(op, constantIndex(rewriter, loc, shape[*index]));
|
||||
return success();
|
||||
}
|
||||
// Any other query can consult the dimSize array.
|
||||
// TODO: this needs tuple access
|
||||
return failure();
|
||||
// Any other query can consult the dimSizes array at field 0 using,
|
||||
// accounting for the reordering applied to the sparse storage.
|
||||
Value tuple = adaptor.getSource();
|
||||
Value dimSizes = createTupleGet(rewriter, loc, tuple, 0);
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(
|
||||
op, dimSizes, constantIndex(rewriter, loc, toStored(enc, *index)));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for pointer accesses.
|
||||
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
|
||||
if (!index)
|
||||
return failure();
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
Location loc = op->getLoc();
|
||||
Value tuple = adaptor.getTensor();
|
||||
unsigned i = getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*index, -1);
|
||||
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for index accesses.
|
||||
class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
|
||||
if (!index)
|
||||
return failure();
|
||||
// Replace the requested indices access with corresponding field.
|
||||
Location loc = op->getLoc();
|
||||
Value tuple = adaptor.getTensor();
|
||||
unsigned i = getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*index);
|
||||
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for value accesses.
|
||||
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Replace the requested values access with corresponding field.
|
||||
Location loc = op->getLoc();
|
||||
Value tuple = adaptor.getTensor();
|
||||
unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
|
||||
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -193,6 +314,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
|
||||
/// the sparsification of linear algebra operations.
|
||||
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<SparseReturnConverter, SparseDimOpConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
patterns.add<SparseReturnConverter, SparseDimOpConverter,
|
||||
SparseToPointersConverter, SparseToIndicesConverter,
|
||||
SparseToValuesConverter>(typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user