[mlir] Fix DataLayoutPropagation foldings invalidating IR (#140103)
Fixes a bug in DataLayoutPropagation that was replacing generic op
destinations with tensor.empty() ops, even when the destination operand
was being used.
Addresses post-merge comment:
a9c1dccc3f (r2091193712)
Signed-off-by: Max Dawkins <maxdawkins19@gmail.com>
Co-authored-by: Max Dawkins <maxdawkins19@gmail.com>
This commit is contained in:
@@ -312,10 +312,17 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
|
||||
SmallVector<Value> inputOperands;
|
||||
SmallVector<Value> inputOperandsFromUnpackedSource;
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
|
||||
return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
|
||||
packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
|
||||
llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
|
||||
};
|
||||
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
|
||||
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
|
||||
rewriter, loc, packInfo, genericOp, inputOperand);
|
||||
if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
|
||||
auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
|
||||
auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
|
||||
if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
|
||||
inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
|
||||
} else {
|
||||
inputOperandsFromUnpackedSource.push_back(packedOperand);
|
||||
@@ -324,14 +331,16 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
|
||||
indexingMaps.push_back(packedIndexingMap);
|
||||
}
|
||||
|
||||
// If the pack and unpack op can be folded:
|
||||
// 1) use unpack op source op for operand to fold unpack -> pack sequence.
|
||||
// 2) init tensor of the generic op can be replaced by the destination of the
|
||||
// pack op.
|
||||
// If the unpack->pack sequences can be folded, replace use the sources of
|
||||
// the unpack ops in any unpack->pack chains on the generic op operands.
|
||||
if (isFoldableUnpackPack) {
|
||||
inputOperands = inputOperandsFromUnpackedSource;
|
||||
if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
|
||||
dest = destPack.getDest();
|
||||
if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) {
|
||||
auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
|
||||
if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
|
||||
dest = destUnPack.getSource();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t numInnerLoops = packInfo.getNumTiledLoops();
|
||||
|
||||
@@ -455,10 +455,9 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
|
||||
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
|
||||
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
|
||||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
|
||||
// CHECK: %[[RES:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP]]]
|
||||
// CHECK-SAME: outs(%[[EMPTY]]
|
||||
// CHECK-SAME: outs(%[[ARG0]]
|
||||
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
|
||||
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[UNPACKED_ARG0]]
|
||||
@@ -482,11 +481,14 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
|
||||
// CHECK-LABEL: func.func @unpack_on_input
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
|
||||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
|
||||
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
|
||||
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
|
||||
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
|
||||
// CHECK: %[[RES:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
|
||||
// CHECK-SAME: ins(%[[ARG0]]
|
||||
// CHECK-SAME: outs(%[[EMPTY]]
|
||||
// CHECK-SAME: outs(%[[ARG1_PACK]]
|
||||
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
|
||||
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[ARG1]]
|
||||
@@ -510,11 +512,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
|
||||
// CHECK-LABEL: func.func @unpack_element_type_change
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
|
||||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
|
||||
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
|
||||
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
|
||||
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
|
||||
// CHECK: %[[RES:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
|
||||
// CHECK-SAME: ins(%[[ARG0]]
|
||||
// CHECK-SAME: outs(%[[EMPTY]]
|
||||
// CHECK-SAME: outs(%[[ARG1_PACK]]
|
||||
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
|
||||
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
|
||||
// CHECK-SAME: into %[[ARG1]]
|
||||
@@ -1397,10 +1402,13 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
|
||||
// CHECK: %[[EMPTY:.+]] = tensor.empty
|
||||
// CHECK: %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
|
||||
// CHECK: %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
|
||||
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
|
||||
// CHECK-SAME: into %[[ARG2_PACK_EMPTY]]
|
||||
// CHECK: %[[GENERIC:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
|
||||
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
|
||||
// CHECK-SAME: outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
|
||||
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
|
||||
// CHECK-SAME: into %[[ARG2]]
|
||||
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
|
||||
@@ -1419,10 +1427,13 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
|
||||
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
|
||||
// CHECK: %[[EMPTY:.+]] = tensor.empty
|
||||
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty
|
||||
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
|
||||
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
|
||||
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
|
||||
// CHECK: %[[GENERIC:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
|
||||
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xf32>)
|
||||
// CHECK-SAME: outs(%[[ARG1_PACK]] : tensor<?x8x4x8xf32>)
|
||||
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
|
||||
// CHECK-SAME: into %[[ARG1]]
|
||||
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>
|
||||
|
||||
Reference in New Issue
Block a user