diff options
author | Justin Lebar <jlebar@google.com> | 2018-12-26 19:12:31 +0000 |
---|---|---|
committer | Justin Lebar <jlebar@google.com> | 2018-12-26 19:12:31 +0000 |
commit | e1e1965a52af03a76f13559a3d06af70d237c244 (patch) | |
tree | 3279c8552477d8bdf9dd717717b1f942555fd2d9 /llvm/lib/Target/NVPTX | |
parent | b2fe5df755ec29ac22f9f6ad47af08579bf4504b (diff) |
[NVPTX] Allow libcalls that are defined in the current module.
The patch adds a possibility to make library calls on NVPTX.
An important thing about library functions - they must be defined within
the current module. This basically should guarantee that we produce a
valid PTX assembly (without calls to not defined functions). The one who
wants to use the libcalls is probably will have to link against
compiler-rt or any other implementation.
Currently, it's completely impossible to make library calls because of
error LLVM ERROR: Cannot select: i32 = ExternalSymbol '...'. But we can
lower ExternalSymbol to TargetExternalSymbol and verify if the function
definition is available.
Also, there was an issue with a DAG during legalisation. When we expand
instruction into libcall, the inner call-chain isn't being "integrated"
into outer chain. Since the last "data-flow" (call retval load) node is
located in call-chain earlier than CALLSEQ_END node, the latter becomes
a leaf and therefore a dead node (and is being removed quite fast).
Proposed here solution relies on another data-flow pseudo nodes
(ProxyReg) which purpose is only to keep CALLSEQ_END at legalisation and
instruction selection phases - we remove the pseudo instructions before
register scheduling phase.
Patch by Denys Zariaiev!
Differential Revision: https://reviews.llvm.org/D34708
Diffstat (limited to 'llvm/lib/Target/NVPTX')
-rw-r--r-- | llvm/lib/Target/NVPTX/CMakeLists.txt | 1 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTX.h | 1 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 5 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 47 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 1 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 21 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp | 122 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp | 8 |
8 files changed, 202 insertions, 4 deletions
diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt index 4a64fe0961e..d094620f1bf 100644 --- a/llvm/lib/Target/NVPTX/CMakeLists.txt +++ b/llvm/lib/Target/NVPTX/CMakeLists.txt @@ -32,6 +32,7 @@ set(NVPTXCodeGen_sources NVPTXUtilities.cpp NVVMIntrRange.cpp NVVMReflect.cpp + NVPTXProxyRegErasure.cpp ) add_llvm_target(NVPTXCodeGen ${NVPTXCodeGen_sources}) diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h index 02b8d8fff64..07bfc58a8da 100644 --- a/llvm/lib/Target/NVPTX/NVPTX.h +++ b/llvm/lib/Target/NVPTX/NVPTX.h @@ -53,6 +53,7 @@ FunctionPass *createNVPTXImageOptimizerPass(); FunctionPass *createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM); BasicBlockPass *createNVPTXLowerAllocaPass(); MachineFunctionPass *createNVPTXPeephole(); +MachineFunctionPass *createNVPTXProxyRegErasurePass(); Target &getTheNVPTXTarget32(); Target &getTheNVPTXTarget64(); diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index ad9b3b37810..6284ad8b82e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -730,6 +730,11 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { for (Module::const_iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) { const Function *F = &*FI; + if (F->getAttributes().hasFnAttribute("nvptx-libcall-callee")) { + emitDeclaration(F, O); + continue; + } + if (F->isDeclaration()) { if (F->use_empty()) continue; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 5c16c34e21d..bec8ece2905 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -663,6 +663,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::CallSeqEnd"; case NVPTXISD::CallPrototype: return "NVPTXISD::CallPrototype"; + case NVPTXISD::ProxyReg: + return "NVPTXISD::ProxyReg"; case NVPTXISD::LoadV2: return "NVPTXISD::LoadV2"; case NVPTXISD::LoadV4: @@ -1666,6 +1668,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // indirect calls but is always null for libcalls. bool isIndirectCall = !Func && CS; + if (isa<ExternalSymbolSDNode>(Callee)) { + Function* CalleeFunc = nullptr; + + // Try to find the callee in the current module. + Callee = DAG.getSymbolFunctionGlobalAddress(Callee, &CalleeFunc); + assert(CalleeFunc != nullptr && "Libcall callee must be set."); + + // Set the "libcall callee" attribute to indicate that the function + // must always have a declaration. + CalleeFunc->addFnAttr("nvptx-libcall-callee", "true"); + } + if (isIndirectCall) { // This is indirect function call case : PTX requires a prototype of the // form @@ -1738,6 +1752,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, InFlag = Chain.getValue(1); } + SmallVector<SDValue, 16> ProxyRegOps; + SmallVector<Optional<MVT>, 16> ProxyRegTruncates; + // Generate loads from param memory/moves from registers for result if (Ins.size() > 0) { SmallVector<EVT, 16> VTs; @@ -1808,11 +1825,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, MachineMemOperand::MOLoad); for (unsigned j = 0; j < NumElts; ++j) { - SDValue Ret = RetVal.getValue(j); + ProxyRegOps.push_back(RetVal.getValue(j)); + if (needTruncate) - Ret = DAG.getNode(ISD::TRUNCATE, dl, Ins[VecIdx + j].VT, Ret); - InVals.push_back(Ret); + ProxyRegTruncates.push_back(Optional<MVT>(Ins[VecIdx + j].VT)); + else + ProxyRegTruncates.push_back(Optional<MVT>()); } + Chain = RetVal.getValue(NumElts); InFlag = RetVal.getValue(NumElts + 1); @@ -1828,8 +1848,29 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, DAG.getIntPtrConstant(uniqueCallSite + 1, dl, true), InFlag, dl); + InFlag = Chain.getValue(1); uniqueCallSite++; + // Append ProxyReg instructions to the chain to make sure that `callseq_end` + // will not get lost. Otherwise, during libcalls expansion, the nodes can become + // dangling. + for (unsigned i = 0; i < ProxyRegOps.size(); ++i) { + SDValue Ret = DAG.getNode( + NVPTXISD::ProxyReg, dl, + DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue), + { Chain, ProxyRegOps[i], InFlag } + ); + + Chain = Ret.getValue(1); + InFlag = Ret.getValue(2); + + if (ProxyRegTruncates[i].hasValue()) { + Ret = DAG.getNode(ISD::TRUNCATE, dl, ProxyRegTruncates[i].getValue(), Ret); + } + + InVals.push_back(Ret); + } + // set isTailCall to false for now, until we figure out how to express // tail call optimization in PTX isTailCall = false; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 3e109f75b66..66fab2b6f48 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -51,6 +51,7 @@ enum NodeType : unsigned { CallSeqBegin, CallSeqEnd, CallPrototype, + ProxyReg, FUN_SHFL_CLAMP, FUN_SHFR_CLAMP, MUL_WIDE_SIGNED, diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 48db941db9b..02a40b9f526 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1885,6 +1885,7 @@ def SDTStoreRetvalProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>; def SDTStoreRetvalV2Profile : SDTypeProfile<0, 3, [SDTCisInt<0>]>; def SDTStoreRetvalV4Profile : SDTypeProfile<0, 5, [SDTCisInt<0>]>; def SDTPseudoUseParamProfile : SDTypeProfile<0, 1, []>; +def SDTProxyRegProfile : SDTypeProfile<1, 1, []>; def DeclareParam : SDNode<"NVPTXISD::DeclareParam", SDTDeclareParamProfile, @@ -1972,6 +1973,9 @@ def PseudoUseParam : def RETURNNode : SDNode<"NVPTXISD::RETURN", SDTCallArgMarkProfile, [SDNPHasChain, SDNPSideEffect]>; +def ProxyReg : + SDNode<"NVPTXISD::ProxyReg", SDTProxyRegProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; let mayLoad = 1 in { class LoadParamMemInst<NVPTXRegClass regclass, string opstr> : @@ -2249,6 +2253,21 @@ def PseudoUseParamI16 : PseudoUseParamInst<Int16Regs>; def PseudoUseParamF64 : PseudoUseParamInst<Float64Regs>; def PseudoUseParamF32 : PseudoUseParamInst<Float32Regs>; +class ProxyRegInst<string SzStr, NVPTXRegClass regclass> : + NVPTXInst<(outs regclass:$dst), (ins regclass:$src), + !strconcat("mov.", SzStr, " \t$dst, $src;"), + [(set regclass:$dst, (ProxyReg regclass:$src))]>; + +let isCodeGenOnly=1, isPseudo=1 in { + def ProxyRegI1 : ProxyRegInst<"pred", Int1Regs>; + def ProxyRegI16 : ProxyRegInst<"b16", Int16Regs>; + def ProxyRegI32 : ProxyRegInst<"b32", Int32Regs>; + def ProxyRegI64 : ProxyRegInst<"b64", Int64Regs>; + def ProxyRegF16 : ProxyRegInst<"b16", Float16Regs>; + def ProxyRegF32 : ProxyRegInst<"f32", Float32Regs>; + def ProxyRegF64 : ProxyRegInst<"f64", Float64Regs>; + def ProxyRegF16x2 : ProxyRegInst<"b32", Float16x2Regs>; +} // // Load / Store Handling @@ -2541,7 +2560,7 @@ let mayStore=1, hasSideEffects=0 in { class F_BITCONVERT<string SzStr, NVPTXRegClass regclassIn, NVPTXRegClass regclassOut> : NVPTXInst<(outs regclassOut:$d), (ins regclassIn:$a), - !strconcat("mov.b", !strconcat(SzStr, " \t$d, $a;")), + !strconcat("mov.b", SzStr, " \t$d, $a;"), [(set regclassOut:$d, (bitconvert regclassIn:$a))]>; def BITCONVERT_16_I2F : F_BITCONVERT<"16", Int16Regs, Float16Regs>; diff --git a/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp b/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp new file mode 100644 index 00000000000..f60d841c168 --- /dev/null +++ b/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp @@ -0,0 +1,122 @@ +//===- NVPTXProxyRegErasure.cpp - NVPTX Proxy Register Instruction Erasure -==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The pass is needed to remove ProxyReg instructions and restore related +// registers. The instructions were needed at instruction selection stage to +// make sure that callseq_end nodes won't be removed as "dead nodes". This can +// happen when we expand instructions into libcalls and the call site doesn't +// care about the libcall chain. Call site cares about data flow only, and the +// latest data flow node happens to be before callseq_end. Therefore the node +// becomes dangling and "dead". The ProxyReg acts like an additional data flow +// node *after* the callseq_end in the chain and ensures that everything will be +// preserved. +// +//===----------------------------------------------------------------------===// + +#include "NVPTX.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" + +using namespace llvm; + +namespace llvm { +void initializeNVPTXProxyRegErasurePass(PassRegistry &); +} + +namespace { + +struct NVPTXProxyRegErasure : public MachineFunctionPass { +public: + static char ID; + NVPTXProxyRegErasure() : MachineFunctionPass(ID) { + initializeNVPTXProxyRegErasurePass(*PassRegistry::getPassRegistry()); + } + + bool runOnMachineFunction(MachineFunction &MF) override; + + StringRef getPassName() const override { + return "NVPTX Proxy Register Instruction Erasure"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + MachineFunctionPass::getAnalysisUsage(AU); + } + +private: + void replaceMachineInstructionUsage(MachineFunction &MF, MachineInstr &MI); + + void replaceRegisterUsage(MachineInstr &Instr, MachineOperand &From, + MachineOperand &To); +}; + +} // namespace + +char NVPTXProxyRegErasure::ID = 0; + +INITIALIZE_PASS(NVPTXProxyRegErasure, "nvptx-proxyreg-erasure", "NVPTX ProxyReg Erasure", false, false) + +bool NVPTXProxyRegErasure::runOnMachineFunction(MachineFunction &MF) { + SmallVector<MachineInstr *, 16> RemoveList; + + for (auto &BB : MF) { + for (auto &MI : BB) { + switch (MI.getOpcode()) { + case NVPTX::ProxyRegI1: + case NVPTX::ProxyRegI16: + case NVPTX::ProxyRegI32: + case NVPTX::ProxyRegI64: + case NVPTX::ProxyRegF16: + case NVPTX::ProxyRegF16x2: + case NVPTX::ProxyRegF32: + case NVPTX::ProxyRegF64: + replaceMachineInstructionUsage(MF, MI); + RemoveList.push_back(&MI); + break; + } + } + } + + for (auto *MI : RemoveList) { + MI->eraseFromParent(); + } + + return !RemoveList.empty(); +} + +void NVPTXProxyRegErasure::replaceMachineInstructionUsage(MachineFunction &MF, + MachineInstr &MI) { + auto &InOp = *MI.uses().begin(); + auto &OutOp = *MI.defs().begin(); + + assert(InOp.isReg() && "ProxyReg input operand should be a register."); + assert(OutOp.isReg() && "ProxyReg output operand should be a register."); + + for (auto &BB : MF) { + for (auto &I : BB) { + replaceRegisterUsage(I, OutOp, InOp); + } + } +} + +void NVPTXProxyRegErasure::replaceRegisterUsage(MachineInstr &Instr, + MachineOperand &From, + MachineOperand &To) { + for (auto &Op : Instr.uses()) { + if (Op.isReg() && Op.getReg() == From.getReg()) { + Op.setReg(To.getReg()); + } + } +} + +MachineFunctionPass *llvm::createNVPTXProxyRegErasurePass() { + return new NVPTXProxyRegErasure(); +} diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp index 8c009aed887..8ec0ddb9b3d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -68,6 +68,7 @@ void initializeNVPTXAssignValidGlobalNamesPass(PassRegistry&); void initializeNVPTXLowerAggrCopiesPass(PassRegistry &); void initializeNVPTXLowerArgsPass(PassRegistry &); void initializeNVPTXLowerAllocaPass(PassRegistry &); +void initializeNVPTXProxyRegErasurePass(PassRegistry &); } // end namespace llvm @@ -87,6 +88,7 @@ extern "C" void LLVMInitializeNVPTXTarget() { initializeNVPTXLowerArgsPass(PR); initializeNVPTXLowerAllocaPass(PR); initializeNVPTXLowerAggrCopiesPass(PR); + initializeNVPTXProxyRegErasurePass(PR); } static std::string computeDataLayout(bool is64Bit, bool UseShortPointers) { @@ -160,6 +162,7 @@ public: void addIRPasses() override; bool addInstSelector() override; + void addPreRegAlloc() override; void addPostRegAlloc() override; void addMachineSSAOptimization() override; @@ -301,6 +304,11 @@ bool NVPTXPassConfig::addInstSelector() { return false; } +void NVPTXPassConfig::addPreRegAlloc() { + // Remove Proxy Register pseudo instructions used to keep `callseq_end` alive. + addPass(createNVPTXProxyRegErasurePass()); +} + void NVPTXPassConfig::addPostRegAlloc() { addPass(createNVPTXPrologEpilogPass(), false); if (getOptLevel() != CodeGenOpt::None) { |