[mlir][sparse] codegen for sparse dealloc

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133171
This commit is contained in:
Aart Bik
2022-09-01 17:18:56 -07:00
parent 11881a8f3f
commit 2ddfacd95c
3 changed files with 56 additions and 9 deletions

View File

@@ -17,6 +17,7 @@
#include "CodegenUtils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -232,7 +233,31 @@ public:
}
};
/// Sparse conversion rule for pointer accesses.
/// Sparse codegen rule for the dealloc operator.
class SparseTensorDeallocConverter
: public OpConversionPattern<bufferization::DeallocTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto enc = getSparseTensorEncoding(op.getTensor().getType());
if (!enc)
return failure();
// Replace the tuple deallocation with field deallocations.
Location loc = op->getLoc();
Value tuple = adaptor.getTensor();
for (unsigned i = 0, sz = tuple.getType().cast<TupleType>().size(); i < sz;
i++) {
Value mem = createTupleGet(rewriter, loc, tuple, i);
rewriter.create<memref::DeallocOp>(loc, mem);
}
rewriter.eraseOp(op);
return success();
}
};
/// Sparse codegen rule for pointer accesses.
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -251,7 +276,7 @@ public:
}
};
/// Sparse conversion rule for index accesses.
/// Sparse codegen rule for index accesses.
class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -270,7 +295,7 @@ public:
}
};
/// Sparse conversion rule for value accesses.
/// Sparse codegen rule for value accesses.
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -280,7 +305,7 @@ public:
// Replace the requested values access with corresponding field.
Location loc = op->getLoc();
Value tuple = adaptor.getTensor();
unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
return success();
}
@@ -306,6 +331,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
SparseToPointersConverter, SparseToIndicesConverter,
SparseToValuesConverter>(typeConverter, patterns.getContext());
SparseTensorDeallocConverter, SparseToPointersConverter,
SparseToIndicesConverter, SparseToValuesConverter>(
typeConverter, patterns.getContext());
}