use heap memory for position buffer allocated for PackOp.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D148818
This commit is contained in:
@@ -827,7 +827,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
bool createDeallocs;
|
||||
const bool createDeallocs;
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for tensor rematerialization.
|
||||
@@ -1343,29 +1343,23 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
break;
|
||||
case SparseTensorFieldKind::PosMemRef: {
|
||||
// TACO-style COO starts with a PosBuffer
|
||||
// By creating a constant value for it, we avoid the complexity of
|
||||
// memory management.
|
||||
const auto posTp = stt.getPosType();
|
||||
if (isCompressedDLT(dlt)) {
|
||||
RankedTensorType tensorType;
|
||||
SmallVector<Attribute> posAttr;
|
||||
tensorType = RankedTensorType::get({batchedCount + 1}, posTp);
|
||||
posAttr.push_back(IntegerAttr::get(posTp, 0));
|
||||
for (unsigned i = 0; i < batchedCount; i++) {
|
||||
auto memrefType = MemRefType::get({batchedCount + 1}, posTp);
|
||||
field = rewriter.create<memref::AllocOp>(loc, memrefType);
|
||||
Value c0 = constantIndex(rewriter, loc, 0);
|
||||
genStore(rewriter, loc, c0, field, c0);
|
||||
for (unsigned i = 1; i <= batchedCount; i++) {
|
||||
// The postion memref will have values as
|
||||
// [0, nse, 2 * nse, ..., batchedCount * nse]
|
||||
posAttr.push_back(IntegerAttr::get(posTp, nse * (i + 1)));
|
||||
Value idx = constantIndex(rewriter, loc, i);
|
||||
Value val = constantIndex(rewriter, loc, nse * i);
|
||||
genStore(rewriter, loc, val, field, idx);
|
||||
}
|
||||
MemRefType memrefType = MemRefType::get(
|
||||
tensorType.getShape(), tensorType.getElementType());
|
||||
auto cstPtr = rewriter.create<arith::ConstantOp>(
|
||||
loc, tensorType, DenseElementsAttr::get(tensorType, posAttr));
|
||||
field = rewriter.create<bufferization::ToMemrefOp>(
|
||||
loc, memrefType, cstPtr);
|
||||
} else {
|
||||
assert(isCompressedWithHiDLT(dlt) && !batchDimSzs.empty());
|
||||
MemRefType posMemTp = MemRefType::get({batchedCount * 2}, posTp);
|
||||
field = rewriter.create<memref::AllocaOp>(loc, posMemTp);
|
||||
field = rewriter.create<memref::AllocOp>(loc, posMemTp);
|
||||
populateCompressedWithHiPosArray(rewriter, loc, batchDimSzs,
|
||||
field, nse, op);
|
||||
}
|
||||
@@ -1430,6 +1424,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
|
||||
struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context,
|
||||
bool createDeallocs)
|
||||
: OpConversionPattern(typeConverter, context),
|
||||
createDeallocs(createDeallocs) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
@@ -1443,6 +1442,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
|
||||
: desc.getAOSMemRef();
|
||||
Value valuesBuf = desc.getValMemRef();
|
||||
Value posBuf = desc.getPosMemRef(0);
|
||||
if (createDeallocs) {
|
||||
// Unpack ends the lifetime of the sparse tensor. While the value array
|
||||
// and coordinate array are unpacked and returned, the position array
|
||||
// becomes useless and need to be freed (if user requests).
|
||||
rewriter.create<memref::DeallocOp>(loc, posBuf);
|
||||
}
|
||||
|
||||
// If frontend requests a static buffer, we reallocate the
|
||||
// values/coordinates to ensure that we meet their need.
|
||||
@@ -1474,6 +1480,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
rewriter.replaceOp(op, {values, coordinates, nse});
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
const bool createDeallocs;
|
||||
};
|
||||
|
||||
struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
@@ -1627,11 +1636,11 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
void mlir::populateSparseTensorCodegenPatterns(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
bool createSparseDeallocs, bool enableBufferInitialization) {
|
||||
patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
|
||||
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
|
||||
SparseCastConverter, SparseExtractSliceConverter,
|
||||
SparseTensorLoadConverter, SparseExpandConverter,
|
||||
SparseCompressConverter, SparseInsertConverter,
|
||||
patterns.add<SparsePackOpConverter, SparseReturnConverter,
|
||||
SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
|
||||
SparseExtractSliceConverter, SparseTensorLoadConverter,
|
||||
SparseExpandConverter, SparseCompressConverter,
|
||||
SparseInsertConverter,
|
||||
SparseSliceGetterOpConverter<ToSliceOffsetOp,
|
||||
StorageSpecifierKind::DimOffset>,
|
||||
SparseSliceGetterOpConverter<ToSliceStrideOp,
|
||||
@@ -1641,7 +1650,7 @@ void mlir::populateSparseTensorCodegenPatterns(
|
||||
SparseConvertConverter, SparseNewOpConverter,
|
||||
SparseNumberOfEntriesConverter>(typeConverter,
|
||||
patterns.getContext());
|
||||
patterns.add<SparseTensorDeallocConverter>(
|
||||
patterns.add<SparseTensorDeallocConverter, SparseUnpackOpConverter>(
|
||||
typeConverter, patterns.getContext(), createSparseDeallocs);
|
||||
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
|
||||
enableBufferInitialization);
|
||||
|
||||
Reference in New Issue
Block a user