[mlir] add a fluent API to GreedyRewriterConfig (#137122)

This is similar to other configuration objects used across MLIR.

Rename some fields to better reflect that they are no longer booleans.

Reland 04d261101b4f229189463136a794e3e362a793af / #132253.
This commit is contained in:
Oleksandr "Alex" Zinenko
2025-04-24 09:51:42 +02:00
committed by GitHub
parent 15bb1db4a9
commit 0c61b24337
30 changed files with 225 additions and 169 deletions

View File

@@ -357,8 +357,8 @@ public:
patterns.insert<PackArrayConversion>(context);
patterns.insert<UnpackArrayConversion>(context);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
(void)applyPatternsGreedily(module, std::move(patterns), config);
}

View File

@@ -119,8 +119,8 @@ public:
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
mlir::RewritePatternSet patterns(context);
patterns.insert<InlineElementalConversion>(context);

View File

@@ -135,8 +135,8 @@ public:
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
mlir::RewritePatternSet patterns(context);
patterns.insert<InlineHLFIRAssignConversion>(context);

View File

@@ -557,8 +557,8 @@ public:
// Pattern rewriting only requires that the resulting IR is still valid
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {

View File

@@ -875,8 +875,8 @@ public:
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
mlir::RewritePatternSet patterns(context);
// TODO: right now the patterns are non-conflicting,

View File

@@ -2132,8 +2132,8 @@ public:
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
mlir::RewritePatternSet patterns(context);
patterns.insert<TransposeAsElementalConversion>(context);

View File

@@ -35,7 +35,8 @@ void addNestedPassToAllTopLevelOperationsConditionally(
void addCanonicalizerPassWithoutRegionSimplification(mlir::OpPassManager &pm) {
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
pm.addPass(mlir::createCanonicalizerPass(config));
}
@@ -163,7 +164,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
// simplify the IR
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
pm.addPass(mlir::createCSEPass());
fir::addAVC(pm, pc.OptLevel);
addNestedPassToAllTopLevelOperations<PassConstructor>(

View File

@@ -152,8 +152,8 @@ public:
patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
(void)applyPatternsGreedily(mod, std::move(patterns), config);
}
};

View File

@@ -168,9 +168,9 @@ public:
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps);
patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(

View File

@@ -205,7 +205,8 @@ void SimplifyFIROperationsPass::runOnOperation() {
fir::populateSimplifyFIROperationsPatterns(patterns,
preferInlineImplementation);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {

View File

@@ -806,7 +806,8 @@ void StackArraysPass::runOnOperation() {
mlir::RewritePatternSet patterns(&context);
mlir::GreedyRewriteConfig config;
// prevent the pattern driver form merging blocks
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
patterns.insert<AllocMemConversion>(&context, *candidateOps);
if (mlir::failed(mlir::applyOpPatternsGreedily(

View File

@@ -49,25 +49,43 @@ public:
/// larger patterns when given an ambiguous pattern set.
///
/// Note: Only applicable when simplifying entire regions.
bool useTopDownTraversal = false;
bool getUseTopDownTraversal() const { return useTopDownTraversal; }
GreedyRewriteConfig &setUseTopDownTraversal(bool use = true) {
useTopDownTraversal = use;
return *this;
}
/// Perform control flow optimizations to the region tree after applying all
/// patterns.
///
/// Note: Only applicable when simplifying entire regions.
GreedySimplifyRegionLevel enableRegionSimplification =
GreedySimplifyRegionLevel::Aggressive;
GreedySimplifyRegionLevel getRegionSimplificationLevel() const {
return regionSimplificationLevel;
}
GreedyRewriteConfig &
setRegionSimplificationLevel(GreedySimplifyRegionLevel level) {
regionSimplificationLevel = level;
return *this;
}
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
/// disable this iteration limit.
///
/// Note: Only applicable when simplifying entire regions.
int64_t maxIterations = 10;
int64_t getMaxIterations() const { return maxIterations; }
GreedyRewriteConfig &setMaxIterations(int64_t iterations) {
maxIterations = iterations;
return *this;
}
/// This specifies the maximum number of rewrites within an iteration. Use
/// `kNoLimit` to disable this limit.
int64_t maxNumRewrites = kNoLimit;
int64_t getMaxNumRewrites() const { return maxNumRewrites; }
GreedyRewriteConfig &setMaxNumRewrites(int64_t limit) {
maxNumRewrites = limit;
return *this;
}
static constexpr int64_t kNoLimit = -1;
@@ -75,7 +93,11 @@ public:
/// specified, the closest enclosing region around the initial list of ops
/// (or the specified region, depending on which greedy rewrite entry point
/// is used) is used as a scope.
Region *scope = nullptr;
Region *getScope() const { return scope; }
GreedyRewriteConfig &setScope(Region *scope) {
this->scope = scope;
return *this;
}
/// Strict mode can restrict the ops that are added to the worklist during
/// the rewrite.
@@ -87,16 +109,44 @@ public:
/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were
/// were on the worklist at the very beginning) enqueued. All other ops are
/// excluded.
GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
GreedyRewriteStrictness getStrictness() const { return strictness; }
GreedyRewriteConfig &setStrictness(GreedyRewriteStrictness mode) {
strictness = mode;
return *this;
}
/// An optional listener that should be notified about IR modifications.
RewriterBase::Listener *listener = nullptr;
RewriterBase::Listener *getListener() const { return listener; }
GreedyRewriteConfig &setListener(RewriterBase::Listener *listener) {
this->listener = listener;
return *this;
}
/// Whether this should fold while greedily rewriting.
bool fold = true;
bool isFoldingEnabled() const { return fold; }
GreedyRewriteConfig &enableFolding(bool enable = true) {
fold = enable;
return *this;
}
/// If set to "true", constants are CSE'd (even across multiple regions that
/// are in a parent-ancestor relationship).
bool isConstantCSEEnabled() const { return cseConstants; }
GreedyRewriteConfig &enableConstantCSE(bool enable = true) {
cseConstants = enable;
return *this;
}
private:
Region *scope = nullptr;
bool useTopDownTraversal = false;
GreedySimplifyRegionLevel regionSimplificationLevel =
GreedySimplifyRegionLevel::Aggressive;
int64_t maxIterations = 10;
int64_t maxNumRewrites = kNoLimit;
GreedyRewriteStrictness strictness = GreedyRewriteStrictness::AnyOp;
RewriterBase::Listener *listener = nullptr;
bool fold = true;
bool cseConstants = true;
};
@@ -128,14 +178,14 @@ applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
/// Same as `applyPatternsAndGreedily` above with folding.
/// FIXME: Remove this once transition to above is complieted.
/// FIXME: Remove this once transition to above is completed.
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
inline LogicalResult
applyPatternsAndFoldGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
config.fold = true;
config.enableFolding();
return applyPatternsGreedily(region, patterns, config, changed);
}
@@ -187,7 +237,7 @@ applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
config.fold = true;
config.enableFolding();
return applyPatternsGreedily(op, patterns, config, changed);
}
@@ -233,7 +283,7 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr) {
config.fold = true;
config.enableFolding();
return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
}

View File

@@ -33,7 +33,7 @@ def Canonicalizer : Pass<"canonicalize"> {
Option<"topDownProcessingEnabled", "top-down", "bool",
/*default=*/"true",
"Seed the worklist in general top-down order">,
Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
Option<"regionSimplifyLevel", "region-simplify", "mlir::GreedySimplifyRegionLevel",
/*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
"Perform control flow optimizations to the region tree",
[{::llvm::cl::values(

View File

@@ -127,12 +127,13 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.listener =
static_cast<RewriterBase::Listener *>(rewriter.getListener());
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
// Apply the simplification pattern to a fixpoint.
if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
if (failed(applyOpPatternsGreedily(
targets, frozenPatterns,
GreedyRewriteConfig()
.setListener(
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps)))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;

View File

@@ -237,7 +237,8 @@ void AffineDataCopyGeneration::runOnOperation() {
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
(void)applyOpPatternsGreedily(
copyOps, frozenPatterns,
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingAndNewOps));
}

View File

@@ -109,7 +109,8 @@ void SimplifyAffineStructures::runOnOperation() {
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
opsToSimplify.push_back(op);
});
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
(void)applyOpPatternsGreedily(
opsToSimplify, frozenPatterns,
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingAndNewOps));
}

View File

@@ -315,11 +315,12 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
// Simplify/canonicalize the affine.for.
RewritePatternSet patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
(void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
config, /*changed=*/nullptr, &erased);
(void)applyOpPatternsGreedily(
res.getOperation(), std::move(patterns),
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingAndNewOps),
/*changed=*/nullptr, &erased);
if (!erased && !prologue)
prologue = res;
if (!erased)

View File

@@ -426,11 +426,11 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
RewritePatternSet patterns(ifOp.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
(void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config,
/*changed=*/nullptr, &erased);
(void)applyOpPatternsGreedily(
ifOp.getOperation(), frozenPatterns,
GreedyRewriteConfig().setStrictness(GreedyRewriteStrictness::ExistingOps),
/*changed=*/nullptr, &erased);
if (erased) {
if (folded)
*folded = true;

View File

@@ -494,10 +494,9 @@ struct IntRangeOptimizationsPass final
RewritePatternSet patterns(ctx);
populateIntRangeOptimizationsPatterns(patterns, solver);
GreedyRewriteConfig config;
config.listener = &listener;
if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
if (failed(applyPatternsGreedily(
op, std::move(patterns),
GreedyRewriteConfig().setListener(&listener))))
signalPassFailure();
}
};
@@ -520,13 +519,12 @@ struct IntRangeNarrowingPass final
RewritePatternSet patterns(ctx);
populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
GreedyRewriteConfig config;
// We specifically need bottom-up traversal as cmpi pattern needs range
// data, attached to its original argument values.
config.useTopDownTraversal = false;
config.listener = &listener;
if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
if (failed(applyPatternsGreedily(
op, std::move(patterns),
GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
&listener))))
signalPassFailure();
}
};

View File

@@ -463,15 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
// We don't want that the block structure changes invalidating the
// `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
// region simplification
GreedyRewriteConfig config;
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
if (failed(
applyPatternsGreedily(getOperation(), std::move(patterns), config)))
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
// We don't want that the block structure changes invalidating the
// `BufferOriginAnalysis` so we apply the rewrites with `Normal` level of
// region simplification
if (failed(applyPatternsGreedily(
getOperation(), std::move(patterns),
GreedyRewriteConfig().setRegionSimplificationLevel(
GreedySimplifyRegionLevel::Normal))))
signalPassFailure();
}
};

View File

@@ -3587,9 +3587,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
vector::populateVectorStepLoweringPatterns(patterns);
TrackingListener listener(state, *this);
GreedyRewriteConfig config;
config.listener = &listener;
if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
if (failed(
applyPatternsGreedily(target, std::move(patterns),
GreedyRewriteConfig().setListener(&listener))))
return emitDefaultDefiniteFailure(target);
results.push_back(target);

View File

@@ -2327,10 +2327,9 @@ struct LinalgElementwiseOpFusionPass
// Add constant folding patterns.
populateConstantFoldLinalgOperations(patterns, defaultControlFn);
// Use TopDownTraversal for compile time reasons
GreedyRewriteConfig grc;
grc.useTopDownTraversal = true;
(void)applyPatternsGreedily(op, std::move(patterns), grc);
// Use TopDownTraversal for compile time reasons.
(void)applyPatternsGreedily(op, std::move(patterns),
GreedyRewriteConfig().setUseTopDownTraversal());
}
};

View File

@@ -1438,10 +1438,10 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
if (!patterns)
return success();
GreedyRewriteConfig config;
config.listener = this;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
return applyOpPatternsGreedily(ops, patterns.value(), config);
return applyOpPatternsGreedily(
ops, patterns.value(),
GreedyRewriteConfig().setListener(this).setStrictness(
GreedyRewriteStrictness::ExistingAndNewOps));
}
void SliceTrackingListener::notifyOperationInserted(

View File

@@ -1353,9 +1353,9 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
// We only want to apply signature conversion once to the existing func ops.
// Without specifying strictMode, the greedy pattern rewriter will keep
// looking for newly created func ops.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
return applyPatternsGreedily(op, std::move(patterns), config);
return applyPatternsGreedily(op, std::move(patterns),
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingOps));
}
LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {

View File

@@ -394,16 +394,16 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
// Configure the GreedyPatternRewriteDriver.
GreedyRewriteConfig config;
config.listener =
static_cast<RewriterBase::Listener *>(rewriter.getListener());
config.setListener(
static_cast<RewriterBase::Listener *>(rewriter.getListener()));
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
? GreedyRewriteConfig::kNoLimit
: getMaxIterations();
config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
? GreedyRewriteConfig::kNoLimit
: getMaxNumRewrites();
: getMaxIterations());
config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
? GreedyRewriteConfig::kNoLimit
: getMaxNumRewrites());
// Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
// was requested, apply the greedy pattern rewrite only once. (The greedy

View File

@@ -62,11 +62,11 @@ static void applyPatterns(Region &region,
// before that transform.
for (Operation *op : opsInRange) {
// `applyOpPatternsGreedily` with folding returns whether the op is
// convered. Omit it because we don't have expectation this reduction will
// converted. Omit it because we don't have expectation this reduction will
// be success or not.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyOpPatternsGreedily(op, patterns, config);
(void)applyOpPatternsGreedily(op, patterns,
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingOps));
}
if (eraseOpNotInRange)

View File

@@ -32,10 +32,10 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
ArrayRef<std::string> disabledPatterns,
ArrayRef<std::string> enabledPatterns)
: config(config) {
this->topDownProcessingEnabled = config.useTopDownTraversal;
this->enableRegionSimplification = config.enableRegionSimplification;
this->maxIterations = config.maxIterations;
this->maxNumRewrites = config.maxNumRewrites;
this->topDownProcessingEnabled = config.getUseTopDownTraversal();
this->regionSimplifyLevel = config.getRegionSimplificationLevel();
this->maxIterations = config.getMaxIterations();
this->maxNumRewrites = config.getMaxNumRewrites();
this->disabledPatterns = disabledPatterns;
this->enabledPatterns = enabledPatterns;
}
@@ -44,10 +44,10 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
/// execution.
LogicalResult initialize(MLIRContext *context) override {
// Set the config from possible pass options set in the meantime.
config.useTopDownTraversal = topDownProcessingEnabled;
config.enableRegionSimplification = enableRegionSimplification;
config.maxIterations = maxIterations;
config.maxNumRewrites = maxNumRewrites;
config.setUseTopDownTraversal(topDownProcessingEnabled);
config.setRegionSimplificationLevel(regionSimplifyLevel);
config.setMaxIterations(maxIterations);
config.setMaxNumRewrites(maxNumRewrites);
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())

View File

@@ -416,7 +416,8 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
// clang-format off
, expensiveChecks(
/*driver=*/this,
/*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
/*topLevel=*/config.getScope() ? config.getScope()->getParentOp()
: nullptr)
// clang-format on
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
{
@@ -455,8 +456,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
bool changed = false;
int64_t numRewrites = 0;
while (!worklist.empty() &&
(numRewrites < config.maxNumRewrites ||
config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
(numRewrites < config.getMaxNumRewrites() ||
config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
auto *op = worklist.pop();
LLVM_DEBUG({
@@ -488,7 +489,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// infinite folding loop, as every constant op would be folded to an
// Attribute and then immediately be rematerialized as a constant op, which
// is then put on the worklist.
if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
if (config.isFoldingEnabled() && !op->hasTrait<OpTrait::ConstantLike>()) {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
@@ -574,21 +575,21 @@ bool GreedyPatternRewriteDriver::processWorklist() {
logger.getOStream() << ")' {\n";
logger.indent();
});
if (config.listener)
config.listener->notifyPatternBegin(pattern, op);
if (RewriterBase::Listener *listener = config.getListener())
listener->notifyPatternBegin(pattern, op);
return true;
};
function_ref<bool(const Pattern &)> canApply = canApplyCallback;
auto onFailureCallback = [&](const Pattern &pattern) {
LLVM_DEBUG(logResult("failure", "pattern failed to match"));
if (config.listener)
config.listener->notifyPatternEnd(pattern, failure());
if (RewriterBase::Listener *listener = config.getListener())
listener->notifyPatternEnd(pattern, failure());
};
function_ref<void(const Pattern &)> onFailure = onFailureCallback;
auto onSuccessCallback = [&](const Pattern &pattern) {
LLVM_DEBUG(logResult("success", "pattern applied successfully"));
if (config.listener)
config.listener->notifyPatternEnd(pattern, success());
if (RewriterBase::Listener *listener = config.getListener())
listener->notifyPatternEnd(pattern, success());
return success();
};
function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
@@ -596,7 +597,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
#ifdef NDEBUG
// Optimization: PatternApplicator callbacks are not needed when running in
// optimized mode and without a listener.
if (!config.listener) {
if (!config.getListener()) {
canApply = nullptr;
onFailure = nullptr;
onSuccess = nullptr;
@@ -604,8 +605,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
#endif // NDEBUG
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope) {
expensiveChecks.computeFingerPrints(config.scope->getParentOp());
if (config.getScope()) {
expensiveChecks.computeFingerPrints(config.getScope()->getParentOp());
}
auto clearFingerprints =
llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
@@ -640,7 +641,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
do {
ancestors.push_back(op);
region = op->getParentRegion();
if (config.scope == region) {
if (config.getScope() == region) {
// Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
for (Operation *op : ancestors)
addSingleOpToWorklist(op);
@@ -652,20 +653,20 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
}
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
if (config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op))
worklist.push(op);
}
void GreedyPatternRewriteDriver::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
if (config.listener)
config.listener->notifyBlockInserted(block, previous, previousIt);
if (RewriterBase::Listener *listener = config.getListener())
listener->notifyBlockInserted(block, previous, previousIt);
}
void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
if (config.listener)
config.listener->notifyBlockErased(block);
if (RewriterBase::Listener *listener = config.getListener())
listener->notifyBlockErased(block);
}
void GreedyPatternRewriteDriver::notifyOperationInserted(
@@ -674,9 +675,9 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.listener)
config.listener->notifyOperationInserted(op, previous);
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
if (RewriterBase::Listener *listener = config.getListener())
listener->notifyOperationInserted(op, previous);
if (config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
addToWorklist(op);
}
@@ -686,8 +687,8 @@ void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.listener)
config.listener->notifyOperationModified(op);
if (RewriterBase::Listener *listener = config.getListener())
listener->notifyOperationModified(op);
addToWorklist(op);
}
@@ -736,18 +737,18 @@ void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
// the part of the IR that is taken into account for the "expensive checks".
// A greedy pattern rewrite is not allowed to erase the parent op of the scope
// region, as that would break the worklist handling and the expensive checks.
if (config.scope && config.scope->getParentOp() == op)
if (Region *scope = config.getScope(); scope->getParentOp() == op)
llvm_unreachable(
"scope region must not be erased during greedy pattern rewrite");
#endif // NDEBUG
if (config.listener)
config.listener->notifyOperationErased(op);
if (RewriterBase::Listener *listener = config.getListener())
listener->notifyOperationErased(op);
addOperandsToWorklist(op);
worklist.remove(op);
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
}
@@ -757,8 +758,8 @@ void GreedyPatternRewriteDriver::notifyOperationReplaced(
logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.listener)
config.listener->notifyOperationReplaced(op, replacement);
if (RewriterBase::Listener *listener = config.getListener())
listener->notifyOperationReplaced(op, replacement);
}
void GreedyPatternRewriteDriver::notifyMatchFailure(
@@ -768,8 +769,8 @@ void GreedyPatternRewriteDriver::notifyMatchFailure(
reasonCallback(diag);
logger.startLine() << "** Match Failure : " << diag.str() << "\n";
});
if (config.listener)
config.listener->notifyMatchFailure(loc, reasonCallback);
if (RewriterBase::Listener *listener = config.getListener())
listener->notifyMatchFailure(loc, reasonCallback);
}
//===----------------------------------------------------------------------===//
@@ -800,7 +801,7 @@ RegionPatternRewriteDriver::RegionPatternRewriteDriver(
const GreedyRewriteConfig &config, Region &region)
: GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
// Populate strict mode ops.
if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
if (config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
}
}
@@ -829,8 +830,8 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
MLIRContext *ctx = rewriter.getContext();
do {
// Check if the iteration limit was reached.
if (++iteration > config.maxIterations &&
config.maxIterations != GreedyRewriteConfig::kNoLimit)
if (++iteration > config.getMaxIterations() &&
config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
break;
// New iteration: start with an empty worklist.
@@ -849,16 +850,16 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
return false;
};
if (!config.useTopDownTraversal) {
if (!config.getUseTopDownTraversal()) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
if (!config.cseConstants || !insertKnownConstant(op))
if (!config.isConstantCSEEnabled() || !insertKnownConstant(op))
addToWorklist(op);
});
} else {
// Add all nested operations to the worklist in preorder.
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (!config.cseConstants || !insertKnownConstant(op)) {
if (!config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
addToWorklist(op);
return WalkResult::advance();
}
@@ -875,11 +876,11 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
if (config.enableRegionSimplification !=
if (config.getRegionSimplificationLevel() !=
GreedySimplifyRegionLevel::Disabled) {
continueRewrites |= succeeded(simplifyRegions(
rewriter, region,
/*mergeBlocks=*/config.enableRegionSimplification ==
/*mergeBlocks=*/config.getRegionSimplificationLevel() ==
GreedySimplifyRegionLevel::Aggressive));
}
},
@@ -904,11 +905,11 @@ mlir::applyPatternsGreedily(Region &region,
"patterns can only be applied to operations IsolatedFromAbove");
// Set scope if not specified.
if (!config.scope)
config.scope = &region;
if (!config.getScope())
config.setScope(&region);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(config.scope->getParentOp())))
if (failed(verify(config.getScope()->getParentOp())))
llvm::report_fatal_error(
"greedy pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -919,7 +920,7 @@ mlir::applyPatternsGreedily(Region &region,
LogicalResult converged = std::move(driver).simplify(changed);
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
<< config.maxIterations << " times\n";
<< config.getMaxIterations() << " times\n";
});
return converged;
}
@@ -960,7 +961,7 @@ MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
llvm::SmallDenseSet<Operation *, 4> *survivingOps)
: GreedyPatternRewriteDriver(ctx, patterns, config),
survivingOps(survivingOps) {
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.insert_range(ops);
if (survivingOps) {
@@ -1024,22 +1025,22 @@ LogicalResult mlir::applyOpPatternsGreedily(
}
// Determine scope of rewrite.
if (!config.scope) {
if (!config.getScope()) {
// Compute scope if none was provided. The scope will remain `nullptr` if
// there is a top-level op among `ops`.
config.scope = findCommonAncestor(ops);
config.setScope(findCommonAncestor(ops));
} else {
// If a scope was provided, make sure that all ops are in scope.
#ifndef NDEBUG
bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
return static_cast<bool>(config.getScope()->findAncestorOpInRegion(*op));
});
assert(allOpsInScope && "ops must be within the specified scope");
#endif // NDEBUG
}
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
if (config.getScope() && failed(verify(config.getScope()->getParentOp())))
llvm::report_fatal_error(
"greedy pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -1054,7 +1055,7 @@ LogicalResult mlir::applyOpPatternsGreedily(
*allErased = surviving.empty();
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after "
<< config.maxNumRewrites << " rewrites";
<< config.getMaxNumRewrites() << " rewrites";
});
return converged;
}

View File

@@ -144,7 +144,7 @@ void TestAffineDataCopy::runOnOperation() {
}
}
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
(void)applyOpPatternsGreedily(copyOps, std::move(patterns), config);
}

View File

@@ -386,26 +386,26 @@ struct TestGreedyPatternDriver
patterns.insert<IncrementIntAttribute<3>>(&getContext());
GreedyRewriteConfig config;
config.useTopDownTraversal = this->useTopDownTraversal;
config.maxIterations = this->maxIterations;
config.fold = this->fold;
config.cseConstants = this->cseConstants;
config.setUseTopDownTraversal(useTopDownTraversal)
.setMaxIterations(this->maxIterations)
.enableFolding(this->fold)
.enableConstantCSE(this->cseConstants);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
Option<bool> useTopDownTraversal{
*this, "top-down",
llvm::cl::desc("Seed the worklist in general top-down order"),
llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
llvm::cl::init(GreedyRewriteConfig().getUseTopDownTraversal())};
Option<int> maxIterations{
*this, "max-iterations",
llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
llvm::cl::init(GreedyRewriteConfig().maxIterations)};
llvm::cl::init(GreedyRewriteConfig().getMaxIterations())};
Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
llvm::cl::init(GreedyRewriteConfig().fold)};
Option<bool> cseConstants{*this, "cse-constants",
llvm::cl::desc("Whether to CSE constants"),
llvm::cl::init(GreedyRewriteConfig().cseConstants)};
llvm::cl::init(GreedyRewriteConfig().isFoldingEnabled())};
Option<bool> cseConstants{
*this, "cse-constants", llvm::cl::desc("Whether to CSE constants"),
llvm::cl::init(GreedyRewriteConfig().isConstantCSEEnabled())};
};
struct DumpNotifications : public RewriterBase::Listener {
@@ -501,13 +501,13 @@ public:
DumpNotifications dumpNotifications;
GreedyRewriteConfig config;
config.listener = &dumpNotifications;
config.setListener(&dumpNotifications);
if (strictMode == "AnyOp") {
config.strictMode = GreedyRewriteStrictness::AnyOp;
config.setStrictness(GreedyRewriteStrictness::AnyOp);
} else if (strictMode == "ExistingAndNewOps") {
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
} else if (strictMode == "ExistingOps") {
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
} else {
llvm_unreachable("invalid strictness option");
}