feat(linalg): add a way to pass controlFn to foldIntoPackUnpackPatterns (#143685)
This PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn. The controlFn always gets the source operand of the consumer in each of the patterns as a parameter. In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue [#20896](https://github.com/iree-org/iree/issues/20896) for more details.
This commit is contained in:
@@ -1984,10 +1984,15 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
|
||||
/// convert to a `linalg.dot`.
|
||||
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Function type which is used to control folding operations like `tensor.pad`
|
||||
/// and `tensor.extract_slice` into linalg.pack/unpack ops.
|
||||
using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
|
||||
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
|
||||
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
|
||||
/// respectively.
|
||||
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
|
||||
void populateFoldIntoPackAndUnpackPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
const ControlFoldIntoPackUnpackFn &controlFn = nullptr);
|
||||
|
||||
/// Populates `patterns` with patterns that fold operations like `linalg.pack`
|
||||
/// and `linalg.unpack` into `tensor.empty`.
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
@@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
|
||||
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
|
||||
/// the pad op has zero low paddings, or if `pack` has no padding values.
|
||||
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
|
||||
using OpRewritePattern<PackOp>::OpRewritePattern;
|
||||
public:
|
||||
FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
|
||||
: OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(PackOp packOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
@@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
|
||||
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
|
||||
return failure();
|
||||
|
||||
// User controlled folding function.
|
||||
if (controlFn && !controlFn(&packOp.getSourceMutable()))
|
||||
return failure();
|
||||
|
||||
Value constantPaddingValue = padOp.getConstantPaddingValue();
|
||||
if (!constantPaddingValue)
|
||||
return failure();
|
||||
@@ -220,13 +227,20 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
|
||||
packOp.getOuterDimsPerm());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlFoldIntoPackUnpackFn controlFn;
|
||||
};
|
||||
|
||||
/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
|
||||
/// has extract_slice semantics.
|
||||
struct FoldUnpackWithExtractSliceOp
|
||||
: public OpRewritePattern<tensor::ExtractSliceOp> {
|
||||
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
|
||||
public:
|
||||
FoldUnpackWithExtractSliceOp(MLIRContext *context,
|
||||
ControlFoldIntoPackUnpackFn controlFn)
|
||||
: OpRewritePattern<tensor::ExtractSliceOp>(context),
|
||||
controlFn(std::move(controlFn)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
@@ -234,6 +248,10 @@ struct FoldUnpackWithExtractSliceOp
|
||||
if (!unpackOp)
|
||||
return failure();
|
||||
|
||||
// User controlled folding function.
|
||||
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
|
||||
return failure();
|
||||
|
||||
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
sliceOp, "rank-reduced folding is not supported");
|
||||
@@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp
|
||||
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlFoldIntoPackUnpackFn controlFn;
|
||||
};
|
||||
|
||||
// Applies 'permutation' on 'inVec' and stores the result in resVec.
|
||||
@@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
|
||||
/// semantics.
|
||||
struct FoldProducerPackWithConsumerLinalgTransposeOp
|
||||
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
|
||||
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
|
||||
|
||||
public:
|
||||
FoldProducerPackWithConsumerLinalgTransposeOp(
|
||||
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
|
||||
: OpInterfaceRewritePattern<linalg::LinalgOp>(context),
|
||||
controlFn(std::move(controlFn)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
@@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
|
||||
if (!packOp)
|
||||
return failure();
|
||||
|
||||
// User controlled folding function.
|
||||
if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
|
||||
return failure();
|
||||
|
||||
FailureOr<SmallVector<int64_t>> maybePerm =
|
||||
getTransposeOpPermutation(linalgOp);
|
||||
if (failed(maybePerm))
|
||||
@@ -331,13 +361,20 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlFoldIntoPackUnpackFn controlFn;
|
||||
};
|
||||
|
||||
/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
|
||||
/// semantics.
|
||||
struct FoldConsumerPackWithProducerLinalgTransposeOp
|
||||
: public OpRewritePattern<PackOp> {
|
||||
using OpRewritePattern<PackOp>::OpRewritePattern;
|
||||
|
||||
public:
|
||||
FoldConsumerPackWithProducerLinalgTransposeOp(
|
||||
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
|
||||
: OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(PackOp packOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
@@ -345,6 +382,10 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
|
||||
// User controlled folding function.
|
||||
if (controlFn && !controlFn(&packOp.getSourceMutable()))
|
||||
return failure();
|
||||
|
||||
FailureOr<SmallVector<int64_t>> maybePerm =
|
||||
getTransposeOpPermutation(linalgOp);
|
||||
if (failed(maybePerm))
|
||||
@@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlFoldIntoPackUnpackFn controlFn;
|
||||
};
|
||||
|
||||
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
|
||||
/// transpose semantics.
|
||||
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
|
||||
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
|
||||
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
|
||||
|
||||
public:
|
||||
FoldProducerUnPackWithConsumerLinalgTransposeOp(
|
||||
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
|
||||
: OpInterfaceRewritePattern<linalg::LinalgOp>(context),
|
||||
controlFn(std::move(controlFn)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
@@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
|
||||
if (!unPackOp)
|
||||
return failure();
|
||||
|
||||
// User controlled folding function.
|
||||
if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
|
||||
return failure();
|
||||
|
||||
FailureOr<SmallVector<int64_t>> maybePerm =
|
||||
getTransposeOpPermutation(linalgOp);
|
||||
if (failed(maybePerm))
|
||||
@@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlFoldIntoPackUnpackFn controlFn;
|
||||
};
|
||||
|
||||
/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
|
||||
@@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
|
||||
: public OpRewritePattern<UnPackOp> {
|
||||
using OpRewritePattern<UnPackOp>::OpRewritePattern;
|
||||
|
||||
public:
|
||||
FoldConsumerUnPackWithProducerLinalgTransposeOp(
|
||||
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
|
||||
: OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(UnPackOp unPackOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
|
||||
// User controlled folding function.
|
||||
if (controlFn && !controlFn(&unPackOp.getSourceMutable()))
|
||||
return failure();
|
||||
|
||||
FailureOr<SmallVector<int64_t>> maybePerm =
|
||||
getTransposeOpPermutation(linalgOp);
|
||||
if (failed(maybePerm))
|
||||
@@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlFoldIntoPackUnpackFn controlFn;
|
||||
};
|
||||
|
||||
/// tensor.empty does not define any tensor contents, so an unpadded pack
|
||||
@@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
|
||||
void populateFoldIntoPackAndUnpackPatterns(
|
||||
RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
|
||||
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
|
||||
FoldProducerPackWithConsumerLinalgTransposeOp,
|
||||
FoldConsumerPackWithProducerLinalgTransposeOp,
|
||||
FoldConsumerUnPackWithProducerLinalgTransposeOp,
|
||||
FoldProducerUnPackWithConsumerLinalgTransposeOp>(
|
||||
patterns.getContext());
|
||||
patterns.getContext(), controlFn);
|
||||
}
|
||||
|
||||
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
|
||||
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL
|
||||
|
||||
func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
|
||||
@@ -373,6 +374,36 @@ func.func @linalg_transpose_linalg.pack_fold(%arg0: tensor<56x57x1x64xf32>) -> t
|
||||
|
||||
// -----
|
||||
|
||||
func.func @linalg_transpose_linalg.pack_fold_multi_result(%arg0: tensor<56x57x1x64xf32>) -> (tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32>) {
|
||||
%0 = tensor.empty() : tensor<1x56x57x64xf32>
|
||||
%transposed = linalg.transpose
|
||||
ins(%arg0 : tensor<56x57x1x64xf32>)
|
||||
outs(%0 : tensor<1x56x57x64xf32>)
|
||||
permutation = [2, 0, 1, 3]
|
||||
|
||||
%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
|
||||
%pack = linalg.pack %transposed
|
||||
outer_dims_perm = [0, 2, 1, 3]
|
||||
inner_dims_pos = [3]
|
||||
inner_tiles = [32]
|
||||
into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
|
||||
return %transposed, %pack : tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
|
||||
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose
|
||||
// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG0]]
|
||||
// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
|
||||
// CHECK: return %[[TRANSPOSE]], %[[PACK]]
|
||||
|
||||
// CONTROL-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result(
|
||||
// CONTROL: %[[TRANSPOSE:.+]] = linalg.transpose
|
||||
// CONTROL: %[[PACK:.+]] = linalg.pack %[[TRANSPOSE]]
|
||||
// CONTROL-SAME: outer_dims_perm = [0, 2, 1, 3]
|
||||
// CONTROL: return %[[TRANSPOSE]], %[[PACK]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @linalg_transpose_linalg.pack_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> {
|
||||
%0 = tensor.empty() : tensor<1x56x57x55xf32>
|
||||
%transpose = linalg.transpose
|
||||
@@ -550,6 +581,36 @@ func.func @linalg_transpose_linalg.unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
|
||||
|
||||
// -----
|
||||
|
||||
func.func @linalg_transpose_linalg.unpack_fold_multi_result(%arg0: tensor<1x1x4x16xi32>) -> (tensor<1x1x16x4xi32>, tensor<16x4xi32>) {
|
||||
%0 = tensor.empty() : tensor<1x1x16x4xi32>
|
||||
%transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
|
||||
outs(%0 : tensor<1x1x16x4xi32>)
|
||||
permutation = [1, 0, 3, 2]
|
||||
%1 = tensor.empty() : tensor<16x4xi32>
|
||||
%unpack = linalg.unpack %transposed
|
||||
outer_dims_perm = [0, 1]
|
||||
inner_dims_pos = [0, 1]
|
||||
inner_tiles = [16, 4] into
|
||||
%1 : tensor<1x1x16x4xi32> -> tensor<16x4xi32>
|
||||
return %transposed, %unpack : tensor<1x1x16x4xi32>, tensor<16x4xi32>
|
||||
}
|
||||
//CHECK-LABEL: func.func @linalg_transpose_linalg.unpack_fold_multi_result(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>)
|
||||
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose
|
||||
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
|
||||
// CHECK-SAME: outer_dims_perm = [1, 0]
|
||||
// CHECK: return %[[TRANSPOSE]], %[[UNPACK]]
|
||||
// CHECK: }
|
||||
|
||||
//CONTROL-LABEL: func.func @linalg_transpose_linalg.unpack_fold_multi_result(
|
||||
// CONTROL: %[[TRANSPOSE:.+]] = linalg.transpose
|
||||
// CONTROL: %[[UNPACK:.+]] = linalg.unpack %[[TRANSPOSE]]
|
||||
// CONTROL-SAME: outer_dims_perm = [0, 1]
|
||||
// CONTROL: return %[[TRANSPOSE]], %[[UNPACK]]
|
||||
// CONTROL: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @linalg_transpose_linalg.unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
|
||||
%0 = tensor.empty() : tensor<1x1x16x4xi32>
|
||||
%transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
|
||||
|
||||
@@ -130,6 +130,11 @@ struct TestLinalgTransforms
|
||||
*this, "test-fold-into-pack-and-unpack",
|
||||
llvm::cl::desc("Test folding ops into linalg.pack and linalg.unpack"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testFoldIntoPackAndUnpackWithControlFn{
|
||||
*this, "test-fold-into-pack-and-unpack-control",
|
||||
llvm::cl::desc(
|
||||
"Test controlling folding ops into linalg.pack and linalg.unpack"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testSimplifyPackUnpackPatterns{
|
||||
*this, "test-simplify-pack-unpack-patterns",
|
||||
llvm::cl::desc("Test patterns to simplify linalg.pack and linalg.unpack"),
|
||||
@@ -222,9 +227,11 @@ static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
|
||||
static void applyFoldIntoPackAndUnpackPatterns(
|
||||
Operation *rootOp,
|
||||
linalg::ControlFoldIntoPackUnpackFn controlFn = nullptr) {
|
||||
RewritePatternSet patterns(rootOp->getContext());
|
||||
linalg::populateFoldIntoPackAndUnpackPatterns(patterns);
|
||||
linalg::populateFoldIntoPackAndUnpackPatterns(patterns, controlFn);
|
||||
(void)applyPatternsGreedily(rootOp, std::move(patterns));
|
||||
}
|
||||
|
||||
@@ -263,6 +270,19 @@ void TestLinalgTransforms::runOnOperation() {
|
||||
Operation *rootOp = getOperation();
|
||||
if (testFoldIntoPackAndUnpack)
|
||||
applyFoldIntoPackAndUnpackPatterns(rootOp);
|
||||
if (testFoldIntoPackAndUnpackWithControlFn) {
|
||||
linalg::ControlFoldIntoPackUnpackFn controlFn = [](OpOperand *opOperand) {
|
||||
Operation *producer = opOperand->get().getDefiningOp();
|
||||
Operation *consumer = opOperand->getOwner();
|
||||
// If we have a pack/unpack consumer and a producer that has multiple
|
||||
// uses, do not apply the folding patterns.
|
||||
if (isa<linalg::PackOp, linalg::UnPackOp>(consumer) &&
|
||||
isa<TilingInterface>(producer) && !producer->hasOneUse())
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
applyFoldIntoPackAndUnpackPatterns(rootOp, controlFn);
|
||||
}
|
||||
if (testSimplifyPackUnpackPatterns)
|
||||
applySimplifyPackUnpackPatterns(rootOp);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user