The greedy rewriter is used in many different flows and it has a lot of
convenience (work list management, debugging actions, tracing, etc). But
it combines two kinds of greedy behavior 1) how ops are matched, 2)
folding wherever it can.
These are independent forms of greedy and leads to inefficiency. E.g.,
cases where one need to create different phases in lowering and is
required to applying patterns in specific order split across different
passes. Using the driver one ends up needlessly retrying folding/having
multiple rounds of folding attempts, where one final run would have
sufficed.
Of course folks can locally avoid this behavior by just building their
own, but this is also a common requested feature that folks keep on
working around locally in suboptimal ways.
For downstream users, there should be no behavioral change. Updating
from the deprecated should just be a find and replace (e.g., `find ./
-type f -exec sed -i
's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety)
as the API arguments hasn't changed between the two.
332 lines
12 KiB
C++
332 lines
12 KiB
C++
//====----- OutlineShapeComputation.cpp -----------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/DenseSet.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include <queue>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
|
|
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
#define DEBUG_TYPE "outline-shape-computation"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
// A Value is an input of the cluster if it is an operand of an operation in the
|
|
// cluster and its defining operation is not in the cluster.
|
|
SmallVector<Value, 4>
|
|
getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
|
|
SmallVector<Value, 4> inputs;
|
|
llvm::SmallDenseSet<Value> inputSet;
|
|
llvm::SmallDenseSet<Operation *> opSet;
|
|
for (Operation *op : cluster) {
|
|
bool inserted = opSet.insert(op).second;
|
|
(void)inserted;
|
|
assert(inserted && "cluster contains duplicate operations");
|
|
}
|
|
|
|
for (Operation *op : cluster) {
|
|
for (Value operand : op->getOperands()) {
|
|
Operation *operandOp = operand.getDefiningOp();
|
|
if (opSet.contains(operandOp)) {
|
|
// Skip if defining op is in the cluster.
|
|
continue;
|
|
}
|
|
if (inputSet.insert(operand).second)
|
|
inputs.push_back(operand);
|
|
}
|
|
}
|
|
return inputs;
|
|
}
|
|
|
|
// Create a shape.func representing the shape computation for `shape`.
|
|
std::pair<shape::FuncOp, SmallVector<Value>>
|
|
createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
|
|
Value shape, StringRef fnName, Location loc) {
|
|
SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
|
|
auto fnType =
|
|
cluster.empty()
|
|
? b.getFunctionType(shape.getType(), shape.getType())
|
|
: b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
|
|
shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
|
|
Block *block = fnOp.addEntryBlock();
|
|
b.setInsertionPointToEnd(block);
|
|
IRMapping bvm;
|
|
if (cluster.empty()) {
|
|
bvm.map(shape, fnOp.getArgument(0));
|
|
} else {
|
|
for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
|
|
bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
|
|
}
|
|
|
|
for (Operation *op : cluster)
|
|
b.clone(*op, bvm);
|
|
llvm::SmallVector<Value, 4> fnReturns;
|
|
fnReturns.push_back(bvm.lookupOrDefault(shape));
|
|
|
|
b.create<shape::ReturnOp>(loc, fnReturns);
|
|
fnOp.setPrivate();
|
|
return std::make_pair(fnOp, inputs);
|
|
}
|
|
|
|
// The operations in the cluster might be unsorted, which could be inconvenient
|
|
// when creating shape.func op.
|
|
DenseMap<Value, SmallVector<Operation *, 8>>
|
|
getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
|
|
func::FuncOp funcOp) {
|
|
// Compute all clusters that each operation is in
|
|
DenseMap<Operation *, SmallVector<Value>> op2Shapes;
|
|
for (const auto &it : clusters) {
|
|
Value shape = it.first;
|
|
const DenseSet<Operation *> &cluster = it.second;
|
|
for (Operation *cOp : cluster)
|
|
op2Shapes[cOp].push_back(shape);
|
|
}
|
|
|
|
// Iterate through all operations in order. Get all the clusters `cOp` belongs
|
|
// to and construct the new ordered cluster as it traverses.
|
|
DenseMap<Value, SmallVector<Operation *, 8>> orderedClusters;
|
|
funcOp.walk([&](Operation *op) {
|
|
auto it = op2Shapes.find(op);
|
|
if (it != op2Shapes.end()) {
|
|
Operation *cOp = it->first;
|
|
for (Value shape : it->second)
|
|
orderedClusters[shape].push_back(cOp);
|
|
}
|
|
});
|
|
|
|
return orderedClusters;
|
|
}
|
|
|
|
void constructShapeFunc(
|
|
const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
|
|
DenseMap<Value, SmallVector<Operation *, 8>> &clusters,
|
|
SymbolTable &symbolTable,
|
|
DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
|
|
func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
|
|
std::string shapeCalculationNamePrefix = "shape_cal_";
|
|
int shapeCalculationNameIdx = 0;
|
|
OpBuilder builder(context);
|
|
|
|
// Construct a shape function
|
|
for (shape::WithOp withOp : allWithOps) {
|
|
Value value = withOp.getOperand();
|
|
Value shape = withOp.getShape();
|
|
RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
|
|
if (rankedType == nullptr)
|
|
continue;
|
|
|
|
const SmallVector<Operation *, 8> &cluster = clusters[shape];
|
|
shape::ShapeMappingValue shapeMappingValue;
|
|
auto it = dynShape2ShapeFunc.find(shape);
|
|
if (it == dynShape2ShapeFunc.end()) {
|
|
std::string name = shapeCalculationNamePrefix +
|
|
std::to_string(shapeCalculationNameIdx++);
|
|
Location loc = value.getLoc();
|
|
builder.setInsertionPointAfter(funcOp);
|
|
auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
|
|
const SmallVector<Value> &inputs = pair.second;
|
|
shape::FuncOp shapeFuncOp = pair.first;
|
|
StringAttr insertedName = symbolTable.insert(shapeFuncOp);
|
|
auto symbol = FlatSymbolRefAttr::get(context, insertedName);
|
|
|
|
shapeMappingValue.funcSymbol = symbol;
|
|
shapeMappingValue.inputs = inputs;
|
|
} else {
|
|
shapeMappingValue = it->second;
|
|
}
|
|
dynShape2ShapeFunc[shape] = shapeMappingValue;
|
|
shapeMappingAnalysis.shapeMapping.insert(
|
|
std::make_pair(value, shapeMappingValue));
|
|
}
|
|
}
|
|
|
|
struct OutlineShapeComputationPass
|
|
: public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
|
|
|
|
void runOnOperation() override;
|
|
|
|
private:
|
|
bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
|
|
|
|
void getClusterFromValue(Value shape,
|
|
DenseMap<Value, DenseSet<Operation *>> &clusters);
|
|
|
|
DenseMap<Value, SmallVector<Operation *, 8>>
|
|
constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
|
|
func::FuncOp funcOp);
|
|
|
|
DenseSet<Operation *> onlyUsedByWithShapes;
|
|
};
|
|
|
|
class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
|
|
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::DimOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto shapeOf =
|
|
rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
|
|
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
|
|
op.getIndex());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void OutlineShapeComputationPass::runOnOperation() {
|
|
ModuleOp moduleOp = getOperation();
|
|
SymbolTable symbolTable(moduleOp);
|
|
DenseMap<Value, shape::ShapeMappingValue> dynShape2ShapeFunc;
|
|
auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
|
|
// TODO: This is as we populate this analysis during a pass that mutates. This
|
|
// pass currently requires 1 single module being compiled.
|
|
shapeMappingAnalysis.shapeMapping.clear();
|
|
markAnalysesPreserved<shape::ShapeMappingAnalysis>();
|
|
|
|
moduleOp.walk([&](func::FuncOp funcOp) {
|
|
MLIRContext *context = funcOp.getContext();
|
|
RewritePatternSet prevPatterns(context);
|
|
prevPatterns.insert<TensorDimOpRewriter>(context);
|
|
if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns))))
|
|
return signalPassFailure();
|
|
|
|
// initialize class member `onlyUsedByWithShapes`
|
|
onlyUsedByWithShapes.clear();
|
|
funcOp.walk([&](Operation *op) {
|
|
calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
|
|
});
|
|
LLVM_DEBUG({
|
|
llvm::dbgs() << "onlyUsedByWithShapes table: \n";
|
|
for (auto it : onlyUsedByWithShapes)
|
|
llvm::dbgs() << *it << "\n";
|
|
});
|
|
|
|
// collect all the shape.with_shape ops.
|
|
std::vector<shape::WithOp> allWithOps;
|
|
funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
|
|
|
|
DenseMap<Value, SmallVector<Operation *, 8>> clusters =
|
|
constructClustersForEachShape(allWithOps, funcOp);
|
|
constructShapeFunc(allWithOps, context, clusters, symbolTable,
|
|
dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
|
|
|
|
for (shape::WithOp withOp : allWithOps) {
|
|
Value value = withOp.getOperand();
|
|
for (Operation *user :
|
|
llvm::make_early_inc_range(withOp.getResult().getUsers())) {
|
|
if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
|
|
// For pattern like
|
|
// %1 = shape.with_shape %arg1, %0
|
|
// %2 = shape.value_of %1
|
|
// because shape.value doesn't care the shape, the shape.with_shape is
|
|
// redundant.
|
|
// If type of %arg1 and %2 has same type, just
|
|
// replaced %2 with %arg1.
|
|
// If type of %arg1 has different type like !shape.value_shape,
|
|
// transform into
|
|
// %2 = shape.value_of %arg1
|
|
if (valueOf.getType() == value.getType())
|
|
valueOf.replaceAllUsesWith(value);
|
|
else
|
|
valueOf.setOperand(value);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Apply patterns, note this also performs DCE.
|
|
if (failed(applyPatternsGreedily(funcOp, {})))
|
|
return signalPassFailure();
|
|
});
|
|
}
|
|
|
|
DenseMap<Value, SmallVector<Operation *, 8>>
|
|
OutlineShapeComputationPass::constructClustersForEachShape(
|
|
const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
|
|
DenseMap<Value, DenseSet<Operation *>> clusters;
|
|
for (shape::WithOp withOp : allWithOps) {
|
|
Value shape = withOp.getShape();
|
|
if (clusters.count(shape) == 0)
|
|
getClusterFromValue(shape, clusters);
|
|
}
|
|
return getOrderedClusters(clusters, funcOp);
|
|
}
|
|
|
|
// The output of a cluster is the `shape`, and the inputs are the outputs of
|
|
// operations who are not in `onlyUsedByWithShapes`
|
|
void OutlineShapeComputationPass::getClusterFromValue(
|
|
Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
|
|
DenseSet<Operation *> cluster;
|
|
|
|
DenseSet<Operation *> visited;
|
|
std::queue<Operation *> queue;
|
|
|
|
// defOp == nullptr means shape is the argument of the func op
|
|
if (Operation *defOp = shape.getDefiningOp()) {
|
|
visited.insert(defOp);
|
|
queue.push(defOp);
|
|
}
|
|
while (!queue.empty()) {
|
|
Operation *op = queue.front();
|
|
queue.pop();
|
|
if (onlyUsedByWithShapes.contains(op)) {
|
|
cluster.insert(op);
|
|
for (Value inp : op->getOperands()) {
|
|
Operation *inpDefOp = inp.getDefiningOp();
|
|
if (nullptr != inpDefOp && visited.insert(inpDefOp).second)
|
|
queue.push(inpDefOp);
|
|
}
|
|
}
|
|
}
|
|
|
|
clusters[shape] = std::move(cluster);
|
|
}
|
|
|
|
// Returns whether `op` is a shape.with_shape, or all the users' of `op`
|
|
// eventually point to the shape operand of shape.with_shape ops
|
|
bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
|
|
Operation *op, Value prevOutput) {
|
|
if (onlyUsedByWithShapes.contains(op))
|
|
return true;
|
|
|
|
if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
|
|
return withOp.getShape() == prevOutput;
|
|
|
|
if (op->use_empty())
|
|
return false;
|
|
|
|
for (Value oup : op->getResults())
|
|
for (Operation *user : oup.getUsers())
|
|
if (!calOnlyUsedByWithShapesRecursively(user, oup))
|
|
return false;
|
|
|
|
onlyUsedByWithShapes.insert(op);
|
|
return true;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
mlir::createOutlineShapeComputationPass() {
|
|
return std::make_unique<OutlineShapeComputationPass>();
|
|
}
|