//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::sparse_tensor; #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc" /// Stage the operations into a sequence of simple operations as follow: /// op -> unsorted_coo + /// unsorted_coo -> sorted_coo + /// sorted_coo -> dstTp. /// /// return `tmpBuf` if a intermediate memory is allocated. LogicalResult sparse_tensor::detail::stageWithSortImpl( StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) { if (!op.needsExtraSort()) return failure(); Location loc = op.getLoc(); Type finalTp = op->getOpResult(0).getType(); SparseTensorType dstStt(cast(finalTp)); Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false); // Clones the original operation but changing the output to an unordered COO. Operation *cloned = rewriter.clone(*op.getOperation()); rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() { cloned->getOpResult(0).setType(srcCOOTp); }); Value srcCOO = cloned->getOpResult(0); // -> sort Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true); Value dstCOO = rewriter.create( loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort); // -> dest. if (dstCOO.getType() == finalTp) { rewriter.replaceOp(op, dstCOO); } else { // Need an extra conversion if the target type is not COO. auto c = rewriter.replaceOpWithNewOp(op, finalTp, dstCOO); rewriter.setInsertionPointAfter(c); // Informs the caller about the intermediate buffer we allocated. We can not // create a bufferization::DeallocateTensorOp here because it would // introduce cyclic dependency between the SparseTensorDialect and the // BufferizationDialect. Besides, whether the buffer need to be deallocated // by SparseTensorDialect or by BufferDeallocationPass is still TBD. tmpBufs = dstCOO; } return success(); }