This utility factors out the machinery required to add iterArgs and yield values to an scf.ForOp. Differential Revision: https://reviews.llvm.org/D80656
74 lines
3.0 KiB
C++
74 lines
3.0 KiB
C++
//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements miscellaneous loop transformation routines.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SCF/Utils.h"
|
|
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
|
|
using namespace mlir;
|
|
|
|
scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
|
|
ValueRange newIterOperands,
|
|
ValueRange newYieldedValues,
|
|
bool replaceLoopResults) {
|
|
assert(newIterOperands.size() == newYieldedValues.size() &&
|
|
"newIterOperands must be of the same size as newYieldedValues");
|
|
|
|
// Create a new loop before the existing one, with the extra operands.
|
|
OpBuilder::InsertionGuard g(b);
|
|
b.setInsertionPoint(loop);
|
|
auto operands = llvm::to_vector<4>(loop.getIterOperands());
|
|
operands.append(newIterOperands.begin(), newIterOperands.end());
|
|
scf::ForOp newLoop =
|
|
b.create<scf::ForOp>(loop.getLoc(), loop.lowerBound(), loop.upperBound(),
|
|
loop.step(), operands);
|
|
|
|
auto &loopBody = *loop.getBody();
|
|
auto &newLoopBody = *newLoop.getBody();
|
|
// Clone / erase the yield inside the original loop to both:
|
|
// 1. augment its operands with the newYieldedValues.
|
|
// 2. automatically apply the BlockAndValueMapping on its operand
|
|
auto yield = cast<scf::YieldOp>(loopBody.getTerminator());
|
|
b.setInsertionPoint(yield);
|
|
auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
|
|
yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
|
|
auto newYield = b.create<scf::YieldOp>(yield.getLoc(), yieldOperands);
|
|
|
|
// Clone the loop body with remaps.
|
|
BlockAndValueMapping bvm;
|
|
// a. remap the induction variable.
|
|
bvm.map(loop.getInductionVar(), newLoop.getInductionVar());
|
|
// b. remap the BB args.
|
|
bvm.map(loopBody.getArguments(),
|
|
newLoopBody.getArguments().take_front(loopBody.getNumArguments()));
|
|
// c. remap the iter args.
|
|
bvm.map(newIterOperands,
|
|
newLoop.getRegionIterArgs().take_back(newIterOperands.size()));
|
|
b.setInsertionPointToStart(&newLoopBody);
|
|
// Skip the original yield terminator which does not have enough operands.
|
|
for (auto &o : loopBody.without_terminator())
|
|
b.clone(o, bvm);
|
|
|
|
// Replace `loop`'s results if requested.
|
|
if (replaceLoopResults) {
|
|
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
|
|
loop.getNumResults())))
|
|
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
|
|
}
|
|
|
|
// TODO: this is unsafe in the context of a PatternRewrite.
|
|
newYield.erase();
|
|
|
|
return newLoop;
|
|
}
|