Elementwise op fusion conserves the result of the producer in the fused op, relying on later clean up patterns to drop unused results of the fused op. Instead, if the producer result has no other use apart from the consumer op, avoid making the producer result available in the fused node. This saves some unnecessary IR manipulations. Differential Revision: https://reviews.llvm.org/D138096
226 lines
8.4 KiB
C++
226 lines
8.4 KiB
C++
//===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass for testing fusion of elementwise operations in
|
|
// Linalg, mainly linalg options.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
using namespace mlir;
|
|
|
|
static void addOperands(Operation *op, SetVector<Value> &operandSet) {
|
|
if (!op)
|
|
return;
|
|
TypeSwitch<Operation *, void>(op)
|
|
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
|
|
SmallVector<Value> inputOperands{linalgOp.getDpsInputOperands()};
|
|
operandSet.insert(inputOperands.begin(), inputOperands.end());
|
|
})
|
|
.Default([&](Operation *operation) {
|
|
operandSet.insert(operation->operand_begin(), operation->operand_end());
|
|
});
|
|
}
|
|
|
|
template <int limit = 3>
|
|
static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
|
|
Operation *producer = fusedOperand->get().getDefiningOp();
|
|
if (!producer)
|
|
return false;
|
|
|
|
Operation *consumer = fusedOperand->getOwner();
|
|
SetVector<Value> fusedOpOperands;
|
|
if (producer->getNumResults() != 1)
|
|
return false;
|
|
addOperands(consumer, fusedOpOperands);
|
|
fusedOpOperands.remove(producer->getResult(0));
|
|
addOperands(producer, fusedOpOperands);
|
|
return fusedOpOperands.size() <= limit;
|
|
}
|
|
|
|
namespace {
|
|
struct TestLinalgElementwiseFusion
|
|
: public PassWrapper<TestLinalgElementwiseFusion,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion)
|
|
|
|
TestLinalgElementwiseFusion() = default;
|
|
TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
|
|
: PassWrapper(pass) {}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
|
|
tensor::TensorDialect>();
|
|
}
|
|
StringRef getArgument() const final {
|
|
return "test-linalg-elementwise-fusion-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test Linalg element wise operation fusion patterns";
|
|
}
|
|
|
|
Option<bool> fuseGenericOps{
|
|
*this, "fuse-generic-ops",
|
|
llvm::cl::desc("Test fusion of generic operations."),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> fuseGenericOpsControl{
|
|
*this, "fuse-generic-ops-control",
|
|
llvm::cl::desc(
|
|
"Test fusion of generic operations with a control function."),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> fuseWithReshapeByExpansion{
|
|
*this, "fuse-with-reshape-by-expansion",
|
|
llvm::cl::desc(
|
|
"Test fusion of generic operations with reshape by expansion"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> controlFuseByExpansion{
|
|
*this, "control-fusion-by-expansion",
|
|
llvm::cl::desc(
|
|
"Test controlling fusion of reshape with generic op by expansion"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> fuseWithReshapeByCollapsing{
|
|
*this, "fuse-with-reshape-by-collapsing",
|
|
llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
|
|
"collapse the iteration space of the consumer"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> fuseWithReshapeByCollapsingWithControlFn{
|
|
*this, "fuse-with-reshape-by-collapsing-control",
|
|
llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
|
|
"fusion patterns that "
|
|
"collapse the iteration space of the consumer"),
|
|
llvm::cl::init(false)};
|
|
ListOption<int64_t> collapseDimensions{
|
|
*this, "collapse-dimensions-control",
|
|
llvm::cl::desc("Test controlling dimension collapse pattern")};
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &this->getContext();
|
|
func::FuncOp funcOp = this->getOperation();
|
|
|
|
if (fuseGenericOps) {
|
|
RewritePatternSet fusionPatterns(context);
|
|
auto controlFn = [](OpOperand *operand) { return true; };
|
|
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
|
|
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
|
std::move(fusionPatterns));
|
|
return;
|
|
}
|
|
|
|
if (fuseGenericOpsControl) {
|
|
RewritePatternSet fusionPatterns(context);
|
|
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
|
|
setFusedOpOperandLimit<4>);
|
|
|
|
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
|
std::move(fusionPatterns));
|
|
return;
|
|
}
|
|
|
|
if (fuseWithReshapeByExpansion) {
|
|
RewritePatternSet fusionPatterns(context);
|
|
linalg::populateFoldReshapeOpsByExpansionPatterns(
|
|
fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; });
|
|
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
|
std::move(fusionPatterns))))
|
|
return signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
if (controlFuseByExpansion) {
|
|
RewritePatternSet fusionPatterns(context);
|
|
|
|
linalg::ControlFusionFn controlReshapeFusionFn =
|
|
[](OpOperand *fusedOperand) {
|
|
Operation *producer = fusedOperand->get().getDefiningOp();
|
|
if (!producer)
|
|
return false;
|
|
|
|
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer)) {
|
|
if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
Operation *consumer = fusedOperand->getOwner();
|
|
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) {
|
|
if (expandOp->hasOneUse()) {
|
|
OpOperand &use = *expandOp->getUses().begin();
|
|
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
|
|
if (linalgOp && linalgOp.isDpsInit(&use))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
return true;
|
|
};
|
|
|
|
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
|
|
controlReshapeFusionFn);
|
|
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
|
std::move(fusionPatterns));
|
|
return;
|
|
}
|
|
|
|
if (fuseWithReshapeByCollapsing) {
|
|
RewritePatternSet patterns(context);
|
|
linalg::populateFoldReshapeOpsByCollapsingPatterns(
|
|
patterns, [](OpOperand * /*fusedOperand */) { return true; });
|
|
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
|
}
|
|
|
|
if (fuseWithReshapeByCollapsingWithControlFn) {
|
|
RewritePatternSet patterns(context);
|
|
linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool {
|
|
Operation *producer = fusedOperand->get().getDefiningOp();
|
|
if (isa<tensor::ExpandShapeOp>(producer)) {
|
|
// Skip fusing the first operand.
|
|
return fusedOperand->getOperandNumber();
|
|
}
|
|
return true;
|
|
};
|
|
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
|
|
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
|
}
|
|
|
|
if (!collapseDimensions.empty()) {
|
|
SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
|
|
collapseDimensions.end());
|
|
linalg::GetCollapsableDimensionsFn collapseFn =
|
|
[&dims](linalg::GenericOp op) {
|
|
SmallVector<ReassociationIndices> reassociations;
|
|
reassociations.emplace_back(dims);
|
|
return reassociations;
|
|
};
|
|
RewritePatternSet patterns(context);
|
|
linalg::populateCollapseDimensions(patterns, collapseFn);
|
|
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestLinalgElementwiseFusion() {
|
|
PassRegistration<TestLinalgElementwiseFusion>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|