[mlir][sparse] unifies sparse_tensor.sort_coo/sort into one operation. (#66722)

The use cases of the two operations are largely overlapped, let's
simplify it and only use one of them.
This commit is contained in:
Peiming Liu
2023-09-19 17:02:32 -07:00
committed by GitHub
parent 74338bfe0c
commit bfa3bc4378
14 changed files with 267 additions and 678 deletions

View File

@@ -890,8 +890,9 @@ public:
// If the innermost level is ordered, we need to sort the coordinates
// in the "added" array prior to applying the compression.
if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{},
SparseTensorSortKind::HybridQuickSort);
rewriter.create<SortCooOp>(
loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
// While performing the insertions, we also need to reset the elements
// of the values/filled-switch by only iterating over the set elements,
// to ensure that the runtime complexity remains proportional to the
@@ -1486,9 +1487,10 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
scf::IfOp ifOp =
rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
rewriter.create<SortCooOp>(
loc, nse, xs, ValueRange{ys}, rewriter.getIndexAttr(lvlRank),
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
rewriter.create<SortCooOp>(loc, nse, xs, ValueRange{ys}, xPerm,
rewriter.getIndexAttr(0),
SparseTensorSortKind::HybridQuickSort);
rewriter.setInsertionPointAfter(ifOp);
}