diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h index 2d4a4a2e6c29..8a5eddda6b02 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h @@ -31,7 +31,7 @@ class MLIRContext; class ModuleOp; class RewritePattern; class Type; -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList; namespace LLVM { class LLVMType; } // end namespace LLVM diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 411a7afb2844..58e61596153f 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -395,8 +395,8 @@ public: void linalg::populateLinalg1ToLLVMConversionPatterns( mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) { - RewriteListBuilder::build(patterns, context); + patterns.insert(context); } namespace { diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 8c77737ff3d9..e4a401ea70fa 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -145,8 +145,7 @@ struct LinalgTypeConverter : public LLVMTypeConverter { // coverters to the list. static void populateLinalg3ToLLVMConversionPatterns( mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) { - RewriteListBuilder::build(patterns, - context); + patterns.insert(context); } LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) { diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index d81eec0a3705..8f97f4317f71 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -261,8 +261,8 @@ struct LowerLinalgLoadStorePass void runOnFunction() { OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back(llvm::make_unique>(context)); - patterns.push_back(llvm::make_unique>(context)); + patterns.insert, Rewriter>( + context); applyPatternsGreedily(getFunction(), std::move(patterns)); } }; diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 92e80d2dfa3e..b89cb85ff06d 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -142,14 +142,14 @@ struct SimplifyNullReshape : public mlir::OpRewritePattern { // Register our patterns for rewrite by the Canonicalization framework. void TransposeOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } // Register our patterns for rewrite by the Canonicalization framework. void ReshapeOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - mlir::RewriteListBuilder::build(results, context); + results.insert(context); } } // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index f3463ba4e0f3..72bc2891db6f 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -132,7 +132,7 @@ struct EarlyLoweringPass : public FunctionPass { target.addLegalOp(); OwningRewritePatternList patterns; - RewriteListBuilder::build(patterns, &getContext()); + patterns.insert(&getContext()); if (failed(applyPartialConversion(getFunction(), target, std::move(patterns)))) { emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n"); diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 5a01122c28ad..8b2cc214a558 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -352,9 +352,9 @@ struct LateLoweringPass : public ModulePass { void runOnModule() override { ToyTypeConverter typeConverter; OwningRewritePatternList toyPatterns; - RewriteListBuilder::build(toyPatterns, &getContext()); + toyPatterns.insert( + &getContext()); mlir::populateFuncOpTypeConversionPattern(toyPatterns, &getContext(), typeConverter); diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 8e9e8ebcd558..4798ad188d15 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -144,14 +144,14 @@ struct SimplifyNullReshape : public mlir::OpRewritePattern { // Register our patterns for rewrite by the Canonicalization framework. void TransposeOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } // Register our patterns for rewrite by the Canonicalization framework. void ReshapeOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - mlir::RewriteListBuilder::build(results, context); + results.insert(context); } namespace { @@ -180,7 +180,7 @@ struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern { void TypeCastOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } } // namespace toy diff --git a/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h b/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h index e8ab2732d312..78e4356607fe 100644 --- a/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h +++ b/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h @@ -29,7 +29,7 @@ class MLIRContext; class RewritePattern; // Owning list of rewriting patterns. -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList; /// Collect a set of patterns to lower from loop.for, loop.if, and /// loop.terminator to CFG operations within the Standard dialect, in particular diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index 361294a729ea..941e382905f4 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -38,7 +38,7 @@ class RewritePattern; class Type; // Owning list of rewriting patterns. -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList; /// Type for a callback constructing the owning list of patterns for the /// conversion to the LLVMIR dialect. The callback is expected to append diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index c76f1d620afc..204da29b39ad 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -57,9 +57,7 @@ class Value; /// either OpTy or OperandAdaptor seamlessly. template using OperandAdaptor = typename OpTy::OperandAdaptor; -/// This is a vector that owns the patterns inside of it. -using OwningPatternList = std::vector>; -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList; enum class OperationProperty { /// This bit is set for an operation if it is a commutative operation: that diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index d739a8044381..e3897b1d63a5 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -394,8 +394,39 @@ private: // Pattern-driven rewriters //===----------------------------------------------------------------------===// -/// This is a vector that owns the patterns inside of it. -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList { + using PatternListT = std::vector>; + +public: + PatternListT::iterator begin() { return patterns.begin(); } + PatternListT::iterator end() { return patterns.end(); } + PatternListT::const_iterator begin() const { return patterns.begin(); } + PatternListT::const_iterator end() const { return patterns.end(); } + + //===--------------------------------------------------------------------===// + // Pattern Insertion + //===--------------------------------------------------------------------===// + + void insert(RewritePattern *pattern) { patterns.emplace_back(pattern); } + + /// Add an instance of each of the pattern types 'Ts' to the pattern list with + /// the given arguments. + // Note: ConstructorArg is necessary here to separate the two variadic lists. + template + void insert(ConstructorArg &&arg, ConstructorArgs &&... args) { + // The following expands a call to emplace_back for each of the pattern + // types 'Ts'. This magic is necessary due to a limitation in the places + // that a parameter pack can be expanded in c++11. + // FIXME: In c++17 this can be simplified by using 'fold expressions'. + using dummy = int[]; + (void)dummy{ + 0, (patterns.emplace_back(llvm::make_unique(arg, args...)), 0)...}; + } + +private: + PatternListT patterns; +}; /// This class manages optimization and execution of a group of rewrite /// patterns, providing an API for finding and applying, the best match against @@ -404,7 +435,7 @@ using OwningRewritePatternList = std::vector>; class RewritePatternMatcher { public: /// Create a RewritePatternMatcher with the specified set of patterns. - explicit RewritePatternMatcher(OwningRewritePatternList &&patterns); + explicit RewritePatternMatcher(OwningRewritePatternList &patterns); /// Try to match the given operation to a pattern and rewrite it. Return /// true if any pattern matches. @@ -416,7 +447,7 @@ private: /// The group of patterns that are matched for optimization through this /// matcher. - OwningRewritePatternList patterns; + std::vector patterns; }; /// Rewrite the regions of the specified operation, which must be isolated from @@ -427,29 +458,6 @@ private: /// bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns); -/// Helper class to create a list of rewrite patterns given a list of their -/// types and a list of attributes perfect-forwarded to each of the conversion -/// constructors. -template struct RewriteListBuilder { - template - static void build(OwningRewritePatternList &patterns, - ConstructorArgs &&... constructorArgs) { - RewriteListBuilder::build( - patterns, std::forward(constructorArgs)...); - RewriteListBuilder::build( - patterns, std::forward(constructorArgs)...); - } -}; - -// Template specialization to stop recursion. -template struct RewriteListBuilder { - template - static void build(OwningRewritePatternList &patterns, - ConstructorArgs &&... constructorArgs) { - patterns.emplace_back(llvm::make_unique( - std::forward(constructorArgs)...)); - } -}; } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/include/mlir/Transforms/LowerAffine.h b/mlir/include/mlir/Transforms/LowerAffine.h index 9ad3f66def57..5fae4763bf7d 100644 --- a/mlir/include/mlir/Transforms/LowerAffine.h +++ b/mlir/include/mlir/Transforms/LowerAffine.h @@ -32,7 +32,7 @@ class RewritePattern; class Value; // Owning list of rewriting patterns. -using OwningRewritePatternList = std::vector>; +class OwningRewritePatternList; /// Emit code that computes the given affine expression using standard /// arithmetic operations applied to the provided dimension and symbol values. diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 9a026231ab24..767c2e344d9a 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -708,7 +708,7 @@ struct SimplifyAffineApply : public OpRewritePattern { void AffineApplyOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -912,8 +912,7 @@ LogicalResult AffineDmaStartOp::verify() { void AffineDmaStartOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// dma_start(memrefcast) -> dma_start - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -989,8 +988,7 @@ LogicalResult AffineDmaWaitOp::verify() { void AffineDmaWaitOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// dma_wait(memrefcast) -> dma_wait - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -1333,7 +1331,7 @@ struct AffineForLoopBoundFolder : public OpRewritePattern { void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } AffineBound AffineForOp::getLowerBound() { @@ -1659,8 +1657,7 @@ LogicalResult AffineLoadOp::verify() { void AffineLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -1752,8 +1749,7 @@ LogicalResult AffineStoreOp::verify() { void AffineStoreOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } #define GET_OP_CLASSES diff --git a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp index c37decf69e6b..034aa22f922c 100644 --- a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp +++ b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp @@ -258,8 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { void mlir::populateLoopToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - RewriteListBuilder::build( - patterns, ctx); + patterns.insert(ctx); } void ControlFlowToCFGPass::runOnFunction() { diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 4eadb8749082..58f01fc66892 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -104,8 +104,7 @@ void GPUToSPIRVPass::runOnModule() { SPIRVTypeConverter typeConverter(context); SPIRVEntryFnTypeConverter entryFnConverter(context); OwningRewritePatternList patterns; - RewriteListBuilder::build( - patterns, context, typeConverter, entryFnConverter); + patterns.insert(context, typeConverter, entryFnConverter); populateStandardToSPIRVPatterns(context, patterns); ConversionTarget target(*context); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index af8812c8cf4f..09ddcd1e475b 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1023,7 +1023,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed - RewriteListBuilder< + patterns.insert< AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering, @@ -1032,8 +1032,7 @@ void mlir::populateStdToLLVMConversionPatterns( MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering, - SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(), - converter); + SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter); } // Convert types using the stored LLVM IR module. diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index d32d8668046c..067f2aeda06d 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -201,6 +201,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context, OwningRewritePatternList &patterns) { populateWithGenerated(context, &patterns); // Add the return op conversion. - RewriteListBuilder::build(patterns, context); + patterns.insert(context); } } // namespace mlir diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index dafc8e711f50..d2f3881710c9 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -368,8 +368,7 @@ void LowerUniformRealMathPass::runOnFunction() { auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back(llvm::make_unique(context)); - patterns.push_back(llvm::make_unique(context)); + patterns.insert(context); applyPatternsGreedily(fn, std::move(patterns)); } @@ -389,7 +388,7 @@ void LowerUniformCastsPass::runOnFunction() { auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back(llvm::make_unique(context)); + patterns.insert(context); applyPatternsGreedily(fn, std::move(patterns)); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index bda5979939c7..2fbaa49f56e9 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -372,7 +372,7 @@ class PropagateConstantBounds : public OpRewritePattern { void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index e237e8b6eb26..3bd49d43adcf 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -60,8 +60,7 @@ public: void StorageCastOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.push_back( - llvm::make_unique(context)); + patterns.insert(context); } QuantizationDialect::QuantizationDialect(MLIRContext *context) diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 8469fa2ea70c..2276fbd21c92 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -108,7 +108,7 @@ void ConvertConstPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto *context = &getContext(); - patterns.push_back(llvm::make_unique(context)); + patterns.insert(context); applyPatternsGreedily(func, std::move(patterns)); } diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 32d8c8a81c1e..8f5d1b33c64b 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -97,8 +97,7 @@ void ConvertSimulatedQuantPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto *context = &getContext(); - patterns.push_back( - llvm::make_unique(context, &hadFailure)); + patterns.insert(context, &hadFailure); applyPatternsGreedily(func, std::move(patterns)); if (hadFailure) signalPassFailure(); diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 5010b845c78c..94fa7ab43f7d 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -149,12 +149,13 @@ void PatternRewriter::updatedRootInPlace( //===----------------------------------------------------------------------===// RewritePatternMatcher::RewritePatternMatcher( - OwningRewritePatternList &&patterns) - : patterns(std::move(patterns)) { + OwningRewritePatternList &patterns) { + for (auto &pattern : patterns) + this->patterns.push_back(pattern.get()); + // Sort the patterns by benefit to simplify the matching logic. std::stable_sort(this->patterns.begin(), this->patterns.end(), - [](const std::unique_ptr &l, - const std::unique_ptr &r) { + [](RewritePattern *l, RewritePattern *r) { return r->getBenefit() < l->getBenefit(); }); } @@ -162,7 +163,7 @@ RewritePatternMatcher::RewritePatternMatcher( /// Try to match the given operation to a pattern and rewrite it. bool RewritePatternMatcher::matchAndRewrite(Operation *op, PatternRewriter &rewriter) { - for (auto &pattern : patterns) { + for (auto *pattern : patterns) { // Ignore patterns that are for the wrong root or are impossible to match. if (pattern->getRootKind() != op->getName() || pattern->getBenefit().isImpossibleToMatch()) diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 6b62a8e13404..7c2ea5945f4f 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -678,12 +678,11 @@ static void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { - RewriteListBuilder, LinalgOpConversion, - LoadOpConversion, RangeOpConversion, SliceOpConversion, - StoreOpConversion, ViewOpConversion>::build(patterns, ctx, - converter); + patterns.insert, LinalgOpConversion, + LoadOpConversion, RangeOpConversion, SliceOpConversion, + StoreOpConversion, ViewOpConversion>(ctx, converter); } namespace { diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp index 6b376db85163..3de89137c3cf 100644 --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -60,12 +60,9 @@ void RemoveInstrumentationPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto *context = &getContext(); - patterns.push_back( - llvm::make_unique>(context)); - patterns.push_back( - llvm::make_unique>(context)); - patterns.push_back( - llvm::make_unique>(context)); + patterns.insert, + RemoveIdentityOpRewrite, + RemoveIdentityOpRewrite>(context); applyPatternsGreedily(func, std::move(patterns)); } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index df99f00c1100..9ecd99a5169b 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -365,8 +365,7 @@ struct SimplifyDeadAlloc : public OpRewritePattern { void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, - context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -544,8 +543,7 @@ static LogicalResult verify(CallIndirectOp op) { void CallIndirectOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back( - llvm::make_unique(context)); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1015,7 +1013,7 @@ static void print(OpAsmPrinter *p, CondBranchOp op) { void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1231,9 +1229,8 @@ static LogicalResult verify(DeallocOp op) { void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dealloc(memrefcast) -> dealloc - results.push_back( - llvm::make_unique(getOperationName(), context)); - results.push_back(llvm::make_unique(context)); + results.insert(getOperationName(), context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1497,8 +1494,7 @@ LogicalResult DmaStartOp::verify() { void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_start(memrefcast) -> dma_start - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } // --------------------------------------------------------------------------- @@ -1561,8 +1557,7 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_wait(memrefcast) -> dma_wait - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -1695,8 +1690,7 @@ static LogicalResult verify(LoadOp op) { void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -2007,8 +2001,7 @@ static LogicalResult verify(StoreOp op) { void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// store(memrefcast) -> store - results.push_back( - llvm::make_unique(getOperationName(), context)); + results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 50c636f708e9..6f264b0af35b 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1243,8 +1243,7 @@ struct FuncOpSignatureConversion : public ConversionPattern { void mlir::populateFuncOpTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter) { - RewriteListBuilder::build(patterns, ctx, - converter); + patterns.insert(ctx, converter); } /// This function converts the type signature of the given block, by invoking diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index f35f963b8aea..1c558efd8e45 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -507,10 +507,11 @@ public: void mlir::populateAffineToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - RewriteListBuilder::build(patterns, ctx); + patterns + .insert( + ctx); } namespace { diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 3585e2befd6c..ef67488023f9 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -365,12 +365,8 @@ struct LowerVectorTransfersPass void runOnFunction() { OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back( - llvm::make_unique>( - context)); - patterns.push_back( - llvm::make_unique>( - context)); + patterns.insert, + VectorTransferRewriter>(context); applyPatternsGreedily(getFunction(), std::move(patterns)); } }; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 52952178b378..1df4ceec8f35 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -44,8 +44,8 @@ namespace { class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - OwningRewritePatternList &&patterns) - : PatternRewriter(ctx), matcher(std::move(patterns)) { + OwningRewritePatternList &patterns) + : PatternRewriter(ctx), matcher(patterns) { worklist.reserve(64); } @@ -224,7 +224,7 @@ bool mlir::applyPatternsGreedily(Operation *op, if (!op->isKnownIsolatedFromAbove()) return false; - GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns)); + GreedyPatternRewriteDriver driver(op->getContext(), patterns); bool converged = driver.simplify(op, maxPatternMatchIterations); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 201dfc3005ca..ed94eed4fdd4 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -41,7 +41,7 @@ struct TestPatternDriver : public FunctionPass { populateWithGenerated(&getContext(), &patterns); // Verify named pattern is generated with expected name. - RewriteListBuilder::build(patterns, &getContext()); + patterns.insert(&getContext()); applyPatternsGreedily(getFunction(), std::move(patterns)); } @@ -193,9 +193,9 @@ struct TestLegalizePatternDriver TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); - RewriteListBuilder::build(patterns, &getContext()); + patterns.insert( + &getContext()); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index edf6aeae469a..f75413fdaed9 100644 --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -133,7 +133,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) { pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { populateStdToLLVMConversionPatterns(converter, patterns); - patterns.push_back(llvm::make_unique(converter)); + patterns.insert(converter); })); pm.addPass(createLowerGpuOpsToNVVMOpsPass()); pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin)); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index d408ecfa5eb9..24eeaf50d782 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -935,8 +935,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { os << "void populateWithGenerated(MLIRContext *context, " << "OwningRewritePatternList *patterns) {\n"; for (const auto &name : rewriterNames) { - os << " patterns->push_back(llvm::make_unique<" << name - << ">(context));\n"; + os << " patterns->insert<" << name << ">(context);\n"; } os << "}\n"; }