From 4b59b7b94608ddbd21d14bec68400f2eb21f510d Mon Sep 17 00:00:00 2001 From: Simone Pellegrini Date: Fri, 13 Jun 2025 11:03:09 +0200 Subject: [PATCH] [mlir][Linalg] Fix fusing of indexed linalg consumer with different axes (#140892) When fusing two `linalg.genericOp`, where the producer has index semantics, invalid `affine.apply` ops can be generated where the number of indices do not match the number of loops in the fused genericOp. This patch fixes the issue by directly using the number of loops from the generated fused op. --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 3 +- .../Linalg/fusion-elementwise-ops.mlir | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 1f5af39e604e..f97ed3d6d511 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -231,8 +231,7 @@ static void generateFusedElementwiseOpRegion( // `consumerToProducerLoopsMap` to map the producer indices. if (producer.hasIndexSemantics()) { // Add an index operation for every fused loop dimension. - unsigned numFusedOpLoops = - std::max(producer.getNumLoops(), consumer.getNumLoops()); + unsigned numFusedOpLoops = fusedOp.getNumLoops(); SmallVector fusedIndices; fusedIndices.reserve(numFusedOpLoops); llvm::transform(llvm::seq(0, numFusedOpLoops), diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 28e1291bce1f..66fc55fadf8f 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -860,6 +860,43 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi // ----- +func.func @fusion_different_axes_indexed(%arg0: tensor<2x2xi32>) -> tensor<2xi32> { + %0 = tensor.empty() : tensor<2x2xi32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) { + ^bb0(%in: i32, %out: i32): + %2 = linalg.index 1 : index + %3 = arith.index_cast %2 : index to i32 + linalg.yield %3 : i32 + } -> tensor<2x2xi32> + %4 = tensor.empty() : tensor<2xi32> + %5 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0, 1)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%1 : tensor<2x2xi32>) outs(%4 : tensor<2xi32>) { + ^bb0(%in: i32, %out: i32): + linalg.yield %in : i32 + } -> tensor<2xi32> + return %5 : tensor<2xi32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)> +// CHECK: func @fusion_different_axes_indexed( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x2xi32> +// CHECK-DAG: %[[CST:.+]] = arith.constant 1 : i32 +// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<2xi32> +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]]] +// CHECK-SAME: outs(%[[INIT]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B0:.+]]: i32 +// CHECK: linalg.yield %[[CST]] : i32 +// CHECK: return %[[RESULT]] + +// ----- + // CHECK-LABEL: func @fold_fill_generic_basic // CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { // CHECK-NOT: linalg.fill