[mlir][sparse] fix a bug in UnpackOp converter.

UnpackOp Converter used to create reallocOp unconditionally, but it might cause issue when the requested memory size is smaller than the actually storage.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D144065
This commit is contained in:
Peiming Liu
2023-02-15 02:18:54 +00:00
parent dd31a3b3a5
commit 81cb70e46e
2 changed files with 57 additions and 13 deletions

View File

@@ -575,6 +575,34 @@ static void genEndInsert(OpBuilder &builder, Location loc,
}
}
/// Returns a memref that fits the requested length (reallocates if requested
/// length is larger, or creates a subview if it is smaller).
static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
Value buffer) {
MemRefType memTp = getMemRefType(buffer);
auto retTp = MemRefType::get(ArrayRef{len}, memTp.getElementType());
Value targetLen = constantIndex(builder, loc, len);
Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0);
Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
targetLen, bufferLen);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, retTp, reallocP, true);
// If targetLen > bufferLen, reallocate to get enough sparse to return.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value reallocBuf = builder.create<memref::ReallocOp>(loc, retTp, buffer);
builder.create<scf::YieldOp>(loc, reallocBuf);
// Else, return a subview to fit the size.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
Value subViewBuf = builder.create<memref::SubViewOp>(
loc, retTp, buffer, /*offset=*/ArrayRef<int64_t>{0},
/*size=*/ArrayRef<int64_t>{len},
/*stride=*/ArrayRef<int64_t>{1});
builder.create<scf::YieldOp>(loc, subViewBuf);
// Resets insertion point.
builder.setInsertionPointAfter(ifOp);
return ifOp.getResult(0);
}
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -1174,16 +1202,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// to ensure that we meet their need.
TensorType dataTp = op.getData().getType();
if (dataTp.hasStaticShape()) {
dataBuf = rewriter.create<memref::ReallocOp>(
loc, MemRefType::get(dataTp.getShape(), dataTp.getElementType()),
dataBuf);
dataBuf = reallocOrSubView(rewriter, loc, dataTp.getShape()[0], dataBuf);
}
TensorType indicesTp = op.getIndices().getType();
if (indicesTp.hasStaticShape()) {
auto len = indicesTp.getShape()[0] * indicesTp.getShape()[1];
flatBuf = rewriter.create<memref::ReallocOp>(
loc, MemRefType::get({len}, indicesTp.getElementType()), flatBuf);
flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
}
Value idxBuf = rewriter.create<memref::ExpandShapeOp>(