From 541f33e0751d60b33e75efe0cd436396f27b91ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrzej=20Warzy=C5=84ski?= Date: Fri, 27 Jun 2025 13:18:15 +0100 Subject: [PATCH] [mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases (#145235) This patch adds additional checks to the hoisting logic to prevent hoisting of `vector.transfer_read` / `vector.transfer_write` pairs when the underlying memref has users that introduce aliases via operations implementing `ViewLikeOpInterface`. Note: This may conservatively block some valid hoisting opportunities and could affect performance. However, as demonstrated by the included tests, the current logic is too permissive and can lead to incorrect transformations. If this change prevents hoisting in cases that are provably safe, please share a minimal repro - I'm happy to explore ways to relax the check. Special treatment is given to `memref.assume_alignment`, mainly to accommodate recent updates in: * https://github.com/llvm/llvm-project/pull/139521 Note that such special casing does not scale and should generally be avoided. The current hoisting logic lacks robust alias analysis. While better support would require more work, the broader semantics of `memref.assume_alignment` remain somewhat unclear. It's possible this op may eventually be replaced with the "alignment" attribute added in: * https://github.com/llvm/llvm-project/pull/144344 --- .../Dialect/Linalg/Transforms/Hoisting.cpp | 45 +++- mlir/test/Dialect/Linalg/hoisting.mlir | 229 ++++++++++++++++++ 2 files changed, 266 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 707b63ff9335..d833e04d6026 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -303,23 +303,51 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, // 1. indices, vector type and permutation map are the same (i.e., the // transfer_read/transfer_write ops are matching), // 2. source operands for transfer.{read|write} do not originate from - // Ops implementing ViewLikeOpInterface. + // nor have users that are Ops implementing ViewLikeOpInterface. // 3. no other operations in the loop access the same memref except // for transfer_read/transfer_write accessing statically disjoint // slices. + + // Check 1. if (transferRead.getIndices() != transferWrite.getIndices() || transferRead.getVectorType() != transferWrite.getVectorType() || transferRead.getPermutationMap() != transferWrite.getPermutationMap()) return WalkResult::advance(); - auto *source = transferRead.getBase().getDefiningOp(); - if (source && isa_and_nonnull(source)) - return WalkResult::advance(); - - source = transferWrite.getBase().getDefiningOp(); - if (source && isa_and_nonnull(source)) + // Check 2. Note, since both xfer Ops share the source, we only need to + // look at one of them. + auto base = transferRead.getBase(); + auto *source = base.getDefiningOp(); + if (source) { + // NOTE: We treat `memref.assume_alignment` as a special case. + // + // The idea is that it is safe to look past AssumeAlignmemtOp (i.e. + // MemRef _before_ alignment) iff: + // 1. It has exactly two uses (these have to be the xfer Ops + // being looked at). + // 2. The original MemRef has only one use (i.e. + // AssumeAlignmentOp). + // + // Relaxing these conditions will most likely require proper alias + // analysis. + if (auto assume = dyn_cast(source)) { + Value memPreAlignment = assume.getMemref(); + auto numInLoopUses = + llvm::count_if(base.getUses(), [&loop](OpOperand &use) { + return loop->isAncestor(use.getOwner()); + }); + + if (numInLoopUses && memPreAlignment.hasOneUse()) + source = memPreAlignment.getDefiningOp(); + } + if (isa_and_nonnull(source)) + return WalkResult::advance(); + } + + if (llvm::any_of(base.getUsers(), llvm::IsaPred)) return WalkResult::advance(); + // Check 3. // TODO: may want to memoize this information for performance but it // likely gets invalidated often. DominanceInfo dom(loop); @@ -358,7 +386,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, // Hoist write after. transferWrite->moveAfter(loop); - // Rewrite `loop` with new yields by cloning and erase the original loop. + // Rewrite `loop` with new yields by cloning and erase the original + // loop. IRRewriter rewriter(transferRead.getContext()); NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 8be4e1b79c52..aa0b97a4787f 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -1,5 +1,234 @@ // RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s +///---------------------------------------------------------------------------------------- +/// Tests for vector.transfer_read + vector.transfer_write pairs +/// +/// * Nested inside a single loop +// * Indices are constant +///---------------------------------------------------------------------------------------- + +// The most basic example - hoisting is safe. + +// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair( +// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) { +func.func @hoist_basic_vector_xfer_pair( + %mem: memref, %lb : index, %ub : index, %step: index) { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref, vector<1xf32> +// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) { +// CHECK: %[[VAL_6:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32> +// CHECK: scf.yield %[[VAL_6]] : vector<1xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref + scf.for %i = %lb to %ub step %step { + %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref, vector<1xf32> + %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32> + vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// Similar as the example above, but hoisting is no longer safe. That's due to +// an extra xfer_write inside the loop. + +// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write( +// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) { +func.func @negative_hoist_basic_vector_xfer_pair_extra_write( + %mem: memref, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref +// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref, vector<1xf32> +// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32> +// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref +// CHECK: } + + scf.for %i = %lb to %ub step %step { + vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref + + %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref, vector<1xf32> + %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32> + vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// Similar as the example above, but hoisting is no longer safe. That's due to +// an extra xfer_write into _an alias_ of the %mem Op that is used by the +// original xfer pair. + +// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias( +// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) { +func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias( + %mem: memref, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref to memref<1x1xf32, strided<[?, 1]>> +// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref, vector<1xf32> +// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32> +// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref +// CHECK: } + + %sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref to memref<1x1xf32, strided<[?, 1]>> + scf.for %i = %lb to %ub step %step { + vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>> + + %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref, vector<1xf32> + %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32> + vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// Similar as the example above, but the memory access is done via +// memref.assume_alignment. Hoisting is safe as the only users of the +// "allignment" Op are the xfer Ops within the loop that we want to hoist. + +// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair_with_assume_align( +// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) { +func.func @hoist_basic_vector_xfer_pair_with_assume_align( + %mem: memref, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref +// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref, vector<1xf32> +// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) { +// CHECK: %[[USE:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[SCF]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref + + %aa = memref.assume_alignment %mem, 4 : memref + scf.for %i = %lb to %ub step %step { + %r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref, vector<1xf32> + %u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32> + vector.transfer_write %u0, %aa[%c0, %c0] : vector<1xf32>, memref + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// Similar as the example above, but hoisting is not safe due to extra memory +// access inside the loop via the original memref. + +// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align( +// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) { +func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align( + %mem: memref, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref +// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref, vector<1xf32> +// CHECK: "mem_use"(%[[MEM]]) +// CHECK: vector.transfer_write %[[READ]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref +// CHECK: } + + %aa = memref.assume_alignment %mem, 4 : memref + scf.for %i = %lb to %ub step %step { + %r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref, vector<1xf32> + "mem_use"(%mem) : (memref) -> () + vector.transfer_write %r0, %aa[%c0, %c0] : vector<1xf32>, memref + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + ///---------------------------------------------------------------------------------------- /// Tests for vector.transfer_read + vector.transfer_write pairs ///