Files
clang-p2996/mlir/test/lib/Transforms/TestTopologicalSort.cpp
Matthias Springer 31fbdab376 [mlir][transforms] Add topological sort analysis
This change add a helper function for computing a topological sorting of a list of ops. E.g. this can be useful in transforms where a subset of ops should be cloned without dominance errors.

The analysis reuses the existing implementation in TopologicalSortUtils.cpp.

Differential Revision: https://reviews.llvm.org/D131669
2022-08-15 21:09:18 +02:00

63 lines
1.9 KiB
C++

//===- TestTopologicalSort.cpp - Pass to test topological sort analysis ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/TopologicalSortUtils.h"
using namespace mlir;
namespace {
struct TestTopologicalSortAnalysisPass
: public PassWrapper<TestTopologicalSortAnalysisPass,
OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortAnalysisPass)
StringRef getArgument() const final {
return "test-topological-sort-analysis";
}
StringRef getDescription() const final {
return "Test topological sorting of ops";
}
void runOnOperation() override {
Operation *op = getOperation();
OpBuilder builder(op->getContext());
op->walk([&](Operation *root) {
if (!root->hasAttr("root"))
return WalkResult::advance();
assert(root->getNumRegions() == 1 && root->getRegion(0).hasOneBlock() &&
"expected one block");
Block *block = &root->getRegion(0).front();
SmallVector<Operation *> selectedOps;
block->walk([&](Operation *op) {
if (op->hasAttr("selected"))
selectedOps.push_back(op);
});
computeTopologicalSorting(block, selectedOps);
for (const auto &it : llvm::enumerate(selectedOps))
it.value()->setAttr("pos", builder.getIndexAttr(it.index()));
return WalkResult::advance();
});
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestTopologicalSortAnalysisPass() {
PassRegistration<TestTopologicalSortAnalysisPass>();
}
} // namespace test
} // namespace mlir