Files
clang-p2996/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
Aart Bik 45288085b5 [mlir][sparse] move toCOOType into SparseTensorType class (#73708)
Migrates dangling convenience method into proper SparseTensorType class.
Also cleans up some details (picking right dim2lvl/lvl2dim). Removes
more dead code.
2023-11-28 16:04:01 -08:00

53 lines
1.9 KiB
C++

//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
LogicalResult
sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
PatternRewriter &rewriter) {
if (!op.needsExtraSort())
return failure();
Location loc = op.getLoc();
Type finalTp = op->getOpResult(0).getType();
SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
// Clones the original operation but changing the output to an unordered COO.
Operation *cloned = rewriter.clone(*op.getOperation());
rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() {
cloned->getOpResult(0).setType(srcCOOTp);
});
Value srcCOO = cloned->getOpResult(0);
// -> sort
Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
Value dstCOO = rewriter.create<ReorderCOOOp>(
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
// -> dest.
if (dstCOO.getType() == finalTp) {
rewriter.replaceOp(op, dstCOO);
} else {
// Need an extra conversion if the target type is not COO.
rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
}
// TODO: deallocate extra COOs, we should probably delegate it to buffer
// deallocation pass.
return success();
}