[mlir][Linalg] Fix fusing of indexed linalg consumer with different axes (#140892)
When fusing two `linalg.genericOp`, where the producer has index semantics, invalid `affine.apply` ops can be generated where the number of indices do not match the number of loops in the fused genericOp. This patch fixes the issue by directly using the number of loops from the generated fused op.
This commit is contained in:
committed by
GitHub
parent
2d49bc01cf
commit
4b59b7b946
@@ -231,8 +231,7 @@ static void generateFusedElementwiseOpRegion(
|
||||
// `consumerToProducerLoopsMap` to map the producer indices.
|
||||
if (producer.hasIndexSemantics()) {
|
||||
// Add an index operation for every fused loop dimension.
|
||||
unsigned numFusedOpLoops =
|
||||
std::max(producer.getNumLoops(), consumer.getNumLoops());
|
||||
unsigned numFusedOpLoops = fusedOp.getNumLoops();
|
||||
SmallVector<Value> fusedIndices;
|
||||
fusedIndices.reserve(numFusedOpLoops);
|
||||
llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
|
||||
|
||||
@@ -860,6 +860,43 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fusion_different_axes_indexed(%arg0: tensor<2x2xi32>) -> tensor<2xi32> {
|
||||
%0 = tensor.empty() : tensor<2x2xi32>
|
||||
%1 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
|
||||
^bb0(%in: i32, %out: i32):
|
||||
%2 = linalg.index 1 : index
|
||||
%3 = arith.index_cast %2 : index to i32
|
||||
linalg.yield %3 : i32
|
||||
} -> tensor<2x2xi32>
|
||||
%4 = tensor.empty() : tensor<2xi32>
|
||||
%5 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0) -> (d0, 1)>, affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%1 : tensor<2x2xi32>) outs(%4 : tensor<2xi32>) {
|
||||
^bb0(%in: i32, %out: i32):
|
||||
linalg.yield %in : i32
|
||||
} -> tensor<2xi32>
|
||||
return %5 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
|
||||
// CHECK: func @fusion_different_axes_indexed(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x2xi32>
|
||||
// CHECK-DAG: %[[CST:.+]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<2xi32>
|
||||
// CHECK: %[[RESULT:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP]]]
|
||||
// CHECK-SAME: outs(%[[INIT]] :
|
||||
// CHECK-NEXT: ^bb0(
|
||||
// CHECK-SAME: %[[B0:.+]]: i32
|
||||
// CHECK: linalg.yield %[[CST]] : i32
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_fill_generic_basic
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK-NOT: linalg.fill
|
||||
|
||||
Reference in New Issue
Block a user