[mlir][scf] Improve scf.parallel fusion pass (#75852)
Abort fusion if memref load may alias write, but not the exact alias. Add alias check hook to `naivelyFuseParallelOps`, so user can customize alias checking. Use builtin alias analysis in `ParallelLoopFusion` pass.
This commit is contained in:
@@ -34,7 +34,10 @@ class ParallelOp;
|
||||
/// Fuses all adjacent scf.parallel operations with identical bounds and step
|
||||
/// into one scf.parallel operations. Uses a naive aliasing and dependency
|
||||
/// analysis.
|
||||
void naivelyFuseParallelOps(Region ®ion);
|
||||
/// User can additionally customize alias checking with `mayAlias` hook.
|
||||
/// `mayAlias` must return false if 2 values are guaranteed to not alias.
|
||||
void naivelyFuseParallelOps(Region ®ion,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias);
|
||||
|
||||
/// Rewrite a for loop with bounds/step that potentially do not divide evenly
|
||||
/// into a for loop where the step divides the iteration space evenly, followed
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
#include "mlir/Dialect/SCF/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Analysis/AliasAnalysis.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
||||
@@ -58,19 +59,27 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
|
||||
/// loop reads.
|
||||
static bool haveNoReadsAfterWriteExceptSameIndex(
|
||||
ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
const IRMapping &firstToSecondPloopIndices) {
|
||||
const IRMapping &firstToSecondPloopIndices,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
||||
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
|
||||
SmallVector<Value> bufferStoresVec;
|
||||
firstPloop.getBody()->walk([&](memref::StoreOp store) {
|
||||
bufferStores[store.getMemRef()].push_back(store.getIndices());
|
||||
bufferStoresVec.emplace_back(store.getMemRef());
|
||||
});
|
||||
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
|
||||
Value loadMem = load.getMemRef();
|
||||
// Stop if the memref is defined in secondPloop body. Careful alias analysis
|
||||
// is needed.
|
||||
auto *memrefDef = load.getMemRef().getDefiningOp();
|
||||
auto *memrefDef = loadMem.getDefiningOp();
|
||||
if (memrefDef && memrefDef->getBlock() == load->getBlock())
|
||||
return WalkResult::interrupt();
|
||||
|
||||
auto write = bufferStores.find(load.getMemRef());
|
||||
for (Value store : bufferStoresVec)
|
||||
if (store != loadMem && mayAlias(store, loadMem))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
auto write = bufferStores.find(loadMem);
|
||||
if (write == bufferStores.end())
|
||||
return WalkResult::advance();
|
||||
|
||||
@@ -98,35 +107,39 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
|
||||
/// write patterns.
|
||||
static LogicalResult
|
||||
verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
const IRMapping &firstToSecondPloopIndices) {
|
||||
if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
|
||||
firstToSecondPloopIndices))
|
||||
const IRMapping &firstToSecondPloopIndices,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
||||
if (!haveNoReadsAfterWriteExceptSameIndex(
|
||||
firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
|
||||
return failure();
|
||||
|
||||
IRMapping secondToFirstPloopIndices;
|
||||
secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
|
||||
firstPloop.getBody()->getArguments());
|
||||
return success(haveNoReadsAfterWriteExceptSameIndex(
|
||||
secondPloop, firstPloop, secondToFirstPloopIndices));
|
||||
secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
|
||||
}
|
||||
|
||||
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
const IRMapping &firstToSecondPloopIndices) {
|
||||
const IRMapping &firstToSecondPloopIndices,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
||||
return !hasNestedParallelOp(firstPloop) &&
|
||||
!hasNestedParallelOp(secondPloop) &&
|
||||
equalIterationSpaces(firstPloop, secondPloop) &&
|
||||
succeeded(verifyDependencies(firstPloop, secondPloop,
|
||||
firstToSecondPloopIndices));
|
||||
firstToSecondPloopIndices, mayAlias));
|
||||
}
|
||||
|
||||
/// Prepends operations of firstPloop's body into secondPloop's body.
|
||||
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
OpBuilder b) {
|
||||
OpBuilder b,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
||||
IRMapping firstToSecondPloopIndices;
|
||||
firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
|
||||
secondPloop.getBody()->getArguments());
|
||||
|
||||
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
|
||||
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
|
||||
mayAlias))
|
||||
return;
|
||||
|
||||
b.setInsertionPointToStart(secondPloop.getBody());
|
||||
@@ -135,7 +148,8 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
firstPloop.erase();
|
||||
}
|
||||
|
||||
void mlir::scf::naivelyFuseParallelOps(Region ®ion) {
|
||||
void mlir::scf::naivelyFuseParallelOps(
|
||||
Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) {
|
||||
OpBuilder b(region);
|
||||
// Consider every single block and attempt to fuse adjacent loops.
|
||||
for (auto &block : region) {
|
||||
@@ -159,7 +173,7 @@ void mlir::scf::naivelyFuseParallelOps(Region ®ion) {
|
||||
}
|
||||
for (ArrayRef<ParallelOp> ploops : ploopChains) {
|
||||
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
|
||||
fuseIfLegal(ploops[i], ploops[i + 1], b);
|
||||
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -168,9 +182,15 @@ namespace {
|
||||
struct ParallelLoopFusion
|
||||
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
|
||||
void runOnOperation() override {
|
||||
auto &AA = getAnalysis<AliasAnalysis>();
|
||||
|
||||
auto mayAlias = [&](Value val1, Value val2) -> bool {
|
||||
return !AA.alias(val1, val2).isNo();
|
||||
};
|
||||
|
||||
getOperation()->walk([&](Operation *child) {
|
||||
for (Region ®ion : child->getRegions())
|
||||
naivelyFuseParallelOps(region);
|
||||
naivelyFuseParallelOps(region, mayAlias);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -357,3 +357,33 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: memref.dealloc [[SUM]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
|
||||
%C: memref<2x2xf32>, %result: memref<2x2xf32>,
|
||||
%sum: memref<2x2xf32>) {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
|
||||
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
|
||||
%sum_elem = arith.addf %B_elem, %C_elem : f32
|
||||
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
|
||||
scf.yield
|
||||
}
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
|
||||
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
|
||||
%product_elem = arith.mulf %sum_elem, %A_elem : f32
|
||||
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
|
||||
scf.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// %sum and %result may alias with other args, do not fuse loops
|
||||
// CHECK-LABEL: func @do_not_fuse_alias
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: scf.parallel
|
||||
|
||||
Reference in New Issue
Block a user