The greedy rewriter is used in many different flows and it has a lot of
convenience (work list management, debugging actions, tracing, etc). But
it combines two kinds of greedy behavior 1) how ops are matched, 2)
folding wherever it can.
These are independent forms of greedy and leads to inefficiency. E.g.,
cases where one need to create different phases in lowering and is
required to applying patterns in specific order split across different
passes. Using the driver one ends up needlessly retrying folding/having
multiple rounds of folding attempts, where one final run would have
sufficed.
Of course folks can locally avoid this behavior by just building their
own, but this is also a common requested feature that folks keep on
working around locally in suboptimal ways.
For downstream users, there should be no behavioral change. Updating
from the deprecated should just be a find and replace (e.g., `find ./
-type f -exec sed -i
's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety)
as the API arguments hasn't changed between the two.
326 lines
13 KiB
C++
326 lines
13 KiB
C++
//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir-c/Rewrite.h"
|
|
|
|
#include "mlir-c/Transforms.h"
|
|
#include "mlir/CAPI/IR.h"
|
|
#include "mlir/CAPI/Rewrite.h"
|
|
#include "mlir/CAPI/Support.h"
|
|
#include "mlir/CAPI/Wrap.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// RewriterBase API inherited from OpBuilder
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
|
|
return wrap(unwrap(rewriter)->getContext());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// Insertion points methods
|
|
|
|
void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) {
|
|
unwrap(rewriter)->clearInsertionPoint();
|
|
}
|
|
|
|
void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
unwrap(rewriter)->setInsertionPoint(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
unwrap(rewriter)->setInsertionPointAfter(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
|
|
MlirValue value) {
|
|
unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value));
|
|
}
|
|
|
|
void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
|
|
MlirBlock block) {
|
|
unwrap(rewriter)->setInsertionPointToStart(unwrap(block));
|
|
}
|
|
|
|
void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
|
|
MlirBlock block) {
|
|
unwrap(rewriter)->setInsertionPointToEnd(unwrap(block));
|
|
}
|
|
|
|
MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) {
|
|
return wrap(unwrap(rewriter)->getInsertionBlock());
|
|
}
|
|
|
|
MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
|
|
return wrap(unwrap(rewriter)->getBlock());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// Block and operation creation/insertion/cloning
|
|
|
|
MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter,
|
|
MlirBlock insertBefore,
|
|
intptr_t nArgTypes,
|
|
MlirType const *argTypes,
|
|
MlirLocation const *locations) {
|
|
SmallVector<Type, 4> args;
|
|
ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args);
|
|
SmallVector<Location, 4> locs;
|
|
ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs);
|
|
return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs,
|
|
unwrappedLocs));
|
|
}
|
|
|
|
MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
return wrap(unwrap(rewriter)->insert(unwrap(op)));
|
|
}
|
|
|
|
// Other methods of OpBuilder
|
|
|
|
MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
return wrap(unwrap(rewriter)->clone(*unwrap(op)));
|
|
}
|
|
|
|
MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op)));
|
|
}
|
|
|
|
void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter,
|
|
MlirRegion region, MlirBlock before) {
|
|
|
|
unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// RewriterBase API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter,
|
|
MlirRegion region, MlirBlock before) {
|
|
unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before));
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter,
|
|
MlirOperation op, intptr_t nValues,
|
|
MlirValue const *values) {
|
|
SmallVector<Value, 4> vals;
|
|
ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals);
|
|
unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals);
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
|
|
MlirOperation op,
|
|
MlirOperation newOp) {
|
|
unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp));
|
|
}
|
|
|
|
void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) {
|
|
unwrap(rewriter)->eraseOp(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) {
|
|
unwrap(rewriter)->eraseBlock(unwrap(block));
|
|
}
|
|
|
|
void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter,
|
|
MlirBlock source, MlirOperation op,
|
|
intptr_t nArgValues,
|
|
MlirValue const *argValues) {
|
|
SmallVector<Value, 4> vals;
|
|
ArrayRef<Value> unwrappedVals = unwrapList(nArgValues, argValues, vals);
|
|
|
|
unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op),
|
|
unwrappedVals);
|
|
}
|
|
|
|
void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source,
|
|
MlirBlock dest, intptr_t nArgValues,
|
|
MlirValue const *argValues) {
|
|
SmallVector<Value, 4> args;
|
|
ArrayRef<Value> unwrappedArgs = unwrapList(nArgValues, argValues, args);
|
|
unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs);
|
|
}
|
|
|
|
void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op,
|
|
MlirOperation existingOp) {
|
|
unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp));
|
|
}
|
|
|
|
void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op,
|
|
MlirOperation existingOp) {
|
|
unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp));
|
|
}
|
|
|
|
void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
|
|
MlirBlock existingBlock) {
|
|
unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock));
|
|
}
|
|
|
|
void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
unwrap(rewriter)->startOpModification(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
unwrap(rewriter)->finalizeOpModification(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
unwrap(rewriter)->cancelOpModification(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter,
|
|
MlirValue from, MlirValue to) {
|
|
unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to));
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter,
|
|
intptr_t nValues,
|
|
MlirValue const *from,
|
|
MlirValue const *to) {
|
|
SmallVector<Value, 4> fromVals;
|
|
ArrayRef<Value> unwrappedFromVals = unwrapList(nValues, from, fromVals);
|
|
SmallVector<Value, 4> toVals;
|
|
ArrayRef<Value> unwrappedToVals = unwrapList(nValues, to, toVals);
|
|
unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals);
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
|
|
MlirOperation from,
|
|
intptr_t nTo,
|
|
MlirValue const *to) {
|
|
SmallVector<Value, 4> toVals;
|
|
ArrayRef<Value> unwrappedToVals = unwrapList(nTo, to, toVals);
|
|
unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals);
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter,
|
|
MlirOperation from,
|
|
MlirOperation to) {
|
|
unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to));
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter,
|
|
MlirOperation op,
|
|
intptr_t nNewValues,
|
|
MlirValue const *newValues,
|
|
MlirBlock block) {
|
|
SmallVector<Value, 4> vals;
|
|
ArrayRef<Value> unwrappedVals = unwrapList(nNewValues, newValues, vals);
|
|
unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals,
|
|
unwrap(block));
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter,
|
|
MlirValue from, MlirValue to,
|
|
MlirOperation exceptedUser) {
|
|
unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to),
|
|
unwrap(exceptedUser));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// IRRewriter API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirRewriterBase mlirIRRewriterCreate(MlirContext context) {
|
|
return wrap(new IRRewriter(unwrap(context)));
|
|
}
|
|
|
|
MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) {
|
|
return wrap(new IRRewriter(unwrap(op)));
|
|
}
|
|
|
|
void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
|
|
delete static_cast<IRRewriter *>(unwrap(rewriter));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// RewritePatternSet and FrozenRewritePatternSet API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
|
|
assert(module.ptr && "unexpected null module");
|
|
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
|
|
}
|
|
|
|
inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
|
|
return {module};
|
|
}
|
|
|
|
inline mlir::FrozenRewritePatternSet *
|
|
unwrap(MlirFrozenRewritePatternSet module) {
|
|
assert(module.ptr && "unexpected null module");
|
|
return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
|
|
}
|
|
|
|
inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
|
|
return {module};
|
|
}
|
|
|
|
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
|
|
auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
|
|
op.ptr = nullptr;
|
|
return wrap(m);
|
|
}
|
|
|
|
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
|
|
delete unwrap(op);
|
|
op.ptr = nullptr;
|
|
}
|
|
|
|
MlirLogicalResult
|
|
mlirApplyPatternsAndFoldGreedily(MlirModule op,
|
|
MlirFrozenRewritePatternSet patterns,
|
|
MlirGreedyRewriteDriverConfig) {
|
|
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// PDLPatternModule API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
|
inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
|
|
assert(module.ptr && "unexpected null module");
|
|
return static_cast<mlir::PDLPatternModule *>(module.ptr);
|
|
}
|
|
|
|
inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
|
|
return {module};
|
|
}
|
|
|
|
MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
|
|
return wrap(new mlir::PDLPatternModule(
|
|
mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
|
|
}
|
|
|
|
void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
|
|
delete unwrap(op);
|
|
op.ptr = nullptr;
|
|
}
|
|
|
|
MlirRewritePatternSet
|
|
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
|
|
auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
|
|
op.ptr = nullptr;
|
|
return wrap(m);
|
|
}
|
|
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
|