The MachineUA now queries the target to determine if a given register holds a uniform value. This is determined using the corresponding register bank if available, or by a combination of the register class and value type. This assumes that the target is optimizing for performance by choosing registers, and the target is responsible for any mismatch with the inferred uniformity. For example, on AMDGPU, an SGPR is now treated as uniform, except if the register bank is VCC (i.e., the register holds a wave-wide vector of 1-bit values) or equivalently if it has a value type of s1. - This does not always work with inline asm, where the register bank or the value type might not be present. We assume that the SGPR is uniform, because it is not expected to be s1 in the vast majority of cases. - The pseudo branch instruction SI_LOOP is now hard-coded to be always divergent, although its condition is an SGPR. Reviewed By: arsenm Differential Revision: https://reviews.llvm.org/D150438
268 lines
8.4 KiB
C++
268 lines
8.4 KiB
C++
//===- MachineUniformityAnalysis.cpp --------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/CodeGen/MachineUniformityAnalysis.h"
|
|
#include "llvm/ADT/GenericUniformityImpl.h"
|
|
#include "llvm/CodeGen/MachineCycleAnalysis.h"
|
|
#include "llvm/CodeGen/MachineDominators.h"
|
|
#include "llvm/CodeGen/MachineRegisterInfo.h"
|
|
#include "llvm/CodeGen/MachineSSAContext.h"
|
|
#include "llvm/CodeGen/TargetInstrInfo.h"
|
|
#include "llvm/InitializePasses.h"
|
|
|
|
using namespace llvm;
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
|
|
const MachineInstr &I) const {
|
|
for (auto &op : I.operands()) {
|
|
if (!op.isReg() || !op.isDef())
|
|
continue;
|
|
if (isDivergent(op.getReg()))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
|
|
const MachineInstr &Instr) {
|
|
bool insertedDivergent = false;
|
|
const auto &MRI = F.getRegInfo();
|
|
const auto &RBI = *F.getSubtarget().getRegBankInfo();
|
|
const auto &TRI = *MRI.getTargetRegisterInfo();
|
|
for (auto &op : Instr.operands()) {
|
|
if (!op.isReg() || !op.isDef())
|
|
continue;
|
|
if (!op.getReg().isVirtual())
|
|
continue;
|
|
assert(!op.getSubReg());
|
|
if (TRI.isUniformReg(MRI, RBI, op.getReg()))
|
|
continue;
|
|
insertedDivergent |= markDivergent(op.getReg());
|
|
}
|
|
return insertedDivergent;
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
|
|
const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
|
|
|
|
for (const MachineBasicBlock &block : F) {
|
|
for (const MachineInstr &instr : block) {
|
|
auto uniformity = InstrInfo.getInstructionUniformity(instr);
|
|
if (uniformity == InstructionUniformity::AlwaysUniform) {
|
|
addUniformOverride(instr);
|
|
continue;
|
|
}
|
|
|
|
if (uniformity == InstructionUniformity::NeverUniform) {
|
|
if (markDivergent(instr))
|
|
Worklist.push_back(&instr);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
|
|
Register Reg) {
|
|
const auto &RegInfo = F.getRegInfo();
|
|
for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
|
|
if (markDivergent(UserInstr))
|
|
Worklist.push_back(&UserInstr);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
|
|
const MachineInstr &Instr) {
|
|
assert(!isAlwaysUniform(Instr));
|
|
if (Instr.isTerminator())
|
|
return;
|
|
for (const MachineOperand &op : Instr.operands()) {
|
|
if (op.isReg() && op.isDef() && op.getReg().isVirtual())
|
|
pushUsers(op.getReg());
|
|
}
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
|
|
const MachineInstr &I, const MachineCycle &DefCycle) const {
|
|
assert(!isAlwaysUniform(I));
|
|
for (auto &Op : I.operands()) {
|
|
if (!Op.isReg() || !Op.readsReg())
|
|
continue;
|
|
auto Reg = Op.getReg();
|
|
|
|
// FIXME: Physical registers need to be properly checked instead of always
|
|
// returning true
|
|
if (Reg.isPhysical())
|
|
return true;
|
|
|
|
auto *Def = F.getRegInfo().getVRegDef(Reg);
|
|
if (DefCycle.contains(Def->getParent()))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
|
|
propagateTemporalDivergence(const MachineInstr &I,
|
|
const MachineCycle &DefCycle) {
|
|
const auto &RegInfo = F.getRegInfo();
|
|
for (auto &Op : I.operands()) {
|
|
if (!Op.isReg() || !Op.isDef())
|
|
continue;
|
|
if (!Op.getReg().isVirtual())
|
|
continue;
|
|
auto Reg = Op.getReg();
|
|
if (isDivergent(Reg))
|
|
continue;
|
|
for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
|
|
if (DefCycle.contains(UserInstr.getParent()))
|
|
continue;
|
|
if (markDivergent(UserInstr))
|
|
Worklist.push_back(&UserInstr);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
|
|
const MachineOperand &U) const {
|
|
if (!U.isReg())
|
|
return false;
|
|
|
|
auto Reg = U.getReg();
|
|
if (isDivergent(Reg))
|
|
return true;
|
|
|
|
const auto &RegInfo = F.getRegInfo();
|
|
auto *Def = RegInfo.getOneDef(Reg);
|
|
if (!Def)
|
|
return true;
|
|
|
|
auto *DefInstr = Def->getParent();
|
|
auto *UseInstr = U.getParent();
|
|
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
|
|
}
|
|
|
|
// This ensures explicit instantiation of
|
|
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
|
|
template class llvm::GenericUniformityInfo<MachineSSAContext>;
|
|
template struct llvm::GenericUniformityAnalysisImplDeleter<
|
|
llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
|
|
|
|
MachineUniformityInfo
|
|
llvm::computeMachineUniformityInfo(MachineFunction &F,
|
|
const MachineCycleInfo &cycleInfo,
|
|
const MachineDomTree &domTree) {
|
|
assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
|
|
return MachineUniformityInfo(F, domTree, cycleInfo);
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Legacy analysis pass which computes a \ref MachineUniformityInfo.
|
|
class MachineUniformityAnalysisPass : public MachineFunctionPass {
|
|
MachineUniformityInfo UI;
|
|
|
|
public:
|
|
static char ID;
|
|
|
|
MachineUniformityAnalysisPass();
|
|
|
|
MachineUniformityInfo &getUniformityInfo() { return UI; }
|
|
const MachineUniformityInfo &getUniformityInfo() const { return UI; }
|
|
|
|
bool runOnMachineFunction(MachineFunction &F) override;
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override;
|
|
void print(raw_ostream &OS, const Module *M = nullptr) const override;
|
|
|
|
// TODO: verify analysis
|
|
};
|
|
|
|
class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
|
|
public:
|
|
static char ID;
|
|
|
|
MachineUniformityInfoPrinterPass();
|
|
|
|
bool runOnMachineFunction(MachineFunction &F) override;
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
char MachineUniformityAnalysisPass::ID = 0;
|
|
|
|
MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
|
|
: MachineFunctionPass(ID) {
|
|
initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
|
|
"Machine Uniformity Info Analysis", true, true)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
|
|
INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
|
|
"Machine Uniformity Info Analysis", true, true)
|
|
|
|
void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
|
|
AU.setPreservesAll();
|
|
AU.addRequired<MachineCycleInfoWrapperPass>();
|
|
AU.addRequired<MachineDominatorTree>();
|
|
MachineFunctionPass::getAnalysisUsage(AU);
|
|
}
|
|
|
|
bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
|
|
auto &DomTree = getAnalysis<MachineDominatorTree>().getBase();
|
|
auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
|
|
UI = computeMachineUniformityInfo(MF, CI, DomTree);
|
|
return false;
|
|
}
|
|
|
|
void MachineUniformityAnalysisPass::print(raw_ostream &OS,
|
|
const Module *) const {
|
|
OS << "MachineUniformityInfo for function: " << UI.getFunction().getName()
|
|
<< "\n";
|
|
UI.print(OS);
|
|
}
|
|
|
|
char MachineUniformityInfoPrinterPass::ID = 0;
|
|
|
|
MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
|
|
: MachineFunctionPass(ID) {
|
|
initializeMachineUniformityInfoPrinterPassPass(
|
|
*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
|
|
"print-machine-uniformity",
|
|
"Print Machine Uniformity Info Analysis", true, true)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
|
|
INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
|
|
"print-machine-uniformity",
|
|
"Print Machine Uniformity Info Analysis", true, true)
|
|
|
|
void MachineUniformityInfoPrinterPass::getAnalysisUsage(
|
|
AnalysisUsage &AU) const {
|
|
AU.setPreservesAll();
|
|
AU.addRequired<MachineUniformityAnalysisPass>();
|
|
MachineFunctionPass::getAnalysisUsage(AU);
|
|
}
|
|
|
|
bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
|
|
MachineFunction &F) {
|
|
auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
|
|
UI.print(errs());
|
|
return false;
|
|
}
|