//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements miscellaneous loop transformation routines. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/BlockAndValueMapping.h" using namespace mlir; scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop, ValueRange newIterOperands, ValueRange newYieldedValues, bool replaceLoopResults) { assert(newIterOperands.size() == newYieldedValues.size() && "newIterOperands must be of the same size as newYieldedValues"); // Create a new loop before the existing one, with the extra operands. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(loop); auto operands = llvm::to_vector<4>(loop.getIterOperands()); operands.append(newIterOperands.begin(), newIterOperands.end()); scf::ForOp newLoop = b.create(loop.getLoc(), loop.lowerBound(), loop.upperBound(), loop.step(), operands); auto &loopBody = *loop.getBody(); auto &newLoopBody = *newLoop.getBody(); // Clone / erase the yield inside the original loop to both: // 1. augment its operands with the newYieldedValues. // 2. automatically apply the BlockAndValueMapping on its operand auto yield = cast(loopBody.getTerminator()); b.setInsertionPoint(yield); auto yieldOperands = llvm::to_vector<4>(yield.getOperands()); yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end()); auto newYield = b.create(yield.getLoc(), yieldOperands); // Clone the loop body with remaps. BlockAndValueMapping bvm; // a. remap the induction variable. bvm.map(loop.getInductionVar(), newLoop.getInductionVar()); // b. remap the BB args. bvm.map(loopBody.getArguments(), newLoopBody.getArguments().take_front(loopBody.getNumArguments())); // c. remap the iter args. bvm.map(newIterOperands, newLoop.getRegionIterArgs().take_back(newIterOperands.size())); b.setInsertionPointToStart(&newLoopBody); // Skip the original yield terminator which does not have enough operands. for (auto &o : loopBody.without_terminator()) b.clone(o, bvm); // Replace `loop`'s results if requested. if (replaceLoopResults) { for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( loop.getNumResults()))) std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); } // TODO: this is unsafe in the context of a PatternRewrite. newYield.erase(); return newLoop; }