From b5ca00a58de52d33c1935f49ce104f9d90cda67c Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 1 Mar 2016 19:24:03 +0000 Subject: [PATCH] [NVPTX] Use different, convergent MIs for convergent calls. Summary: Calls sometimes need to be convergent. This is already handled at the LLVM IR level, but it also needs to be handled at the MI level. Ideally we'd propagate convergence from instructions, down through the selection DAG, and into MIs. But this is Hard, and would affect optimizations in the SDNs -- right now only SDNs with two operands have any flags at all. Instead, here's a much simpler hack: Add new opcodes for NVPTX for convergent calls, and generate these when lowering convergent LLVM calls. Reviewers: jholewinski Subscribers: jholewinski, chandlerc, joker.eph, jhen, tra, llvm-commits Differential Revision: http://reviews.llvm.org/D17423 llvm-svn: 262373 --- llvm/include/llvm/Target/TargetLowering.h | 15 +++- .../SelectionDAG/SelectionDAGBuilder.cpp | 8 +- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 12 ++- llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 2 + llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 89 +++++++++---------- .../test/CodeGen/NVPTX/convergent-mir-call.ll | 27 ++++++ 6 files changed, 98 insertions(+), 55 deletions(-) create mode 100644 llvm/test/CodeGen/NVPTX/convergent-mir-call.ll diff --git a/llvm/include/llvm/Target/TargetLowering.h b/llvm/include/llvm/Target/TargetLowering.h index eb640529e0fe..6abeb44a3681 100644 --- a/llvm/include/llvm/Target/TargetLowering.h +++ b/llvm/include/llvm/Target/TargetLowering.h @@ -2348,6 +2348,7 @@ public: bool IsInReg : 1; bool DoesNotReturn : 1; bool IsReturnValueUsed : 1; + bool IsConvergent : 1; // IsTailCall should be modified by implementations of // TargetLowering::LowerCall that perform tail call conversions. @@ -2366,10 +2367,11 @@ public: SmallVector Ins; CallLoweringInfo(SelectionDAG &DAG) - : RetTy(nullptr), RetSExt(false), RetZExt(false), IsVarArg(false), - IsInReg(false), DoesNotReturn(false), IsReturnValueUsed(true), - IsTailCall(false), NumFixedArgs(-1), CallConv(CallingConv::C), - DAG(DAG), CS(nullptr), IsPatchPoint(false) {} + : RetTy(nullptr), RetSExt(false), RetZExt(false), IsVarArg(false), + IsInReg(false), DoesNotReturn(false), IsReturnValueUsed(true), + IsConvergent(false), IsTailCall(false), NumFixedArgs(-1), + CallConv(CallingConv::C), DAG(DAG), CS(nullptr), IsPatchPoint(false) { + } CallLoweringInfo &setDebugLoc(SDLoc dl) { DL = dl; @@ -2441,6 +2443,11 @@ public: return *this; } + CallLoweringInfo &setConvergent(bool Value = true) { + IsConvergent = Value; + return *this; + } + CallLoweringInfo &setSExtResult(bool Value = true) { RetSExt = Value; return *this; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index d501a916ee8f..a2e7eca127da 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -5562,9 +5562,11 @@ void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee, isTailCall = false; TargetLowering::CallLoweringInfo CLI(DAG); - CLI.setDebugLoc(getCurSDLoc()).setChain(getRoot()) - .setCallee(RetTy, FTy, Callee, std::move(Args), CS) - .setTailCall(isTailCall); + CLI.setDebugLoc(getCurSDLoc()) + .setChain(getRoot()) + .setCallee(RetTy, FTy, Callee, std::move(Args), CS) + .setTailCall(isTailCall) + .setConvergent(CS.isConvergent()); std::pair Result = lowerInvokable(CLI, EHPadBB); if (Result.first.getNode()) { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index c6263ca7317e..592a269d1a0b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -314,8 +314,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::DeclareRetParam"; case NVPTXISD::PrintCall: return "NVPTXISD::PrintCall"; + case NVPTXISD::PrintConvergentCall: + return "NVPTXISD::PrintConvergentCall"; case NVPTXISD::PrintCallUni: return "NVPTXISD::PrintCallUni"; + case NVPTXISD::PrintConvergentCallUni: + return "NVPTXISD::PrintConvergentCallUni"; case NVPTXISD::LoadParam: return "NVPTXISD::LoadParam"; case NVPTXISD::LoadParamV2: @@ -1439,8 +1443,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SDValue PrintCallOps[] = { Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InFlag }; - Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall), - dl, PrintCallVTs, PrintCallOps); + // We model convergent calls as separate opcodes. + unsigned Opcode = Func ? NVPTXISD::PrintCallUni : NVPTXISD::PrintCall; + if (CLI.IsConvergent) + Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni + : NVPTXISD::PrintConvergentCall; + Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps); InFlag = Chain.getValue(1); // Ops to print out the function name diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 60914c1d09b4..735cd01ced6b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -34,7 +34,9 @@ enum NodeType : unsigned { DeclareRet, DeclareScalarRet, PrintCall, + PrintConvergentCall, PrintCallUni, + PrintConvergentCallUni, CallArgBegin, CallArg, LastCallArg, diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 51db8246c532..685d1b447b99 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1701,9 +1701,15 @@ def LoadParamV4 : def PrintCall : SDNode<"NVPTXISD::PrintCall", SDTPrintCallProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def PrintConvergentCall : + SDNode<"NVPTXISD::PrintConvergentCall", SDTPrintCallProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def PrintCallUni : SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallUniProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def PrintConvergentCallUni : + SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallUniProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def StoreParam : SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; @@ -1821,53 +1827,44 @@ class StoreRetvalV4Inst : []>; let isCall=1 in { - def PrintCallNoRetInst : NVPTXInst<(outs), (ins), - "call ", [(PrintCall (i32 0))]>; - def PrintCallRetInst1 : NVPTXInst<(outs), (ins), - "call (retval0), ", [(PrintCall (i32 1))]>; - def PrintCallRetInst2 : NVPTXInst<(outs), (ins), - "call (retval0, retval1), ", [(PrintCall (i32 2))]>; - def PrintCallRetInst3 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2), ", [(PrintCall (i32 3))]>; - def PrintCallRetInst4 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3), ", [(PrintCall (i32 4))]>; - def PrintCallRetInst5 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4), ", - [(PrintCall (i32 5))]>; - def PrintCallRetInst6 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5), ", - [(PrintCall (i32 6))]>; - def PrintCallRetInst7 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ", - [(PrintCall (i32 7))]>; - def PrintCallRetInst8 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6, " - "retval7), ", - [(PrintCall (i32 8))]>; + multiclass CALL { + def PrintCallNoRetInst : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " "), [(OpNode (i32 0))]>; + def PrintCallRetInst1 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0), "), [(OpNode (i32 1))]>; + def PrintCallRetInst2 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1), "), [(OpNode (i32 2))]>; + def PrintCallRetInst3 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2), "), [(OpNode (i32 3))]>; + def PrintCallRetInst4 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3), "), + [(OpNode (i32 4))]>; + def PrintCallRetInst5 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4), "), + [(OpNode (i32 5))]>; + def PrintCallRetInst6 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5), "), + [(OpNode (i32 6))]>; + def PrintCallRetInst7 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5, retval6), "), + [(OpNode (i32 7))]>; + def PrintCallRetInst8 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5, retval6, retval7), "), + [(OpNode (i32 8))]>; + } +} - def PrintCallUniNoRetInst : NVPTXInst<(outs), (ins), - "call.uni ", [(PrintCallUni (i32 0))]>; - def PrintCallUniRetInst1 : NVPTXInst<(outs), (ins), - "call.uni (retval0), ", [(PrintCallUni (i32 1))]>; - def PrintCallUniRetInst2 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1), ", [(PrintCallUni (i32 2))]>; - def PrintCallUniRetInst3 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2), ", [(PrintCallUni (i32 3))]>; - def PrintCallUniRetInst4 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3), ", [(PrintCallUni (i32 4))]>; - def PrintCallUniRetInst5 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4), ", - [(PrintCallUni (i32 5))]>; - def PrintCallUniRetInst6 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5), ", - [(PrintCallUni (i32 6))]>; - def PrintCallUniRetInst7 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ", - [(PrintCallUni (i32 7))]>; - def PrintCallUniRetInst8 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6, " - "retval7), ", - [(PrintCallUni (i32 8))]>; +defm Call : CALL<"call", PrintCall>; +defm CallUni : CALL<"call.uni", PrintCallUni>; + +// Convergent call instructions. These are identical to regular calls, except +// they have the isConvergent bit set. +let isConvergent=1 in { + defm ConvergentCall : CALL<"call", PrintConvergentCall>; + defm ConvergentCallUni : CALL<"call.uni", PrintConvergentCallUni>; } def LoadParamMemI64 : LoadParamMemInst; diff --git a/llvm/test/CodeGen/NVPTX/convergent-mir-call.ll b/llvm/test/CodeGen/NVPTX/convergent-mir-call.ll new file mode 100644 index 000000000000..18142450490c --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/convergent-mir-call.ll @@ -0,0 +1,27 @@ +; RUN: llc -mtriple nvptx64-nvidia-cuda -stop-after machine-cp -o - < %s 2>&1 | FileCheck %s + +; Check that convergent calls are emitted using convergent MIR instructions, +; while non-convergent calls are not. + +target triple = "nvptx64-nvidia-cuda" + +declare void @conv() convergent +declare void @not_conv() + +define void @test(void ()* %f) { + ; CHECK: ConvergentCallUniPrintCall + ; CHECK-NEXT: @conv + call void @conv() + + ; CHECK: CallUniPrintCall + ; CHECK-NEXT: @not_conv + call void @not_conv() + + ; CHECK: ConvergentCallPrintCall + call void %f() convergent + + ; CHECK: CallPrintCall + call void %f() + + ret void +}