[mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases (#145235)

This patch adds additional checks to the hoisting logic to prevent hoisting of
`vector.transfer_read` / `vector.transfer_write` pairs when the underlying
memref has users that introduce aliases via operations implementing
`ViewLikeOpInterface`.

Note: This may conservatively block some valid hoisting opportunities and could
affect performance. However, as demonstrated by the included tests, the current
logic is too permissive and can lead to incorrect transformations.

If this change prevents hoisting in cases that are provably safe, please share
a minimal repro - I'm happy to explore ways to relax the check.

Special treatment is given to `memref.assume_alignment`, mainly to accommodate
recent updates in:

* https://github.com/llvm/llvm-project/pull/139521

Note that such special casing does not scale and should generally be avoided.
The current hoisting logic lacks robust alias analysis. While better support
would require more work, the broader semantics of `memref.assume_alignment`
remain somewhat unclear. It's possible this op may eventually be replaced with
the "alignment" attribute added in:

* https://github.com/llvm/llvm-project/pull/144344
This commit is contained in:
Andrzej Warzyński
2025-06-27 13:18:15 +01:00
committed by GitHub
parent 7e2e030121
commit 541f33e075
2 changed files with 266 additions and 8 deletions

View File

@@ -303,23 +303,51 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
// 1. indices, vector type and permutation map are the same (i.e., the
// transfer_read/transfer_write ops are matching),
// 2. source operands for transfer.{read|write} do not originate from
// Ops implementing ViewLikeOpInterface.
// nor have users that are Ops implementing ViewLikeOpInterface.
// 3. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
// Check 1.
if (transferRead.getIndices() != transferWrite.getIndices() ||
transferRead.getVectorType() != transferWrite.getVectorType() ||
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
auto *source = transferRead.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
source = transferWrite.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
// Check 2. Note, since both xfer Ops share the source, we only need to
// look at one of them.
auto base = transferRead.getBase();
auto *source = base.getDefiningOp();
if (source) {
// NOTE: We treat `memref.assume_alignment` as a special case.
//
// The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
// MemRef _before_ alignment) iff:
// 1. It has exactly two uses (these have to be the xfer Ops
// being looked at).
// 2. The original MemRef has only one use (i.e.
// AssumeAlignmentOp).
//
// Relaxing these conditions will most likely require proper alias
// analysis.
if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
Value memPreAlignment = assume.getMemref();
auto numInLoopUses =
llvm::count_if(base.getUses(), [&loop](OpOperand &use) {
return loop->isAncestor(use.getOwner());
});
if (numInLoopUses && memPreAlignment.hasOneUse())
source = memPreAlignment.getDefiningOp();
}
if (isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
}
if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
return WalkResult::advance();
// Check 3.
// TODO: may want to memoize this information for performance but it
// likely gets invalidated often.
DominanceInfo dom(loop);
@@ -358,7 +386,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
// Hoist write after.
transferWrite->moveAfter(loop);
// Rewrite `loop` with new yields by cloning and erase the original loop.
// Rewrite `loop` with new yields by cloning and erase the original
// loop.
IRRewriter rewriter(transferRead.getContext());
NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> newBBArgs) {

View File

@@ -1,5 +1,234 @@
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
///----------------------------------------------------------------------------------------
/// Tests for vector.transfer_read + vector.transfer_write pairs
///
/// * Nested inside a single loop
// * Indices are constant
///----------------------------------------------------------------------------------------
// The most basic example - hoisting is safe.
// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair(
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) {
func.func @hoist_basic_vector_xfer_pair(
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
// CHECK: %[[VAL_6:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: scf.yield %[[VAL_6]] : vector<1xf32>
// CHECK: }
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
scf.for %i = %lb to %ub step %step {
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
// Similar as the example above, but hoisting is no longer safe. That's due to
// an extra xfer_write inside the loop.
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: }
scf.for %i = %lb to %ub step %step {
vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
// Similar as the example above, but hoisting is no longer safe. That's due to
// an extra xfer_write into _an alias_ of the %mem Op that is used by the
// original xfer pair.
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: }
%sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
scf.for %i = %lb to %ub step %step {
vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
// Similar as the example above, but the memory access is done via
// memref.assume_alignment. Hoisting is safe as the only users of the
// "allignment" Op are the xfer Ops within the loop that we want to hoist.
// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair_with_assume_align(
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
func.func @hoist_basic_vector_xfer_pair_with_assume_align(
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
// CHECK: %[[USE:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: }
// CHECK: vector.transfer_write %[[SCF]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
%aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
scf.for %i = %lb to %ub step %step {
%r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %u0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
// Similar as the example above, but hoisting is not safe due to extra memory
// access inside the loop via the original memref.
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: "mem_use"(%[[MEM]])
// CHECK: vector.transfer_write %[[READ]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: }
%aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
scf.for %i = %lb to %ub step %step {
%r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
vector.transfer_write %r0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
///----------------------------------------------------------------------------------------
/// Tests for vector.transfer_read + vector.transfer_write pairs
///