[mlir][sparse] add create-sparse-deallocs options to match the create-deallocs in BufferizationOption.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D147010
This commit is contained in:
Peiming Liu
2023-03-27 22:56:52 +00:00
parent 1eb9948f02
commit c44d307c55
8 changed files with 95 additions and 30 deletions

View File

@@ -780,6 +780,11 @@ class SparseTensorDeallocConverter
: public OpConversionPattern<bufferization::DeallocTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
SparseTensorDeallocConverter(TypeConverter &typeConverter,
MLIRContext *context, bool createDeallocs)
: OpConversionPattern(typeConverter, context),
createDeallocs(createDeallocs) {}
LogicalResult
matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -787,16 +792,22 @@ public:
if (!enc)
return failure();
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
for (auto input : desc.getMemRefFields())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
// If user requests not to deallocate sparse tensors, simply erase the
// operation.
if (createDeallocs) {
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
for (auto input : desc.getMemRefFields())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
}
rewriter.eraseOp(op);
return success();
}
private:
bool createDeallocs;
};
/// Sparse codegen rule for tensor rematerialization.
@@ -1492,13 +1503,12 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
/// the sparsification of linear algebra operations.
void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool enableBufferInitialization) {
bool createSparseDeallocs, bool enableBufferInitialization) {
patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
SparseCastConverter, SparseTensorDeallocConverter,
SparseExtractSliceConverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseCompressConverter,
SparseInsertConverter,
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
@@ -1508,6 +1518,8 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseConvertConverter, SparseNewOpConverter,
SparseNumberOfEntriesConverter>(typeConverter,
patterns.getContext());
patterns.add<SparseTensorDeallocConverter>(
typeConverter, patterns.getContext(), createSparseDeallocs);
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
}