diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 1f5af39e604e..f97ed3d6d511 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -231,8 +231,7 @@ static void generateFusedElementwiseOpRegion( // `consumerToProducerLoopsMap` to map the producer indices. if (producer.hasIndexSemantics()) { // Add an index operation for every fused loop dimension. - unsigned numFusedOpLoops = - std::max(producer.getNumLoops(), consumer.getNumLoops()); + unsigned numFusedOpLoops = fusedOp.getNumLoops(); SmallVector fusedIndices; fusedIndices.reserve(numFusedOpLoops); llvm::transform(llvm::seq(0, numFusedOpLoops), diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 28e1291bce1f..66fc55fadf8f 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -860,6 +860,43 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi // ----- +func.func @fusion_different_axes_indexed(%arg0: tensor<2x2xi32>) -> tensor<2xi32> { + %0 = tensor.empty() : tensor<2x2xi32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) { + ^bb0(%in: i32, %out: i32): + %2 = linalg.index 1 : index + %3 = arith.index_cast %2 : index to i32 + linalg.yield %3 : i32 + } -> tensor<2x2xi32> + %4 = tensor.empty() : tensor<2xi32> + %5 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0, 1)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%1 : tensor<2x2xi32>) outs(%4 : tensor<2xi32>) { + ^bb0(%in: i32, %out: i32): + linalg.yield %in : i32 + } -> tensor<2xi32> + return %5 : tensor<2xi32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)> +// CHECK: func @fusion_different_axes_indexed( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x2xi32> +// CHECK-DAG: %[[CST:.+]] = arith.constant 1 : i32 +// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<2xi32> +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]]] +// CHECK-SAME: outs(%[[INIT]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B0:.+]]: i32 +// CHECK: linalg.yield %[[CST]] : i32 +// CHECK: return %[[RESULT]] + +// ----- + // CHECK-LABEL: func @fold_fill_generic_basic // CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { // CHECK-NOT: linalg.fill