[mlir][linalg] Allow promotion to use the original subview size (#144334)

linalg promotion attempts to compute a constant upper bound for the
allocated buffer size. Only when failed to compute an upperbound it
fallbacks to the original subview size, which may be dynamic.

Adding a promotion option to use the original subview size by default,
thus minimizing the allocation size.
Fixes #144268.
This commit is contained in:
zbenzion
2025-07-02 09:47:51 +03:00
committed by GitHub
parent 3c6cade485
commit b68e8f1de7
5 changed files with 80 additions and 4 deletions

View File

@@ -42,3 +42,59 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
// -----
func.func @matmul_f32(%A: memref<512x256xf32>, %B: memref<256x512xf32>, %C: memref<256x256xf32>, %s0: index, %s1: index, %s2: index) {
%c0 = arith.constant 0 : index
%c256 = arith.constant 256 : index
%c512 = arith.constant 512 : index
scf.for %arg4 = %c0 to %c512 step %s0 {
scf.for %arg5 = %c0 to %c512 step %s1 {
scf.for %arg6 = %c0 to %c256 step %s2 {
%i0 = affine.min affine_map<(d0)[s0] -> (-d0 + 512, s0)>(%arg4)[%s0]
%i1 = affine.min affine_map<(d0)[s0] -> (-d0 + 512, s0)>(%arg5)[%s1]
%i2 = affine.min affine_map<(d0)[s0] -> (-d0 + 256, s0)>(%arg6)[%s2]
%0 = memref.subview %A[%arg4, %arg6][%i0, %i2][1, 1] : memref<512x256xf32> to memref<?x?xf32, strided<[256, 1], offset: ?>>
%1 = memref.subview %B[%arg6, %arg5][%i2, %i1][1, 1] : memref<256x512xf32> to memref<?x?xf32, strided<[512, 1], offset: ?>>
%2 = memref.subview %C[%arg4, %arg5][%i0, %i1][1, 1] : memref<256x256xf32> to memref<?x?xf32, strided<[256, 1], offset: ?>>
linalg.matmul
ins(%0, %1: memref<?x?xf32, strided<[256, 1], offset: ?>>,
memref<?x?xf32, strided<[512, 1], offset: ?>>)
outs(%2: memref<?x?xf32, strided<[256, 1], offset: ?>>)
}
}
}
return
}
// CHECK-LABEL: func.func @matmul_f32(
// CHECK-SAME: %[[ARG0:.*]]: memref<512x256xf32>
// CHECK-SAME: %[[ARG1:.*]]: memref<256x512xf32>
// CHECK-SAME: %[[ARG2:.*]]: memref<256x256xf32>
// CHECK-SAME: %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[i0:.*]] = affine.min
// CHECK: %[[i1:.*]] = affine.min
// CHECK: %[[i2:.*]] = affine.min
// CHECK: %[[VAL_13:.*]] = arith.muli %[[i0]], %[[i2]] : index
// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_13]], %[[C4]] : index
// CHECK: %[[VAL_15:.*]] = memref.alloc(%[[VAL_14]]) : memref<?xi8>
// CHECK: %[[VAL_18:.*]] = arith.muli %[[i2]], %[[i1]] : index
// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_18]], %[[C4]] : index
// CHECK: %[[VAL_20:.*]] = memref.alloc(%[[VAL_19]]) : memref<?xi8>
// CHECK: %[[VAL_23:.*]] = arith.muli %[[i0]], %[[i1]] : index
// CHECK: %[[VAL_24:.*]] = arith.muli %[[VAL_23]], %[[C4]] : index
// CHECK: %[[VAL_25:.*]] = memref.alloc(%[[VAL_24]]) : memref<?xi8>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.promote %0 { use_original_subview_size } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}