[mlir][vector] Avoid setting padding by default to 0 in vector.transfer_read prefer ub.poison (#146088)

Context:
`vector.transfer_read` always requires a padding value. Most of its
builders take no `padding` value and assume the safe value of `0`.
However, this should be a conscious choice by the API user, as it makes
it easy to introduce bugs.
For example, I found several occasions while making this patch that the
padding value was not getting propagated (`vector.transfer_read` was
transformed into another `vector.transfer_read`). These bugs, were
always caused because of constructors that don't require specifying
padding.

Additionally, using `ub.poison` as a possible default value is better,
as it indicates the user "doesn't care" about the actual padding value,
forcing users to specify the actual padding semantics they want.

With that in mind, this patch changes the builders in
`vector.transfer_read` to always having a `std::optional<Value> padding`
argument. This argument is never optional, but for convenience users can
pass `std::nullopt`, padding the transfer read with `ub.poison`.

---------

Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
This commit is contained in:
Fabian Mora
2025-06-30 15:20:42 -04:00
committed by GitHub
parent 6a57af8d03
commit 878d3594ed
15 changed files with 108 additions and 79 deletions

View File

@@ -21,7 +21,7 @@ func.func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for {{.*}} step 128
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector
%a0 = affine.load %A[%c0, %c0] : memref<?x?xf32>
@@ -47,7 +47,7 @@ func.func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
%P = memref.dim %B, %c2 : memref<?x?x?xf32>
// CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i3 = 0 to %M { // vectorized
%a3 = affine.load %A[%c0, %i3] : memref<?x?xf32>
@@ -76,7 +76,7 @@ func.func @vec1d_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK-NEXT: for [[IV9:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
// CHECK-NEXT: %[[APP9_0:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
// CHECK-NEXT: %[[APP9_1:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%[[APP9_0]], %[[APP9_1]]], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i8 = 0 to %M { // vectorized
affine.for %i9 = 0 to %N {
@@ -280,7 +280,7 @@ func.func @vec_rejected_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK:for [[IV4:%[0-9a-zA-Z_]+]] = 0 to [[ARG_M]] step 128 {
// CHECK-NEXT: for [[IV5:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{[a-zA-Z0-9_]*}} : memref<?x?xf32>, vector<128xf32>
affine.for %i4 = 0 to %M { // vectorized
affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1
@@ -424,7 +424,7 @@ func.func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK: %{{.*}} = ub.poison : f32
// CHECK: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %{{.*}} in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
@@ -458,7 +458,7 @@ func.func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector

View File

@@ -11,7 +11,7 @@ func.func @vec_affine_apply(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf3
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
@@ -42,7 +42,7 @@ func.func @vec_affine_apply_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48x
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 12 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID2]](%[[ARG4]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S1:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
@@ -140,7 +140,7 @@ func.func @affine_map_with_expr_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID3]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID4]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP_ID5]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S3]], %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }

View File

@@ -11,8 +11,8 @@
// CHECK-LABEL: @base_case
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]:
// CHECK: %[[PAD:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8>
@@ -36,8 +36,8 @@ func.func @base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<
// CHECK-LABEL: @with_3d_vector
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]:
// CHECK: %[[PAD:.+]] = arith.constant 0 : i8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
// CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<?x?x2x8xi8> into memref<?x?xi8>
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %[[PAD]] {in_bounds = [true]}

View File

@@ -85,8 +85,8 @@ func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}>
@@ -116,8 +116,8 @@ func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
// CHECK-SAME: -> vector<1x1x4x3x2xi8>
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
@@ -149,8 +149,8 @@ func.func @transfer_read_non_contiguous_unit_dims(
// CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i8
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
@@ -182,8 +182,8 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index,
// CHECK-SAME: %[[MEM:.+]]: memref<1x43x4x6xi32>
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
// CHECK: %[[C_0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
// CHECK-DAG: %[[C_0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1x43x24xi32>
@@ -241,8 +241,8 @@ func.func @transfer_read_leading_dynamic_dims(
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
@@ -304,8 +304,8 @@ func.func @transfer_read_dynamic_dim_to_flatten(
// CHECK-SAME: %[[IDX_1:arg0]]
// CHECK-SAME: %[[IDX_2:arg1]]
// CHECK-SAME: %[[MEM:arg2]]
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?x24xi32>

View File

@@ -1132,8 +1132,8 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
// CHECK-SCF-IF: gpu.barrier
// CHECK-SCF-IF: %[[WID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]]
// CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
// CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
// CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
// CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
// CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32>
return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32>
}