[mlir][Linalg] Fix fusing of indexed linalg consumer with different axes (#140892)

When fusing two `linalg.genericOp`, where the producer has index
semantics, invalid `affine.apply` ops can be generated where the number
of indices do not match the number of loops in the fused genericOp.

This patch fixes the issue by directly using the number of loops from
the generated fused op.
This commit is contained in:
Simone Pellegrini
2025-06-13 11:03:09 +02:00
committed by GitHub
parent 2d49bc01cf
commit 4b59b7b946
2 changed files with 38 additions and 2 deletions

View File

@@ -231,8 +231,7 @@ static void generateFusedElementwiseOpRegion(
// `consumerToProducerLoopsMap` to map the producer indices.
if (producer.hasIndexSemantics()) {
// Add an index operation for every fused loop dimension.
unsigned numFusedOpLoops =
std::max(producer.getNumLoops(), consumer.getNumLoops());
unsigned numFusedOpLoops = fusedOp.getNumLoops();
SmallVector<Value> fusedIndices;
fusedIndices.reserve(numFusedOpLoops);
llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),

View File

@@ -860,6 +860,43 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
// -----
func.func @fusion_different_axes_indexed(%arg0: tensor<2x2xi32>) -> tensor<2xi32> {
%0 = tensor.empty() : tensor<2x2xi32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
^bb0(%in: i32, %out: i32):
%2 = linalg.index 1 : index
%3 = arith.index_cast %2 : index to i32
linalg.yield %3 : i32
} -> tensor<2x2xi32>
%4 = tensor.empty() : tensor<2xi32>
%5 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0, 1)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%1 : tensor<2x2xi32>) outs(%4 : tensor<2xi32>) {
^bb0(%in: i32, %out: i32):
linalg.yield %in : i32
} -> tensor<2xi32>
return %5 : tensor<2xi32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
// CHECK: func @fusion_different_axes_indexed(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x2xi32>
// CHECK-DAG: %[[CST:.+]] = arith.constant 1 : i32
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<2xi32>
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]]]
// CHECK-SAME: outs(%[[INIT]] :
// CHECK-NEXT: ^bb0(
// CHECK-SAME: %[[B0:.+]]: i32
// CHECK: linalg.yield %[[CST]] : i32
// CHECK: return %[[RESULT]]
// -----
// CHECK-LABEL: func @fold_fill_generic_basic
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-NOT: linalg.fill