diff options
Diffstat (limited to 'contrib/native/client/src')
9 files changed, 247 insertions, 307 deletions
diff --git a/contrib/native/client/src/clientlib/CMakeLists.txt b/contrib/native/client/src/clientlib/CMakeLists.txt index 2270c91ee..7b9ecc3c0 100644 --- a/contrib/native/client/src/clientlib/CMakeLists.txt +++ b/contrib/native/client/src/clientlib/CMakeLists.txt @@ -40,7 +40,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../i include_directories(${PROTOBUF_INCLUDE_DIR}) include_directories(${Zookeeper_INCLUDE_DIRS}) include_directories(${SASL_INCLUDE_DIRS}) -include_directories("${OPENSSL_INCLUDE_DIR}") +include_directories(${OPENSSL_INCLUDE_DIR}) link_directories(/usr/local/lib) diff --git a/contrib/native/client/src/clientlib/channel.cpp b/contrib/native/client/src/clientlib/channel.cpp index 62ce976b1..84f3eb43f 100644 --- a/contrib/native/client/src/clientlib/channel.cpp +++ b/contrib/native/client/src/clientlib/channel.cpp @@ -39,7 +39,7 @@ ConnectionEndpoint::ConnectionEndpoint(const char* connStr){ ConnectionEndpoint::ConnectionEndpoint(const char* host, const char* port){ m_host=host; m_port=port; - m_protocol="drillbit"; // direct connection + m_protocol=PROTOCOL_TYPE_DIRECT; // direct connection m_pError=NULL; } @@ -61,7 +61,7 @@ connectionStatus_t ConnectionEndpoint::getDrillbitEndpoint(){ DRILL_LOG(LOG_INFO) << "Failed to get endpoint from zk" << std::endl; return ret; } - }else if(!this->isDirectConnection()){ + }else if(!isDirectConnection()){ return handleError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, this->getProtocol().c_str())); } }else{ @@ -77,19 +77,19 @@ void ConnectionEndpoint::parseConnectString(){ boost::cmatch matched; if(boost::regex_match(m_connectString.c_str(), matched, connStrExpr)){ - m_protocol.assign(matched[1].first, matched[1].second); + m_protocol = matched[1].str(); if(isDirectConnection()){ - m_host.assign(matched[4].first, matched[4].second); - m_port.assign(matched[5].first, matched[5].second); + m_host = matched[4].str(); + m_port = matched[5].str(); }else { // if the connection is to a zookeeper, // we will get the host and the port only after connecting to the Zookeeper m_host = ""; m_port = ""; } - m_hostPortStr.assign(matched[2].first, matched[2].second); + m_hostPortStr = matched[2].str(); if(matched[6].matched) { - m_pathToDrill.assign(matched[6].first, matched[6].second); + m_pathToDrill = matched[6].str(); } DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Conn str: "<< m_connectString @@ -106,12 +106,12 @@ void ConnectionEndpoint::parseConnectString(){ bool ConnectionEndpoint::isDirectConnection(){ assert(!m_protocol.empty()); - return (!strcmp(m_protocol.c_str(), "local") || !strcmp(m_protocol.c_str(), "drillbit")); + return ( m_protocol == PROTOCOL_TYPE_DIRECT || m_protocol == PROTOCOL_TYPE_DIRECT_2 ); } bool ConnectionEndpoint::isZookeeperConnection(){ assert(!m_protocol.empty()); - return (!strcmp(m_protocol.c_str(), "zk")); + return (m_protocol == PROTOCOL_TYPE_ZK); } connectionStatus_t ConnectionEndpoint::getDrillbitEndpointFromZk(){ @@ -148,138 +148,86 @@ connectionStatus_t ConnectionEndpoint::handleError(connectionStatus_t status, st return status; } -/**************************** - * Channel Context Factory - ****************************/ -ChannelContext* ChannelContextFactory::getChannelContext(channelType_t t, DrillUserProperties* props){ - ChannelContext* pChannelContext=NULL; - switch(t){ - case CHANNEL_TYPE_SOCKET: - pChannelContext=new ChannelContext(props); - break; -#if defined(IS_SSL_ENABLED) - case CHANNEL_TYPE_SSLSTREAM: { - - std::string protocol; - props->getProp(USERPROP_TLSPROTOCOL, protocol); - boost::asio::ssl::context::method tlsVersion = SSLChannelContext::getTlsVersion(protocol); - - std::string noVerifyCert; - props->getProp(USERPROP_DISABLE_CERTVERIFICATION, noVerifyCert); - boost::asio::ssl::context::verify_mode verifyMode = boost::asio::ssl::context::verify_peer; - if (noVerifyCert == "true") { - verifyMode = boost::asio::ssl::context::verify_none; - } - - pChannelContext = new SSLChannelContext(props, tlsVersion, verifyMode); - } - break; -#endif - default: - DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; - break; - } - return pChannelContext; -} - /******************* * ChannelFactory * *****************/ -Channel* ChannelFactory::getChannel(channelType_t t, const char* connStr){ +Channel* ChannelFactory::getChannel(channelType_t t, boost::asio::io_service& ioService, const char* connStr, DrillUserProperties* props){ Channel* pChannel=NULL; + ChannelContext_t * pChannelContext = ChannelFactory::getChannelContext(t, props); switch(t){ case CHANNEL_TYPE_SOCKET: - pChannel=new SocketChannel(connStr); + pChannel=new SocketChannel(ioService, connStr); break; #if defined(IS_SSL_ENABLED) case CHANNEL_TYPE_SSLSTREAM: - pChannel=new SSLStreamChannel(connStr); + pChannel=new SSLStreamChannel(ioService, connStr); break; #endif default: DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; break; } + pChannel->m_pContext = pChannelContext; return pChannel; } -Channel* ChannelFactory::getChannel(channelType_t t, const char* host, const char* port){ +Channel* ChannelFactory::getChannel(channelType_t t, boost::asio::io_service& ioService, const char* host, const char* port, DrillUserProperties* props){ Channel* pChannel=NULL; + ChannelContext_t * pChannelContext = ChannelFactory::getChannelContext(t, props); switch(t){ case CHANNEL_TYPE_SOCKET: - pChannel=new SocketChannel(host, port); + pChannel=new SocketChannel(ioService, host, port); break; #if defined(IS_SSL_ENABLED) case CHANNEL_TYPE_SSLSTREAM: - pChannel=new SSLStreamChannel(host, port); + pChannel=new SSLStreamChannel(ioService, host, port); break; #endif default: DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; break; } + pChannel->m_pContext = pChannelContext; return pChannel; } -Channel* ChannelFactory::getChannel(channelType_t t, boost::asio::io_service& ioService, const char* connStr){ - Channel* pChannel=NULL; +ChannelContext* ChannelFactory::getChannelContext(channelType_t t, DrillUserProperties* props){ + ChannelContext* pChannelContext=NULL; switch(t){ case CHANNEL_TYPE_SOCKET: - pChannel=new SocketChannel(ioService, connStr); + pChannelContext=new ChannelContext(props); break; #if defined(IS_SSL_ENABLED) - case CHANNEL_TYPE_SSLSTREAM: - pChannel=new SSLStreamChannel(ioService, connStr); - break; -#endif - default: - DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; - break; - } - return pChannel; -} + case CHANNEL_TYPE_SSLSTREAM: { -Channel* ChannelFactory::getChannel(channelType_t t, boost::asio::io_service& ioService, const char* host, const char* port){ - Channel* pChannel=NULL; - switch(t){ - case CHANNEL_TYPE_SOCKET: - pChannel=new SocketChannel(ioService, host, port); - break; -#if defined(IS_SSL_ENABLED) - case CHANNEL_TYPE_SSLSTREAM: - pChannel=new SSLStreamChannel(ioService, host, port); + std::string protocol; + props->getProp(USERPROP_TLSPROTOCOL, protocol); + boost::asio::ssl::context::method tlsVersion = SSLChannelContext::getTlsVersion(protocol); + + std::string noVerifyCert; + props->getProp(USERPROP_DISABLE_CERTVERIFICATION, noVerifyCert); + boost::asio::ssl::context::verify_mode verifyMode = boost::asio::ssl::context::verify_peer; + if (noVerifyCert == "true") { + verifyMode = boost::asio::ssl::context::verify_none; + } + + pChannelContext = new SSLChannelContext(props, tlsVersion, verifyMode); + } break; #endif default: DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; break; } - return pChannel; + return pChannelContext; } /******************* * Channel * *****************/ -Channel::Channel(const char* connStr) : m_ioService(m_ioServiceFallback){ - m_pEndpoint=new ConnectionEndpoint(connStr); - m_ownIoService = true; - m_pSocket=NULL; - m_state=CHANNEL_UNINITIALIZED; - m_pError=NULL; -} - -Channel::Channel(const char* host, const char* port) : m_ioService(m_ioServiceFallback){ - m_pEndpoint=new ConnectionEndpoint(host, port); - m_ownIoService = true; - m_pSocket=NULL; - m_state=CHANNEL_UNINITIALIZED; - m_pError=NULL; -} - Channel::Channel(boost::asio::io_service& ioService, const char* connStr):m_ioService(ioService){ m_pEndpoint=new ConnectionEndpoint(connStr); - m_ownIoService = false; m_pSocket=NULL; m_state=CHANNEL_UNINITIALIZED; m_pError=NULL; @@ -287,7 +235,6 @@ Channel::Channel(boost::asio::io_service& ioService, const char* connStr):m_ioSe Channel::Channel(boost::asio::io_service& ioService, const char* host, const char* port) : m_ioService(ioService){ m_pEndpoint=new ConnectionEndpoint(host, port); - m_ownIoService = true; m_pSocket=NULL; m_state=CHANNEL_UNINITIALIZED; m_pError=NULL; @@ -311,10 +258,9 @@ template <typename SettableSocketOption> void Channel::setOption(SettableSocketO assert(0); } -connectionStatus_t Channel::init(ChannelContext_t* pContext){ +connectionStatus_t Channel::init(){ connectionStatus_t ret=CONN_SUCCESS; this->m_state=CHANNEL_INITIALIZED; - this->m_pContext = pContext; return ret; } @@ -389,11 +335,11 @@ connectionStatus_t Channel::connectInternal() { } -connectionStatus_t SocketChannel::init(ChannelContext_t* pContext){ +connectionStatus_t SocketChannel::init(){ connectionStatus_t ret=CONN_SUCCESS; m_pSocket=new Socket(m_ioService); if(m_pSocket!=NULL){ - ret=Channel::init(pContext); + ret=Channel::init(); }else{ DRILL_LOG(LOG_ERROR) << "Channel initialization failure. " << std::endl; handleError(CONN_NOSOCKET, getMessage(ERR_CONN_NOSOCKET)); @@ -403,17 +349,17 @@ connectionStatus_t SocketChannel::init(ChannelContext_t* pContext){ } #if defined(IS_SSL_ENABLED) -connectionStatus_t SSLStreamChannel::init(ChannelContext_t* pContext){ +connectionStatus_t SSLStreamChannel::init(){ connectionStatus_t ret=CONN_SUCCESS; - const DrillUserProperties* props = pContext->getUserProperties(); + const DrillUserProperties* props = m_pContext->getUserProperties(); std::string useSystemTrustStore; props->getProp(USERPROP_USESYSTEMTRUSTSTORE, useSystemTrustStore); if (useSystemTrustStore != "true"){ std::string certFile; props->getProp(USERPROP_CERTFILEPATH, certFile); try{ - ((SSLChannelContext_t*)pContext)->getSslContext().load_verify_file(certFile); + ((SSLChannelContext_t*)m_pContext)->getSslContext().load_verify_file(certFile); } catch (boost::system::system_error e){ DRILL_LOG(LOG_ERROR) << "Channel initialization failure. Certificate file " @@ -429,13 +375,13 @@ connectionStatus_t SSLStreamChannel::init(ChannelContext_t* pContext){ props->getProp(USERPROP_DISABLE_HOSTVERIFICATION, disableHostVerification); if (disableHostVerification != "true") { std::string hostPortStr = m_pEndpoint->getHost() + ":" + m_pEndpoint->getPort(); - ((SSLChannelContext_t *) pContext)->getSslContext().set_verify_callback( + ((SSLChannelContext_t *) m_pContext)->getSslContext().set_verify_callback( boost::asio::ssl::rfc2818_verification(hostPortStr.c_str())); } - m_pSocket=new SslSocket(m_ioService, ((SSLChannelContext_t*)pContext)->getSslContext() ); + m_pSocket=new SslSocket(m_ioService, ((SSLChannelContext_t*)m_pContext)->getSslContext() ); if(m_pSocket!=NULL){ - ret=Channel::init(pContext); + ret=Channel::init(); }else{ DRILL_LOG(LOG_ERROR) << "Channel initialization failure. " << std::endl; handleError(CONN_NOSOCKET, getMessage(ERR_CONN_NOSOCKET)); diff --git a/contrib/native/client/src/clientlib/channel.hpp b/contrib/native/client/src/clientlib/channel.hpp index 7f310e899..c7ebfeee6 100644 --- a/contrib/native/client/src/clientlib/channel.hpp +++ b/contrib/native/client/src/clientlib/channel.hpp @@ -36,9 +36,9 @@ class UserProperties; //parse the connection string and set up the host and port to connect to connectionStatus_t getDrillbitEndpoint(); - std::string& getProtocol(){return m_protocol;} - std::string& getHost(){return m_host;} - std::string& getPort(){return m_port;} + const std::string& getProtocol() const {return m_protocol;} + const std::string& getHost() const {return m_host;} + const std::string& getPort() const {return m_port;} DrillClientError* getError(){ return m_pError;}; private: @@ -70,25 +70,21 @@ class UserProperties; class SSLChannelContext: public ChannelContext{ public: - static boost::asio::ssl::context::method getTlsVersion(std::string version){ - if(version.empty()){ - return boost::asio::ssl::context::tlsv12; - } else if (version == "tlsv12") { + static boost::asio::ssl::context::method getTlsVersion(const std::string & version){ + if (version == "tlsv12") { return boost::asio::ssl::context::tlsv12; } else if (version == "tlsv11") { return boost::asio::ssl::context::tlsv11; - } else if (version == "sslv23") { - return boost::asio::ssl::context::sslv23; } else if (version == "tlsv1") { return boost::asio::ssl::context::tlsv1; - } else if (version == "sslv3") { - return boost::asio::ssl::context::sslv3; } else { return boost::asio::ssl::context::tlsv12; } } - SSLChannelContext(DrillUserProperties *props, boost::asio::ssl::context::method tlsVersion, boost::asio::ssl::verify_mode verifyMode) : + SSLChannelContext(DrillUserProperties *props, + boost::asio::ssl::context::method tlsVersion, + boost::asio::ssl::verify_mode verifyMode) : ChannelContext(props), m_SSLContext(tlsVersion) { m_SSLContext.set_default_verify_paths(); @@ -108,11 +104,6 @@ class UserProperties; typedef ChannelContext ChannelContext_t; typedef SSLChannelContext SSLChannelContext_t; - class ChannelContextFactory{ - public: - static ChannelContext_t* getChannelContext(channelType_t t, DrillUserProperties* props); - }; - /*** * The Channel class encapsulates a connection to a drillbit. Based on * the connection string and the options, the connection will be either @@ -122,13 +113,12 @@ class UserProperties; * will use to communicate with the server. ***/ class Channel{ + friend class ChannelFactory; public: - Channel(const char* connStr); - Channel(const char* host, const char* port); Channel(boost::asio::io_service& ioService, const char* connStr); Channel(boost::asio::io_service& ioService, const char* host, const char* port); virtual ~Channel(); - virtual connectionStatus_t init(ChannelContext_t* context)=0; + virtual connectionStatus_t init()=0; connectionStatus_t connect(); bool isConnected(){ return m_state == CHANNEL_CONNECTED;} template <typename SettableSocketOption> void setOption(SettableSocketOption& option); @@ -168,7 +158,6 @@ class UserProperties; ChannelContext_t *m_pContext; private: - typedef enum channelState{ CHANNEL_UNINITIALIZED=1, CHANNEL_INITIALIZED, @@ -189,45 +178,42 @@ class UserProperties; channelState_t m_state; DrillClientError* m_pError; - bool m_ownIoService; }; class SocketChannel: public Channel{ public: - SocketChannel(const char* connStr):Channel(connStr){ - } - SocketChannel(const char* host, const char* port):Channel(host, port){ - } SocketChannel(boost::asio::io_service& ioService, const char* connStr) :Channel(ioService, connStr){ } SocketChannel(boost::asio::io_service& ioService, const char* host, const char* port) :Channel(ioService, host, port){ } - connectionStatus_t init(ChannelContext_t* context=NULL); + connectionStatus_t init(); }; class SSLStreamChannel: public Channel{ public: - SSLStreamChannel(const char* connStr):Channel(connStr){ - } - SSLStreamChannel(const char* host, const char* port):Channel(host, port){ - } SSLStreamChannel(boost::asio::io_service& ioService, const char* connStr) :Channel(ioService, connStr){ } SSLStreamChannel(boost::asio::io_service& ioService, const char* host, const char* port) :Channel(ioService, host, port){ } - connectionStatus_t init(ChannelContext_t* context); + connectionStatus_t init(); }; class ChannelFactory{ public: - static Channel* getChannel(channelType_t t, const char* connStr); - static Channel* getChannel(channelType_t t, const char* host, const char* port); - static Channel* getChannel(channelType_t t, boost::asio::io_service& ioService, const char* connStr); - static Channel* getChannel(channelType_t t, boost::asio::io_service& ioService, const char* host, const char* port); + static Channel* getChannel(channelType_t t, + boost::asio::io_service& ioService, + const char* connStr, DrillUserProperties* props); + static Channel* getChannel(channelType_t t, + boost::asio::io_service& ioService, + const char* host, + const char* port, + DrillUserProperties* props); + private: + static ChannelContext_t* getChannelContext(channelType_t t, DrillUserProperties* props); }; diff --git a/contrib/native/client/src/clientlib/drillClientImpl.cpp b/contrib/native/client/src/clientlib/drillClientImpl.cpp index 39ac847c6..f0bb636b3 100644 --- a/contrib/native/client/src/clientlib/drillClientImpl.cpp +++ b/contrib/native/client/src/clientlib/drillClientImpl.cpp @@ -19,31 +19,22 @@ #include "drill/common.hpp" #include <queue> -#include <string> #include <boost/algorithm/string.hpp> #include <boost/asio.hpp> #include <boost/assign.hpp> #include <boost/bind.hpp> #include <boost/date_time/posix_time/posix_time.hpp> -#include <boost/date_time/posix_time/posix_time_duration.hpp> #include <boost/functional/factory.hpp> -#include <boost/lexical_cast.hpp> #include <boost/thread.hpp> #include "drill/drillClient.hpp" #include "drill/fieldmeta.hpp" #include "drill/recordBatch.hpp" +#include "drill/userProperties.hpp" #include "drillClientImpl.hpp" -#include "collectionsImpl.hpp" #include "errmsgs.hpp" #include "logger.hpp" -#include "metadata.hpp" -#include "rpcMessage.hpp" -#include "utils.hpp" -#include "GeneralRPC.pb.h" -#include "UserBitShared.pb.h" #include "zookeeperClient.hpp" -#include "saslAuthenticatorImpl.hpp" namespace Drill{ namespace { // anonymous namespace @@ -65,108 +56,69 @@ struct ToRpcType: public std::unary_function<google::protobuf::int32, exec::user return static_cast<exec::user::RpcType>(i); } }; -} -connectionStatus_t DrillClientImpl::connect(const char* connStr, DrillUserProperties* props){ - std::string pathToDrill, protocol, hostPortStr; - std::string host; - std::string port; +} // anonymous +connectionStatus_t DrillClientImpl::connect(const char* connStr, DrillUserProperties* props){ if (this->m_bIsConnected) { - if(std::strcmp(connStr, m_connectStr.c_str())){ // trying to connect to a different address is not allowed if already connected + if(!std::strcmp(connStr, m_connectStr.c_str())){ + // trying to connect to a different address is not allowed if already connected return handleConnError(CONN_ALREADYCONNECTED, getMessage(ERR_CONN_ALREADYCONN)); } return CONN_SUCCESS; } - - m_connectStr=connStr; - Utils::parseConnectStr(connStr, pathToDrill, protocol, hostPortStr); - if(protocol == "zk"){ - ZookeeperClient zook(pathToDrill); - std::vector<std::string> drillbits; - int err = zook.getAllDrillbits(hostPortStr, drillbits); - if(!err){ - if (drillbits.empty()){ - return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_ZKNODBIT)); - } - Utils::shuffle(drillbits); - exec::DrillbitEndpoint endpoint; - err = zook.getEndPoint(drillbits[drillbits.size() -1], endpoint);// get the last one in the list - if(!err){ - host=boost::lexical_cast<std::string>(endpoint.address()); - port=boost::lexical_cast<std::string>(endpoint.user_port()); - } - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Choosing drillbit <" << (drillbits.size() - 1) << ">. Selected " << endpoint.DebugString() << std::endl;) - - } - if(err){ - return handleConnError(CONN_ZOOKEEPER_ERROR, getMessage(ERR_CONN_ZOOKEEPER, zook.getError().c_str())); - } - zook.close(); - m_bIsDirectConnection=true; - }else if(protocol == "local"){ - boost::lock_guard<boost::mutex> lock(m_dcMutex);//strtok is not reentrant - char tempStr[MAX_CONNECT_STR+1]; - strncpy(tempStr, hostPortStr.c_str(), MAX_CONNECT_STR); tempStr[MAX_CONNECT_STR]=0; - host=strtok(tempStr, ":"); - port=strtok(NULL, ""); - m_bIsDirectConnection=false; - }else{ - return handleConnError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, protocol.c_str())); - } - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Connecting to endpoint: " << host << ":" << port << std::endl;) - std::string serviceHost; - for (size_t i = 0; i < props->size(); i++) { - if (props->keyAt(i) == USERPROP_SERVICE_HOST) { - serviceHost = props->valueAt(i); - } + std::string val; + channelType_t type = ( props->isPropSet(USERPROP_USESSL) && + props->getProp(USERPROP_USESSL, val) =="true") ? + CHANNEL_TYPE_SSLSTREAM : + CHANNEL_TYPE_SOCKET; + + connectionStatus_t ret = CONN_SUCCESS; + m_pChannel= ChannelFactory::getChannel(type, m_io_service, connStr, props); + ret=m_pChannel->init(); + if(ret!=CONN_SUCCESS){ + handleConnError(m_pChannel->getError()); + return ret; } - if (serviceHost.empty()) { - props->setProperty(USERPROP_SERVICE_HOST, host); + ret= m_pChannel->connect(); + if(ret!=CONN_SUCCESS){ + handleConnError(m_pChannel->getError()); + return ret; } - connectionStatus_t ret = this->connect(host.c_str(), port.c_str()); + props->setProperty(USERPROP_SERVICE_HOST, m_pChannel->getEndpoint()->getHost()); + m_bIsConnected = true; return ret; } -connectionStatus_t DrillClientImpl::connect(const char* host, const char* port){ - using boost::asio::ip::tcp; - tcp::endpoint endpoint; - try{ - tcp::resolver resolver(m_io_service); - tcp::resolver::query query(tcp::v4(), host, port); - tcp::resolver::iterator iter = resolver.resolve(query); - tcp::resolver::iterator end; - while (iter != end){ - endpoint = *iter++; - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << endpoint << std::endl;) - } - boost::system::error_code ec; - m_socket.connect(endpoint, ec); - if(ec){ - return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_FAILURE, host, port, ec.message().c_str())); - } - - }catch(const std::exception & e){ - // Handle case when the hostname cannot be resolved. "resolve" is hard-coded in boost asio resolver.resolve - if (!strcmp(e.what(), "resolve")) { - return handleConnError(CONN_HOSTNAME_RESOLUTION_ERROR, getMessage(ERR_CONN_EXCEPT, e.what())); +connectionStatus_t DrillClientImpl::connect(const char* host, const char* port, DrillUserProperties* props){ + if (this->m_bIsConnected) { + std::string connStr = std::string(host)+":"+std::string(port); + if(!std::strcmp(connStr.c_str(), m_connectStr.c_str())){ + // trying to connect to a different address is not allowed if already connected + return handleConnError(CONN_ALREADYCONNECTED, getMessage(ERR_CONN_ALREADYCONN)); } - return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_EXCEPT, e.what())); + return CONN_SUCCESS; } - - m_bIsConnected=true; - // set socket keep alive - boost::asio::socket_base::keep_alive keepAlive(true); - m_socket.set_option(keepAlive); - // set no_delay - boost::asio::ip::tcp::no_delay noDelay(true); - m_socket.set_option(noDelay); - - std::ostringstream connectedHost; - connectedHost << "id: " << m_socket.native_handle() << " address: " << host << ":" << port; - m_connectedHost = connectedHost.str(); - DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Connected to endpoint: " << m_connectedHost << std::endl;) - - return CONN_SUCCESS; + std::string val; + channelType_t type = ( props->isPropSet(USERPROP_USESSL) && + props->getProp(USERPROP_USESSL, val) =="true") ? + CHANNEL_TYPE_SSLSTREAM : + CHANNEL_TYPE_SOCKET; + + connectionStatus_t ret = CONN_SUCCESS; + m_pChannel= ChannelFactory::getChannel(type, m_io_service, host, port, props); + ret=m_pChannel->init(); + if(ret!=CONN_SUCCESS){ + handleConnError(m_pChannel->getError()); + return ret; + } + ret=m_pChannel->connect(); + if(ret!=CONN_SUCCESS){ + handleConnError(m_pChannel->getError()); + return ret; + } + props->setProperty(USERPROP_SERVICE_HOST, m_pChannel->getEndpoint()->getHost()); + m_bIsConnected = true; + return ret; } void DrillClientImpl::startHeartbeatTimer(){ @@ -250,7 +202,15 @@ void DrillClientImpl::doWriteToSocket(const char* dataPtr, size_t bytesToWrite, // Write all the bytes to socket. In case of error when all bytes are not successfully written // proper errorCode will be set. while(1) { - size_t bytesWritten = m_socket.write_some(boost::asio::buffer(dataPtr, bytesToWrite), errorCode); + size_t bytesWritten; + { + boost::lock_guard<boost::mutex> lock(m_channelMutex); + if(m_pChannel==NULL){ + return; + } + bytesWritten = m_pChannel->getSocketStream().writeSome(boost::asio::buffer(dataPtr, bytesToWrite), + errorCode); + } if(errorCode && boost::asio::error::interrupted != errorCode){ break; @@ -359,8 +319,10 @@ connectionStatus_t DrillClientImpl::recvHandshake(){ } m_io_service.reset(); - if (DrillClientConfig::getHandshakeTimeout() > 0){ - m_deadlineTimer.expires_from_now(boost::posix_time::seconds(DrillClientConfig::getHandshakeTimeout())); + + int32_t handshakeTimeout=DrillClientConfig::getHandshakeTimeout(); + if (handshakeTimeout > 0){ + m_deadlineTimer.expires_from_now(boost::posix_time::seconds(handshakeTimeout)); m_deadlineTimer.async_wait(boost::bind( &DrillClientImpl::handleHShakeReadTimeout, this, @@ -370,16 +332,21 @@ connectionStatus_t DrillClientImpl::recvHandshake(){ << DrillClientConfig::getHandshakeTimeout() << " seconds." << std::endl;) } - async_read( - this->m_socket, - boost::asio::buffer(m_rbuf, LEN_PREFIX_BUFLEN), - boost::bind( - &DrillClientImpl::handleHandshake, - this, - m_rbuf, - boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred) - ); + { + boost::lock_guard<boost::mutex> lock(m_channelMutex); + if (m_pChannel == NULL) { + return CONN_NOSOCKET; + } + m_pChannel->getSocketStream().asyncRead( + boost::asio::buffer(m_rbuf, LEN_PREFIX_BUFLEN), + boost::bind( + &DrillClientImpl::handleHandshake, + this, + m_rbuf, + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred) + ); + } DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::recvHandshake: async read waiting for server handshake response.\n";) m_io_service.run(); if(m_rbuf!=NULL){ @@ -418,8 +385,15 @@ void DrillClientImpl::doReadFromSocket(ByteBuf_t inBuf, size_t bytesToRead, // Read all the bytes. In case when all the bytes were not read the proper // errorCode will be set. while(1){ - size_t dataBytesRead = m_socket.read_some(boost::asio::buffer(inBuf, bytesToRead), + size_t dataBytesRead; + { + boost::lock_guard<boost::mutex> lock(m_channelMutex); + if(m_pChannel==NULL){ + return; + } + dataBytesRead = m_pChannel->getSocketStream().readSome(boost::asio::buffer(inBuf, bytesToRead), errorCode); + } // Check if errorCode is EINTR then just retry otherwise break from loop if(errorCode && boost::asio::error::interrupted != errorCode){ break; @@ -518,8 +492,10 @@ void DrillClientImpl::handleHShakeReadTimeout(const boost::system::error_code & << "Deadline timer expired; ERR_CONN_HSHAKETIMOUT.\n";) handleConnError(CONN_HANDSHAKE_TIMEOUT, getMessage(ERR_CONN_HSHAKETIMOUT)); m_io_service.stop(); - boost::system::error_code ignorederr; - m_socket.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignorederr); + { + boost::lock_guard<boost::mutex> lock(m_channelMutex); + if(m_pChannel != NULL) m_pChannel->close(); + } } } return; @@ -1027,16 +1003,21 @@ void DrillClientImpl::getNextResult(){ startHeartbeatTimer(); - async_read( - this->m_socket, - boost::asio::buffer(readBuf, LEN_PREFIX_BUFLEN), - boost::bind( - &DrillClientImpl::handleRead, - this, - readBuf, - boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred) - ); + { + boost::lock_guard<boost::mutex> lock(m_channelMutex); + if (m_pChannel == NULL) { + return; + } + m_pChannel->getSocketStream().asyncRead( + boost::asio::buffer(readBuf, LEN_PREFIX_BUFLEN), + boost::bind( + &DrillClientImpl::handleRead, + this, + readBuf, + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred) + ); + } DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::getNextResult: async_read from the server\n";) } @@ -1937,10 +1918,16 @@ void DrillClientImpl::handleReadTimeout(const boost::system::error_code & err){ // defined. To be really sure, we need to close the socket. Closing the socket is a bit // drastic and we will defer that till a later release. #ifdef WIN32_SHUTDOWN_ON_TIMEOUT - boost::system::error_code ignorederr; - m_socket.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignorederr); + { + boost::lock_guard<boost::mutex> lock(m_channelMutex); + if(m_pChannel != NULL) m_pChannel->close(); + } + m_pChannel->close(); #else // NOT WIN32_SHUTDOWN_ON_TIMEOUT - m_socket.cancel(); + { + boost::lock_guard<boost::mutex> lock(m_channelMutex); + if(m_pChannel != NULL) m_pChannel->getInnerSocket().cancel(); + } #endif // WIN32_SHUTDOWN_ON_TIMEOUT } } @@ -2149,6 +2136,20 @@ connectionStatus_t DrillClientImpl::handleConnError(connectionStatus_t status, c return status; } +connectionStatus_t DrillClientImpl::handleConnError(DrillClientError* err){ + DrillClientError* pErr = new DrillClientError(*err); + m_pendingRequests=0; + if(!m_queryHandles.empty()){ + // set query error only if queries are running + broadcastError(pErr); + }else{ + if(m_pError!=NULL){ delete m_pError; m_pError=NULL;} + m_pError=pErr; + shutdownSocket(); + } + return (connectionStatus_t)pErr->status; +} + /* * Always called with NULL QueryHandle when there is any error while reading data from socket. Once enough data is read * and a valid RPC message is formed then it can get called with NULL/valid QueryHandle depending on if QueryHandle is found @@ -2268,9 +2269,16 @@ void DrillClientImpl::sendCancel(const exec::shared::QueryId* pQueryId){ } void DrillClientImpl::shutdownSocket(){ + m_pendingRequests=0; + m_heartbeatTimer.cancel(); + m_deadlineTimer.cancel(); + { + boost::lock_guard<boost::mutex> lock(m_channelMutex); + if (m_pChannel != NULL) { + m_pChannel->close(); + } + } m_io_service.stop(); - boost::system::error_code ignorederr; - m_socket.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignorederr); m_bIsConnected=false; // Delete the saslAuthenticatorImpl instance since connection is broken. It will recreated on next @@ -2697,7 +2705,7 @@ connectionStatus_t PooledDrillClientImpl::connect(const char* connStr, DrillUser } DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Connecting to endpoint: (Pooled) " << host << ":" << port << std::endl;) DrillClientImpl* pDrillClientImpl = new DrillClientImpl(); - stat = pDrillClientImpl->connect(host.c_str(), port.c_str()); + stat = pDrillClientImpl->connect(host.c_str(), port.c_str(), props); if(stat == CONN_SUCCESS){ boost::lock_guard<boost::mutex> lock(m_poolMutex); m_clientConnections.push_back(pDrillClientImpl); diff --git a/contrib/native/client/src/clientlib/drillClientImpl.hpp b/contrib/native/client/src/clientlib/drillClientImpl.hpp index dacc2c30a..dc4a67eaa 100644 --- a/contrib/native/client/src/clientlib/drillClientImpl.hpp +++ b/contrib/native/client/src/clientlib/drillClientImpl.hpp @@ -45,6 +45,7 @@ #include "drill/drillConfig.hpp" #include "drill/drillError.hpp" #include "drill/preparedStatement.hpp" +#include "channel.hpp" #include "collectionsImpl.hpp" #include "metadata.hpp" #include "rpcMessage.hpp" @@ -386,7 +387,8 @@ class DrillClientImpl : public DrillClientImplBase{ m_pError(NULL), m_pListenerThread(NULL), m_pWork(NULL), - m_socket(m_io_service), + m_pChannel(NULL), + m_pChannelContext(NULL), m_deadlineTimer(m_io_service), m_heartbeatTimer(m_io_service), m_rbuf(NULL), @@ -399,9 +401,11 @@ class DrillClientImpl : public DrillClientImplBase{ }; ~DrillClientImpl(){ - //TODO: Cleanup. - //Free any record batches or buffers remaining //Cancel any pending requests + m_heartbeatTimer.cancel(); + m_deadlineTimer.cancel(); + m_io_service.stop(); + //Free any record batches or buffers remaining //Clear and destroy DrillClientQueryResults vector? if(this->m_pWork!=NULL){ delete this->m_pWork; @@ -411,13 +415,19 @@ class DrillClientImpl : public DrillClientImplBase{ delete this->m_saslAuthenticator; this->m_saslAuthenticator = NULL; } + { + boost::lock_guard<boost::mutex> lock(m_channelMutex); + if (this->m_pChannel != NULL) { + m_pChannel->close(); + delete this->m_pChannel; + this->m_pChannel = NULL; + } + if (this->m_pChannelContext != NULL) { + delete this->m_pChannelContext; + this->m_pChannelContext = NULL; + } + } - m_heartbeatTimer.cancel(); - m_deadlineTimer.cancel(); - m_io_service.stop(); - boost::system::error_code ignorederr; - m_socket.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignorederr); - m_socket.close(); if(m_rbuf!=NULL){ Utils::freeBuffer(m_rbuf, MAX_SOCK_RD_BUFSIZE); m_rbuf=NULL; } @@ -442,6 +452,8 @@ class DrillClientImpl : public DrillClientImplBase{ //Connect via Zookeeper or directly connectionStatus_t connect(const char* connStr, DrillUserProperties* props); + connectionStatus_t connect(const char* host, const char* port, DrillUserProperties* props); + // test whether the client is active bool Active(); void Close() ; @@ -523,6 +535,7 @@ class DrillClientImpl : public DrillClientImplBase{ status_t validateResultMessage(const rpc::InBoundRpcMessage& msg, const exec::shared::QueryResult& qr, std::string& valError); bool validateResultRPCType(DrillClientQueryHandle* pQueryHandle, const rpc::InBoundRpcMessage& msg); connectionStatus_t handleConnError(connectionStatus_t status, const std::string& msg); + connectionStatus_t handleConnError(DrillClientError* err); status_t handleQryCancellation(status_t status, DrillClientQueryResult* pQueryResult); status_t handleQryError(status_t status, const std::string& msg, DrillClientQueryHandle* pQueryHandle); status_t handleQryError(status_t status, const exec::shared::DrillPBError& e, DrillClientQueryHandle* pQueryHandle); @@ -603,7 +616,12 @@ class DrillClientImpl : public DrillClientImplBase{ boost::asio::io_service m_io_service; // the work object prevent io_service running out of work boost::asio::io_service::work * m_pWork; - boost::asio::ip::tcp::socket m_socket; + + // Mutex to protect channel + boost::mutex m_channelMutex; + Channel* m_pChannel; + ChannelContext_t* m_pChannelContext; + boost::asio::deadline_timer m_deadlineTimer; // to timeout async queries that never return boost::asio::deadline_timer m_heartbeatTimer; // to send heartbeat messages diff --git a/contrib/native/client/src/clientlib/drillConfig.cpp b/contrib/native/client/src/clientlib/drillConfig.cpp index abaa79aff..90a751a19 100644 --- a/contrib/native/client/src/clientlib/drillConfig.cpp +++ b/contrib/native/client/src/clientlib/drillConfig.cpp @@ -16,7 +16,7 @@ * limitations under the License. */ - +#include <boost/thread/lock_guard.hpp> #include "drill/common.hpp" #include "drill/drillConfig.hpp" #include "env.h" diff --git a/contrib/native/client/src/include/drill/common.hpp b/contrib/native/client/src/include/drill/common.hpp index 012bd1974..d8e2da78d 100644 --- a/contrib/native/client/src/include/drill/common.hpp +++ b/contrib/native/client/src/include/drill/common.hpp @@ -165,6 +165,10 @@ typedef enum{ RET_FAILURE=1 } ret_t; +// Connect string protocol types +#define PROTOCOL_TYPE_ZK "zk" +#define PROTOCOL_TYPE_DIRECT "drillbit" +#define PROTOCOL_TYPE_DIRECT_2 "local" // User Property Names #define USERPROP_USERNAME "userName" @@ -173,7 +177,8 @@ typedef enum{ #define USERPROP_USESSL "enableTLS" #define USERPROP_TLSPROTOCOL "TLSProtocol" //TLS version #define USERPROP_CERTFILEPATH "certFilePath" // pem file path and name -#define USERPROP_CERTPASSWORD "certPassword" // Password for certificate file +// TODO: support truststore protected by password. +// #define USERPROP_CERTPASSWORD "certPassword" // Password for certificate file. #define USERPROP_DISABLE_HOSTVERIFICATION "disableHostVerification" #define USERPROP_DISABLE_CERTVERIFICATION "disableCertVerification" #define USERPROP_USESYSTEMTRUSTSTORE "useSystemTrustStore" //Windows only, use the system trust store diff --git a/contrib/native/client/src/include/drill/drillConfig.hpp b/contrib/native/client/src/include/drill/drillConfig.hpp index 669267d86..46bbbb2d2 100644 --- a/contrib/native/client/src/include/drill/drillConfig.hpp +++ b/contrib/native/client/src/include/drill/drillConfig.hpp @@ -21,27 +21,7 @@ #define DRILL_CONFIG_H #include "drill/common.hpp" -#include <boost/thread.hpp> - - - -#if defined _WIN32 || defined __CYGWIN__ - #ifdef DRILL_CLIENT_EXPORTS - #define DECLSPEC_DRILL_CLIENT __declspec(dllexport) - #else - #ifdef USE_STATIC_LIBDRILL - #define DECLSPEC_DRILL_CLIENT - #else - #define DECLSPEC_DRILL_CLIENT __declspec(dllimport) - #endif - #endif -#else - #if __GNUC__ >= 4 - #define DECLSPEC_DRILL_CLIENT __attribute__ ((visibility ("default"))) - #else - #define DECLSPEC_DRILL_CLIENT - #endif -#endif +#include <boost/thread/mutex.hpp> namespace exec{ namespace shared{ diff --git a/contrib/native/client/src/include/drill/userProperties.hpp b/contrib/native/client/src/include/drill/userProperties.hpp index 3490dce7a..62a04f787 100644 --- a/contrib/native/client/src/include/drill/userProperties.hpp +++ b/contrib/native/client/src/include/drill/userProperties.hpp @@ -36,20 +36,17 @@ class DECLSPEC_DRILL_CLIENT DrillUserProperties{ size_t size() const { return m_properties.size(); } - //const std::string& keyAt(size_t i) const { return m_properties.at(i).first; } - - //const std::string& valueAt(size_t i) const { return m_properties.at(i).second; } - const bool isPropSet(const std::string& key) const{ bool isSet=true; - auto f= m_properties.find(key); + std::map<std::string, std::string>::const_iterator f=m_properties.find(key); if(f==m_properties.end()){ isSet=false; } return isSet; } + const std::string& getProp(const std::string& key, std::string& value) const{ - auto f= m_properties.find(key); + std::map<std::string, std::string>::const_iterator f=m_properties.find(key); if(f!=m_properties.end()){ value=f->second; } |