use heap memory for position buffer allocated for PackOp.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D148818
This commit is contained in:
Peiming Liu
2023-04-20 17:42:19 +00:00
parent 1a0a0305c0
commit fd2211d84a
4 changed files with 60 additions and 46 deletions

View File

@@ -827,7 +827,7 @@ public:
}
private:
bool createDeallocs;
const bool createDeallocs;
};
/// Sparse codegen rule for tensor rematerialization.
@@ -1343,29 +1343,23 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
break;
case SparseTensorFieldKind::PosMemRef: {
// TACO-style COO starts with a PosBuffer
// By creating a constant value for it, we avoid the complexity of
// memory management.
const auto posTp = stt.getPosType();
if (isCompressedDLT(dlt)) {
RankedTensorType tensorType;
SmallVector<Attribute> posAttr;
tensorType = RankedTensorType::get({batchedCount + 1}, posTp);
posAttr.push_back(IntegerAttr::get(posTp, 0));
for (unsigned i = 0; i < batchedCount; i++) {
auto memrefType = MemRefType::get({batchedCount + 1}, posTp);
field = rewriter.create<memref::AllocOp>(loc, memrefType);
Value c0 = constantIndex(rewriter, loc, 0);
genStore(rewriter, loc, c0, field, c0);
for (unsigned i = 1; i <= batchedCount; i++) {
// The postion memref will have values as
// [0, nse, 2 * nse, ..., batchedCount * nse]
posAttr.push_back(IntegerAttr::get(posTp, nse * (i + 1)));
Value idx = constantIndex(rewriter, loc, i);
Value val = constantIndex(rewriter, loc, nse * i);
genStore(rewriter, loc, val, field, idx);
}
MemRefType memrefType = MemRefType::get(
tensorType.getShape(), tensorType.getElementType());
auto cstPtr = rewriter.create<arith::ConstantOp>(
loc, tensorType, DenseElementsAttr::get(tensorType, posAttr));
field = rewriter.create<bufferization::ToMemrefOp>(
loc, memrefType, cstPtr);
} else {
assert(isCompressedWithHiDLT(dlt) && !batchDimSzs.empty());
MemRefType posMemTp = MemRefType::get({batchedCount * 2}, posTp);
field = rewriter.create<memref::AllocaOp>(loc, posMemTp);
field = rewriter.create<memref::AllocOp>(loc, posMemTp);
populateCompressedWithHiPosArray(rewriter, loc, batchDimSzs,
field, nse, op);
}
@@ -1430,6 +1424,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
using OpConversionPattern::OpConversionPattern;
SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context,
bool createDeallocs)
: OpConversionPattern(typeConverter, context),
createDeallocs(createDeallocs) {}
LogicalResult
matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -1443,6 +1442,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
: desc.getAOSMemRef();
Value valuesBuf = desc.getValMemRef();
Value posBuf = desc.getPosMemRef(0);
if (createDeallocs) {
// Unpack ends the lifetime of the sparse tensor. While the value array
// and coordinate array are unpacked and returned, the position array
// becomes useless and need to be freed (if user requests).
rewriter.create<memref::DeallocOp>(loc, posBuf);
}
// If frontend requests a static buffer, we reallocate the
// values/coordinates to ensure that we meet their need.
@@ -1474,6 +1480,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
rewriter.replaceOp(op, {values, coordinates, nse});
return success();
}
private:
const bool createDeallocs;
};
struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
@@ -1627,11 +1636,11 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool createSparseDeallocs, bool enableBufferInitialization) {
patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
patterns.add<SparsePackOpConverter, SparseReturnConverter,
SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
SparseExtractSliceConverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseCompressConverter,
SparseInsertConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
@@ -1641,7 +1650,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseConvertConverter, SparseNewOpConverter,
SparseNumberOfEntriesConverter>(typeConverter,
patterns.getContext());
patterns.add<SparseTensorDeallocConverter>(
patterns.add<SparseTensorDeallocConverter, SparseUnpackOpConverter>(
typeConverter, patterns.getContext(), createSparseDeallocs);
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);