The patch adds SPIR-V specific intrinsics required to keep information critical to SPIR-V consistency (types, constants, etc.) during translation from IR to MIR. Two related passes (SPIRVEmitIntrinsics and SPIRVPreLegalizer) and several LIT tests (passed with this change) have also been added. It also fixes the issue with opaque pointers in SPIRVGlobalRegistry.cpp and the mismatch of the data layout between the SPIR-V backend and clang (Issue #55122). Differential Revision: https://reviews.llvm.org/D124416 Co-authored-by: Aleksandr Bezzubikov <zuban32s@gmail.com> Co-authored-by: Michal Paszkowski <michal.paszkowski@outlook.com> Co-authored-by: Andrey Tretyakov <andrey1.tretyakov@intel.com> Co-authored-by: Konrad Trifunovic <konrad.trifunovic@intel.com>
302 lines
11 KiB
C++
302 lines
11 KiB
C++
//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the targeting of the Machinelegalizer class for SPIR-V.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "SPIRVLegalizerInfo.h"
|
|
#include "SPIRV.h"
|
|
#include "SPIRVGlobalRegistry.h"
|
|
#include "SPIRVSubtarget.h"
|
|
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
|
|
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
|
|
#include "llvm/CodeGen/MachineInstr.h"
|
|
#include "llvm/CodeGen/MachineRegisterInfo.h"
|
|
#include "llvm/CodeGen/TargetOpcodes.h"
|
|
|
|
using namespace llvm;
|
|
using namespace llvm::LegalizeActions;
|
|
using namespace llvm::LegalityPredicates;
|
|
|
|
static const std::set<unsigned> TypeFoldingSupportingOpcs = {
|
|
TargetOpcode::G_ADD,
|
|
TargetOpcode::G_FADD,
|
|
TargetOpcode::G_SUB,
|
|
TargetOpcode::G_FSUB,
|
|
TargetOpcode::G_MUL,
|
|
TargetOpcode::G_FMUL,
|
|
TargetOpcode::G_SDIV,
|
|
TargetOpcode::G_UDIV,
|
|
TargetOpcode::G_FDIV,
|
|
TargetOpcode::G_SREM,
|
|
TargetOpcode::G_UREM,
|
|
TargetOpcode::G_FREM,
|
|
TargetOpcode::G_FNEG,
|
|
TargetOpcode::G_CONSTANT,
|
|
TargetOpcode::G_FCONSTANT,
|
|
TargetOpcode::G_AND,
|
|
TargetOpcode::G_OR,
|
|
TargetOpcode::G_XOR,
|
|
TargetOpcode::G_SHL,
|
|
TargetOpcode::G_ASHR,
|
|
TargetOpcode::G_LSHR,
|
|
TargetOpcode::G_SELECT,
|
|
TargetOpcode::G_EXTRACT_VECTOR_ELT,
|
|
};
|
|
|
|
bool isTypeFoldingSupported(unsigned Opcode) {
|
|
return TypeFoldingSupportingOpcs.count(Opcode) > 0;
|
|
}
|
|
|
|
SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
|
|
using namespace TargetOpcode;
|
|
|
|
this->ST = &ST;
|
|
GR = ST.getSPIRVGlobalRegistry();
|
|
|
|
const LLT s1 = LLT::scalar(1);
|
|
const LLT s8 = LLT::scalar(8);
|
|
const LLT s16 = LLT::scalar(16);
|
|
const LLT s32 = LLT::scalar(32);
|
|
const LLT s64 = LLT::scalar(64);
|
|
|
|
const LLT v16s64 = LLT::fixed_vector(16, 64);
|
|
const LLT v16s32 = LLT::fixed_vector(16, 32);
|
|
const LLT v16s16 = LLT::fixed_vector(16, 16);
|
|
const LLT v16s8 = LLT::fixed_vector(16, 8);
|
|
const LLT v16s1 = LLT::fixed_vector(16, 1);
|
|
|
|
const LLT v8s64 = LLT::fixed_vector(8, 64);
|
|
const LLT v8s32 = LLT::fixed_vector(8, 32);
|
|
const LLT v8s16 = LLT::fixed_vector(8, 16);
|
|
const LLT v8s8 = LLT::fixed_vector(8, 8);
|
|
const LLT v8s1 = LLT::fixed_vector(8, 1);
|
|
|
|
const LLT v4s64 = LLT::fixed_vector(4, 64);
|
|
const LLT v4s32 = LLT::fixed_vector(4, 32);
|
|
const LLT v4s16 = LLT::fixed_vector(4, 16);
|
|
const LLT v4s8 = LLT::fixed_vector(4, 8);
|
|
const LLT v4s1 = LLT::fixed_vector(4, 1);
|
|
|
|
const LLT v3s64 = LLT::fixed_vector(3, 64);
|
|
const LLT v3s32 = LLT::fixed_vector(3, 32);
|
|
const LLT v3s16 = LLT::fixed_vector(3, 16);
|
|
const LLT v3s8 = LLT::fixed_vector(3, 8);
|
|
const LLT v3s1 = LLT::fixed_vector(3, 1);
|
|
|
|
const LLT v2s64 = LLT::fixed_vector(2, 64);
|
|
const LLT v2s32 = LLT::fixed_vector(2, 32);
|
|
const LLT v2s16 = LLT::fixed_vector(2, 16);
|
|
const LLT v2s8 = LLT::fixed_vector(2, 8);
|
|
const LLT v2s1 = LLT::fixed_vector(2, 1);
|
|
|
|
const unsigned PSize = ST.getPointerSize();
|
|
const LLT p0 = LLT::pointer(0, PSize); // Function
|
|
const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
|
|
const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
|
|
const LLT p3 = LLT::pointer(3, PSize); // Workgroup
|
|
const LLT p4 = LLT::pointer(4, PSize); // Generic
|
|
const LLT p5 = LLT::pointer(5, PSize); // Input
|
|
|
|
// TODO: remove copy-pasting here by using concatenation in some way.
|
|
auto allPtrsScalarsAndVectors = {
|
|
p0, p1, p2, p3, p4, p5, s1, s8, s16,
|
|
s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
|
|
v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
|
|
v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
|
|
|
|
auto allScalarsAndVectors = {
|
|
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
|
|
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
|
|
v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
|
|
|
|
auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
|
|
v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
|
|
v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
|
|
v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
|
|
|
|
auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
|
|
|
|
auto allIntScalars = {s8, s16, s32, s64};
|
|
|
|
auto allFloatScalarsAndVectors = {
|
|
s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
|
|
v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
|
|
|
|
auto allFloatAndIntScalars = allIntScalars;
|
|
|
|
auto allPtrs = {p0, p1, p2, p3, p4, p5};
|
|
auto allWritablePtrs = {p0, p1, p3, p4};
|
|
|
|
for (auto Opc : TypeFoldingSupportingOpcs)
|
|
getActionDefinitionsBuilder(Opc).custom();
|
|
|
|
getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
|
|
|
|
// TODO: add proper rules for vectors legalization.
|
|
getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
|
|
|
|
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
|
|
.legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
|
|
|
|
getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
|
|
.legalForCartesianProduct(allPtrs, allPtrs);
|
|
|
|
getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
|
|
|
|
getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
|
|
|
|
getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
|
|
|
|
getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
|
|
.legalForCartesianProduct(allIntScalarsAndVectors,
|
|
allFloatScalarsAndVectors);
|
|
|
|
getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
|
|
.legalForCartesianProduct(allFloatScalarsAndVectors,
|
|
allScalarsAndVectors);
|
|
|
|
getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
|
|
.legalFor(allIntScalarsAndVectors);
|
|
|
|
getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
|
|
allIntScalarsAndVectors, allIntScalarsAndVectors);
|
|
|
|
getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
|
|
|
|
getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
|
|
typeInSet(0, allPtrsScalarsAndVectors),
|
|
typeInSet(1, allPtrsScalarsAndVectors),
|
|
LegalityPredicate(([=](const LegalityQuery &Query) {
|
|
return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
|
|
}))));
|
|
|
|
getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
|
|
|
|
getActionDefinitionsBuilder(G_INTTOPTR)
|
|
.legalForCartesianProduct(allPtrs, allIntScalars);
|
|
getActionDefinitionsBuilder(G_PTRTOINT)
|
|
.legalForCartesianProduct(allIntScalars, allPtrs);
|
|
getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
|
|
allPtrs, allIntScalars);
|
|
|
|
// ST.canDirectlyComparePointers() for pointer args is supported in
|
|
// legalizeCustom().
|
|
getActionDefinitionsBuilder(G_ICMP).customIf(
|
|
all(typeInSet(0, allBoolScalarsAndVectors),
|
|
typeInSet(1, allPtrsScalarsAndVectors)));
|
|
|
|
getActionDefinitionsBuilder(G_FCMP).legalIf(
|
|
all(typeInSet(0, allBoolScalarsAndVectors),
|
|
typeInSet(1, allFloatScalarsAndVectors)));
|
|
|
|
getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
|
|
G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
|
|
G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
|
|
G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
|
|
.legalForCartesianProduct(allIntScalars, allWritablePtrs);
|
|
|
|
getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
|
|
.legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
|
|
|
|
getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
|
|
// TODO: add proper legalization rules.
|
|
getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
|
|
|
|
getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
|
|
.alwaysLegal();
|
|
|
|
// Extensions.
|
|
getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
|
|
.legalForCartesianProduct(allScalarsAndVectors);
|
|
|
|
// FP conversions.
|
|
getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
|
|
.legalForCartesianProduct(allFloatScalarsAndVectors);
|
|
|
|
// Pointer-handling.
|
|
getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
|
|
|
|
// Control-flow.
|
|
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1});
|
|
|
|
getActionDefinitionsBuilder({G_FPOW,
|
|
G_FEXP,
|
|
G_FEXP2,
|
|
G_FLOG,
|
|
G_FLOG2,
|
|
G_FABS,
|
|
G_FMINNUM,
|
|
G_FMAXNUM,
|
|
G_FCEIL,
|
|
G_FCOS,
|
|
G_FSIN,
|
|
G_FSQRT,
|
|
G_FFLOOR,
|
|
G_FRINT,
|
|
G_FNEARBYINT,
|
|
G_INTRINSIC_ROUND,
|
|
G_INTRINSIC_TRUNC,
|
|
G_FMINIMUM,
|
|
G_FMAXIMUM,
|
|
G_INTRINSIC_ROUNDEVEN})
|
|
.legalFor(allFloatScalarsAndVectors);
|
|
|
|
getActionDefinitionsBuilder(G_FCOPYSIGN)
|
|
.legalForCartesianProduct(allFloatScalarsAndVectors,
|
|
allFloatScalarsAndVectors);
|
|
|
|
getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
|
|
allFloatScalarsAndVectors, allIntScalarsAndVectors);
|
|
|
|
getLegacyLegalizerInfo().computeTables();
|
|
verify(*ST.getInstrInfo());
|
|
}
|
|
|
|
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
|
|
LegalizerHelper &Helper,
|
|
MachineRegisterInfo &MRI,
|
|
SPIRVGlobalRegistry *GR) {
|
|
Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
|
|
GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
|
|
Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
|
|
.addDef(ConvReg)
|
|
.addUse(Reg);
|
|
return ConvReg;
|
|
}
|
|
|
|
bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper,
|
|
MachineInstr &MI) const {
|
|
auto Opc = MI.getOpcode();
|
|
MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
|
|
if (!isTypeFoldingSupported(Opc)) {
|
|
assert(Opc == TargetOpcode::G_ICMP);
|
|
assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
|
|
auto &Op0 = MI.getOperand(2);
|
|
auto &Op1 = MI.getOperand(3);
|
|
Register Reg0 = Op0.getReg();
|
|
Register Reg1 = Op1.getReg();
|
|
CmpInst::Predicate Cond =
|
|
static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
|
|
if ((!ST->canDirectlyComparePointers() ||
|
|
(Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
|
|
MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
|
|
LLT ConvT = LLT::scalar(ST->getPointerSize());
|
|
Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
|
|
ST->getPointerSize());
|
|
SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
|
|
Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
|
|
Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
|
|
}
|
|
return true;
|
|
}
|
|
// TODO: implement legalization for other opcodes.
|
|
return true;
|
|
}
|