[mlir][sparse] add create-sparse-deallocs options to match the create-deallocs in BufferizationOption.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D147010
This commit is contained in:
@@ -780,6 +780,11 @@ class SparseTensorDeallocConverter
|
||||
: public OpConversionPattern<bufferization::DeallocTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
SparseTensorDeallocConverter(TypeConverter &typeConverter,
|
||||
MLIRContext *context, bool createDeallocs)
|
||||
: OpConversionPattern(typeConverter, context),
|
||||
createDeallocs(createDeallocs) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
@@ -787,16 +792,22 @@ public:
|
||||
if (!enc)
|
||||
return failure();
|
||||
|
||||
// Replace the sparse tensor deallocation with field deallocations.
|
||||
Location loc = op.getLoc();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
for (auto input : desc.getMemRefFields())
|
||||
// Deallocate every buffer used to store the sparse tensor handler.
|
||||
rewriter.create<memref::DeallocOp>(loc, input);
|
||||
|
||||
// If user requests not to deallocate sparse tensors, simply erase the
|
||||
// operation.
|
||||
if (createDeallocs) {
|
||||
// Replace the sparse tensor deallocation with field deallocations.
|
||||
Location loc = op.getLoc();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
for (auto input : desc.getMemRefFields())
|
||||
// Deallocate every buffer used to store the sparse tensor handler.
|
||||
rewriter.create<memref::DeallocOp>(loc, input);
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
bool createDeallocs;
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for tensor rematerialization.
|
||||
@@ -1492,13 +1503,12 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
/// the sparsification of linear algebra operations.
|
||||
void mlir::populateSparseTensorCodegenPatterns(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
bool enableBufferInitialization) {
|
||||
bool createSparseDeallocs, bool enableBufferInitialization) {
|
||||
patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
|
||||
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
|
||||
SparseCastConverter, SparseTensorDeallocConverter,
|
||||
SparseExtractSliceConverter, SparseTensorLoadConverter,
|
||||
SparseExpandConverter, SparseCompressConverter,
|
||||
SparseInsertConverter,
|
||||
SparseCastConverter, SparseExtractSliceConverter,
|
||||
SparseTensorLoadConverter, SparseExpandConverter,
|
||||
SparseCompressConverter, SparseInsertConverter,
|
||||
SparseSliceGetterOpConverter<ToSliceOffsetOp,
|
||||
StorageSpecifierKind::DimOffset>,
|
||||
SparseSliceGetterOpConverter<ToSliceStrideOp,
|
||||
@@ -1508,6 +1518,8 @@ void mlir::populateSparseTensorCodegenPatterns(
|
||||
SparseConvertConverter, SparseNewOpConverter,
|
||||
SparseNumberOfEntriesConverter>(typeConverter,
|
||||
patterns.getContext());
|
||||
patterns.add<SparseTensorDeallocConverter>(
|
||||
typeConverter, patterns.getContext(), createSparseDeallocs);
|
||||
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
|
||||
enableBufferInitialization);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user