Files
clang-p2996/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
Jorn Tuyls 286bd42a7a [mlir] Extract forall_to_for logic into reusable function and add pass (#89636)
This PR extracts the existing `scf.forall` to `scf.for` conversion logic
inside a transform op (https://github.com/llvm/llvm-project/pull/65474)
into a standalone function which can be used in other transformations
and adds a `scf-forall-to-for` pass.
2024-04-24 09:57:48 -07:00

80 lines
2.6 KiB
C++

//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.ForallOp's into SCF.ForOp's.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
using namespace llvm;
using namespace mlir;
using scf::ForallOp;
using scf::ForOp;
using scf::LoopNest;
LogicalResult
mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
SmallVectorImpl<Operation *> *results) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(forallOp);
Location loc = forallOp.getLoc();
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedLowerBound());
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedUpperBound());
SmallVector<Value> steps =
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
SmallVector<Value> ivs = llvm::map_to_vector(
loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
Block *innermostBlock = loopNest.loops.back().getBody();
rewriter.eraseOp(forallOp.getBody()->getTerminator());
rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
innermostBlock->getTerminator()->getIterator(),
ivs);
rewriter.eraseOp(forallOp);
if (results) {
llvm::move(loopNest.loops, std::back_inserter(*results));
}
return success();
}
namespace {
struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
void runOnOperation() override {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());
parentOp->walk([&](scf::ForallOp forallOp) {
if (failed(scf::forallToForLoop(rewriter, forallOp))) {
return signalPassFailure();
}
});
}
};
} // namespace
std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
return std::make_unique<ForallToForLoop>();
}