//===- SCFToStandard.cpp - ControlFlow to CFG conversion ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert scf.for, scf.if and loop.terminator // ops into standard CFG ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "../PassDetail.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" using namespace mlir; using namespace mlir::scf; namespace { struct SCFToStandardPass : public SCFToStandardBase { void runOnOperation() override; }; // Create a CFG subgraph for the loop around its body blocks (if the body // contained other loops, they have been already lowered to a flow of blocks). // Maintain the invariants that a CFG subgraph created for any loop has a single // entry and a single exit, and that the entry/exit blocks are respectively // first/last blocks in the parent region. The original loop operation is // replaced by the initialization operations that set up the initial value of // the loop induction variable (%iv) and computes the loop bounds that are loop- // invariant for affine loops. The operations following the original scf.for // are split out into a separate continuation (exit) block. A condition block is // created before the continuation block. It checks the exit condition of the // loop and branches either to the continuation block, or to the first block of // the body. The condition block takes as arguments the values of the induction // variable followed by loop-carried values. Since it dominates both the body // blocks and the continuation block, loop-carried values are visible in all of // those blocks. Induction variable modification is appended to the last block // of the body (which is the exit block from the body subgraph thanks to the // invariant we maintain) along with a branch that loops back to the condition // block. Loop-carried values are the loop terminator operands, which are // forwarded to the branch. // // +---------------------------------+ // | | // | | // | | // | br cond(%iv, %init...) | // +---------------------------------+ // | // -------| | // | v v // | +--------------------------------+ // | | cond(%iv, %init...): | // | | | // | | cond_br %r, body, end | // | +--------------------------------+ // | | | // | | -------------| // | v | // | +--------------------------------+ | // | | body-first: | | // | | <%init visible by dominance> | | // | | | | // | +--------------------------------+ | // | | | // | ... | // | | | // | +--------------------------------+ | // | | body-last: | | // | | | | // | | | | // | | %new_iv = | | // | | br cond(%new_iv, %yields) | | // | +--------------------------------+ | // | | | // |----------- |-------------------- // v // +--------------------------------+ // | end: | // | | // | <%init visible by dominance> | // +--------------------------------+ // struct ForLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const override; }; // Create a CFG subgraph for the scf.if operation (including its "then" and // optional "else" operation blocks). We maintain the invariants that the // subgraph has a single entry and a single exit point, and that the entry/exit // blocks are respectively the first/last block of the enclosing region. The // operations following the scf.if are split into a continuation (subgraph // exit) block. The condition is lowered to a chain of blocks that implement the // short-circuit scheme. The "scf.if" operation is replaced with a conditional // branch to either the first block of the "then" region, or to the first block // of the "else" region. In these blocks, "scf.yield" is unconditional branches // to the post-dominating block. When the "scf.if" does not return values, the // post-dominating block is the same as the continuation block. When it returns // values, the post-dominating block is a new block with arguments that // correspond to the values returned by the "scf.if" that unconditionally // branches to the continuation block. This allows block arguments to dominate // any uses of the hitherto "scf.if" results that they replaced. (Inserting a // new block allows us to avoid modifying the argument list of an existing // block, which is illegal in a conversion pattern). When the "else" region is // empty, which is only allowed for "scf.if"s that don't return values, the // condition branches directly to the continuation block. // // CFG for a scf.if with else and without results. // // +--------------------------------+ // | | // | cond_br %cond, %then, %else | // +--------------------------------+ // | | // | --------------| // v | // +--------------------------------+ | // | then: | | // | | | // | br continue | | // +--------------------------------+ | // | | // |---------- |------------- // | V // | +--------------------------------+ // | | else: | // | | | // | | br continue | // | +--------------------------------+ // | | // ------| | // v v // +--------------------------------+ // | continue: | // | | // +--------------------------------+ // // CFG for a scf.if with results. // // +--------------------------------+ // | | // | cond_br %cond, %then, %else | // +--------------------------------+ // | | // | --------------| // v | // +--------------------------------+ | // | then: | | // | | | // | br dom(%args...) | | // +--------------------------------+ | // | | // |---------- |------------- // | V // | +--------------------------------+ // | | else: | // | | | // | | br dom(%args...) | // | +--------------------------------+ // | | // ------| | // v v // +--------------------------------+ // | dom(%args...): | // | br continue | // +--------------------------------+ // | // v // +--------------------------------+ // | continue: | // | | // +--------------------------------+ // struct IfLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const override; }; struct ParallelLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp, PatternRewriter &rewriter) const override; }; } // namespace LogicalResult ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { Location loc = forOp.getLoc(); // Start by splitting the block containing the 'scf.for' into two parts. // The part before will get the init code, the part after will be the end // point. auto *initBlock = rewriter.getInsertionBlock(); auto initPosition = rewriter.getInsertionPoint(); auto *endBlock = rewriter.splitBlock(initBlock, initPosition); // Use the first block of the loop body as the condition block since it is the // block that has the induction variable and loop-carried values as arguments. // Split out all operations from the first block into a new block. Move all // body blocks from the loop body region to the region containing the loop. auto *conditionBlock = &forOp.region().front(); auto *firstBodyBlock = rewriter.splitBlock(conditionBlock, conditionBlock->begin()); auto *lastBodyBlock = &forOp.region().back(); rewriter.inlineRegionBefore(forOp.region(), endBlock); auto iv = conditionBlock->getArgument(0); // Append the induction variable stepping logic to the last body block and // branch back to the condition block. Loop-carried values are taken from // operands of the loop terminator. Operation *terminator = lastBodyBlock->getTerminator(); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.step(); auto stepped = rewriter.create(loc, iv, step).getResult(); if (!stepped) return failure(); SmallVector loopCarried; loopCarried.push_back(stepped); loopCarried.append(terminator->operand_begin(), terminator->operand_end()); rewriter.create(loc, conditionBlock, loopCarried); rewriter.eraseOp(terminator); // Compute loop bounds before branching to the condition. rewriter.setInsertionPointToEnd(initBlock); Value lowerBound = forOp.lowerBound(); Value upperBound = forOp.upperBound(); if (!lowerBound || !upperBound) return failure(); // The initial values of loop-carried values is obtained from the operands // of the loop operation. SmallVector destOperands; destOperands.push_back(lowerBound); auto iterOperands = forOp.getIterOperands(); destOperands.append(iterOperands.begin(), iterOperands.end()); rewriter.create(loc, conditionBlock, destOperands); // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); auto comparison = rewriter.create(loc, CmpIPredicate::slt, iv, upperBound); rewriter.create(loc, comparison, firstBodyBlock, ArrayRef(), endBlock, ArrayRef()); // The result of the loop operation is the values of the condition block // arguments except the induction variable on the last iteration. rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); return success(); } LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { auto loc = ifOp.getLoc(); // Start by splitting the block containing the 'scf.if' into two parts. // The part before will contain the condition, the part after will be the // continuation point. auto *condBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); Block *continueBlock; if (ifOp.getNumResults() == 0) { continueBlock = remainingOpsBlock; } else { continueBlock = rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes()); rewriter.create(loc, remainingOpsBlock); } // Move blocks from the "then" region to the region containing 'scf.if', // place it before the continuation block, and branch to it. auto &thenRegion = ifOp.thenRegion(); auto *thenBlock = &thenRegion.front(); Operation *thenTerminator = thenRegion.back().getTerminator(); ValueRange thenTerminatorOperands = thenTerminator->getOperands(); rewriter.setInsertionPointToEnd(&thenRegion.back()); rewriter.create(loc, continueBlock, thenTerminatorOperands); rewriter.eraseOp(thenTerminator); rewriter.inlineRegionBefore(thenRegion, continueBlock); // Move blocks from the "else" region (if present) to the region containing // 'scf.if', place it before the continuation block and branch to it. It // will be placed after the "then" regions. auto *elseBlock = continueBlock; auto &elseRegion = ifOp.elseRegion(); if (!elseRegion.empty()) { elseBlock = &elseRegion.front(); Operation *elseTerminator = elseRegion.back().getTerminator(); ValueRange elseTerminatorOperands = elseTerminator->getOperands(); rewriter.setInsertionPointToEnd(&elseRegion.back()); rewriter.create(loc, continueBlock, elseTerminatorOperands); rewriter.eraseOp(elseTerminator); rewriter.inlineRegionBefore(elseRegion, continueBlock); } rewriter.setInsertionPointToEnd(condBlock); rewriter.create(loc, ifOp.condition(), thenBlock, /*trueArgs=*/ArrayRef(), elseBlock, /*falseArgs=*/ArrayRef()); // Ok, we're done! rewriter.replaceOp(ifOp, continueBlock->getArguments()); return success(); } LogicalResult ParallelLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { Location loc = parallelOp.getLoc(); BlockAndValueMapping mapping; // For a parallel loop, we essentially need to create an n-dimensional loop // nest. We do this by translating to scf.for ops and have those lowered in // a further rewrite. If a parallel loop contains reductions (and thus returns // values), forward the initial values for the reductions down the loop // hierarchy and bubble up the results by modifying the "yield" terminator. SmallVector iterArgs = llvm::to_vector<4>(parallelOp.initVals()); bool first = true; SmallVector loopResults(iterArgs); for (auto loop_operands : llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(), parallelOp.upperBound(), parallelOp.step())) { Value iv, lower, upper, step; std::tie(iv, lower, upper, step) = loop_operands; ForOp forOp = rewriter.create(loc, lower, upper, step, iterArgs); mapping.map(iv, forOp.getInductionVar()); auto iterRange = forOp.getRegionIterArgs(); iterArgs.assign(iterRange.begin(), iterRange.end()); if (first) { // Store the results of the outermost loop that will be used to replace // the results of the parallel loop when it is fully rewritten. loopResults.assign(forOp.result_begin(), forOp.result_end()); first = false; } else if (!forOp.getResults().empty()) { // A loop is constructed with an empty "yield" terminator if there are // no results. rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); rewriter.create(loc, forOp.getResults()); } rewriter.setInsertionPointToStart(forOp.getBody()); } // Now copy over the contents of the body. SmallVector yieldOperands; yieldOperands.reserve(parallelOp.getNumResults()); for (auto &op : parallelOp.getBody()->without_terminator()) { // Reduction blocks are handled differently. auto reduce = dyn_cast(op); if (!reduce) { rewriter.clone(op, mapping); continue; } // Clone the body of the reduction operation into the body of the loop, // using operands of "scf.reduce" and iteration arguments corresponding // to the reduction value to replace arguments of the reduction block. // Collect operands of "scf.reduce.return" to be returned by a final // "scf.yield" instead. Value arg = iterArgs[yieldOperands.size()]; Block &reduceBlock = reduce.reductionOperator().front(); mapping.map(reduceBlock.getArgument(0), mapping.lookupOrDefault(arg)); mapping.map(reduceBlock.getArgument(1), mapping.lookupOrDefault(reduce.operand())); for (auto &nested : reduceBlock.without_terminator()) rewriter.clone(nested, mapping); yieldOperands.push_back( mapping.lookup(reduceBlock.getTerminator()->getOperand(0))); } if (!yieldOperands.empty()) { rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); rewriter.create(loc, yieldOperands); } rewriter.replaceOp(parallelOp, loopResults); return success(); } void mlir::populateLoopToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert(ctx); } void SCFToStandardPass::runOnOperation() { OwningRewritePatternList patterns; populateLoopToStdConversionPatterns(patterns, &getContext()); // Configure conversion to lower out scf.for, scf.if and scf.parallel. // Anything else is fine. ConversionTarget target(getContext()); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed(applyPartialConversion(getOperation(), target, patterns))) signalPassFailure(); } std::unique_ptr mlir::createLowerToCFGPass() { return std::make_unique(); }