Files
clang-p2996/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
Eugene Zhulenev a86a9b5ef7 [mlir] Automatic reference counting for Async values + runtime support for ref counted objects
Depends On D89963

**Automatic reference counting algorithm outline:**

1. `ReturnLike` operations forward the reference counted values without
    modifying the reference count.
2. Use liveness analysis to find blocks in the CFG where the lifetime of
   reference counted values ends, and insert `drop_ref` operations after
   the last use of the value.
3. Insert `add_ref` before the `async.execute` operation capturing the
   value, and pairing `drop_ref` before the async body region terminator,
   to release the captured reference counted value when execution
   completes.
4. If the reference counted value is passed only to some of the block
   successors, insert `drop_ref` operations in the beginning of the blocks
   that do not have reference coutned value uses.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D90716
2020-11-20 03:08:44 -08:00

219 lines
7.0 KiB
C++

//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Optimize Async dialect reference counting operations.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "llvm/ADT/SmallSet.h"
using namespace mlir;
using namespace mlir::async;
#define DEBUG_TYPE "async-ref-counting"
namespace {
class AsyncRefCountingOptimizationPass
: public AsyncRefCountingOptimizationBase<
AsyncRefCountingOptimizationPass> {
public:
AsyncRefCountingOptimizationPass() = default;
void runOnFunction() override;
private:
LogicalResult optimizeReferenceCounting(Value value);
};
} // namespace
LogicalResult
AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) {
Region *definingRegion = value.getParentRegion();
// Find all users of the `value` inside each block, including operations that
// do not use `value` directly, but have a direct use inside nested region(s).
//
// Example:
//
// ^bb1:
// %token = ...
// scf.if %cond {
// ^bb2:
// async.await %token : !async.token
// }
//
// %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`).
//
// In addition to the operation that uses the `value` we also keep track if
// this user is an `async.execute` operation itself, or has `async.execute`
// operations in the nested regions that do use the `value`.
struct UserInfo {
Operation *operation;
bool hasExecuteUser;
};
struct BlockUsersInfo {
llvm::SmallVector<AddRefOp, 4> addRefs;
llvm::SmallVector<DropRefOp, 4> dropRefs;
llvm::SmallVector<UserInfo, 4> users;
};
llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
auto updateBlockUsersInfo = [&](UserInfo user) {
BlockUsersInfo &info = blockUsers[user.operation->getBlock()];
info.users.push_back(user);
if (auto addRef = dyn_cast<AddRefOp>(user.operation))
info.addRefs.push_back(addRef);
if (auto dropRef = dyn_cast<DropRefOp>(user.operation))
info.dropRefs.push_back(dropRef);
};
for (Operation *user : value.getUsers()) {
bool isAsyncUser = isa<ExecuteOp>(user);
while (user->getParentRegion() != definingRegion) {
updateBlockUsersInfo({user, isAsyncUser});
user = user->getParentOp();
isAsyncUser |= isa<ExecuteOp>(user);
assert(user != nullptr && "value user lies outside of the value region");
}
updateBlockUsersInfo({user, isAsyncUser});
}
// Sort all operations found in the block.
auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
return a->isBeforeInBlock(b);
};
llvm::sort(info.addRefs, isBeforeInBlock);
llvm::sort(info.dropRefs, isBeforeInBlock);
llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool {
return isBeforeInBlock(a.operation, b.operation);
});
return info;
};
// Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
// blocks that modify the reference count of the `value`.
for (auto &kv : blockUsers) {
BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
// Find all cancellable pairs first and erase them later to keep all
// pointers in the `info` valid until the end.
//
// Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
llvm::SmallDenseMap<Operation *, Operation *> cancellable;
for (AddRefOp addRef : info.addRefs) {
for (DropRefOp dropRef : info.dropRefs) {
// `drop_ref` operation after the `add_ref` with matching count.
if (dropRef.count() != addRef.count() ||
dropRef.getOperation()->isBeforeInBlock(addRef.getOperation()))
continue;
// `drop_ref` was already marked for removal.
if (cancellable.find(dropRef.getOperation()) != cancellable.end())
continue;
// Check `value` users between `addRef` and `dropRef` in the `block`.
Operation *addRefOp = addRef.getOperation();
Operation *dropRefOp = dropRef.getOperation();
// If there is a "regular" user after the `async.execute` user it is
// unsafe to erase cancellable reference counting operations pair,
// because async region can complete before the "regular" user and
// destroy the reference counted value.
bool hasExecuteUser = false;
bool unsafeToCancel = false;
for (UserInfo &user : info.users) {
Operation *op = user.operation;
// `user` operation lies after `addRef` ...
if (op == addRefOp || op->isBeforeInBlock(addRefOp))
continue;
// ... and before `dropRef`.
if (op == dropRefOp || dropRefOp->isBeforeInBlock(op))
break;
bool isRegularUser = !user.hasExecuteUser;
bool isExecuteUser = user.hasExecuteUser;
// It is unsafe to cancel `addRef` / `dropRef` pair.
if (isRegularUser && hasExecuteUser) {
unsafeToCancel = true;
break;
}
hasExecuteUser |= isExecuteUser;
}
// Mark the pair of reference counting operations for removal.
if (!unsafeToCancel)
cancellable[dropRef.getOperation()] = addRef.getOperation();
// If it us unsafe to cancel `addRef <-> dropRef` pair at this point,
// all the following pairs will be also unsafe.
break;
}
}
// Erase all cancellable `addRef <-> dropRef` operation pairs.
for (auto &kv : cancellable) {
kv.first->erase();
kv.second->erase();
}
}
return success();
}
void AsyncRefCountingOptimizationPass::runOnFunction() {
FuncOp func = getFunction();
// Optimize reference counting for values defined by block arguments.
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
for (BlockArgument arg : block->getArguments())
if (isRefCounted(arg.getType()))
if (failed(optimizeReferenceCounting(arg)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (blockWalk.wasInterrupted())
signalPassFailure();
// Optimize reference counting for values defined by operation results.
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
for (unsigned i = 0; i < op->getNumResults(); ++i)
if (isRefCounted(op->getResultTypes()[i]))
if (failed(optimizeReferenceCounting(op->getResult(i))))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (opWalk.wasInterrupted())
signalPassFailure();
}
std::unique_ptr<OperationPass<FuncOp>>
mlir::createAsyncRefCountingOptimizationPass() {
return std::make_unique<AsyncRefCountingOptimizationPass>();
}