[mlir][tensor] Fold producer linalg transpose with consumer tensor pack (#75658)
Successor to https://github.com/llvm/llvm-project/pull/74206 Partial fix to https://github.com/openxla/iree/issues/15367
This commit is contained in:
committed by
GitHub
parent
9bde5becb4
commit
113bce0c79
@@ -9,6 +9,7 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
@@ -223,11 +224,52 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
|
||||
/// semantics.
|
||||
struct FoldConsumerPackWithProducerLinalgTransposeOp
|
||||
: public OpRewritePattern<PackOp> {
|
||||
using OpRewritePattern<PackOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(PackOp packOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
|
||||
|
||||
if (!transposeOp)
|
||||
return failure();
|
||||
|
||||
auto transposePermutation = transposeOp.getPermutation();
|
||||
auto outerDimsPerm = packOp.getOuterDimsPerm();
|
||||
auto innerDimsPos = packOp.getInnerDimsPos();
|
||||
SmallVector<int64_t> newInnerDimsPosVec;
|
||||
SmallVector<int64_t> newOuterDimsPermVec =
|
||||
llvm::to_vector(transposePermutation);
|
||||
|
||||
if (!outerDimsPerm.empty())
|
||||
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
|
||||
|
||||
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
|
||||
// permutation rank won't necessarily be equal in all cases.
|
||||
for (auto dim : innerDimsPos)
|
||||
newInnerDimsPosVec.push_back(transposePermutation[dim]);
|
||||
|
||||
Value output = packOp.createDestinationTensor(
|
||||
rewriter, packOp.getLoc(), transposeOp.getOperand(0),
|
||||
packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
|
||||
|
||||
rewriter.replaceOpWithNewOp<PackOp>(
|
||||
packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
|
||||
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
|
||||
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
|
||||
FoldProducerPackWithConsumerLinalgTransposeOp>(
|
||||
FoldProducerPackWithConsumerLinalgTransposeOp,
|
||||
FoldConsumerPackWithProducerLinalgTransposeOp>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
|
||||
@@ -345,3 +345,180 @@ func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_s
|
||||
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1, 2] inner_tiles = [%[[ARG3]], %[[ARG1]], %[[ARG2]]] into %[[INIT]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
|
||||
// CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x57x56x2x32xf32> {
|
||||
%0 = tensor.empty() : tensor<1x56x57x64xf32>
|
||||
%transposed = linalg.transpose
|
||||
ins(%arg0 : tensor<56x57x1x64xf32>)
|
||||
outs(%0 : tensor<1x56x57x64xf32>)
|
||||
permutation = [2, 0, 1, 3]
|
||||
|
||||
%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
|
||||
%pack = tensor.pack %transposed
|
||||
outer_dims_perm = [0, 2, 1, 3]
|
||||
inner_dims_pos = [3]
|
||||
inner_tiles = [32]
|
||||
into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
|
||||
return %pack : tensor<1x57x56x2x32xf32>
|
||||
}
|
||||
//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
|
||||
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
|
||||
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
|
||||
// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
|
||||
// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[INIT]]
|
||||
// CHECK: return %[[PACK]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @linalg_transpose_tensor_pack_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> {
|
||||
%0 = tensor.empty() : tensor<1x56x57x55xf32>
|
||||
%transpose = linalg.transpose
|
||||
ins(%arg0 : tensor<56x57x1x55xf32>)
|
||||
outs(%0 : tensor<1x56x57x55xf32>)
|
||||
permutation = [2, 0, 1, 3]
|
||||
|
||||
%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
|
||||
%pack = tensor.pack %transpose padding_value(%padding : f32)
|
||||
outer_dims_perm = [0, 2, 1, 3]
|
||||
inner_dims_pos = [3]
|
||||
inner_tiles = [32]
|
||||
into %1 : tensor<1x56x57x55xf32> -> tensor<1x57x56x2x32xf32>
|
||||
return %pack : tensor<1x57x56x2x32xf32>
|
||||
}
|
||||
//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold_with_padding(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x55xf32>, %[[PADDING:.+]]: f32)
|
||||
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
|
||||
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PADDING]] : f32)
|
||||
// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
|
||||
// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[INIT]]
|
||||
// CHECK: return %[[PACK]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x56x57x2x32xf32> {
|
||||
%0 = tensor.empty() : tensor<1x56x57x64xf32>
|
||||
%transposed = linalg.transpose
|
||||
ins(%arg0 : tensor<56x57x1x64xf32>)
|
||||
outs(%0 : tensor<1x56x57x64xf32>)
|
||||
permutation = [2, 0, 1, 3]
|
||||
|
||||
%1 = tensor.empty() : tensor<1x56x57x2x32xf32>
|
||||
%pack = tensor.pack %transposed
|
||||
inner_dims_pos = [3]
|
||||
inner_tiles = [32]
|
||||
into %1 : tensor<1x56x57x64xf32> -> tensor<1x56x57x2x32xf32>
|
||||
return %pack : tensor<1x56x57x2x32xf32>
|
||||
}
|
||||
//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
|
||||
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x56x57x2x32xf32>
|
||||
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
|
||||
// CHECK-SAME: outer_dims_perm = [2, 0, 1, 3]
|
||||
// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[INIT]]
|
||||
// CHECK: return %[[PACK]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change(%arg0: tensor<25x30x35x40xf32>, %transpose_dest: tensor<35x40x25x30xf32>, %pack_dest: tensor<3x35x5x8x5x10x5xf32>) -> tensor<3x35x5x8x5x10x5xf32> {
|
||||
%transposed = linalg.transpose
|
||||
ins(%arg0 : tensor<25x30x35x40xf32>)
|
||||
outs(%transpose_dest : tensor<35x40x25x30xf32>)
|
||||
permutation = [2, 3, 0, 1]
|
||||
|
||||
%pack = tensor.pack %transposed
|
||||
outer_dims_perm = [3, 0, 2, 1]
|
||||
inner_dims_pos = [1, 3, 2]
|
||||
inner_tiles = [5, 10, 5]
|
||||
into %pack_dest : tensor<35x40x25x30xf32> -> tensor<3x35x5x8x5x10x5xf32>
|
||||
return %pack : tensor<3x35x5x8x5x10x5xf32>
|
||||
}
|
||||
//CHECK-LABEL: func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<25x30x35x40xf32>,
|
||||
// CHECK-SAME: %[[ARG1:.+]]: tensor<35x40x25x30xf32>,
|
||||
// CHECK-SAME: %[[ARG2:.+]]: tensor<3x35x5x8x5x10x5xf32>) -> tensor<3x35x5x8x5x10x5xf32> {
|
||||
// CHECK: %[[VAL0:.+]] = tensor.empty() : tensor<3x35x5x8x5x10x5xf32>
|
||||
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
|
||||
// CHECK-SAME: outer_dims_perm = [1, 2, 0, 3]
|
||||
// CHECK-SAME: inner_dims_pos = [3, 1, 0]
|
||||
// CHECK-SAME: inner_tiles = [5, 10, 5]
|
||||
// CHECK-SAME: into %[[VAL0]]
|
||||
// CHECK: return %[[PACK]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %pack_dest: tensor<?x?x?x?x?x?x?xf32>, %tile_p : index, %tile_q : index, %tile_r : index) -> tensor<?x?x?x?x?x?x?xf32> {
|
||||
%transposed = linalg.transpose
|
||||
ins(%arg0 : tensor<?x?x?x?xf32>)
|
||||
outs(%transpose_dest : tensor<?x?x?x?xf32>)
|
||||
permutation = [2, 3, 0, 1]
|
||||
|
||||
%pack = tensor.pack %transposed
|
||||
outer_dims_perm = [3, 0, 2, 1]
|
||||
inner_dims_pos = [1, 3, 2]
|
||||
inner_tiles = [%tile_p, %tile_q, %tile_r]
|
||||
into %pack_dest : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
|
||||
return %pack : tensor<?x?x?x?x?x?x?xf32>
|
||||
}
|
||||
// CHECK: #[[map:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
|
||||
//CHECK-LABEL: func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>,
|
||||
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x?x?x?x?x?x?xf32>, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index) -> tensor<?x?x?x?x?x?x?xf32> {
|
||||
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK: %[[C2:.+]] = arith.constant 2 : index
|
||||
// CHECK: %[[C3:.+]] = arith.constant 3 : index
|
||||
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL0:.+]] = affine.apply #[[map:.+]]()[%[[DIM2]], %[[ARG3]]]
|
||||
// CHECK: %[[VAL1:.+]] = affine.apply #[[map:.+]]()[%[[DIM0]], %[[ARG4]]]
|
||||
// CHECK: %[[VAL2:.+]] = affine.apply #[[map:.+]]()[%[[DIM]], %[[ARG5]]]
|
||||
// CHECK: %[[VAL3:.+]] = tensor.empty(%[[VAL1]], %[[DIM1]], %[[VAL2]], %[[VAL0]], %[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?x?x?x?x?xf32>
|
||||
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [1, 2, 0, 3] inner_dims_pos = [3, 1, 0] inner_tiles = [%[[ARG3]], %[[ARG4]], %[[ARG5]]] into %[[VAL3]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
|
||||
// CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @linalg_transpose_tensor_pack_multiple_tiles(%arg0: tensor<?x32x128xbf16>) -> tensor<32x?x64x16x2xbf16> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant 0.000000e+00 : bf16
|
||||
%dim = tensor.dim %arg0, %c0 : tensor<?x32x128xbf16>
|
||||
|
||||
%0 = tensor.empty(%dim) : tensor<32x128x?xbf16>
|
||||
%transposed = linalg.transpose
|
||||
ins(%arg0 : tensor<?x32x128xbf16>)
|
||||
outs(%0 : tensor<32x128x?xbf16>)
|
||||
permutation = [1, 2, 0]
|
||||
|
||||
%2 = tensor.empty(%dim) : tensor<32x?x64x16x2xbf16>
|
||||
%pack = tensor.pack %transposed
|
||||
padding_value(%cst : bf16)
|
||||
outer_dims_perm = [0, 2, 1]
|
||||
inner_dims_pos = [2, 1]
|
||||
inner_tiles = [16, 2]
|
||||
into %2 : tensor<32x128x?xbf16> -> tensor<32x?x64x16x2xbf16>
|
||||
return %pack : tensor<32x?x64x16x2xbf16>
|
||||
}
|
||||
// CHECK: #[[map:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
|
||||
//CHECK-LABEL: func.func @linalg_transpose_tensor_pack_multiple_tiles(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x32x128xbf16>) -> tensor<32x?x64x16x2xbf16> {
|
||||
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : bf16
|
||||
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x128xbf16>
|
||||
// CHECK: %[[VAL0:.+]] = affine.apply #[[map:.+]]()[%[[DIM]]]
|
||||
// CHECK: %[[VAL1:.+]] = tensor.empty(%[[VAL0]]) : tensor<32x?x64x16x2xbf16>
|
||||
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
|
||||
// CHECK-SAME: padding_value(%[[CST]] : bf16)
|
||||
// CHECK-SAME: outer_dims_perm = [1, 0, 2]
|
||||
// CHECK-SAME: inner_dims_pos = [0, 2]
|
||||
// CHECK-SAME: inner_tiles = [16, 2]
|
||||
// CHECK-SAME: into %[[VAL1]] : tensor<?x32x128xbf16> -> tensor<32x?x64x16x2xbf16>
|
||||
// CHECK: return %[[PACK]] : tensor<32x?x64x16x2xbf16>
|
||||
// CHECK: }
|
||||
|
||||
Reference in New Issue
Block a user