Files
clang-p2996/mlir/lib/Analysis/SliceAnalysis.cpp
Tres Popp 5550c82189 [mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.

Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.

Caveats include:
- This clang-tidy script probably has more problems.
- This only touches C++ code, so nothing that is being generated.

Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants
  for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443

Implementation:
This first patch was created with the following steps. The intention is
to only do automated changes at first, so I waste less time if it's
reverted, and so the first mass change is more clear as an example to
other teams that will need to follow similar steps.

Steps are described per line, as comments are removed by git:
0. Retrieve the change from the following to build clang-tidy with an
   additional check:
   https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check
1. Build clang-tidy
2. Run clang-tidy over your entire codebase while disabling all checks
   and enabling the one relevant one. Run on all header files also.
3. Delete .inc files that were also modified, so the next build rebuilds
   them to a pure state.
4. Some changes have been deleted for the following reasons:
   - Some files had a variable also named cast
   - Some files had not included a header file that defines the cast
     functions
   - Some files are definitions of the classes that have the casting
     methods, so the code still refers to the method instead of the
     function without adding a prefix or removing the method declaration
     at the same time.

```
ninja -C $BUILD_DIR clang-tidy

run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
               -header-filter=mlir/ mlir/* -fix

rm -rf $BUILD_DIR/tools/mlir/**/*.inc

git restore mlir/lib/IR mlir/lib/Dialect/DLTI/DLTI.cpp\
            mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp\
            mlir/lib/**/IR/\
            mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp\
            mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp\
            mlir/test/lib/Dialect/Test/TestTypes.cpp\
            mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp\
            mlir/test/lib/Dialect/Test/TestAttributes.cpp\
            mlir/unittests/TableGen/EnumsGenTest.cpp\
            mlir/test/python/lib/PythonTestCAPI.cpp\
            mlir/include/mlir/IR/
```

Differential Revision: https://reviews.llvm.org/D150123
2023-05-12 11:21:25 +02:00

321 lines
12 KiB
C++

//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements Analysis functions specific to slicing in Function.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
///
/// Implements Analysis functions specific to slicing in Function.
///
using namespace mlir;
static void getForwardSliceImpl(Operation *op,
SetVector<Operation *> *forwardSlice,
TransitiveFilter filter) {
if (!op)
return;
// Evaluate whether we should keep this use.
// This is useful in particular to implement scoping; i.e. return the
// transitive forwardSlice in the current scope.
if (filter && !filter(op))
return;
for (Region &region : op->getRegions())
for (Block &block : region)
for (Operation &blockOp : block)
if (forwardSlice->count(&blockOp) == 0)
getForwardSliceImpl(&blockOp, forwardSlice, filter);
for (Value result : op->getResults()) {
for (Operation *userOp : result.getUsers())
if (forwardSlice->count(userOp) == 0)
getForwardSliceImpl(userOp, forwardSlice, filter);
}
forwardSlice->insert(op);
}
void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
TransitiveFilter filter, bool inclusive) {
getForwardSliceImpl(op, forwardSlice, filter);
if (!inclusive) {
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
forwardSlice->remove(op);
}
// Reverse to get back the actual topological order.
// std::reverse does not work out of the box on SetVector and I want an
// in-place swap based thing (the real std::reverse, not the LLVM adapter).
std::vector<Operation *> v(forwardSlice->takeVector());
forwardSlice->insert(v.rbegin(), v.rend());
}
void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
TransitiveFilter filter, bool inclusive) {
for (Operation *user : root.getUsers())
getForwardSliceImpl(user, forwardSlice, filter);
// Reverse to get back the actual topological order.
// std::reverse does not work out of the box on SetVector and I want an
// in-place swap based thing (the real std::reverse, not the LLVM adapter).
std::vector<Operation *> v(forwardSlice->takeVector());
forwardSlice->insert(v.rbegin(), v.rend());
}
static void getBackwardSliceImpl(Operation *op,
SetVector<Operation *> *backwardSlice,
TransitiveFilter filter) {
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
return;
// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive backwardSlice in the current scope.
if (filter && !filter(op))
return;
for (const auto &en : llvm::enumerate(op->getOperands())) {
auto operand = en.value();
if (auto *definingOp = operand.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
getBackwardSliceImpl(definingOp, backwardSlice, filter);
} else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
Block *block = blockArg.getOwner();
Operation *parentOp = block->getParentOp();
// TODO: determine whether we want to recurse backward into the other
// blocks of parentOp, which are not technically backward unless they flow
// into us. For now, just bail.
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
parentOp->getRegion(0).getBlocks().size() == 1);
getBackwardSliceImpl(parentOp, backwardSlice, filter);
}
} else {
llvm_unreachable("No definingOp and not a block argument.");
}
}
backwardSlice->insert(op);
}
void mlir::getBackwardSlice(Operation *op,
SetVector<Operation *> *backwardSlice,
TransitiveFilter filter, bool inclusive) {
getBackwardSliceImpl(op, backwardSlice, filter);
if (!inclusive) {
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
backwardSlice->remove(op);
}
}
void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
TransitiveFilter filter, bool inclusive) {
if (Operation *definingOp = root.getDefiningOp()) {
getBackwardSlice(definingOp, backwardSlice, filter, inclusive);
return;
}
Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
getBackwardSlice(bbAargOwner, backwardSlice, filter, inclusive);
}
SetVector<Operation *> mlir::getSlice(Operation *op,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter,
bool inclusive) {
SetVector<Operation *> slice;
slice.insert(op);
unsigned currentIndex = 0;
SetVector<Operation *> backwardSlice;
SetVector<Operation *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter, inclusive);
slice.insert(backwardSlice.begin(), backwardSlice.end());
// Compute and insert the forwardSlice starting from currentOp.
forwardSlice.clear();
getForwardSlice(currentOp, &forwardSlice, forwardFilter, inclusive);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
return topologicalSort(slice);
}
namespace {
/// DFS post-order implementation that maintains a global count to work across
/// multiple invocations, to help implement topological sort on multi-root DAGs.
/// We traverse all operations but only record the ones that appear in
/// `toSort` for the final result.
struct DFSState {
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
const SetVector<Operation *> &toSort;
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;
};
} // namespace
static void dfsPostorder(Operation *root, DFSState *state) {
SmallVector<Operation *> queue(1, root);
std::vector<Operation *> ops;
while (!queue.empty()) {
Operation *current = queue.pop_back_val();
ops.push_back(current);
for (Operation *op : current->getUsers())
queue.push_back(op);
for (Region &region : current->getRegions()) {
for (Operation &op : region.getOps())
queue.push_back(&op);
}
}
for (Operation *op : llvm::reverse(ops)) {
if (state->seen.insert(op).second && state->toSort.count(op) > 0)
state->topologicalCounts.push_back(op);
}
}
SetVector<Operation *>
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
if (toSort.empty()) {
return toSort;
}
// Run from each root with global count and `seen` set.
DFSState state(toSort);
for (auto *s : toSort) {
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
dfsPostorder(s, &state);
}
// Reorder and return.
SetVector<Operation *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {
res.insert(*it);
}
return res;
}
/// Returns true if `value` (transitively) depends on iteration-carried values
/// of the given `ancestorOp`.
static bool dependsOnCarriedVals(Value value,
ArrayRef<BlockArgument> iterCarriedArgs,
Operation *ancestorOp) {
// Compute the backward slice of the value.
SetVector<Operation *> slice;
getBackwardSlice(value, &slice,
[&](Operation *op) { return !ancestorOp->isAncestor(op); });
// Check that none of the operands of the operations in the backward slice are
// loop iteration arguments, and neither is the value itself.
SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
iterCarriedArgs.end());
if (iterCarriedValSet.contains(value))
return true;
for (Operation *op : slice)
for (Value operand : op->getOperands())
if (iterCarriedValSet.contains(operand))
return true;
return false;
}
/// Utility to match a generic reduction given a list of iteration-carried
/// arguments, `iterCarriedArgs` and the position of the potential reduction
/// argument within the list, `redPos`. If a reduction is matched, returns the
/// reduced value and the topologically-sorted list of combiner operations
/// involved in the reduction. Otherwise, returns a null value.
///
/// The matching algorithm relies on the following invariants, which are subject
/// to change:
/// 1. The first combiner operation must be a binary operation with the
/// iteration-carried value and the reduced value as operands.
/// 2. The iteration-carried value and combiner operations must be side
/// effect-free, have single result and a single use.
/// 3. Combiner operations must be immediately nested in the region op
/// performing the reduction.
/// 4. Reduction def-use chain must end in a terminator op that yields the
/// next iteration/output values in the same order as the iteration-carried
/// values in `iterCarriedArgs`.
/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
/// of the region op performing the reduction.
///
/// This utility is generic enough to detect reductions involving multiple
/// combiner operations (disabled for now) across multiple dialects, including
/// Linalg, Affine and SCF. For the sake of genericity, it does not return
/// specific enum values for the combiner operations since its goal is also
/// matching reductions without pre-defined semantics in core MLIR. It's up to
/// each client to make sense out of the list of combiner operations. It's also
/// up to each client to check for additional invariants on the expected
/// reductions not covered by this generic matching.
Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
unsigned redPos,
SmallVectorImpl<Operation *> &combinerOps) {
assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
BlockArgument redCarriedVal = iterCarriedArgs[redPos];
if (!redCarriedVal.hasOneUse())
return nullptr;
// For now, the first combiner op must be a binary op.
Operation *combinerOp = *redCarriedVal.getUsers().begin();
if (combinerOp->getNumOperands() != 2)
return nullptr;
Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
? combinerOp->getOperand(1)
: combinerOp->getOperand(0);
Operation *redRegionOp =
iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
return nullptr;
// Traverse the def-use chain starting from the first combiner op until a
// terminator is found. Gather all the combiner ops along the way in
// topological order.
while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
!combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
return nullptr;
combinerOps.push_back(combinerOp);
combinerOp = *combinerOp->getUsers().begin();
}
// Limit matching to single combiner op until we can properly test reductions
// involving multiple combiners.
if (combinerOps.size() != 1)
return nullptr;
// Check that the yielded value is in the same position as in
// `iterCarriedArgs`.
Operation *terminatorOp = combinerOp;
if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
return nullptr;
return reducedVal;
}