[MLIR] Make generic skip packing init operand when not used in DataLayoutPropagation (#146139)
In both `bubbleUpPackOpThroughGenericOp()` or `pushDownUnPackOpThroughGenericOp()`, we can simplify the lowered IR by removing the pack of an empty when the init tensor isn't used in generic op. Instead of packing an empty tensor, the empty tensor can be forwarded to the generic output. This allows cleaner result after data layout propagation.
This commit is contained in:
@@ -358,6 +358,12 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
|
||||
return newGenericOp;
|
||||
}
|
||||
|
||||
static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
|
||||
return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) {
|
||||
return genericOp.getMatchingBlockArgument(&operand).use_empty();
|
||||
});
|
||||
}
|
||||
|
||||
/// Bubbles up linalg.pack op through a producer generic op. This
|
||||
/// swap pack(generic) to generic(pack). The new generic op works on packed
|
||||
/// domain; pack ops are created for input and output operands. E.g.,
|
||||
@@ -470,12 +476,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
|
||||
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
|
||||
genericOp, opOperand);
|
||||
|
||||
// If the dps init operand of the generic is a tensor.empty forward the pack
|
||||
// op destination.
|
||||
// Forward the new tensor.empty as a destination if it is one of the following
|
||||
// situations:
|
||||
// 1) The dps init operand is a tensor.empty.
|
||||
// 2) The dps init is a write-only operand, i.e., it is not used in the
|
||||
// genericOp
|
||||
Value dest = packedOutOperand;
|
||||
if (auto initTensor = genericOp.getDpsInitOperand(0)
|
||||
->get()
|
||||
.getDefiningOp<tensor::EmptyOp>()) {
|
||||
auto initTensor =
|
||||
genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
|
||||
if (initTensor || isGenericOutsNotUsed(genericOp)) {
|
||||
dest = packOpDest;
|
||||
}
|
||||
// pack(unpack) isn't naively foldable because the unpack op can be from
|
||||
@@ -1101,12 +1110,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
|
||||
genericOp, genericOp.getDpsInitOperand(0));
|
||||
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
|
||||
|
||||
// If the dps init operand of the generic is a tensor.empty, do not pack it
|
||||
// and forward the new tensor.empty as a destination.
|
||||
// Forward the new tensor.empty as a destination if it is one of the following
|
||||
// situations:
|
||||
// 1) The dps init operand is a tensor.empty.
|
||||
// 2) The dps init is a write-only operand, i.e., it is not used in the
|
||||
// genericOp
|
||||
Value dest = packedOutOperand;
|
||||
if (auto initTensor = genericOp.getDpsInitOperand(0)
|
||||
->get()
|
||||
.getDefiningOp<tensor::EmptyOp>()) {
|
||||
auto initTensor =
|
||||
genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
|
||||
if (initTensor || isGenericOutsNotUsed(genericOp)) {
|
||||
if (destPack)
|
||||
dest = destPack.getDest();
|
||||
}
|
||||
|
||||
@@ -435,6 +435,40 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func.func @elem_pack_transpose_outer_dims_unused_init(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
|
||||
%elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<128x256xi32>)
|
||||
outs(%init : tensor<128x256xi32>) {
|
||||
^bb0(%arg3: i32, %arg4: i32):
|
||||
%4 = arith.addi %arg3, %arg3 : i32
|
||||
linalg.yield %4 : i32
|
||||
} -> tensor<128x256xi32>
|
||||
%empty = tensor.empty() : tensor<16x4x32x16xi32>
|
||||
%pack = linalg.pack %elem
|
||||
outer_dims_perm = [1, 0]
|
||||
inner_dims_pos = [0, 1]
|
||||
inner_tiles = [32, 16]
|
||||
into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
|
||||
return %pack : tensor<16x4x32x16xi32>
|
||||
}
|
||||
|
||||
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
|
||||
// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
|
||||
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
|
||||
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[ARG0]]
|
||||
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
|
||||
// CHECK-SAME: into %[[ARG0_EMPTY]]
|
||||
// CHECK: %[[RES:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
|
||||
// CHECK-SAME: ins(%[[PACKED_ARG0]]
|
||||
// CHECK-SAME: outs(%[[ARG1_EMPTY]]
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
|
||||
func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
|
||||
@@ -497,7 +531,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
|
||||
|
||||
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
|
||||
func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
|
||||
func.func @unpack_element_type_change_no_use(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
|
||||
%0 = tensor.empty() : tensor<12x56x56x64xf32>
|
||||
%1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
|
||||
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
|
||||
@@ -509,17 +543,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
|
||||
}
|
||||
|
||||
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
|
||||
// CHECK-LABEL: func.func @unpack_element_type_change
|
||||
// CHECK-LABEL: func.func @unpack_element_type_change_no_use
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
|
||||
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
|
||||
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
|
||||
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
|
||||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
|
||||
// CHECK: %[[RES:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
|
||||
// CHECK-SAME: ins(%[[ARG0]]
|
||||
// CHECK-SAME: outs(%[[ARG1_PACK]]
|
||||
// CHECK-SAME: outs(%[[EMPTY]]
|
||||
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
|
||||
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[ARG1]]
|
||||
@@ -1402,13 +1433,10 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
|
||||
// CHECK: %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
|
||||
// CHECK: %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
|
||||
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
|
||||
// CHECK-SAME: into %[[ARG2_PACK_EMPTY]]
|
||||
// CHECK: %[[EMPTY:.+]] = tensor.empty
|
||||
// CHECK: %[[GENERIC:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
|
||||
// CHECK-SAME: outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
|
||||
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
|
||||
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
|
||||
// CHECK-SAME: into %[[ARG2]]
|
||||
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
|
||||
|
||||
Reference in New Issue
Block a user