[MLIR] Make generic skip packing init operand when not used in DataLayoutPropagation (#146139)

In both `bubbleUpPackOpThroughGenericOp()` or
`pushDownUnPackOpThroughGenericOp()`, we can simplify the lowered IR by
removing the pack of an empty when the init tensor isn't used in generic
op. Instead of packing an empty tensor, the empty tensor can be
forwarded to the generic output. This allows cleaner result after data
layout propagation.
This commit is contained in:
Zhuoran Yin
2025-07-01 09:39:30 -04:00
committed by GitHub
parent 08cf6ae537
commit 8cfd9b8821
2 changed files with 62 additions and 22 deletions

View File

@@ -358,6 +358,12 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
return newGenericOp;
}
static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) {
return genericOp.getMatchingBlockArgument(&operand).use_empty();
});
}
/// Bubbles up linalg.pack op through a producer generic op. This
/// swap pack(generic) to generic(pack). The new generic op works on packed
/// domain; pack ops are created for input and output operands. E.g.,
@@ -470,12 +476,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, opOperand);
// If the dps init operand of the generic is a tensor.empty forward the pack
// op destination.
// Forward the new tensor.empty as a destination if it is one of the following
// situations:
// 1) The dps init operand is a tensor.empty.
// 2) The dps init is a write-only operand, i.e., it is not used in the
// genericOp
Value dest = packedOutOperand;
if (auto initTensor = genericOp.getDpsInitOperand(0)
->get()
.getDefiningOp<tensor::EmptyOp>()) {
auto initTensor =
genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
if (initTensor || isGenericOutsNotUsed(genericOp)) {
dest = packOpDest;
}
// pack(unpack) isn't naively foldable because the unpack op can be from
@@ -1101,12 +1110,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
genericOp, genericOp.getDpsInitOperand(0));
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
// If the dps init operand of the generic is a tensor.empty, do not pack it
// and forward the new tensor.empty as a destination.
// Forward the new tensor.empty as a destination if it is one of the following
// situations:
// 1) The dps init operand is a tensor.empty.
// 2) The dps init is a write-only operand, i.e., it is not used in the
// genericOp
Value dest = packedOutOperand;
if (auto initTensor = genericOp.getDpsInitOperand(0)
->get()
.getDefiningOp<tensor::EmptyOp>()) {
auto initTensor =
genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
if (initTensor || isGenericOutsNotUsed(genericOp)) {
if (destPack)
dest = destPack.getDest();
}

View File

@@ -435,6 +435,40 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
// -----
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @elem_pack_transpose_outer_dims_unused_init(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
%elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<128x256xi32>)
outs(%init : tensor<128x256xi32>) {
^bb0(%arg3: i32, %arg4: i32):
%4 = arith.addi %arg3, %arg3 : i32
linalg.yield %4 : i32
} -> tensor<128x256xi32>
%empty = tensor.empty() : tensor<16x4x32x16xi32>
%pack = linalg.pack %elem
outer_dims_perm = [1, 0]
inner_dims_pos = [0, 1]
inner_tiles = [32, 16]
into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
return %pack : tensor<16x4x32x16xi32>
}
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
// CHECK-SAME: into %[[ARG0_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: outs(%[[ARG1_EMPTY]]
// -----
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
@@ -497,7 +531,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
func.func @unpack_element_type_change_no_use(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
%0 = tensor.empty() : tensor<12x56x56x64xf32>
%1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
@@ -509,17 +543,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
}
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func.func @unpack_element_type_change
// CHECK-LABEL: func.func @unpack_element_type_change_no_use
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0]]
// CHECK-SAME: outs(%[[ARG1_PACK]]
// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
@@ -1402,13 +1433,10 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK: %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
// CHECK: %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
// CHECK-SAME: into %[[ARG2_PACK_EMPTY]]
// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
// CHECK-SAME: outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
// CHECK-SAME: into %[[ARG2]]
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>