[mlir][sparse] misc code cleanup

* Flattening/simplifying some nested conditionals
* const-ifying some local variables

Depends On D143800

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D143949
This commit is contained in:
wren romano
2023-02-15 13:28:11 -08:00
parent 74a5d7471f
commit d950bdc73e
2 changed files with 25 additions and 30 deletions

View File

@@ -513,13 +513,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
}
needTmpCOO = !allDense && !allOrdered;
const RankedTensorType tp = needTmpCOO ? getUnorderedCOOFromType(dstTp)
: dstTp.getRankedTensorType();
encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst;
SmallVector<Value> dynSizes;
getDynamicSizes(dstTp, sizes, dynSizes);
RankedTensorType tp = dstTp;
if (needTmpCOO) {
tp = getUnorderedCOOFromType(dstTp);
encDst = getSparseTensorEncoding(tp);
}
dst = rewriter.create<AllocTensorOp>(loc, tp, dynSizes).getResult();
if (allDense) {
// Create a view of the values buffer to match the unannotated dense
@@ -592,21 +590,20 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// Temp variable to avoid needing to call `getRankedTensorType`
// in the three use-sites below.
const RankedTensorType dstRTT = dstTp;
if (encDst) {
if (!allDense) {
dst = rewriter.create<LoadOp>(loc, dst, true);
if (needTmpCOO) {
Value tmpCoo = dst;
dst = rewriter.create<ConvertOp>(loc, dstRTT, tmpCoo).getResult();
rewriter.create<DeallocTensorOp>(loc, tmpCoo);
}
} else {
dst = rewriter.create<ConvertOp>(loc, dstRTT, annotatedDenseDst)
.getResult();
if (!encDst) {
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstRTT, dst);
} else if (allDense) {
rewriter.replaceOp(
op, rewriter.create<ConvertOp>(loc, dstRTT, annotatedDenseDst)
.getResult());
} else {
dst = rewriter.create<LoadOp>(loc, dst, true);
if (needTmpCOO) {
Value tmpCoo = dst;
dst = rewriter.create<ConvertOp>(loc, dstRTT, tmpCoo).getResult();
rewriter.create<DeallocTensorOp>(loc, tmpCoo);
}
rewriter.replaceOp(op, dst);
} else {
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstRTT, dst);
}
return success();
}