The greedy rewriter is used in many different flows and it has a lot of
convenience (work list management, debugging actions, tracing, etc). But
it combines two kinds of greedy behavior 1) how ops are matched, 2)
folding wherever it can.
These are independent forms of greedy and leads to inefficiency. E.g.,
cases where one need to create different phases in lowering and is
required to applying patterns in specific order split across different
passes. Using the driver one ends up needlessly retrying folding/having
multiple rounds of folding attempts, where one final run would have
sufficed.
Of course folks can locally avoid this behavior by just building their
own, but this is also a common requested feature that folks keep on
working around locally in suboptimal ways.
For downstream users, there should be no behavioral change. Updating
from the deprecated should just be a find and replace (e.g., `find ./
-type f -exec sed -i
's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety)
as the API arguments hasn't changed between the two.
248 lines
9.2 KiB
C++
248 lines
9.2 KiB
C++
//===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass to test SCF dialect utils.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
|
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
struct TestSCFForUtilsPass
|
|
: public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass)
|
|
|
|
StringRef getArgument() const final { return "test-scf-for-utils"; }
|
|
StringRef getDescription() const final { return "test scf.for utils"; }
|
|
explicit TestSCFForUtilsPass() = default;
|
|
TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
|
|
|
|
Option<bool> testReplaceWithNewYields{
|
|
*this, "test-replace-with-new-yields",
|
|
llvm::cl::desc("Test replacing a loop with a new loop that returns new "
|
|
"additional yield values"),
|
|
llvm::cl::init(false)};
|
|
|
|
void runOnOperation() override {
|
|
func::FuncOp func = getOperation();
|
|
SmallVector<scf::ForOp, 4> toErase;
|
|
|
|
if (testReplaceWithNewYields) {
|
|
func.walk([&](scf::ForOp forOp) {
|
|
if (forOp.getNumResults() == 0)
|
|
return;
|
|
auto newInitValues = forOp.getInitArgs();
|
|
if (newInitValues.empty())
|
|
return;
|
|
SmallVector<Value> oldYieldValues =
|
|
llvm::to_vector(forOp.getYieldedValues());
|
|
NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
|
|
ArrayRef<BlockArgument> newBBArgs) {
|
|
SmallVector<Value> newYieldValues;
|
|
for (auto yieldVal : oldYieldValues) {
|
|
newYieldValues.push_back(
|
|
b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
|
|
}
|
|
return newYieldValues;
|
|
};
|
|
IRRewriter rewriter(forOp.getContext());
|
|
if (failed(forOp.replaceWithAdditionalYields(
|
|
rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true,
|
|
fn)))
|
|
signalPassFailure();
|
|
});
|
|
}
|
|
}
|
|
};
|
|
|
|
struct TestSCFIfUtilsPass
|
|
: public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass)
|
|
|
|
StringRef getArgument() const final { return "test-scf-if-utils"; }
|
|
StringRef getDescription() const final { return "test scf.if utils"; }
|
|
explicit TestSCFIfUtilsPass() = default;
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<func::FuncDialect>();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
int count = 0;
|
|
getOperation().walk([&](scf::IfOp ifOp) {
|
|
auto strCount = std::to_string(count++);
|
|
func::FuncOp thenFn, elseFn;
|
|
OpBuilder b(ifOp);
|
|
IRRewriter rewriter(b);
|
|
if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
|
|
std::string("outlined_then") + strCount, &elseFn,
|
|
std::string("outlined_else") + strCount))) {
|
|
this->signalPassFailure();
|
|
return WalkResult::interrupt();
|
|
}
|
|
return WalkResult::advance();
|
|
});
|
|
}
|
|
};
|
|
|
|
static const StringLiteral kTestPipeliningLoopMarker =
|
|
"__test_pipelining_loop__";
|
|
static const StringLiteral kTestPipeliningStageMarker =
|
|
"__test_pipelining_stage__";
|
|
/// Marker to express the order in which operations should be after
|
|
/// pipelining.
|
|
static const StringLiteral kTestPipeliningOpOrderMarker =
|
|
"__test_pipelining_op_order__";
|
|
|
|
static const StringLiteral kTestPipeliningAnnotationPart =
|
|
"__test_pipelining_part";
|
|
static const StringLiteral kTestPipeliningAnnotationIteration =
|
|
"__test_pipelining_iteration";
|
|
|
|
struct TestSCFPipeliningPass
|
|
: public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass)
|
|
|
|
TestSCFPipeliningPass() = default;
|
|
TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
|
|
StringRef getArgument() const final { return "test-scf-pipelining"; }
|
|
StringRef getDescription() const final { return "test scf.forOp pipelining"; }
|
|
|
|
Option<bool> annotatePipeline{
|
|
*this, "annotate",
|
|
llvm::cl::desc("Annote operations during loop pipelining transformation"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> noEpiloguePeeling{
|
|
*this, "no-epilogue-peeling",
|
|
llvm::cl::desc("Use predicates instead of peeling the epilogue."),
|
|
llvm::cl::init(false)};
|
|
|
|
static void
|
|
getSchedule(scf::ForOp forOp,
|
|
std::vector<std::pair<Operation *, unsigned>> &schedule) {
|
|
if (!forOp->hasAttr(kTestPipeliningLoopMarker))
|
|
return;
|
|
|
|
schedule.resize(forOp.getBody()->getOperations().size() - 1);
|
|
forOp.walk([&schedule](Operation *op) {
|
|
auto attrStage =
|
|
op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
|
|
auto attrCycle =
|
|
op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
|
|
if (attrCycle && attrStage) {
|
|
// TODO: Index can be out-of-bounds if ops of the loop body disappear
|
|
// due to folding.
|
|
schedule[attrCycle.getInt()] =
|
|
std::make_pair(op, unsigned(attrStage.getInt()));
|
|
}
|
|
});
|
|
}
|
|
|
|
/// Helper to generate "predicated" version of `op`. For simplicity we just
|
|
/// wrap the operation in a scf.ifOp operation.
|
|
static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
|
|
Value pred) {
|
|
Location loc = op->getLoc();
|
|
auto ifOp =
|
|
rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
|
|
// True branch.
|
|
rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
|
|
ifOp.getThenRegion().front().begin());
|
|
rewriter.setInsertionPointAfter(op);
|
|
if (op->getNumResults() > 0)
|
|
rewriter.create<scf::YieldOp>(loc, op->getResults());
|
|
// False branch.
|
|
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
|
|
SmallVector<Value> elseYieldOperands;
|
|
elseYieldOperands.reserve(ifOp.getNumResults());
|
|
if (auto viewOp = dyn_cast<memref::SubViewOp>(op)) {
|
|
// For sub-views, just clone the op.
|
|
// NOTE: This is okay in the test because we use dynamic memref sizes, so
|
|
// the verifier will not complain. Otherwise, we may create a logically
|
|
// out-of-bounds view and a different technique should be used.
|
|
Operation *opClone = rewriter.clone(*op);
|
|
elseYieldOperands.append(opClone->result_begin(), opClone->result_end());
|
|
} else {
|
|
// Default to assuming constant numeric values.
|
|
for (Type type : op->getResultTypes()) {
|
|
elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>(
|
|
loc, rewriter.getZeroAttr(type)));
|
|
}
|
|
}
|
|
if (op->getNumResults() > 0)
|
|
rewriter.create<scf::YieldOp>(loc, elseYieldOperands);
|
|
return ifOp.getOperation();
|
|
}
|
|
|
|
static void annotate(Operation *op,
|
|
mlir::scf::PipeliningOption::PipelinerPart part,
|
|
unsigned iteration) {
|
|
OpBuilder b(op);
|
|
switch (part) {
|
|
case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
|
|
op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
|
|
break;
|
|
case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
|
|
op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
|
|
break;
|
|
case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
|
|
op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
|
|
break;
|
|
}
|
|
op->setAttr(kTestPipeliningAnnotationIteration,
|
|
b.getI32IntegerAttr(iteration));
|
|
}
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<arith::ArithDialect, memref::MemRefDialect>();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
mlir::scf::PipeliningOption options;
|
|
options.getScheduleFn = getSchedule;
|
|
options.supportDynamicLoops = true;
|
|
options.predicateFn = predicateOp;
|
|
if (annotatePipeline)
|
|
options.annotateFn = annotate;
|
|
if (noEpiloguePeeling) {
|
|
options.peelEpilogue = false;
|
|
}
|
|
scf::populateSCFLoopPipeliningPatterns(patterns, options);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
getOperation().walk([](Operation *op) {
|
|
// Clean up the markers.
|
|
op->removeAttr(kTestPipeliningStageMarker);
|
|
op->removeAttr(kTestPipeliningOpOrderMarker);
|
|
});
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestSCFUtilsPass() {
|
|
PassRegistration<TestSCFForUtilsPass>();
|
|
PassRegistration<TestSCFIfUtilsPass>();
|
|
PassRegistration<TestSCFPipeliningPass>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|