Define SCF dialect patterns rotating `scf.while` loops leveraging
existing `mlir::scf::wrapWhileLoopInZeroTripCheck`. `forceCreateCheck`
is always `false` as the pattern would lead to an infinite recursion
otherwise.
This pattern rotates `scf.while` ops, mutating them from "while" loops to
"do-while" loops. A guard checking the condition for the first iteration
is inserted. Note this guard can be optimized away if the compiler can
prove the loop will be executed at least once.
Using this pattern, the following while loop:
```mlir
scf.while (%arg0 = %init) : (i32) -> i64 {
%val = .., %arg0 : i64
%cond = arith.cmpi .., %arg0 : i32
scf.condition(%cond) %val : i64
} do {
^bb0(%arg1: i64):
%next = .., %arg1 : i32
scf.yield %next : i32
}
```
Can be transformed into:
``` mlir
%pre_val = .., %init : i64
%pre_cond = arith.cmpi .., %init : i32
scf.if %pre_cond -> i64 {
%res = scf.while (%arg1 = %va0) : (i64) -> i64 {
// Original after block
%next = .., %arg1 : i32
// Original before block
%val = .., %next : i64
%cond = arith.cmpi .., %next : i32
scf.condition(%cond) %val : i64
} do {
^bb0(%arg2: i64):
%scf.yield %arg2 : i32
}
scf.yield %res : i64
} else {
scf.yield %pre_val : i64
}
```
The test pass for `wrapWhileLoopInZeroTripCheck` has been modified to
use the new pattern when `forceCreateCheck=false`.
---------
Signed-off-by: Victor Perez <victor.perez@codeplay.com>
81 lines
2.6 KiB
C++
81 lines
2.6 KiB
C++
//===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the passes to test wrap-in-zero-trip-check transforms on
|
|
// SCF loop ops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
struct TestWrapWhileLoopInZeroTripCheckPass
|
|
: public PassWrapper<TestWrapWhileLoopInZeroTripCheckPass,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestWrapWhileLoopInZeroTripCheckPass)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-wrap-scf-while-loop-in-zero-trip-check";
|
|
}
|
|
|
|
StringRef getDescription() const final {
|
|
return "test scf::wrapWhileLoopInZeroTripCheck";
|
|
}
|
|
|
|
TestWrapWhileLoopInZeroTripCheckPass() = default;
|
|
TestWrapWhileLoopInZeroTripCheckPass(
|
|
const TestWrapWhileLoopInZeroTripCheckPass &) {}
|
|
explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam) {
|
|
forceCreateCheck = forceCreateCheckParam;
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
func::FuncOp func = getOperation();
|
|
MLIRContext *context = &getContext();
|
|
IRRewriter rewriter(context);
|
|
if (forceCreateCheck) {
|
|
func.walk([&](scf::WhileOp op) {
|
|
FailureOr<scf::WhileOp> result =
|
|
scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
|
|
// Ignore not implemented failure in tests. The expected output should
|
|
// catch problems (e.g. transformation doesn't happen).
|
|
(void)result;
|
|
});
|
|
} else {
|
|
RewritePatternSet patterns(context);
|
|
scf::populateSCFRotateWhileLoopPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
|
}
|
|
}
|
|
|
|
Option<bool> forceCreateCheck{
|
|
*this, "force-create-check",
|
|
llvm::cl::desc("Force to create zero-trip-check."),
|
|
llvm::cl::init(false)};
|
|
};
|
|
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestSCFWrapInZeroTripCheckPasses() {
|
|
PassRegistration<TestWrapWhileLoopInZeroTripCheckPass>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|