[mlir][sparse] misc code cleanup
* Flattening/simplifying some nested conditionals * const-ifying some local variables Depends On D143800 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D143949
This commit is contained in:
@@ -513,13 +513,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
|
||||
}
|
||||
|
||||
needTmpCOO = !allDense && !allOrdered;
|
||||
const RankedTensorType tp = needTmpCOO ? getUnorderedCOOFromType(dstTp)
|
||||
: dstTp.getRankedTensorType();
|
||||
encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst;
|
||||
SmallVector<Value> dynSizes;
|
||||
getDynamicSizes(dstTp, sizes, dynSizes);
|
||||
RankedTensorType tp = dstTp;
|
||||
if (needTmpCOO) {
|
||||
tp = getUnorderedCOOFromType(dstTp);
|
||||
encDst = getSparseTensorEncoding(tp);
|
||||
}
|
||||
dst = rewriter.create<AllocTensorOp>(loc, tp, dynSizes).getResult();
|
||||
if (allDense) {
|
||||
// Create a view of the values buffer to match the unannotated dense
|
||||
@@ -592,21 +590,20 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
|
||||
// Temp variable to avoid needing to call `getRankedTensorType`
|
||||
// in the three use-sites below.
|
||||
const RankedTensorType dstRTT = dstTp;
|
||||
if (encDst) {
|
||||
if (!allDense) {
|
||||
dst = rewriter.create<LoadOp>(loc, dst, true);
|
||||
if (needTmpCOO) {
|
||||
Value tmpCoo = dst;
|
||||
dst = rewriter.create<ConvertOp>(loc, dstRTT, tmpCoo).getResult();
|
||||
rewriter.create<DeallocTensorOp>(loc, tmpCoo);
|
||||
}
|
||||
} else {
|
||||
dst = rewriter.create<ConvertOp>(loc, dstRTT, annotatedDenseDst)
|
||||
.getResult();
|
||||
if (!encDst) {
|
||||
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstRTT, dst);
|
||||
} else if (allDense) {
|
||||
rewriter.replaceOp(
|
||||
op, rewriter.create<ConvertOp>(loc, dstRTT, annotatedDenseDst)
|
||||
.getResult());
|
||||
} else {
|
||||
dst = rewriter.create<LoadOp>(loc, dst, true);
|
||||
if (needTmpCOO) {
|
||||
Value tmpCoo = dst;
|
||||
dst = rewriter.create<ConvertOp>(loc, dstRTT, tmpCoo).getResult();
|
||||
rewriter.create<DeallocTensorOp>(loc, tmpCoo);
|
||||
}
|
||||
rewriter.replaceOp(op, dst);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstRTT, dst);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user