[mlir][Linalg] Make Elementwise op fusion return a map from existing values to values in the fused op.
This replacement can be used to eliminate all uses of the producer/consumer for case where producer/consumer has other uses outside of the producer/consumer pair. This makes the producer/consumer dead. Add test and minor fixup to the test harness. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D142848
This commit is contained in:
@@ -162,8 +162,12 @@ bool areElementwiseOpsFusable(OpOperand *fusedOperand);
|
||||
/// Fuse two `linalg.generic` operations that have a producer-consumer
|
||||
/// relationship captured through `fusedOperand`. The method expects
|
||||
/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
|
||||
FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
|
||||
OpOperand *fusedOperand);
|
||||
struct ElementwiseOpFusionResult {
|
||||
Operation *fusedOp;
|
||||
llvm::DenseMap<Value, Value> replacements;
|
||||
};
|
||||
FailureOr<ElementwiseOpFusionResult>
|
||||
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
|
||||
|
||||
/// Split the given `op` into two parts along the given iteration space
|
||||
/// `dimension` at the specified `splitPoint`, and return the two parts.
|
||||
|
||||
@@ -23,8 +23,8 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include <utility>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
|
||||
@@ -73,6 +73,9 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
||||
|
||||
/// Conditions for elementwise fusion of generic operations.
|
||||
bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
|
||||
if (!fusedOperand)
|
||||
return false;
|
||||
|
||||
auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
|
||||
auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
|
||||
|
||||
@@ -270,7 +273,7 @@ static void generateFusedElementwiseOpRegion(
|
||||
"Ill-formed GenericOp region");
|
||||
}
|
||||
|
||||
FailureOr<Operation *>
|
||||
FailureOr<mlir::linalg::ElementwiseOpFusionResult>
|
||||
mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
|
||||
OpOperand *fusedOperand) {
|
||||
assert(areElementwiseOpsFusable(fusedOperand) &&
|
||||
@@ -390,7 +393,15 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
|
||||
generateFusedElementwiseOpRegion(
|
||||
rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
|
||||
consumer.getNumLoops(), preservedProducerResults);
|
||||
return fusedOp.getOperation();
|
||||
ElementwiseOpFusionResult result;
|
||||
result.fusedOp = fusedOp;
|
||||
int resultNum = 0;
|
||||
for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
|
||||
if (preservedProducerResults.count(index))
|
||||
result.replacements[producerResult] = fusedOp->getResult(resultNum++);
|
||||
for (auto consumerResult : consumer->getResults())
|
||||
result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
@@ -411,13 +422,20 @@ public:
|
||||
if (!controlFn(&opOperand))
|
||||
continue;
|
||||
|
||||
FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, &opOperand);
|
||||
if (succeeded(fusedOp)) {
|
||||
auto replacements =
|
||||
(*fusedOp)->getResults().take_back(genericOp.getNumResults());
|
||||
rewriter.replaceOp(genericOp, replacements);
|
||||
return success();
|
||||
FailureOr<ElementwiseOpFusionResult> fusionResult =
|
||||
fuseElementwiseOps(rewriter, &opOperand);
|
||||
if (failed(fusionResult))
|
||||
rewriter.notifyMatchFailure(genericOp, "fusion failed");
|
||||
Operation *producer = opOperand.get().getDefiningOp();
|
||||
for (auto [origVal, replacement] : fusionResult->replacements) {
|
||||
Value origValCopy = origVal;
|
||||
rewriter.replaceUseIf(origVal, replacement, [&](OpOperand &use) {
|
||||
// Only replace consumer uses.
|
||||
return use.get().getDefiningOp() != producer;
|
||||
});
|
||||
}
|
||||
rewriter.eraseOp(genericOp);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
34
mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
Normal file
34
mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
Normal file
@@ -0,0 +1,34 @@
|
||||
// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=fuse-multiuse-producer -split-input-file %s | FileCheck %s
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func.func @multi_use_producer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?x?xf32>, %arg3 : tensor<?x?xf32>, %arg4 : tensor<?x?xf32>)
|
||||
-> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0:2 = linalg.generic {
|
||||
indexing_maps = [#map, #map, #map],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<?x?xf32>)
|
||||
outs(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
^bb0(%b0: f32, %b1 : f32, %b2 : f32):
|
||||
%1 = arith.addf %b0, %b1 : f32
|
||||
linalg.yield %1, %1 : f32, f32
|
||||
} -> (tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
%2 = linalg.generic {
|
||||
indexing_maps = [#map, #map, #map],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%0#1, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg4 : tensor<?x?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||
%3 = arith.mulf %b0, %b1 : f32
|
||||
linalg.yield %3 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %0#0, %0#1, %2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
// CHECK: func @multi_use_producer(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
|
||||
// CHECK: %[[RESULT:.+]]:3 = linalg.generic
|
||||
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2
|
||||
@@ -51,6 +51,38 @@ static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Pattern to test fusion of producer with consumer, even if producer has
|
||||
/// multiple uses.
|
||||
struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
|
||||
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
OpOperand *fusableOperand = nullptr;
|
||||
for (OpOperand &operand : genericOp->getOpOperands()) {
|
||||
if (linalg::areElementwiseOpsFusable(&operand)) {
|
||||
fusableOperand = &operand;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!fusableOperand) {
|
||||
return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
|
||||
}
|
||||
std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
|
||||
linalg::fuseElementwiseOps(rewriter, fusableOperand);
|
||||
if (!fusionResult)
|
||||
rewriter.notifyMatchFailure(genericOp, "fusion failed");
|
||||
for (auto [origValue, replacement] : fusionResult->replacements) {
|
||||
rewriter.replaceUseIf(origValue, replacement, [&](OpOperand &use) {
|
||||
return use.getOwner() != genericOp.getOperation();
|
||||
});
|
||||
}
|
||||
rewriter.eraseOp(genericOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TestLinalgElementwiseFusion
|
||||
: public PassWrapper<TestLinalgElementwiseFusion,
|
||||
OperationPass<func::FuncOp>> {
|
||||
@@ -105,6 +137,12 @@ struct TestLinalgElementwiseFusion
|
||||
"fusion patterns that "
|
||||
"collapse the iteration space of the consumer"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<bool> fuseMultiUseProducer{
|
||||
*this, "fuse-multiuse-producer",
|
||||
llvm::cl::desc("Test fusion of producer ops with multiple uses"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
ListOption<int64_t> collapseDimensions{
|
||||
*this, "collapse-dimensions-control",
|
||||
llvm::cl::desc("Test controlling dimension collapse pattern")};
|
||||
@@ -117,8 +155,9 @@ struct TestLinalgElementwiseFusion
|
||||
RewritePatternSet fusionPatterns(context);
|
||||
auto controlFn = [](OpOperand *operand) { return true; };
|
||||
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns));
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -127,8 +166,9 @@ struct TestLinalgElementwiseFusion
|
||||
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
|
||||
setFusedOpOperandLimit<4>);
|
||||
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns));
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -172,8 +212,9 @@ struct TestLinalgElementwiseFusion
|
||||
|
||||
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
|
||||
controlReshapeFusionFn);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns));
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -181,7 +222,10 @@ struct TestLinalgElementwiseFusion
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateFoldReshapeOpsByCollapsingPatterns(
|
||||
patterns, [](OpOperand * /*fusedOperand */) { return true; });
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (fuseWithReshapeByCollapsingWithControlFn) {
|
||||
@@ -195,7 +239,19 @@ struct TestLinalgElementwiseFusion
|
||||
return true;
|
||||
};
|
||||
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (fuseMultiUseProducer) {
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.insert<TestMultiUseProducerFusion>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!collapseDimensions.empty()) {
|
||||
@@ -209,7 +265,10 @@ struct TestLinalgElementwiseFusion
|
||||
};
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateCollapseDimensions(patterns, collapseFn);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user