[mlir][Linalg] Make Elementwise op fusion return a map from existing values to values in the fused op.

This replacement can be used to eliminate all uses of the
producer/consumer for case where producer/consumer has other uses
outside of the producer/consumer pair. This makes the
producer/consumer dead.

Add test and minor fixup to the test harness.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D142848
This commit is contained in:
Mahesh Ravishankar
2023-01-31 20:33:19 +00:00
parent 9271c5da43
commit 69011a2ad0
4 changed files with 135 additions and 20 deletions

View File

@@ -162,8 +162,12 @@ bool areElementwiseOpsFusable(OpOperand *fusedOperand);
/// Fuse two `linalg.generic` operations that have a producer-consumer
/// relationship captured through `fusedOperand`. The method expects
/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
OpOperand *fusedOperand);
struct ElementwiseOpFusionResult {
Operation *fusedOp;
llvm::DenseMap<Value, Value> replacements;
};
FailureOr<ElementwiseOpFusionResult>
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
/// Split the given `op` into two parts along the given iteration space
/// `dimension` at the specified `splitPoint`, and return the two parts.

View File

@@ -23,8 +23,8 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <utility>
#include <optional>
#include <utility>
namespace mlir {
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
@@ -73,6 +73,9 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
/// Conditions for elementwise fusion of generic operations.
bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
return false;
auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
@@ -270,7 +273,7 @@ static void generateFusedElementwiseOpRegion(
"Ill-formed GenericOp region");
}
FailureOr<Operation *>
FailureOr<mlir::linalg::ElementwiseOpFusionResult>
mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
OpOperand *fusedOperand) {
assert(areElementwiseOpsFusable(fusedOperand) &&
@@ -390,7 +393,15 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
generateFusedElementwiseOpRegion(
rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
consumer.getNumLoops(), preservedProducerResults);
return fusedOp.getOperation();
ElementwiseOpFusionResult result;
result.fusedOp = fusedOp;
int resultNum = 0;
for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
if (preservedProducerResults.count(index))
result.replacements[producerResult] = fusedOp->getResult(resultNum++);
for (auto consumerResult : consumer->getResults())
result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
return result;
}
namespace {
@@ -411,13 +422,20 @@ public:
if (!controlFn(&opOperand))
continue;
FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, &opOperand);
if (succeeded(fusedOp)) {
auto replacements =
(*fusedOp)->getResults().take_back(genericOp.getNumResults());
rewriter.replaceOp(genericOp, replacements);
return success();
FailureOr<ElementwiseOpFusionResult> fusionResult =
fuseElementwiseOps(rewriter, &opOperand);
if (failed(fusionResult))
rewriter.notifyMatchFailure(genericOp, "fusion failed");
Operation *producer = opOperand.get().getDefiningOp();
for (auto [origVal, replacement] : fusionResult->replacements) {
Value origValCopy = origVal;
rewriter.replaceUseIf(origVal, replacement, [&](OpOperand &use) {
// Only replace consumer uses.
return use.get().getDefiningOp() != producer;
});
}
rewriter.eraseOp(genericOp);
return success();
}
return failure();
}

View File

@@ -0,0 +1,34 @@
// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=fuse-multiuse-producer -split-input-file %s | FileCheck %s
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @multi_use_producer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>, %arg3 : tensor<?x?xf32>, %arg4 : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
%0:2 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?x?xf32>)
outs(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) {
^bb0(%b0: f32, %b1 : f32, %b2 : f32):
%1 = arith.addf %b0, %b1 : f32
linalg.yield %1, %1 : f32, f32
} -> (tensor<?x?xf32>, tensor<?x?xf32>)
%2 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%0#1, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg4 : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%3 = arith.mulf %b0, %b1 : f32
linalg.yield %3 : f32
} -> tensor<?x?xf32>
return %0#0, %0#1, %2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
}
// CHECK: func @multi_use_producer(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK: %[[RESULT:.+]]:3 = linalg.generic
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2

View File

@@ -51,6 +51,38 @@ static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
}
namespace {
/// Pattern to test fusion of producer with consumer, even if producer has
/// multiple uses.
struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
OpOperand *fusableOperand = nullptr;
for (OpOperand &operand : genericOp->getOpOperands()) {
if (linalg::areElementwiseOpsFusable(&operand)) {
fusableOperand = &operand;
break;
}
}
if (!fusableOperand) {
return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
}
std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
linalg::fuseElementwiseOps(rewriter, fusableOperand);
if (!fusionResult)
rewriter.notifyMatchFailure(genericOp, "fusion failed");
for (auto [origValue, replacement] : fusionResult->replacements) {
rewriter.replaceUseIf(origValue, replacement, [&](OpOperand &use) {
return use.getOwner() != genericOp.getOperation();
});
}
rewriter.eraseOp(genericOp);
return success();
}
};
struct TestLinalgElementwiseFusion
: public PassWrapper<TestLinalgElementwiseFusion,
OperationPass<func::FuncOp>> {
@@ -105,6 +137,12 @@ struct TestLinalgElementwiseFusion
"fusion patterns that "
"collapse the iteration space of the consumer"),
llvm::cl::init(false)};
Option<bool> fuseMultiUseProducer{
*this, "fuse-multiuse-producer",
llvm::cl::desc("Test fusion of producer ops with multiple uses"),
llvm::cl::init(false)};
ListOption<int64_t> collapseDimensions{
*this, "collapse-dimensions-control",
llvm::cl::desc("Test controlling dimension collapse pattern")};
@@ -117,8 +155,9 @@ struct TestLinalgElementwiseFusion
RewritePatternSet fusionPatterns(context);
auto controlFn = [](OpOperand *operand) { return true; };
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
return signalPassFailure();
return;
}
@@ -127,8 +166,9 @@ struct TestLinalgElementwiseFusion
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
setFusedOpOperandLimit<4>);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
return signalPassFailure();
return;
}
@@ -172,8 +212,9 @@ struct TestLinalgElementwiseFusion
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
controlReshapeFusionFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
return signalPassFailure();
return;
}
@@ -181,7 +222,10 @@ struct TestLinalgElementwiseFusion
RewritePatternSet patterns(context);
linalg::populateFoldReshapeOpsByCollapsingPatterns(
patterns, [](OpOperand * /*fusedOperand */) { return true; });
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
return signalPassFailure();
return;
}
if (fuseWithReshapeByCollapsingWithControlFn) {
@@ -195,7 +239,19 @@ struct TestLinalgElementwiseFusion
return true;
};
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
return signalPassFailure();
return;
}
if (fuseMultiUseProducer) {
RewritePatternSet patterns(context);
patterns.insert<TestMultiUseProducerFusion>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
return signalPassFailure();
return;
}
if (!collapseDimensions.empty()) {
@@ -209,7 +265,10 @@ struct TestLinalgElementwiseFusion
};
RewritePatternSet patterns(context);
linalg::populateCollapseDimensions(patterns, collapseFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
return signalPassFailure();
return;
}
}
};