Files
clang-p2996/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
River Riddle 3fffffa882 [mlir][Pattern] Add a new FrozenRewritePatternList class
This class represents a rewrite pattern list that has been frozen, and thus immutable. This replaces the uses of OwningRewritePatternList in pattern driver related API, such as dialect conversion. When PDL becomes more prevalent, this API will allow for optimizing a set of patterns once without the need to do this per run of a pass.

Differential Revision: https://reviews.llvm.org/D89104
2020-10-26 18:01:06 -07:00

98 lines
4.2 KiB
C++

//===- CodegenStrategy.cpp - Linalg programmable codegen strategy ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements logic and helpers to expose Linalg transforms as
// composable rewrite patterns through a programmable CodegenStrategy object.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::linalg;
#define DEBUG_TYPE "linalg-codegen-strategy"
void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
MLIRContext *context = func.getContext();
// Emplace patterns one at a time while also maintaining a simple chained
// state transition.
unsigned stepCount = 0;
SmallVector<FrozenRewritePatternList, 4> stage1Patterns;
auto zeroState = Identifier::get(std::to_string(stepCount), context);
auto currentState = zeroState;
for (const std::unique_ptr<Transformation> &t : transformationSequence) {
auto nextState = Identifier::get(std::to_string(++stepCount), context);
auto marker = (currentState == zeroState)
? linalg::LinalgMarker({}, nextState)
: linalg::LinalgMarker(currentState, nextState);
stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
currentState = nextState;
}
OwningRewritePatternList stage2Patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
stage2Patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
auto stage3Transforms = [](Operation *op) {
// Some of these may be too aggressive as a stage 3 that is applied on each
// stage 1 application and may have to be split out to post staged patterns
// application (in which case they could just be passes, TBD).
PassManager pm(op->getContext());
pm.addPass(createLoopInvariantCodeMotionPass());
if (failed(pm.run(op->getParentOfType<ModuleOp>())))
llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
promoteSingleIterationLoops(cast<FuncOp>(op));
hoistViewAllocOps(cast<FuncOp>(op));
hoistRedundantVectorTransfers(cast<FuncOp>(op));
return success();
};
linalg::applyStagedPatterns(func, stage1Patterns, std::move(stage2Patterns),
stage3Transforms);
//===--------------------------------------------------------------------===//
// Post staged patterns transforms
//===--------------------------------------------------------------------===//
ModuleOp module = func.getParentOfType<ModuleOp>();
// Programmatic splitting of slow/fast path vector transfers.
OwningRewritePatternList patterns;
patterns.insert<vector::VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
applyPatternsAndFoldGreedily(module, std::move(patterns));
// Programmatic controlled lowering of vector.contract only.
OwningRewritePatternList vectorContractLoweringPatterns;
vectorContractLoweringPatterns
.insert<ContractionOpToOuterProductOpLowering,
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
vectorTransformsOptions, context);
applyPatternsAndFoldGreedily(module,
std::move(vectorContractLoweringPatterns));
// Programmatic controlled lowering of vector.transfer only.
OwningRewritePatternList vectorToLoopsPatterns;
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
vectorToSCFOptions);
applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
// Ensure we drop the marker in the end.
module.walk([](LinalgOp op) {
op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
});
}