[mlir][sparse] codegen for sparse dealloc
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D133171
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
|
||||
#include "CodegenUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
@@ -232,7 +233,31 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for pointer accesses.
|
||||
/// Sparse codegen rule for the dealloc operator.
|
||||
class SparseTensorDeallocConverter
|
||||
: public OpConversionPattern<bufferization::DeallocTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto enc = getSparseTensorEncoding(op.getTensor().getType());
|
||||
if (!enc)
|
||||
return failure();
|
||||
// Replace the tuple deallocation with field deallocations.
|
||||
Location loc = op->getLoc();
|
||||
Value tuple = adaptor.getTensor();
|
||||
for (unsigned i = 0, sz = tuple.getType().cast<TupleType>().size(); i < sz;
|
||||
i++) {
|
||||
Value mem = createTupleGet(rewriter, loc, tuple, i);
|
||||
rewriter.create<memref::DeallocOp>(loc, mem);
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for pointer accesses.
|
||||
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@@ -251,7 +276,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for index accesses.
|
||||
/// Sparse codegen rule for index accesses.
|
||||
class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@@ -270,7 +295,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for value accesses.
|
||||
/// Sparse codegen rule for value accesses.
|
||||
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@@ -280,7 +305,7 @@ public:
|
||||
// Replace the requested values access with corresponding field.
|
||||
Location loc = op->getLoc();
|
||||
Value tuple = adaptor.getTensor();
|
||||
unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
|
||||
unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
|
||||
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
|
||||
return success();
|
||||
}
|
||||
@@ -306,6 +331,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
|
||||
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
|
||||
SparseToPointersConverter, SparseToIndicesConverter,
|
||||
SparseToValuesConverter>(typeConverter, patterns.getContext());
|
||||
SparseTensorDeallocConverter, SparseToPointersConverter,
|
||||
SparseToIndicesConverter, SparseToValuesConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user