This change cleans up call sites. Next step is to mark the member functions deprecated. See https://mlir.llvm.org/deprecation and https://discourse.llvm.org/t/preferred-casting-style-going-forward.
64 lines
2.4 KiB
C++
64 lines
2.4 KiB
C++
//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::sparse_tensor;
|
|
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
|
|
|
|
/// Stage the operations into a sequence of simple operations as follow:
|
|
/// op -> unsorted_coo +
|
|
/// unsorted_coo -> sorted_coo +
|
|
/// sorted_coo -> dstTp.
|
|
///
|
|
/// return `tmpBuf` if a intermediate memory is allocated.
|
|
LogicalResult sparse_tensor::detail::stageWithSortImpl(
|
|
StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
|
|
if (!op.needsExtraSort())
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
Type finalTp = op->getOpResult(0).getType();
|
|
SparseTensorType dstStt(cast<RankedTensorType>(finalTp));
|
|
Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
|
|
|
|
// Clones the original operation but changing the output to an unordered COO.
|
|
Operation *cloned = rewriter.clone(*op.getOperation());
|
|
rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() {
|
|
cloned->getOpResult(0).setType(srcCOOTp);
|
|
});
|
|
Value srcCOO = cloned->getOpResult(0);
|
|
|
|
// -> sort
|
|
Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
|
|
Value dstCOO = rewriter.create<ReorderCOOOp>(
|
|
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
|
|
|
|
// -> dest.
|
|
if (dstCOO.getType() == finalTp) {
|
|
rewriter.replaceOp(op, dstCOO);
|
|
} else {
|
|
// Need an extra conversion if the target type is not COO.
|
|
auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
|
|
rewriter.setInsertionPointAfter(c);
|
|
// Informs the caller about the intermediate buffer we allocated. We can not
|
|
// create a bufferization::DeallocateTensorOp here because it would
|
|
// introduce cyclic dependency between the SparseTensorDialect and the
|
|
// BufferizationDialect. Besides, whether the buffer need to be deallocated
|
|
// by SparseTensorDialect or by BufferDeallocationPass is still TBD.
|
|
tmpBufs = dstCOO;
|
|
}
|
|
|
|
return success();
|
|
}
|