Files
clang-p2996/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Vyacheslav Levytskyy 83c1d00311 [SPIR-V] Overhaul module analysis to improve translation speed and simplify the underlying logics (#120415)
This PR is to address legacy issues with module analysis that currently
uses a complicated and not so efficient approach to trace dependencies
between SPIR-V id's via a duplicate tracker data structures and an
explicitly built dependency graph. Even a quick performance check
without any specialized benchmarks points to this part of the
implementation as a biggest bottleneck.

This PR specifically:
* eliminates a need to build a dependency graph as a data structure,
* updates the test suite (mainly, by fixing incorrect CHECK's referring
to a hardcoded order of definitions, contradicting the spec requirement
to allow certain definitions to go "in any order", see
https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_logical_layout_of_a_module),
* improves function pointers implementation so that it now passes
EXPENSIVE_CHECKS (thus removing 3 XFAIL's in the test suite).

As a quick sanity check of whether goals of the PR are achieved, we can
measure time of translation for any big LLVM IR. While testing the PR in
the local development environment, improvements of the x5 order have
been observed.

For example, the SYCL test case "group barrier" that is a ~1Mb binary IR
input shows the following values of the naive performance metric that we
can nevertheless apply here to roughly estimate effects of the PR.

before the PR:
```
$ time llc -O0 -mtriple=spirv64v1.6-unknown-unknown _group_barrier_phi.bc -o 1 --filetype=obj

real    3m33.241s
user    3m14.688s
sys     0m18.530s
```

after the PR

```
$ time llc -O0 -mtriple=spirv64v1.6-unknown-unknown _group_barrier_phi.bc -o 1 --filetype=obj

real    0m42.031s
user    0m38.834s
sys     0m3.193s
```

Next work should probably address Duplicate Tracker further, as it needs
analysis now from the perspective of what parts of it are not necessary
now, after changing the approach to implementation of the module
analysis step.
2025-01-07 10:42:23 +01:00

1707 lines
68 KiB
C++

//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the implementation of the SPIRVGlobalRegistry class,
// which is used to maintain rich type information required for SPIR-V even
// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
// an OpTypeXXX instruction, and map it to a virtual register. Also it builds
// and supports consistency of constants and global variables.
//
//===----------------------------------------------------------------------===//
#include "SPIRVGlobalRegistry.h"
#include "SPIRV.h"
#include "SPIRVBuiltins.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/ADT/APInt.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
#include <cassert>
#include <functional>
using namespace llvm;
inline unsigned typeToAddressSpace(const Type *Ty) {
if (auto PType = dyn_cast<TypedPointerType>(Ty))
return PType->getAddressSpace();
if (auto PType = dyn_cast<PointerType>(Ty))
return PType->getAddressSpace();
if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
ExtTy && isTypedPointerWrapper(ExtTy))
return ExtTy->getIntParameter(0);
report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
}
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
: PointerSize(PointerSize), Bound(0) {}
SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
Register VReg,
MachineInstr &I,
const SPIRVInstrInfo &TII) {
SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
return SpirvType;
}
SPIRVType *
SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
MachineInstr &I,
const SPIRVInstrInfo &TII) {
SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
return SpirvType;
}
SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
const SPIRVInstrInfo &TII) {
SPIRVType *SpirvType =
getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
return SpirvType;
}
SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
SPIRVType *SpirvType =
getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
return SpirvType;
}
void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
Register VReg,
const MachineFunction &MF) {
VRegToTypeMap[&MF][VReg] = SpirvType;
}
static Register createTypeVReg(MachineRegisterInfo &MRI) {
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(64));
MRI.setRegClass(Res, &SPIRV::TYPERegClass);
return Res;
}
inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
return createTypeVReg(MIRBuilder.getMF().getRegInfo());
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
.addDef(createTypeVReg(MIRBuilder));
});
}
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
if (Width > 64)
report_fatal_error("Unsupported integer width!");
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
return Width;
if (Width <= 8)
Width = 8;
else if (Width <= 16)
Width = 16;
else if (Width <= 32)
Width = 32;
else
Width = 64;
return Width;
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
MachineIRBuilder &MIRBuilder,
bool IsSigned) {
Width = adjustOpTypeIntWidth(Width);
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
if (ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
MIRBuilder.buildInstr(SPIRV::OpExtension)
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
MIRBuilder.buildInstr(SPIRV::OpCapability)
.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
}
return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
.addImm(IsSigned ? 1 : 0);
});
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
MachineIRBuilder &MIRBuilder) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width);
});
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
.addDef(createTypeVReg(MIRBuilder));
});
}
void SPIRVGlobalRegistry::invalidateMachineInstr(MachineInstr *MI) {
// TODO:
// - take into account duplicate tracker case which is a known issue,
// - review other data structure wrt. possible issues related to removal
// of a machine instruction during instruction selection.
const MachineFunction *MF = MI->getParent()->getParent();
auto It = LastInsertedTypeMap.find(MF);
if (It == LastInsertedTypeMap.end())
return;
if (It->second == MI)
LastInsertedTypeMap.erase(MF);
}
SPIRVType *SPIRVGlobalRegistry::createOpType(
MachineIRBuilder &MIRBuilder,
std::function<MachineInstr *(MachineIRBuilder &)> Op) {
auto oldInsertPoint = MIRBuilder.getInsertPt();
MachineBasicBlock *OldMBB = &MIRBuilder.getMBB();
MachineBasicBlock *NewMBB = &*MIRBuilder.getMF().begin();
auto LastInsertedType = LastInsertedTypeMap.find(CurMF);
if (LastInsertedType != LastInsertedTypeMap.end()) {
auto It = LastInsertedType->second->getIterator();
// It might happen that this instruction was removed from the first MBB,
// hence the Parent's check.
MachineBasicBlock::iterator InsertAt;
if (It->getParent() != NewMBB)
InsertAt = oldInsertPoint->getParent() == NewMBB
? oldInsertPoint
: getInsertPtValidEnd(NewMBB);
else if (It->getNextNode())
InsertAt = It->getNextNode()->getIterator();
else
InsertAt = getInsertPtValidEnd(NewMBB);
MIRBuilder.setInsertPt(*NewMBB, InsertAt);
} else {
MIRBuilder.setInsertPt(*NewMBB, NewMBB->begin());
auto Result = LastInsertedTypeMap.try_emplace(CurMF, nullptr);
assert(Result.second);
LastInsertedType = Result.first;
}
MachineInstr *Type = Op(MIRBuilder);
// We expect all users of this function to insert definitions at the insertion
// point set above that is always the first MBB.
assert(Type->getParent() == NewMBB);
LastInsertedType->second = Type;
MIRBuilder.setInsertPt(*OldMBB, oldInsertPoint);
return Type;
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder) {
auto EleOpc = ElemType->getOpcode();
(void)EleOpc;
assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
EleOpc == SPIRV::OpTypeBool) &&
"Invalid vector element type");
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
.addImm(NumElems);
});
}
std::tuple<Register, ConstantInt *, bool, unsigned>
SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
MachineIRBuilder *MIRBuilder,
MachineInstr *I,
const SPIRVInstrInfo *TII) {
assert(SpvType);
const IntegerType *LLVMIntTy =
cast<IntegerType>(getTypeForSPIRVType(SpvType));
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
bool NewInstr = false;
// Find a constant in DT or build a new one.
ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
Res =
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
if (MIRBuilder)
assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
else
assignIntTypeToVReg(BitWidth, Res, *I, *TII);
DT.add(CI, CurMF, Res);
NewInstr = true;
}
return std::make_tuple(Res, CI, NewInstr, BitWidth);
}
std::tuple<Register, ConstantFP *, bool, unsigned>
SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
MachineIRBuilder *MIRBuilder,
MachineInstr *I,
const SPIRVInstrInfo *TII) {
assert(SpvType);
LLVMContext &Ctx = CurMF->getFunction().getContext();
const Type *LLVMFloatTy = getTypeForSPIRVType(SpvType);
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
bool NewInstr = false;
// Find a constant in DT or build a new one.
auto *const CI = ConstantFP::get(Ctx, Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
Res =
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
if (MIRBuilder)
assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
else
assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
DT.add(CI, CurMF, Res);
NewInstr = true;
}
return std::make_tuple(Res, CI, NewInstr, BitWidth);
}
Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
assert(SpvType);
ConstantFP *CI;
Register Res;
bool New;
unsigned BitWidth;
std::tie(Res, CI, New, BitWidth) =
getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII);
// If we have found Res register which is defined by the passed G_CONSTANT
// machine instruction, a new constant instruction should be created.
if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
return Res;
MachineIRBuilder MIRBuilder(I);
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
MachineInstrBuilder MIB;
// In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
if (Val.isPosZero() && ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(
APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
MIB);
}
const auto &ST = CurMF->getSubtarget();
constrainSelectedInstRegOperands(
*MIB, *ST.getInstrInfo(), *ST.getRegisterInfo(), *ST.getRegBankInfo());
return MIB;
});
return Res;
}
Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
assert(SpvType);
ConstantInt *CI;
Register Res;
bool New;
unsigned BitWidth;
std::tie(Res, CI, New, BitWidth) =
getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
// If we have found Res register which is defined by the passed G_CONSTANT
// machine instruction, a new constant instruction should be created.
if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
return Res;
MachineIRBuilder MIRBuilder(I);
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
MachineInstrBuilder MIB;
if (Val || !ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
}
const auto &ST = CurMF->getSubtarget();
constrainSelectedInstRegOperands(
*MIB, *ST.getInstrInfo(), *ST.getRegisterInfo(), *ST.getRegBankInfo());
return MIB;
});
return Res;
}
Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
bool ZeroAsNull) {
assert(SpvType);
auto &MF = MIRBuilder.getMF();
const IntegerType *LLVMIntTy =
cast<IntegerType>(getTypeForSPIRVType(SpvType));
// Find a constant in DT or build a new one.
const auto ConstInt =
ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
Register Res = DT.find(ConstInt, &MF);
if (!Res.isValid()) {
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
LLT LLTy = LLT::scalar(BitWidth);
Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
MF.getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, EmitIR);
DT.add(ConstInt, &MIRBuilder.getMF(), Res);
if (EmitIR) {
MIRBuilder.buildConstant(Res, *ConstInt);
} else {
Register SpvTypeReg = getSPIRVTypeID(SpvType);
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
MachineInstrBuilder MIB;
if (Val || !ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(SpvTypeReg);
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(SpvTypeReg);
}
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
*Subtarget.getRegisterInfo(),
*Subtarget.getRegBankInfo());
return MIB;
});
}
}
return Res;
}
Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType) {
auto &MF = MIRBuilder.getMF();
auto &Ctx = MF.getFunction().getContext();
if (!SpvType) {
const Type *LLVMFPTy = Type::getFloatTy(Ctx);
SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
}
// Find a constant in DT or build a new one.
const auto ConstFP = ConstantFP::get(Ctx, Val);
Register Res = DT.find(ConstFP, &MF);
if (!Res.isValid()) {
Res = MF.getRegInfo().createGenericVirtualRegister(
LLT::scalar(getScalarOrVectorBitWidth(SpvType)));
MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, MF);
DT.add(ConstFP, &MF, Res);
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
MachineInstrBuilder MIB;
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
return MIB;
});
}
return Res;
}
Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
Constant *Val, MachineInstr &I, SPIRVType *SpvType,
const SPIRVInstrInfo &TII, unsigned BitWidth, bool ZeroAsNull) {
SPIRVType *Type = SpvType;
if (SpvType->getOpcode() == SPIRV::OpTypeVector ||
SpvType->getOpcode() == SPIRV::OpTypeArray) {
auto EleTypeReg = SpvType->getOperand(1).getReg();
Type = getSPIRVTypeForVReg(EleTypeReg);
}
if (Type->getOpcode() == SPIRV::OpTypeFloat) {
SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I,
SpvBaseType, TII, ZeroAsNull);
}
assert(Type->getOpcode() == SPIRV::OpTypeInt);
SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
SpvBaseType, TII, ZeroAsNull);
}
Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
Constant *Val, MachineInstr &I, SPIRVType *SpvType,
const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
unsigned ElemCnt, bool ZeroAsNull) {
// Find a constant vector or array in DT or build a new one.
Register Res = DT.find(CA, CurMF);
// If no values are attached, the composite is null constant.
bool IsNull = Val->isNullValue() && ZeroAsNull;
if (!Res.isValid()) {
// SpvScalConst should be created before SpvVecConst to avoid undefined ID
// error on validation.
// TODO: can moved below once sorting of types/consts/defs is implemented.
Register SpvScalConst;
if (!IsNull)
SpvScalConst =
getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth, ZeroAsNull);
LLT LLTy = LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(SpvVecConst, getRegClass(SpvType));
assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
DT.add(CA, CurMF, SpvVecConst);
MachineIRBuilder MIRBuilder(I);
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
MachineInstrBuilder MIB;
if (!IsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
for (unsigned i = 0; i < ElemCnt; ++i)
MIB.addUse(SpvScalConst);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
}
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
*Subtarget.getRegisterInfo(),
*Subtarget.getRegBankInfo());
return MIB;
});
return SpvVecConst;
}
return Res;
}
Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy->isVectorTy());
const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
Type *LLVMBaseTy = LLVMVecTy->getElementType();
assert(LLVMBaseTy->isIntegerTy());
auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val);
auto *ConstVec =
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
unsigned BW = getScalarOrVectorBitWidth(SpvType);
return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
SpvType->getOperand(2).getImm(),
ZeroAsNull);
}
Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy->isVectorTy());
const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
Type *LLVMBaseTy = LLVMVecTy->getElementType();
assert(LLVMBaseTy->isFloatingPointTy());
auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val);
auto *ConstVec =
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
unsigned BW = getScalarOrVectorBitWidth(SpvType);
return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
SpvType->getOperand(2).getImm(),
ZeroAsNull);
}
Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
const SPIRVInstrInfo &TII) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy->isArrayTy());
const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
Type *LLVMBaseTy = LLVMArrTy->getElementType();
Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
// The following is reasonably unique key that is better that [Val]. The naive
// alternative would be something along the lines of:
// SmallVector<Constant *> NumCI(Num, CI);
// Constant *UniqueKey =
// ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
// that would be a truly unique but dangerous key, because it could lead to
// the creation of constants of arbitrary length (that is, the parameter of
// memset) which were missing in the original module.
Constant *UniqueKey = ConstantStruct::getAnon(
{PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),
ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});
return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,
LLVMArrTy->getNumElements());
}
Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
Register Res = DT.find(CA, CurMF);
if (!Res.isValid()) {
Register SpvScalConst;
if (Val || EmitIR) {
SPIRVType *SpvBaseType =
getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
}
LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
DT.add(CA, CurMF, SpvVecConst);
if (EmitIR) {
MIRBuilder.buildSplatBuildVector(SpvVecConst, SpvScalConst);
} else {
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
if (Val) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
for (unsigned i = 0; i < ElemCnt; ++i)
MIB.addUse(SpvScalConst);
return MIB;
} else {
return MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
}
});
}
return SpvVecConst;
}
return Res;
}
Register
SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy->isVectorTy());
const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
Type *LLVMBaseTy = LLVMVecTy->getElementType();
const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
auto ConstVec =
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
unsigned BW = getScalarOrVectorBitWidth(SpvType);
return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
ConstVec, BW,
SpvType->getOperand(2).getImm());
}
Register
SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
unsigned AddressSpace = typeToAddressSpace(LLVMTy);
// Find a constant in DT or build a new one.
Constant *CP = ConstantPointerNull::get(
PointerType::get(::getPointeeType(LLVMTy), AddressSpace));
Register Res = DT.find(CP, CurMF);
if (!Res.isValid()) {
LLT LLTy = LLT::pointer(AddressSpace, PointerSize);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(Res, &SPIRV::pIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
});
DT.add(CP, CurMF, Res);
}
return Res;
}
Register SPIRVGlobalRegistry::buildConstantSampler(
Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
SPIRVType *SampTy;
if (SpvType)
SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",
MIRBuilder)) == nullptr)
report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");
auto Sampler =
ResReg.isValid()
? ResReg
: MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
auto Res = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
.addDef(Sampler)
.addUse(getSPIRVTypeID(SampTy))
.addImm(AddrMode)
.addImm(Param)
.addImm(FilerMode);
});
assert(Res->getOperand(0).isReg());
return Res->getOperand(0).getReg();
}
Register SPIRVGlobalRegistry::buildGlobalVariable(
Register ResVReg, SPIRVType *BaseType, StringRef Name,
const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
bool IsInstSelector) {
const GlobalVariable *GVar = nullptr;
if (GV)
GVar = cast<const GlobalVariable>(GV);
else {
// If GV is not passed explicitly, use the name to find or construct
// the global variable.
Module *M = MIRBuilder.getMF().getFunction().getParent();
GVar = M->getGlobalVariable(Name);
if (GVar == nullptr) {
const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
// Module takes ownership of the global var.
GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
GlobalValue::ExternalLinkage, nullptr,
Twine(Name));
}
GV = GVar;
}
Register Reg = DT.find(GVar, &MIRBuilder.getMF());
if (Reg.isValid()) {
if (Reg != ResVReg)
MIRBuilder.buildCopy(ResVReg, Reg);
return ResVReg;
}
auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(BaseType))
.addImm(static_cast<uint32_t>(Storage));
if (Init != 0) {
MIB.addUse(Init->getOperand(0).getReg());
}
// ISel may introduce a new register on this step, so we need to add it to
// DT and correct its type avoiding fails on the next stage.
if (IsInstSelector) {
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
*Subtarget.getRegisterInfo(),
*Subtarget.getRegBankInfo());
}
Reg = MIB->getOperand(0).getReg();
DT.add(GVar, &MIRBuilder.getMF(), Reg);
addGlobalObject(GVar, &MIRBuilder.getMF(), Reg);
// Set to Reg the same type as ResVReg has.
auto MRI = MIRBuilder.getMRI();
if (Reg != ResVReg) {
LLT RegLLTy =
LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());
MRI->setType(Reg, RegLLTy);
assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
} else {
// Our knowledge about the type may be updated.
// If that's the case, we need to update a type
// associated with the register.
SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg);
if (!DefType || DefType != BaseType)
assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
}
// If it's a global variable with name, output OpName for it.
if (GVar && GVar->hasName())
buildOpName(Reg, GVar->getName(), MIRBuilder);
// Output decorations for the GV.
// TODO: maybe move to GenerateDecorations pass.
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
if (IsConst && ST.isOpenCLEnv())
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
}
if (HasLinkageTy)
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
{static_cast<uint32_t>(LinkageType)}, Name);
SPIRV::BuiltIn::BuiltIn BuiltInId;
if (getSpirvBuiltInIdByName(Name, BuiltInId))
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
{static_cast<uint32_t>(BuiltInId)});
// If it's a global variable with "spirv.Decorations" metadata node
// recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations"
// arguments.
MDNode *GVarMD = nullptr;
if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr)
buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);
return Reg;
}
static std::string GetSpirvImageTypeName(const SPIRVType *Type,
MachineIRBuilder &MIRBuilder,
const std::string &Prefix);
static std::string buildSpirvTypeName(const SPIRVType *Type,
MachineIRBuilder &MIRBuilder) {
switch (Type->getOpcode()) {
case SPIRV::OpTypeSampledImage: {
return GetSpirvImageTypeName(Type, MIRBuilder, "sampled_image_");
}
case SPIRV::OpTypeImage: {
return GetSpirvImageTypeName(Type, MIRBuilder, "image_");
}
case SPIRV::OpTypeArray: {
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register ElementTypeReg = Type->getOperand(1).getReg();
auto *ElementType = MRI->getUniqueVRegDef(ElementTypeReg);
const SPIRVType *TypeInst = MRI->getVRegDef(Type->getOperand(2).getReg());
assert(TypeInst->getOpcode() != SPIRV::OpConstantI);
MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg());
assert(ImmInst->getOpcode() == TargetOpcode::G_CONSTANT);
uint32_t ArraySize = ImmInst->getOperand(1).getCImm()->getZExtValue();
return (buildSpirvTypeName(ElementType, MIRBuilder) + Twine("[") +
Twine(ArraySize) + Twine("]"))
.str();
}
case SPIRV::OpTypeFloat:
return ("f" + Twine(Type->getOperand(1).getImm())).str();
case SPIRV::OpTypeSampler:
return ("sampler");
case SPIRV::OpTypeInt:
if (Type->getOperand(2).getImm())
return ("i" + Twine(Type->getOperand(1).getImm())).str();
return ("u" + Twine(Type->getOperand(1).getImm())).str();
default:
llvm_unreachable("Trying to the the name of an unknown type.");
}
}
static std::string GetSpirvImageTypeName(const SPIRVType *Type,
MachineIRBuilder &MIRBuilder,
const std::string &Prefix) {
Register SampledTypeReg = Type->getOperand(1).getReg();
auto *SampledType = MIRBuilder.getMRI()->getUniqueVRegDef(SampledTypeReg);
std::string TypeName = Prefix + buildSpirvTypeName(SampledType, MIRBuilder);
for (uint32_t I = 2; I < Type->getNumOperands(); ++I) {
TypeName = (TypeName + '_' + Twine(Type->getOperand(I).getImm())).str();
}
return TypeName;
}
Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding(
const SPIRVType *VarType, uint32_t Set, uint32_t Binding,
MachineIRBuilder &MIRBuilder) {
SPIRVType *VarPointerTypeReg = getOrCreateSPIRVPointerType(
VarType, MIRBuilder, SPIRV::StorageClass::UniformConstant);
Register VarReg =
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
// TODO: The name should come from the llvm-ir, but how that name will be
// passed from the HLSL to the backend has not been decided. Using this place
// holder for now.
std::string Name = ("__resource_" + buildSpirvTypeName(VarType, MIRBuilder) +
"_" + Twine(Set) + "_" + Twine(Binding))
.str();
buildGlobalVariable(VarReg, VarPointerTypeReg, Name, nullptr,
SPIRV::StorageClass::UniformConstant, nullptr, false,
false, SPIRV::LinkageType::Import, MIRBuilder, false);
buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::DescriptorSet, {Set});
buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::Binding, {Binding});
return VarReg;
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder,
bool EmitIR) {
assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
"Invalid array element type");
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
Register NumElementsVReg =
buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeArray)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
.addUse(NumElementsVReg);
});
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
MachineIRBuilder &MIRBuilder) {
assert(Ty->hasName());
const StringRef Name = Ty->hasName() ? Ty->getName() : "";
Register ResVReg = createTypeVReg(MIRBuilder);
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
addStringImm(Name, MIB);
buildOpName(ResVReg, Name, MIRBuilder);
return MIB;
});
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
MachineIRBuilder &MIRBuilder,
bool EmitIR) {
SmallVector<Register, 4> FieldTypes;
for (const auto &Elem : Ty->elements()) {
SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder);
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
"Invalid struct element type");
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
}
Register ResVReg = createTypeVReg(MIRBuilder);
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
for (const auto &Ty : FieldTypes)
MIB.addUse(Ty);
if (Ty->hasName())
buildOpName(ResVReg, Ty->getName(), MIRBuilder);
if (Ty->isPacked())
buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
return MIB;
});
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual) {
assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
}
SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder, Register Reg) {
if (!Reg.isValid())
Reg = createTypeVReg(MIRBuilder);
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
.addDef(Reg)
.addImm(static_cast<uint32_t>(SC))
.addUse(getSPIRVTypeID(ElemType));
});
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
.addUse(createTypeVReg(MIRBuilder))
.addImm(static_cast<uint32_t>(SC));
});
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
MachineIRBuilder &MIRBuilder) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(RetType));
for (const SPIRVType *ArgType : ArgTypes)
MIB.addUse(getSPIRVTypeID(ArgType));
return MIB;
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
const Type *Ty, SPIRVType *RetType,
const SmallVectorImpl<SPIRVType *> &ArgTypes,
MachineIRBuilder &MIRBuilder) {
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType));
return finishCreatingSPIRVType(Ty, SpirvType);
}
SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
Ty = adjustIntTypeByWidth(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
if (ForwardPointerTypes.contains(Ty))
return ForwardPointerTypes[Ty];
return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
}
Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
assert(SpirvType && "Attempting to get type id for nullptr type.");
if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
return SpirvType->uses().begin()->getReg();
return SpirvType->defs().begin()->getReg();
}
// We need to use a new LLVM integer type if there is a mismatch between
// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
// same "OpTypeInt 8" type for a series of LLVM integer types with number of
// bits less than 8. This would lead to duplicate type definitions
// eventually due to the method that DuplicateTracker utilizes to reason
// about uniqueness of type records.
const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
if (auto IType = dyn_cast<IntegerType>(Ty)) {
unsigned SrcBitWidth = IType->getBitWidth();
if (SrcBitWidth > 1) {
unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
// Maybe change source LLVM type to keep DuplicateTracker consistent.
if (SrcBitWidth != BitWidth)
Ty = IntegerType::get(Ty->getContext(), BitWidth);
}
}
return Ty;
}
SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
if (isSpecialOpaqueType(Ty))
return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
auto t = TypeToSPIRVTypeMap.find(Ty);
if (t != TypeToSPIRVTypeMap.end()) {
auto tt = t->second.find(&MIRBuilder.getMF());
if (tt != t->second.end())
return getSPIRVTypeForVReg(tt->second);
}
if (auto IType = dyn_cast<IntegerType>(Ty)) {
const unsigned Width = IType->getBitWidth();
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}
if (Ty->isFloatingPointTy())
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
SPIRVType *El =
findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
MIRBuilder);
}
if (Ty->isArrayTy()) {
SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
}
if (auto SType = dyn_cast<StructType>(Ty)) {
if (SType->isOpaque())
return getOpTypeOpaque(SType, MIRBuilder);
return getOpTypeStruct(SType, MIRBuilder, EmitIR);
}
if (auto FType = dyn_cast<FunctionType>(Ty)) {
SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
SmallVector<SPIRVType *, 4> ParamTypes;
for (const auto &t : FType->params()) {
ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
}
return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
}
unsigned AddrSpace = typeToAddressSpace(Ty);
SPIRVType *SpvElementType = nullptr;
if (Type *ElemTy = ::getPointeeType(Ty))
SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR);
else
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
// Null pointer means we have a loop in type definitions, make and
// return corresponding OpTypeForwardPointer.
if (SpvElementType == nullptr) {
if (!ForwardPointerTypes.contains(Ty))
ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder);
return ForwardPointerTypes[Ty];
}
// If we have forward pointer associated with this type, use its register
// operand to create OpTypePointer.
if (ForwardPointerTypes.contains(Ty)) {
Register Reg = getSPIRVTypeID(ForwardPointerTypes[Ty]);
return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
}
return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);
}
SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
if (TypesInProcessing.count(Ty) && !isPointerTyOrWrapper(Ty))
return nullptr;
TypesInProcessing.insert(Ty);
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
TypesInProcessing.erase(Ty);
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
// Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
// will be added later. For special types it is already added to DT.
if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
!isSpecialOpaqueType(Ty)) {
if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
ExtTy && isTypedPointerWrapper(ExtTy))
DT.add(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0),
&MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
else if (!isPointerTy(Ty))
DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
else if (isTypedPointerTy(Ty))
DT.add(cast<TypedPointerType>(Ty)->getElementType(),
getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
getSPIRVTypeID(SpirvType));
else
DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
getSPIRVTypeID(SpirvType));
}
return SpirvType;
}
SPIRVType *
SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
const MachineFunction *MF) const {
auto t = VRegToTypeMap.find(MF ? MF : CurMF);
if (t != VRegToTypeMap.end()) {
auto tt = t->second.find(VReg);
if (tt != t->second.end())
return tt->second;
}
return nullptr;
}
SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg) {
MachineInstr *Instr = getVRegDef(CurMF->getRegInfo(), VReg);
return getSPIRVTypeForVReg(Instr->getOperand(1).getReg());
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
Register Reg;
if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
ExtTy && isTypedPointerWrapper(ExtTy)) {
Reg = DT.find(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0),
&MIRBuilder.getMF());
} else if (!isPointerTy(Ty)) {
Ty = adjustIntTypeByWidth(Ty);
Reg = DT.find(Ty, &MIRBuilder.getMF());
} else if (isTypedPointerTy(Ty)) {
Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
} else {
Reg =
DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
}
if (Reg.isValid() && !isSpecialOpaqueType(Ty))
return getSPIRVTypeForVReg(Reg);
TypesInProcessing.clear();
SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
// Create normal pointer types for the corresponding OpTypeForwardPointers.
for (auto &CU : ForwardPointerTypes) {
const Type *Ty2 = CU.first;
SPIRVType *STy2 = CU.second;
if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
STy2 = getSPIRVTypeForVReg(Reg);
else
STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
if (Ty == Ty2)
STy = STy2;
}
ForwardPointerTypes.clear();
return STy;
}
bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
unsigned TypeOpcode) const {
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
assert(Type && "isScalarOfType VReg has no type assigned");
return Type->getOpcode() == TypeOpcode;
}
bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
unsigned TypeOpcode) const {
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
if (Type->getOpcode() == TypeOpcode)
return true;
if (Type->getOpcode() == SPIRV::OpTypeVector) {
Register ScalarTypeVReg = Type->getOperand(1).getReg();
SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
return ScalarType->getOpcode() == TypeOpcode;
}
return false;
}
unsigned
SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));
}
unsigned
SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
if (!Type)
return 0;
return Type->getOpcode() == SPIRV::OpTypeVector
? static_cast<unsigned>(Type->getOperand(2).getImm())
: 1;
}
SPIRVType *
SPIRVGlobalRegistry::getScalarOrVectorComponentType(Register VReg) const {
return getScalarOrVectorComponentType(getSPIRVTypeForVReg(VReg));
}
SPIRVType *
SPIRVGlobalRegistry::getScalarOrVectorComponentType(SPIRVType *Type) const {
if (!Type)
return nullptr;
Register ScalarReg = Type->getOpcode() == SPIRV::OpTypeVector
? Type->getOperand(1).getReg()
: Type->getOperand(0).getReg();
SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarReg);
assert(isScalarOrVectorOfType(Type->getOperand(0).getReg(),
ScalarType->getOpcode()));
return ScalarType;
}
unsigned
SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
assert(Type && "Invalid Type pointer");
if (Type->getOpcode() == SPIRV::OpTypeVector) {
auto EleTypeReg = Type->getOperand(1).getReg();
Type = getSPIRVTypeForVReg(EleTypeReg);
}
if (Type->getOpcode() == SPIRV::OpTypeInt ||
Type->getOpcode() == SPIRV::OpTypeFloat)
return Type->getOperand(1).getImm();
if (Type->getOpcode() == SPIRV::OpTypeBool)
return 1;
llvm_unreachable("Attempting to get bit width of non-integer/float type.");
}
unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
const SPIRVType *Type) const {
assert(Type && "Invalid Type pointer");
unsigned NumElements = 1;
if (Type->getOpcode() == SPIRV::OpTypeVector) {
NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
}
return Type->getOpcode() == SPIRV::OpTypeInt ||
Type->getOpcode() == SPIRV::OpTypeFloat
? NumElements * Type->getOperand(1).getImm()
: 0;
}
const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
const SPIRVType *Type) const {
if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
}
bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
return IntType && IntType->getOperand(2).getImm() != 0;
}
SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
: nullptr;
}
unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
return ElemType ? ElemType->getOpcode() : 0;
}
bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
const SPIRVType *Type2) const {
if (!Type1 || !Type2)
return false;
auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
// Ignore difference between <1.5 and >=1.5 protocol versions:
// it's valid if either Result Type or Operand is a pointer, and the other
// is a pointer, an integer scalar, or an integer vector.
if (Op1 == SPIRV::OpTypePointer &&
(Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
return true;
if (Op2 == SPIRV::OpTypePointer &&
(Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
return true;
unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
return Bits1 > 0 && Bits1 == Bits2;
}
SPIRV::StorageClass::StorageClass
SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
Type->getOperand(1).isImm() && "Pointer type is expected");
return getPointerStorageClass(Type);
}
SPIRV::StorageClass::StorageClass
SPIRVGlobalRegistry::getPointerStorageClass(const SPIRVType *Type) const {
return static_cast<SPIRV::StorageClass::StorageClass>(
Type->getOperand(1).getImm());
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
SPIRV::ImageFormat::ImageFormat ImageFormat,
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim,
Depth, Arrayed, Multisampled, Sampled,
ImageFormat, AccessQual);
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
DT.add(TD, &MIRBuilder.getMF(), ResVReg);
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeImage)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(SampledType))
.addImm(Dim)
.addImm(Depth) // Depth (whether or not it is a Depth image).
.addImm(Arrayed) // Arrayed.
.addImm(Multisampled) // Multisampled (0 = only single-sample).
.addImm(Sampled) // Sampled (0 = usage known at runtime).
.addImm(ImageFormat);
if (AccessQual != SPIRV::AccessQualifier::None)
MIB.addImm(AccessQual);
return MIB;
}
SPIRVType *
SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
auto TD = SPIRV::make_descr_sampler();
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
DT.add(TD, &MIRBuilder.getMF(), ResVReg);
return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
auto TD = SPIRV::make_descr_pipe(AccessQual);
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
DT.add(TD, &MIRBuilder.getMF(), ResVReg);
return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
.addDef(ResVReg)
.addImm(AccessQual);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
MachineIRBuilder &MIRBuilder) {
auto TD = SPIRV::make_descr_event();
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
DT.add(TD, &MIRBuilder.getMF(), ResVReg);
return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
auto TD = SPIRV::make_descr_sampled_image(
SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
ImageType->getOperand(1).getReg())),
ImageType);
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
DT.add(TD, &MIRBuilder.getMF(), ResVReg);
return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(ImageType));
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
uint32_t Use) {
Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());
if (ResVReg.isValid())
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
ResVReg = createTypeVReg(MIRBuilder);
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
SPIRVType *SpirvTy =
MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(ElemType))
.addUse(buildConstantInt(Scope, MIRBuilder, SpvTypeInt32, true))
.addUse(buildConstantInt(Rows, MIRBuilder, SpvTypeInt32, true))
.addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, true))
.addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, true));
DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
return SpirvTy;
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
if (ResVReg.isValid())
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
ResVReg = createTypeVReg(MIRBuilder);
SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg);
DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
return SpirvTy;
}
const MachineInstr *
SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
MachineIRBuilder &MIRBuilder) {
Register Reg = DT.find(TD, &MIRBuilder.getMF());
if (Reg.isValid())
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
return nullptr;
}
// Returns nullptr if unable to recognize SPIRV type name
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC,
SPIRV::AccessQualifier::AccessQualifier AQ) {
unsigned VecElts = 0;
auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
// Parse strings representing either a SPIR-V or OpenCL builtin type.
if (hasBuiltinTypePrefix(TypeStr))
return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(
TypeStr.str(), MIRBuilder.getContext()),
MIRBuilder, AQ);
// Parse type name in either "typeN" or "type vector[N]" format, where
// N is the number of elements of the vector.
Type *Ty;
Ty = parseBasicTypeName(TypeStr, Ctx);
if (!Ty)
// Unable to recognize SPIRV type name
return nullptr;
auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
// Handle "type*" or "type* vector[N]".
if (TypeStr.starts_with("*")) {
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
TypeStr = TypeStr.substr(strlen("*"));
}
// Handle "typeN*" or "type vector[N]*".
bool IsPtrToVec = TypeStr.consume_back("*");
if (TypeStr.consume_front(" vector[")) {
TypeStr = TypeStr.substr(0, TypeStr.find(']'));
}
TypeStr.getAsInteger(10, VecElts);
if (VecElts > 0)
SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
if (IsPtrToVec)
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
return SpirvTy;
}
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
MIRBuilder);
}
SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
SPIRVType *SpirvType) {
assert(CurMF == SpirvType->getMF());
VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
return SpirvType;
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
MachineInstr &I,
const SPIRVInstrInfo &TII,
unsigned SPIRVOPcode,
Type *LLVMTy) {
Register Reg = DT.find(LLVMTy, CurMF);
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(0);
DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
return finishCreatingSPIRVType(LLVMTy, MIB);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
// Maybe adjust bit width to keep DuplicateTracker consistent. Without
// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
// example, the same "OpTypeInt 8" type for a series of LLVM integer types
// with number of bits less than 8, causing duplicate type definitions.
BitWidth = adjustOpTypeIntWidth(BitWidth);
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
LLVMContext &Ctx = CurMF->getFunction().getContext();
Type *LLVMTy;
switch (BitWidth) {
case 16:
LLVMTy = Type::getHalfTy(Ctx);
break;
case 32:
LLVMTy = Type::getFloatTy(Ctx);
break;
case 64:
LLVMTy = Type::getDoubleTy(Ctx);
break;
default:
llvm_unreachable("Bit width is of unexpected size.");
}
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
}
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
MIRBuilder);
}
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
const SPIRVInstrInfo &TII) {
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
Register Reg = DT.find(LLVMTy, CurMF);
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
.addDef(createTypeVReg(CurMF->getRegInfo()));
DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
return finishCreatingSPIRVType(LLVMTy, MIB);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
NumElements),
MIRBuilder);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
const SPIRVInstrInfo &TII) {
Type *LLVMTy = FixedVectorType::get(
const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
Register Reg = DT.find(LLVMTy, CurMF);
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addUse(getSPIRVTypeID(BaseType))
.addImm(NumElements);
DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
return finishCreatingSPIRVType(LLVMTy, MIB);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
const SPIRVInstrInfo &TII) {
Type *LLVMTy = ArrayType::get(
const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
Register Reg = DT.find(LLVMTy, CurMF);
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, I, TII);
Register Len = getOrCreateConstInt(NumElements, I, SpvTypeInt32, TII);
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addUse(getSPIRVTypeID(BaseType))
.addUse(Len);
DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
return finishCreatingSPIRVType(LLVMTy, MIB);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC) {
const Type *PointerElementType = getTypeForSPIRVType(BaseType);
unsigned AddressSpace = storageClassToAddressSpace(SC);
Type *LLVMTy = TypedPointerType::get(const_cast<Type *>(PointerElementType),
AddressSpace);
// check if this type is already available
Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
// create a new type
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
MIRBuilder.getDebugLoc(),
MIRBuilder.getTII().get(SPIRV::OpTypePointer))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(static_cast<uint32_t>(SC))
.addUse(getSPIRVTypeID(BaseType));
DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
finishCreatingSPIRVType(LLVMTy, MIB);
return MIB;
});
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
SPIRV::StorageClass::StorageClass SC) {
MachineIRBuilder MIRBuilder(I);
return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
}
Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII) {
assert(SpvType);
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy);
// Find a constant in DT or build a new one.
UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
Register Res = DT.find(UV, CurMF);
if (Res.isValid())
return Res;
LLT LLTy = LLT::scalar(64);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
DT.add(UV, CurMF, Res);
MachineInstrBuilder MIB;
MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
const auto &ST = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
*ST.getRegisterInfo(), *ST.getRegBankInfo());
return Res;
}
const TargetRegisterClass *
SPIRVGlobalRegistry::getRegClass(SPIRVType *SpvType) const {
unsigned Opcode = SpvType->getOpcode();
switch (Opcode) {
case SPIRV::OpTypeFloat:
return &SPIRV::fIDRegClass;
case SPIRV::OpTypePointer:
return &SPIRV::pIDRegClass;
case SPIRV::OpTypeVector: {
SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0;
if (ElemOpcode == SPIRV::OpTypeFloat)
return &SPIRV::vfIDRegClass;
if (ElemOpcode == SPIRV::OpTypePointer)
return &SPIRV::vpIDRegClass;
return &SPIRV::vIDRegClass;
}
}
return &SPIRV::iIDRegClass;
}
inline unsigned getAS(SPIRVType *SpvType) {
return storageClassToAddressSpace(
static_cast<SPIRV::StorageClass::StorageClass>(
SpvType->getOperand(1).getImm()));
}
LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
unsigned Opcode = SpvType ? SpvType->getOpcode() : 0;
switch (Opcode) {
case SPIRV::OpTypeInt:
case SPIRV::OpTypeFloat:
case SPIRV::OpTypeBool:
return LLT::scalar(getScalarOrVectorBitWidth(SpvType));
case SPIRV::OpTypePointer:
return LLT::pointer(getAS(SpvType), getPointerSize());
case SPIRV::OpTypeVector: {
SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
LLT ET;
switch (ElemType ? ElemType->getOpcode() : 0) {
case SPIRV::OpTypePointer:
ET = LLT::pointer(getAS(ElemType), getPointerSize());
break;
case SPIRV::OpTypeInt:
case SPIRV::OpTypeFloat:
case SPIRV::OpTypeBool:
ET = LLT::scalar(getScalarOrVectorBitWidth(ElemType));
break;
default:
ET = LLT::scalar(64);
}
return LLT::fixed_vector(
static_cast<unsigned>(SpvType->getOperand(2).getImm()), ET);
}
}
return LLT::scalar(64);
}