[mlir][scf] Improve scf.parallel fusion pass (#75852)

Abort fusion if memref load may alias write, but not the exact alias. 
Add alias check hook to `naivelyFuseParallelOps`, so user can customize
alias checking.
Use builtin alias analysis in `ParallelLoopFusion` pass.
This commit is contained in:
Ivan Butygin
2023-12-19 18:07:46 +03:00
committed by GitHub
parent 9aeb3336fd
commit c0d2ea9d42
3 changed files with 68 additions and 15 deletions

View File

@@ -34,7 +34,10 @@ class ParallelOp;
/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.
void naivelyFuseParallelOps(Region &region);
/// User can additionally customize alias checking with `mayAlias` hook.
/// `mayAlias` must return false if 2 values are guaranteed to not alias.
void naivelyFuseParallelOps(Region &region,
llvm::function_ref<bool(Value, Value)> mayAlias);
/// Rewrite a for loop with bounds/step that potentially do not divide evenly
/// into a for loop where the step divides the iteration space evenly, followed

View File

@@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
@@ -58,19 +59,27 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
/// loop reads.
static bool haveNoReadsAfterWriteExceptSameIndex(
ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices) {
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias) {
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
SmallVector<Value> bufferStoresVec;
firstPloop.getBody()->walk([&](memref::StoreOp store) {
bufferStores[store.getMemRef()].push_back(store.getIndices());
bufferStoresVec.emplace_back(store.getMemRef());
});
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
Value loadMem = load.getMemRef();
// Stop if the memref is defined in secondPloop body. Careful alias analysis
// is needed.
auto *memrefDef = load.getMemRef().getDefiningOp();
auto *memrefDef = loadMem.getDefiningOp();
if (memrefDef && memrefDef->getBlock() == load->getBlock())
return WalkResult::interrupt();
auto write = bufferStores.find(load.getMemRef());
for (Value store : bufferStoresVec)
if (store != loadMem && mayAlias(store, loadMem))
return WalkResult::interrupt();
auto write = bufferStores.find(loadMem);
if (write == bufferStores.end())
return WalkResult::advance();
@@ -98,35 +107,39 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
/// write patterns.
static LogicalResult
verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices) {
if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
firstToSecondPloopIndices))
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias) {
if (!haveNoReadsAfterWriteExceptSameIndex(
firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
return failure();
IRMapping secondToFirstPloopIndices;
secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
firstPloop.getBody()->getArguments());
return success(haveNoReadsAfterWriteExceptSameIndex(
secondPloop, firstPloop, secondToFirstPloopIndices));
secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
}
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices) {
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias) {
return !hasNestedParallelOp(firstPloop) &&
!hasNestedParallelOp(secondPloop) &&
equalIterationSpaces(firstPloop, secondPloop) &&
succeeded(verifyDependencies(firstPloop, secondPloop,
firstToSecondPloopIndices));
firstToSecondPloopIndices, mayAlias));
}
/// Prepends operations of firstPloop's body into secondPloop's body.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
OpBuilder b) {
OpBuilder b,
llvm::function_ref<bool(Value, Value)> mayAlias) {
IRMapping firstToSecondPloopIndices;
firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
secondPloop.getBody()->getArguments());
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
mayAlias))
return;
b.setInsertionPointToStart(secondPloop.getBody());
@@ -135,7 +148,8 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
firstPloop.erase();
}
void mlir::scf::naivelyFuseParallelOps(Region &region) {
void mlir::scf::naivelyFuseParallelOps(
Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
OpBuilder b(region);
// Consider every single block and attempt to fuse adjacent loops.
for (auto &block : region) {
@@ -159,7 +173,7 @@ void mlir::scf::naivelyFuseParallelOps(Region &region) {
}
for (ArrayRef<ParallelOp> ploops : ploopChains) {
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
fuseIfLegal(ploops[i], ploops[i + 1], b);
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
}
}
}
@@ -168,9 +182,15 @@ namespace {
struct ParallelLoopFusion
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
void runOnOperation() override {
auto &AA = getAnalysis<AliasAnalysis>();
auto mayAlias = [&](Value val1, Value val2) -> bool {
return !AA.alias(val1, val2).isNo();
};
getOperation()->walk([&](Operation *child) {
for (Region &region : child->getRegions())
naivelyFuseParallelOps(region);
naivelyFuseParallelOps(region, mayAlias);
});
}
};

View File

@@ -357,3 +357,33 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// CHECK: }
// CHECK: }
// CHECK: memref.dealloc [[SUM]]
// -----
func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C: memref<2x2xf32>, %result: memref<2x2xf32>,
%sum: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.yield
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
scf.yield
}
return
}
// %sum and %result may alias with other args, do not fuse loops
// CHECK-LABEL: func @do_not_fuse_alias
// CHECK: scf.parallel
// CHECK: scf.parallel