Files
clang-p2996/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
Matthias Springer 5fcf907b34 [mlir][IR] Rename "update root" to "modify op" in rewriter API (#78260)
This commit renames 4 pattern rewriter API functions:
* `updateRootInPlace` -> `modifyOpInPlace`
* `startRootUpdate` -> `startOpModification`
* `finalizeRootUpdate` -> `finalizeOpModification`
* `cancelRootUpdate` -> `cancelOpModification`

The term "root" is a misnomer. The root is the op that a rewrite pattern
matches against
(https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional).
A rewriter must be notified of all in-place op modifications, not just
in-place modifications of the root
(https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old
function names were confusing and have contributed to various broken
rewrite patterns.

Note: The new function names use the term "modify" instead of "update"
for consistency with the `RewriterBase::Listener` terminology
(`notifyOperationModified`).
2024-01-17 11:08:59 +01:00

53 lines
1.9 KiB
C++

//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
LogicalResult
sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
PatternRewriter &rewriter) {
if (!op.needsExtraSort())
return failure();
Location loc = op.getLoc();
Type finalTp = op->getOpResult(0).getType();
SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
// Clones the original operation but changing the output to an unordered COO.
Operation *cloned = rewriter.clone(*op.getOperation());
rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() {
cloned->getOpResult(0).setType(srcCOOTp);
});
Value srcCOO = cloned->getOpResult(0);
// -> sort
Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
Value dstCOO = rewriter.create<ReorderCOOOp>(
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
// -> dest.
if (dstCOO.getType() == finalTp) {
rewriter.replaceOp(op, dstCOO);
} else {
// Need an extra conversion if the target type is not COO.
rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
}
// TODO: deallocate extra COOs, we should probably delegate it to buffer
// deallocation pass.
return success();
}