[mlir][Vector] Tighten up application conditions in TransferReadAfter… (#143869)

…WriteToBroadcast

The pattern would previously apply in spurious cases and generate
incorrect IR.

In the process, we disable the application of this pattern in the case
where there is no broadcast; this should be handled separately and may
more easily support masking.

The case {no-broadcast, yes-transpose} was previously caught by this
pattern and arguably could also generate incorrect IR (and was also
untested): this case does not apply anymore.

The last cast {yes-broadcast, yes-transpose} continues to apply but
should arguably be removed from the future because creating transposes
as part of canonicalization feels dangerous.
There are other patterns that move permutation logic:

- either into the transfer, or
- outside of the transfer

Ideally, this would be target-dependent and not a canonicalization (i.e.
does your DMA HW allow transpose on the fly or not) but this is beyond
the scope of this PR.

Co-authored-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
This commit is contained in:
Nicolas Vasilache
2025-06-12 17:11:06 +02:00
committed by GitHub
parent 62b6940900
commit e4de74ba11
2 changed files with 117 additions and 21 deletions

View File

@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
if (readOp.hasOutOfBoundsDim() ||
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
return failure();
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
if (!defWrite)
return failure();
// Bail if we need an alias analysis.
if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
return failure();
// Bail if we need a bounds analysis.
if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
return failure();
// TODO: If the written transfer chunk is a superset of the read transfer
// chunk we could do an extract_strided_slice.
if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,28 @@ struct TransferReadAfterWriteToBroadcast
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
return failure();
if (readOp.getIndices() != defWrite.getIndices() ||
readOp.getMask() != defWrite.getMask())
// This pattern should only catch the broadcast case, the non-broadcast case
// should be done separately to keep application conditions clean and
// separate.
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
bool bcast = !readMap.getBroadcastDims().empty() ||
!writeMap.getBroadcastDims().empty();
if (!bcast)
return failure();
// At this point, we know we have a bcast.
// Bail in the masked case (too complex atm and needed to properly account
// for padding).
if (readOp.getMask() || defWrite.getMask())
return failure();
// If indices are not the same a shift may be required, bail.
if (readOp.getIndices() != defWrite.getIndices())
return failure();
Value vec = defWrite.getVector();
// TODO: loop through the chain of transfer_write if we can prove that they
// don't overlap with the transfer_read. This requires improving
// `isDisjointTransferIndices` helper.
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
AffineMap map = readMap.compose(writeMap);
if (map.getNumResults() == 0)
return failure();

View File

@@ -408,7 +408,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
// -----
// Negative test where the extract is not a subset of the element inserted.
// CHECK-LABEL: extract_strided_fold_negative
// CHECK-LABEL: negative_extract_strided_fold
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]}
@@ -417,7 +417,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32>
// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32>
func.func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
func.func @negative_extract_strided_fold(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
-> (vector<6x4xf32>) {
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
@@ -753,10 +753,10 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
// -----
// CHECK-LABEL: fold_extract_broadcast_negative
// CHECK-LABEL: negative_fold_extract_broadcast
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> {
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
return %r : vector<4xf32>
@@ -895,11 +895,11 @@ func.func @fold_extract_shapecast_0d_source(%arg0 : vector<f32>) -> f32 {
// -----
// CHECK-LABEL: fold_extract_shapecast_negative
// CHECK-LABEL: negative_fold_extract_shapecast
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
// CHECK: return %[[R]] : vector<4x2xf32>
func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
%r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
return %r : vector<4x2xf32>
@@ -1460,11 +1460,11 @@ func.func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
// -----
// CHECK-LABEL: func @store_after_load_tensor_negative
// CHECK-LABEL: func @negative_store_after_load_tensor
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
// CHECK: return
func.func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
func.func @negative_store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
@@ -1499,12 +1499,12 @@ func.func @store_to_load_tensor(%arg0 : tensor<4x4xf32>,
// -----
// CHECK-LABEL: func @store_to_load_negative_tensor
// CHECK-LABEL: func @negative_store_to_load_tensor
// CHECK: vector.transfer_write
// CHECK: vector.transfer_write
// CHECK: %[[V:.*]] = vector.transfer_read
// CHECK: return %[[V]] : vector<1x4xf32>
func.func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
func.func @negative_store_to_load_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
@@ -1540,6 +1540,86 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
// -----
// CHECK-LABEL: func @negative_store_to_load_tensor_memref
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.transpose
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_memref(
%arg0 : tensor<?x?xf32>,
%arg1 : memref<?x?xf32>,
%v0 : vector<4x2xf32>
) -> vector<4x2xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
vector.transfer_write %v0, %arg1[%c0, %c0] {in_bounds = [true, true]} :
vector<4x2xf32>, memref<?x?xf32>
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
tensor<?x?xf32>, vector<4x2xf32>
return %0 : vector<4x2xf32>
}
// -----
// CHECK-LABEL: func @negative_store_to_load_tensor_no_actual_broadcast
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.transpose
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
%v0 : vector<4x2xf32>) -> vector<4x2xf32> {
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
vector<4x2xf32>, tensor<?x?xf32>
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
tensor<?x?xf32>, vector<4x2xf32>
return %0 : vector<4x2xf32>
}
// -----
// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_out_of_bounds
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.transpose
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<?x?xf32>,
%v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
vector<4x2xf32>, tensor<?x?xf32>
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
tensor<?x?xf32>, vector<4x2x6xf32>
return %0 : vector<4x2x6xf32>
}
// -----
// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_masked
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.transpose
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_masked(
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
-> vector<4x2x6xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} :
vector<4x2xf32>, tensor<?x?xf32>
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
tensor<?x?xf32>, vector<4x2x6xf32>
return %0 : vector<4x2x6xf32>
}
// -----
// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
@@ -1604,7 +1684,7 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
// -----
// CHECK-LABEL: func @dead_store_tensor_negative
// CHECK-LABEL: func @negative_dead_store_tensor
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: vector.transfer_write
@@ -1612,7 +1692,7 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
// CHECK: vector.transfer_read
// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]]
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
func.func @negative_dead_store_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
@@ -2063,10 +2143,10 @@ func.func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>)
// -----
// CHECK-LABEL: extract_insert_negative
// CHECK-LABEL: negative_extract_insert
// CHECK: vector.insert_strided_slice
// CHECK: vector.extract
func.func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
func.func @negative_extract_insert(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
-> vector<16xf32> {
%0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]}
: vector<2x15xf32> into vector<12x8x16xf32>