//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::sparse_tensor; #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc" LogicalResult sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op, PatternRewriter &rewriter) { if (!op.needsExtraSort()) return failure(); Location loc = op.getLoc(); Type finalTp = op->getOpResult(0).getType(); SparseTensorType dstStt(finalTp.cast()); Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false); // Clones the original operation but changing the output to an unordered COO. Operation *cloned = rewriter.clone(*op.getOperation()); rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() { cloned->getOpResult(0).setType(srcCOOTp); }); Value srcCOO = cloned->getOpResult(0); // -> sort Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true); Value dstCOO = rewriter.create( loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort); // -> dest. if (dstCOO.getType() == finalTp) { rewriter.replaceOp(op, dstCOO); } else { // Need an extra conversion if the target type is not COO. rewriter.replaceOpWithNewOp(op, finalTp, dstCOO); } // TODO: deallocate extra COOs, we should probably delegate it to buffer // deallocation pass. return success(); }