This PR attempts to consolidate the different topological sort utilities into one place. It adds them to the analysis folder because the `SliceAnalysis` uses some of these. There are now two different sorting strategies: 1. Sort only according to SSA use-def chains 2. Sort while taking regions into account. This requires a much more elaborate traversal and cannot be applied on graph regions that easily. This additionally reimplements the region aware topological sorting because the previous implementation had an exponential space complexity. I'm open to suggestions on how to combine this further or how to fuse the test passes.
86 lines
2.7 KiB
C++
86 lines
2.7 KiB
C++
//===- TestTopologicalSort.cpp - Pass to test topological sort analysis ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Analysis/TopologicalSortUtils.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
struct TestTopologicalSortAnalysisPass
|
|
: public PassWrapper<TestTopologicalSortAnalysisPass,
|
|
OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortAnalysisPass)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-topological-sort-analysis";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test topological sorting of ops";
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
OpBuilder builder(op->getContext());
|
|
|
|
WalkResult result = op->walk([&](Operation *root) {
|
|
if (!root->hasAttr("root"))
|
|
return WalkResult::advance();
|
|
|
|
SmallVector<Operation *> selectedOps;
|
|
root->walk([&](Operation *selected) {
|
|
if (!selected->hasAttr("selected"))
|
|
return WalkResult::advance();
|
|
if (root->hasAttr("ordered")) {
|
|
// If the root has an "ordered" attribute, we fill the selectedOps
|
|
// vector in a certain order.
|
|
int64_t pos =
|
|
cast<IntegerAttr>(selected->getDiscardableAttr("selected"))
|
|
.getInt();
|
|
if (pos >= static_cast<int64_t>(selectedOps.size()))
|
|
selectedOps.append(pos + 1 - selectedOps.size(), nullptr);
|
|
selectedOps[pos] = selected;
|
|
} else {
|
|
selectedOps.push_back(selected);
|
|
}
|
|
return WalkResult::advance();
|
|
});
|
|
|
|
if (llvm::find(selectedOps, nullptr) != selectedOps.end()) {
|
|
root->emitError("invalid test case: some indices are missing among the "
|
|
"selected ops");
|
|
return WalkResult::skip();
|
|
}
|
|
|
|
if (!computeTopologicalSorting(selectedOps)) {
|
|
root->emitError("could not schedule all ops");
|
|
return WalkResult::skip();
|
|
}
|
|
|
|
for (const auto &it : llvm::enumerate(selectedOps))
|
|
it.value()->setAttr("pos", builder.getIndexAttr(it.index()));
|
|
|
|
return WalkResult::advance();
|
|
});
|
|
|
|
if (result.wasSkipped())
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestTopologicalSortAnalysisPass() {
|
|
PassRegistration<TestTopologicalSortAnalysisPass>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|