[mlir][sparse] fix a bug in UnpackOp converter.
UnpackOp Converter used to create reallocOp unconditionally, but it might cause issue when the requested memory size is smaller than the actually storage. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D144065
This commit is contained in:
@@ -575,6 +575,34 @@ static void genEndInsert(OpBuilder &builder, Location loc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a memref that fits the requested length (reallocates if requested
|
||||
/// length is larger, or creates a subview if it is smaller).
|
||||
static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
|
||||
Value buffer) {
|
||||
MemRefType memTp = getMemRefType(buffer);
|
||||
auto retTp = MemRefType::get(ArrayRef{len}, memTp.getElementType());
|
||||
|
||||
Value targetLen = constantIndex(builder, loc, len);
|
||||
Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0);
|
||||
Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
||||
targetLen, bufferLen);
|
||||
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, retTp, reallocP, true);
|
||||
// If targetLen > bufferLen, reallocate to get enough sparse to return.
|
||||
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
||||
Value reallocBuf = builder.create<memref::ReallocOp>(loc, retTp, buffer);
|
||||
builder.create<scf::YieldOp>(loc, reallocBuf);
|
||||
// Else, return a subview to fit the size.
|
||||
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
|
||||
Value subViewBuf = builder.create<memref::SubViewOp>(
|
||||
loc, retTp, buffer, /*offset=*/ArrayRef<int64_t>{0},
|
||||
/*size=*/ArrayRef<int64_t>{len},
|
||||
/*stride=*/ArrayRef<int64_t>{1});
|
||||
builder.create<scf::YieldOp>(loc, subViewBuf);
|
||||
// Resets insertion point.
|
||||
builder.setInsertionPointAfter(ifOp);
|
||||
return ifOp.getResult(0);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Codegen rules.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1174,16 +1202,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
// to ensure that we meet their need.
|
||||
TensorType dataTp = op.getData().getType();
|
||||
if (dataTp.hasStaticShape()) {
|
||||
dataBuf = rewriter.create<memref::ReallocOp>(
|
||||
loc, MemRefType::get(dataTp.getShape(), dataTp.getElementType()),
|
||||
dataBuf);
|
||||
dataBuf = reallocOrSubView(rewriter, loc, dataTp.getShape()[0], dataBuf);
|
||||
}
|
||||
|
||||
TensorType indicesTp = op.getIndices().getType();
|
||||
if (indicesTp.hasStaticShape()) {
|
||||
auto len = indicesTp.getShape()[0] * indicesTp.getShape()[1];
|
||||
flatBuf = rewriter.create<memref::ReallocOp>(
|
||||
loc, MemRefType::get({len}, indicesTp.getElementType()), flatBuf);
|
||||
flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
|
||||
}
|
||||
|
||||
Value idxBuf = rewriter.create<memref::ExpandShapeOp>(
|
||||
|
||||
Reference in New Issue
Block a user