This dialect contains various structured control flow operaitons, not only loops, reflect this in the name. Drop the Ops suffix for consistency with other dialects. Note that this only moves the files and changes the C++ namespace from 'loop' to 'scf'. The visible IR prefix remains the same and will be updated separately. The conversions will also be updated separately. Differential Revision: https://reviews.llvm.org/D79578
69 lines
2.1 KiB
C++
69 lines
2.1 KiB
C++
//===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass to unroll loops by a specified unroll factor.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/LoopUtils.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
static unsigned getNestingDepth(Operation *op) {
|
|
Operation *currOp = op;
|
|
unsigned depth = 0;
|
|
while ((currOp = currOp->getParentOp())) {
|
|
if (isa<scf::ForOp>(currOp))
|
|
depth++;
|
|
}
|
|
return depth;
|
|
}
|
|
|
|
class TestLoopUnrollingPass
|
|
: public PassWrapper<TestLoopUnrollingPass, FunctionPass> {
|
|
public:
|
|
TestLoopUnrollingPass() = default;
|
|
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
|
|
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
|
|
unsigned loopDepthParam) {
|
|
unrollFactor = unrollFactorParam;
|
|
loopDepth = loopDepthParam;
|
|
}
|
|
|
|
void runOnFunction() override {
|
|
FuncOp func = getFunction();
|
|
SmallVector<scf::ForOp, 4> loops;
|
|
func.walk([&](scf::ForOp forOp) {
|
|
if (getNestingDepth(forOp) == loopDepth)
|
|
loops.push_back(forOp);
|
|
});
|
|
for (auto loop : loops) {
|
|
loopUnrollByFactor(loop, unrollFactor);
|
|
}
|
|
}
|
|
Option<uint64_t> unrollFactor{*this, "unroll-factor",
|
|
llvm::cl::desc("Loop unroll factor."),
|
|
llvm::cl::init(1)};
|
|
Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
|
|
llvm::cl::init(0)};
|
|
};
|
|
} // end namespace
|
|
|
|
namespace mlir {
|
|
void registerTestLoopUnrollingPass() {
|
|
PassRegistration<TestLoopUnrollingPass>(
|
|
"test-loop-unrolling", "Tests loop unrolling transformation");
|
|
}
|
|
} // namespace mlir
|