[mlir][sparse] Replace sparse_tensor.sort with sparse_tensor.sort_coo for sorting COO tensors.

Add codegen pattern for sparse_tensor.indices_buffer.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140871
This commit is contained in:
bixia1
2023-01-05 09:39:23 -08:00
parent 47232bea9e
commit 81e3079d0f
5 changed files with 88 additions and 33 deletions

View File

@@ -937,6 +937,26 @@ public:
}
};
/// Sparse codegen rule for accessing the linear indices buffer.
class SparseToIndicesBufferConverter
: public OpConversionPattern<ToIndicesBufferOp> {
public:
using OpAdaptor = typename ToIndicesBufferOp::Adaptor;
using OpConversionPattern<ToIndicesBufferOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToIndicesBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
rewriter.replaceOp(op, desc.getAOSMemRef());
return success();
}
};
/// Sparse codegen rule for value accesses.
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
@@ -1005,9 +1025,9 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
SparseToPointersConverter, SparseToIndicesConverter,
SparseToValuesConverter, SparseConvertConverter,
SparseNumberOfEntriesConverter>(typeConverter,
patterns.getContext());
SparseToIndicesBufferConverter, SparseToValuesConverter,
SparseConvertConverter, SparseNumberOfEntriesConverter>(
typeConverter, patterns.getContext());
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
}