Files
clang-p2996/llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp
hsmahesha 98f4713122 [AMDGPU] Split entry basic block after alloca instructions.
While initializing the LDS pointers within entry basic block of kernel(s), make
sure that the entry basic block is split after alloca instructions.

Reviewed By: rampitec

Differential Revision: https://reviews.llvm.org/D108971
2021-09-01 10:18:44 +05:30

470 lines
17 KiB
C++

//===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass replaces all the uses of LDS within non-kernel functions by
// corresponding pointer counter-parts.
//
// The main motivation behind this pass is - to *avoid* subsequent LDS lowering
// pass from directly packing LDS (assume large LDS) into a struct type which
// would otherwise cause allocating huge memory for struct instance within every
// kernel.
//
// Brief sketch of the algorithm implemented in this pass is as below:
//
// 1. Collect all the LDS defined in the module which qualify for pointer
// replacement, say it is, LDSGlobals set.
//
// 2. Collect all the reachable callees for each kernel defined in the module,
// say it is, KernelToCallees map.
//
// 3. FOR (each global GV from LDSGlobals set) DO
// LDSUsedNonKernels = Collect all non-kernel functions which use GV.
// FOR (each kernel K in KernelToCallees map) DO
// ReachableCallees = KernelToCallees[K]
// ReachableAndLDSUsedCallees =
// SetIntersect(LDSUsedNonKernels, ReachableCallees)
// IF (ReachableAndLDSUsedCallees is not empty) THEN
// Pointer = Create a pointer to point-to GV if not created.
// Initialize Pointer to point-to GV within kernel K.
// ENDIF
// ENDFOR
// Replace all uses of GV within non kernel functions by Pointer.
// ENFOR
//
// LLVM IR example:
//
// Input IR:
//
// @lds = internal addrspace(3) global [4 x i32] undef, align 16
//
// define internal void @f0() {
// entry:
// %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds,
// i32 0, i32 0
// ret void
// }
//
// define protected amdgpu_kernel void @k0() {
// entry:
// call void @f0()
// ret void
// }
//
// Output IR:
//
// @lds = internal addrspace(3) global [4 x i32] undef, align 16
// @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2
//
// define internal void @f0() {
// entry:
// %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2
// %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0
// %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)*
// %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2,
// i32 0, i32 0
// ret void
// }
//
// define protected amdgpu_kernel void @k0() {
// entry:
// store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16),
// i16 addrspace(3)* @lds.ptr, align 2
// call void @f0()
// ret void
// }
//
//===----------------------------------------------------------------------===//
#include "AMDGPU.h"
#include "GCNSubtarget.h"
#include "Utils/AMDGPUBaseInfo.h"
#include "Utils/AMDGPULDSUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <algorithm>
#include <vector>
#define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer"
using namespace llvm;
namespace {
class ReplaceLDSUseImpl {
Module &M;
LLVMContext &Ctx;
const DataLayout &DL;
Constant *LDSMemBaseAddr;
DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer;
DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels;
DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees;
DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers;
DenseMap<Function *, BasicBlock *> KernelToInitBB;
DenseMap<Function *, DenseMap<GlobalVariable *, Value *>>
FunctionToLDSToReplaceInst;
// Collect LDS which requires their uses to be replaced by pointer.
std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() {
// Collect LDS which requires module lowering.
std::vector<GlobalVariable *> LDSGlobals = AMDGPU::findVariablesToLower(M);
// Remove LDS which don't qualify for replacement.
LDSGlobals.erase(std::remove_if(LDSGlobals.begin(), LDSGlobals.end(),
[&](GlobalVariable *GV) {
return shouldIgnorePointerReplacement(GV);
}),
LDSGlobals.end());
return LDSGlobals;
}
// Returns true if uses of given LDS global within non-kernel functions should
// be keep as it is without pointer replacement.
bool shouldIgnorePointerReplacement(GlobalVariable *GV) {
// LDS whose size is very small and doesn`t exceed pointer size is not worth
// replacing.
if (DL.getTypeAllocSize(GV->getValueType()) <= 2)
return true;
// LDS which is not used from non-kernel function scope or it is used from
// global scope does not qualify for replacement.
LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV);
return LDSToNonKernels[GV].empty();
// FIXME: When GV is used within all (or within most of the kernels), then
// it does not make sense to create a pointer for it.
}
// Insert new global LDS pointer which points to LDS.
GlobalVariable *createLDSPointer(GlobalVariable *GV) {
// LDS pointer which points to LDS is already created? return it.
auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr));
if (!PointerEntry.second)
return PointerEntry.first->second;
// We need to create new LDS pointer which points to LDS.
//
// Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to
// 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address.
auto *I16Ty = Type::getInt16Ty(Ctx);
GlobalVariable *LDSPointer = new GlobalVariable(
M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty),
GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal,
AMDGPUAS::LOCAL_ADDRESS);
LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
LDSPointer->setAlignment(AMDGPU::getAlign(DL, LDSPointer));
// Mark that an associated LDS pointer is created for LDS.
LDSToPointer[GV] = LDSPointer;
return LDSPointer;
}
// Split entry basic block in such a way that only lane 0 of each wave does
// the LDS pointer initialization, and return newly created basic block.
BasicBlock *activateLaneZero(Function *K) {
// If the entry basic block of kernel K is already splitted, then return
// newly created basic block.
auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr));
if (!BasicBlockEntry.second)
return BasicBlockEntry.first->second;
// Split entry basic block of kernel K just after alloca.
//
// Find the split point just after alloca.
auto &EBB = K->getEntryBlock();
auto *EI = &(*(EBB.getFirstInsertionPt()));
BasicBlock::reverse_iterator RIT(EBB.getTerminator());
while (!isa<AllocaInst>(*RIT) && (&*RIT != EI))
++RIT;
if (isa<AllocaInst>(*RIT))
--RIT;
// Split entry basic block.
IRBuilder<> Builder(&*RIT);
Value *Mbcnt =
Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {},
{Builder.getInt32(-1), Builder.getInt32(0)});
Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0));
Instruction *WB = cast<Instruction>(
Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {}));
BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent();
// Mark that the entry basic block of kernel K is splitted.
KernelToInitBB[K] = NBB;
return NBB;
}
// Within given kernel, initialize given LDS pointer to point to given LDS.
void initializeLDSPointer(Function *K, GlobalVariable *GV,
GlobalVariable *LDSPointer) {
// If LDS pointer is already initialized within K, then nothing to do.
auto PointerEntry = KernelToLDSPointers.insert(
std::make_pair(K, SmallPtrSet<GlobalVariable *, 8>()));
if (!PointerEntry.second)
if (PointerEntry.first->second.contains(LDSPointer))
return;
// Insert instructions at EI which initialize LDS pointer to point-to LDS
// within kernel K.
//
// That is, convert pointer type of GV to i16, and then store this converted
// i16 value within LDSPointer which is of type i16*.
auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt()));
IRBuilder<> Builder(EI);
Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)),
LDSPointer);
// Mark that LDS pointer is initialized within kernel K.
KernelToLDSPointers[K].insert(LDSPointer);
}
// We have created an LDS pointer for LDS, and initialized it to point-to LDS
// within all relevent kernels. Now replace all the uses of LDS within
// non-kernel functions by LDS pointer.
void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) {
SmallVector<User *, 8> LDSUsers(GV->users());
for (auto *U : LDSUsers) {
// When `U` is a constant expression, it is possible that same constant
// expression exists within multiple instructions, and within multiple
// non-kernel functions. Collect all those non-kernel functions and all
// those instructions within which `U` exist.
auto FunctionToInsts =
AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/);
for (auto FI = FunctionToInsts.begin(), FE = FunctionToInsts.end();
FI != FE; ++FI) {
Function *F = FI->first;
auto &Insts = FI->second;
for (auto *I : Insts) {
// If `U` is a constant expression, then we need to break the
// associated instruction into a set of separate instructions by
// converting constant expressions into instructions.
SmallPtrSet<Instruction *, 8> UserInsts;
if (U == I) {
// `U` is an instruction, conversion from constant expression to
// set of instructions is *not* required.
UserInsts.insert(I);
} else {
// `U` is a constant expression, convert it into corresponding set
// of instructions.
auto *CE = cast<ConstantExpr>(U);
convertConstantExprsToInstructions(I, CE, &UserInsts);
}
// Go through all the user instrutions, if LDS exist within them as an
// operand, then replace it by replace instruction.
for (auto *II : UserInsts) {
auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer);
II->replaceUsesOfWith(GV, ReplaceInst);
}
}
}
}
}
// Create a set of replacement instructions which together replace LDS within
// non-kernel function F by accessing LDS indirectly using LDS pointer.
Value *getReplacementInst(Function *F, GlobalVariable *GV,
GlobalVariable *LDSPointer) {
// If the instruction which replaces LDS within F is already created, then
// return it.
auto LDSEntry = FunctionToLDSToReplaceInst.insert(
std::make_pair(F, DenseMap<GlobalVariable *, Value *>()));
if (!LDSEntry.second) {
auto ReplaceInstEntry =
LDSEntry.first->second.insert(std::make_pair(GV, nullptr));
if (!ReplaceInstEntry.second)
return ReplaceInstEntry.first->second;
}
// Get the instruction insertion point within the beginning of the entry
// block of current non-kernel function.
auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt()));
IRBuilder<> Builder(EI);
// Insert required set of instructions which replace LDS within F.
auto *V = Builder.CreateBitCast(
Builder.CreateGEP(
Builder.getInt8Ty(), LDSMemBaseAddr,
Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)),
GV->getType());
// Mark that the replacement instruction which replace LDS within F is
// created.
FunctionToLDSToReplaceInst[F][GV] = V;
return V;
}
public:
ReplaceLDSUseImpl(Module &M)
: M(M), Ctx(M.getContext()), DL(M.getDataLayout()) {
LDSMemBaseAddr = Constant::getIntegerValue(
PointerType::get(Type::getInt8Ty(M.getContext()),
AMDGPUAS::LOCAL_ADDRESS),
APInt(32, 0));
}
// Entry-point function which interface ReplaceLDSUseImpl with outside of the
// class.
bool replaceLDSUse();
private:
// For a given LDS from collected LDS globals set, replace its non-kernel
// function scope uses by pointer.
bool replaceLDSUse(GlobalVariable *GV);
};
// For given LDS from collected LDS globals set, replace its non-kernel function
// scope uses by pointer.
bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) {
// Holds all those non-kernel functions within which LDS is being accessed.
SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV];
// The LDS pointer which points to LDS and replaces all the uses of LDS.
GlobalVariable *LDSPointer = nullptr;
// Traverse through each kernel K, check and if required, initialize the
// LDS pointer to point to LDS within K.
for (auto KI = KernelToCallees.begin(), KE = KernelToCallees.end(); KI != KE;
++KI) {
Function *K = KI->first;
SmallPtrSet<Function *, 8> Callees = KI->second;
// Compute reachable and LDS used callees for kernel K.
set_intersect(Callees, LDSAccessors);
// None of the LDS accessing non-kernel functions are reachable from
// kernel K. Hence, no need to initialize LDS pointer within kernel K.
if (Callees.empty())
continue;
// We have found reachable and LDS used callees for kernel K, and we need to
// initialize LDS pointer within kernel K, and we need to replace LDS use
// within those callees by LDS pointer.
//
// But, first check if LDS pointer is already created, if not create one.
LDSPointer = createLDSPointer(GV);
// Initialize LDS pointer to point to LDS within kernel K.
initializeLDSPointer(K, GV, LDSPointer);
}
// We have not found reachable and LDS used callees for any of the kernels,
// and hence we have not created LDS pointer.
if (!LDSPointer)
return false;
// We have created an LDS pointer for LDS, and initialized it to point-to LDS
// within all relevent kernels. Now replace all the uses of LDS within
// non-kernel functions by LDS pointer.
replaceLDSUseByPointer(GV, LDSPointer);
return true;
}
// Entry-point function which interface ReplaceLDSUseImpl with outside of the
// class.
bool ReplaceLDSUseImpl::replaceLDSUse() {
// Collect LDS which requires their uses to be replaced by pointer.
std::vector<GlobalVariable *> LDSGlobals =
collectLDSRequiringPointerReplace();
// No LDS to pointer-replace. Nothing to do.
if (LDSGlobals.empty())
return false;
// Collect reachable callee set for each kernel defined in the module.
AMDGPU::collectReachableCallees(M, KernelToCallees);
if (KernelToCallees.empty()) {
// Either module does not have any kernel definitions, or none of the kernel
// has a call to non-kernel functions, or we could not resolve any of the
// call sites to proper non-kernel functions, because of the situations like
// inline asm calls. Nothing to replace.
return false;
}
// For every LDS from collected LDS globals set, replace its non-kernel
// function scope use by pointer.
bool Changed = false;
for (auto *GV : LDSGlobals)
Changed |= replaceLDSUse(GV);
return Changed;
}
class AMDGPUReplaceLDSUseWithPointer : public ModulePass {
public:
static char ID;
AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) {
initializeAMDGPUReplaceLDSUseWithPointerPass(
*PassRegistry::getPassRegistry());
}
bool runOnModule(Module &M) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<TargetPassConfig>();
}
};
} // namespace
char AMDGPUReplaceLDSUseWithPointer::ID = 0;
char &llvm::AMDGPUReplaceLDSUseWithPointerID =
AMDGPUReplaceLDSUseWithPointer::ID;
INITIALIZE_PASS_BEGIN(
AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
"Replace within non-kernel function use of LDS with pointer",
false /*only look at the cfg*/, false /*analysis pass*/)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_END(
AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
"Replace within non-kernel function use of LDS with pointer",
false /*only look at the cfg*/, false /*analysis pass*/)
bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) {
ReplaceLDSUseImpl LDSUseReplacer{M};
return LDSUseReplacer.replaceLDSUse();
}
ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() {
return new AMDGPUReplaceLDSUseWithPointer();
}
PreservedAnalyses
AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) {
ReplaceLDSUseImpl LDSUseReplacer{M};
LDSUseReplacer.replaceLDSUse();
return PreservedAnalyses::all();
}