Files
clang-p2996/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
Mahesh Ravishankar 2d4b998697 [mlir][Linalg] Avoid unnecessary propagating producer result to fused op result.
Elementwise op fusion conserves the result of the producer in the
fused op, relying on later clean up patterns to drop unused results of
the fused op. Instead, if the producer result has no other use apart
from the consumer op, avoid making the producer result available in
the fused node. This saves some unnecessary IR manipulations.

Differential Revision: https://reviews.llvm.org/D138096
2022-11-22 07:08:17 +00:00

226 lines
8.4 KiB
C++

//===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass for testing fusion of elementwise operations in
// Linalg, mainly linalg options.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
static void addOperands(Operation *op, SetVector<Value> &operandSet) {
if (!op)
return;
TypeSwitch<Operation *, void>(op)
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
SmallVector<Value> inputOperands{linalgOp.getDpsInputOperands()};
operandSet.insert(inputOperands.begin(), inputOperands.end());
})
.Default([&](Operation *operation) {
operandSet.insert(operation->operand_begin(), operation->operand_end());
});
}
template <int limit = 3>
static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
if (!producer)
return false;
Operation *consumer = fusedOperand->getOwner();
SetVector<Value> fusedOpOperands;
if (producer->getNumResults() != 1)
return false;
addOperands(consumer, fusedOpOperands);
fusedOpOperands.remove(producer->getResult(0));
addOperands(producer, fusedOpOperands);
return fusedOpOperands.size() <= limit;
}
namespace {
struct TestLinalgElementwiseFusion
: public PassWrapper<TestLinalgElementwiseFusion,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion)
TestLinalgElementwiseFusion() = default;
TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-elementwise-fusion-patterns";
}
StringRef getDescription() const final {
return "Test Linalg element wise operation fusion patterns";
}
Option<bool> fuseGenericOps{
*this, "fuse-generic-ops",
llvm::cl::desc("Test fusion of generic operations."),
llvm::cl::init(false)};
Option<bool> fuseGenericOpsControl{
*this, "fuse-generic-ops-control",
llvm::cl::desc(
"Test fusion of generic operations with a control function."),
llvm::cl::init(false)};
Option<bool> fuseWithReshapeByExpansion{
*this, "fuse-with-reshape-by-expansion",
llvm::cl::desc(
"Test fusion of generic operations with reshape by expansion"),
llvm::cl::init(false)};
Option<bool> controlFuseByExpansion{
*this, "control-fusion-by-expansion",
llvm::cl::desc(
"Test controlling fusion of reshape with generic op by expansion"),
llvm::cl::init(false)};
Option<bool> fuseWithReshapeByCollapsing{
*this, "fuse-with-reshape-by-collapsing",
llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
"collapse the iteration space of the consumer"),
llvm::cl::init(false)};
Option<bool> fuseWithReshapeByCollapsingWithControlFn{
*this, "fuse-with-reshape-by-collapsing-control",
llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
"fusion patterns that "
"collapse the iteration space of the consumer"),
llvm::cl::init(false)};
ListOption<int64_t> collapseDimensions{
*this, "collapse-dimensions-control",
llvm::cl::desc("Test controlling dimension collapse pattern")};
void runOnOperation() override {
MLIRContext *context = &this->getContext();
func::FuncOp funcOp = this->getOperation();
if (fuseGenericOps) {
RewritePatternSet fusionPatterns(context);
auto controlFn = [](OpOperand *operand) { return true; };
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return;
}
if (fuseGenericOpsControl) {
RewritePatternSet fusionPatterns(context);
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
setFusedOpOperandLimit<4>);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return;
}
if (fuseWithReshapeByExpansion) {
RewritePatternSet fusionPatterns(context);
linalg::populateFoldReshapeOpsByExpansionPatterns(
fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; });
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
return signalPassFailure();
return;
}
if (controlFuseByExpansion) {
RewritePatternSet fusionPatterns(context);
linalg::ControlFusionFn controlReshapeFusionFn =
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
if (!producer)
return false;
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer)) {
if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) {
return false;
}
}
Operation *consumer = fusedOperand->getOwner();
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) {
if (expandOp->hasOneUse()) {
OpOperand &use = *expandOp->getUses().begin();
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
if (linalgOp && linalgOp.isDpsInit(&use))
return true;
}
return false;
}
return true;
};
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
controlReshapeFusionFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return;
}
if (fuseWithReshapeByCollapsing) {
RewritePatternSet patterns(context);
linalg::populateFoldReshapeOpsByCollapsingPatterns(
patterns, [](OpOperand * /*fusedOperand */) { return true; });
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
if (fuseWithReshapeByCollapsingWithControlFn) {
RewritePatternSet patterns(context);
linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool {
Operation *producer = fusedOperand->get().getDefiningOp();
if (isa<tensor::ExpandShapeOp>(producer)) {
// Skip fusing the first operand.
return fusedOperand->getOperandNumber();
}
return true;
};
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
if (!collapseDimensions.empty()) {
SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
collapseDimensions.end());
linalg::GetCollapsableDimensionsFn collapseFn =
[&dims](linalg::GenericOp op) {
SmallVector<ReassociationIndices> reassociations;
reassociations.emplace_back(dims);
return reassociations;
};
RewritePatternSet patterns(context);
linalg::populateCollapseDimensions(patterns, collapseFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestLinalgElementwiseFusion() {
PassRegistration<TestLinalgElementwiseFusion>();
}
} // namespace test
} // namespace mlir