diff options
Diffstat (limited to 'contrib/native/client/src')
-rw-r--r-- | contrib/native/client/src/clientlib/CMakeLists.txt | 4 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/channel.cpp | 448 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/channel.hpp | 237 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/errmsgs.cpp | 3 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/errmsgs.hpp | 5 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/logger.hpp | 1 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/streamSocket.hpp | 218 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/userProperties.cpp | 4 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/wincert.ipp | 98 | ||||
-rw-r--r-- | contrib/native/client/src/include/drill/common.hpp | 19 |
10 files changed, 1031 insertions, 6 deletions
diff --git a/contrib/native/client/src/clientlib/CMakeLists.txt b/contrib/native/client/src/clientlib/CMakeLists.txt index 6124fc898..2270c91ee 100644 --- a/contrib/native/client/src/clientlib/CMakeLists.txt +++ b/contrib/native/client/src/clientlib/CMakeLists.txt @@ -19,6 +19,7 @@ # Drill Client library set (CLIENTLIB_SRC_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/channel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/decimalUtils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/drillConfig.cpp ${CMAKE_CURRENT_SOURCE_DIR}/drillClient.cpp @@ -39,6 +40,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../i include_directories(${PROTOBUF_INCLUDE_DIR}) include_directories(${Zookeeper_INCLUDE_DIRS}) include_directories(${SASL_INCLUDE_DIRS}) +include_directories("${OPENSSL_INCLUDE_DIR}") link_directories(/usr/local/lib) @@ -53,4 +55,4 @@ if(MSVC) endif() add_library(drillClient SHARED ${CLIENTLIB_SRC_FILES} ) -target_link_libraries(drillClient ${Boost_LIBRARIES} ${PROTOBUF_LIBRARY} ${Zookeeper_LIBRARIES} ${SASL_LIBRARIES} protomsgs y2038) +target_link_libraries(drillClient ${Boost_LIBRARIES} ${PROTOBUF_LIBRARY} ${Zookeeper_LIBRARIES} ${SASL_LIBRARIES} ${OPENSSL_LIBRARIES} protomsgs y2038) diff --git a/contrib/native/client/src/clientlib/channel.cpp b/contrib/native/client/src/clientlib/channel.cpp new file mode 100644 index 000000000..62ce976b1 --- /dev/null +++ b/contrib/native/client/src/clientlib/channel.cpp @@ -0,0 +1,448 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <boost/lexical_cast.hpp> +#include <boost/regex.hpp> +#include "drill/drillConfig.hpp" +#include "drill/drillError.hpp" +#include "drill/userProperties.hpp" +#include "channel.hpp" +#include "errmsgs.hpp" +#include "logger.hpp" +#include "utils.hpp" +#include "zookeeperClient.hpp" + +#include "GeneralRPC.pb.h" + +namespace Drill{ + +ConnectionEndpoint::ConnectionEndpoint(const char* connStr){ + m_connectString=connStr; + m_pError=NULL; +} + +ConnectionEndpoint::ConnectionEndpoint(const char* host, const char* port){ + m_host=host; + m_port=port; + m_protocol="drillbit"; // direct connection + m_pError=NULL; +} + +ConnectionEndpoint::~ConnectionEndpoint(){ + if(m_pError!=NULL){ + delete m_pError; m_pError=NULL; + } +} + +connectionStatus_t ConnectionEndpoint::getDrillbitEndpoint(){ + connectionStatus_t ret=CONN_SUCCESS; + if(!m_connectString.empty()){ + parseConnectString(); + if(m_protocol.empty()){ + return handleError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, "<invalid_string>")); + } + if(isZookeeperConnection()){ + if((ret=getDrillbitEndpointFromZk())!=CONN_SUCCESS){ + DRILL_LOG(LOG_INFO) << "Failed to get endpoint from zk" << std::endl; + return ret; + } + }else if(!this->isDirectConnection()){ + return handleError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, this->getProtocol().c_str())); + } + }else{ + if(m_host.empty() || m_port.empty()){ + return handleError(CONN_INVALID_INPUT, getMessage(ERR_CONN_NOCONNSTR)); + } + } + return ret; +} + +void ConnectionEndpoint::parseConnectString(){ + boost::regex connStrExpr("(.*)=(((.*):([0-9]+),?)+)(/.+)?"); + boost::cmatch matched; + + if(boost::regex_match(m_connectString.c_str(), matched, connStrExpr)){ + m_protocol.assign(matched[1].first, matched[1].second); + if(isDirectConnection()){ + m_host.assign(matched[4].first, matched[4].second); + m_port.assign(matched[5].first, matched[5].second); + }else { + // if the connection is to a zookeeper, + // we will get the host and the port only after connecting to the Zookeeper + m_host = ""; + m_port = ""; + } + m_hostPortStr.assign(matched[2].first, matched[2].second); + if(matched[6].matched) { + m_pathToDrill.assign(matched[6].first, matched[6].second); + } + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) + << "Conn str: "<< m_connectString + << "; protocol: " << m_protocol + << "; host: " << m_host + << "; port: " << m_port + << "; path to drill: " << m_pathToDrill + << std::endl;) + } else { + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Invalid connect string. Regexp did not match" << std::endl;) + } + return; +} + +bool ConnectionEndpoint::isDirectConnection(){ + assert(!m_protocol.empty()); + return (!strcmp(m_protocol.c_str(), "local") || !strcmp(m_protocol.c_str(), "drillbit")); +} + +bool ConnectionEndpoint::isZookeeperConnection(){ + assert(!m_protocol.empty()); + return (!strcmp(m_protocol.c_str(), "zk")); +} + +connectionStatus_t ConnectionEndpoint::getDrillbitEndpointFromZk(){ + ZookeeperClient zook(m_pathToDrill); + assert(!m_hostPortStr.empty()); + std::vector<std::string> drillbits; + if(zook.getAllDrillbits(m_hostPortStr.c_str(), drillbits)!=0){ + return handleError(CONN_ZOOKEEPER_ERROR, getMessage(ERR_CONN_ZOOKEEPER, zook.getError().c_str())); + } + if (drillbits.empty()){ + return handleError(CONN_FAILURE, getMessage(ERR_CONN_ZKNODBIT)); + } + Utils::shuffle(drillbits); + exec::DrillbitEndpoint endpoint; + int err = zook.getEndPoint(drillbits[drillbits.size() -1], endpoint);// get the last one in the list + if(!err){ + m_host=boost::lexical_cast<std::string>(endpoint.address()); + m_port=boost::lexical_cast<std::string>(endpoint.user_port()); + } + if(err){ + return handleError(CONN_ZOOKEEPER_ERROR, getMessage(ERR_CONN_ZOOKEEPER, zook.getError().c_str())); + } + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Choosing drillbit <" << (drillbits.size() - 1) + << ">. Selected " << endpoint.DebugString() << std::endl;) + zook.close(); + return CONN_SUCCESS; +} + +connectionStatus_t ConnectionEndpoint::handleError(connectionStatus_t status, std::string msg){ + DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg); + if(m_pError!=NULL){ delete m_pError; m_pError=NULL;} + m_pError=pErr; + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Error getting drillbit endpoint:" << pErr->msg << std::endl;) + return status; +} + +/**************************** + * Channel Context Factory + ****************************/ +ChannelContext* ChannelContextFactory::getChannelContext(channelType_t t, DrillUserProperties* props){ + ChannelContext* pChannelContext=NULL; + switch(t){ + case CHANNEL_TYPE_SOCKET: + pChannelContext=new ChannelContext(props); + break; +#if defined(IS_SSL_ENABLED) + case CHANNEL_TYPE_SSLSTREAM: { + + std::string protocol; + props->getProp(USERPROP_TLSPROTOCOL, protocol); + boost::asio::ssl::context::method tlsVersion = SSLChannelContext::getTlsVersion(protocol); + + std::string noVerifyCert; + props->getProp(USERPROP_DISABLE_CERTVERIFICATION, noVerifyCert); + boost::asio::ssl::context::verify_mode verifyMode = boost::asio::ssl::context::verify_peer; + if (noVerifyCert == "true") { + verifyMode = boost::asio::ssl::context::verify_none; + } + + pChannelContext = new SSLChannelContext(props, tlsVersion, verifyMode); + } + break; +#endif + default: + DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; + break; + } + return pChannelContext; +} + +/******************* + * ChannelFactory + * *****************/ +Channel* ChannelFactory::getChannel(channelType_t t, const char* connStr){ + Channel* pChannel=NULL; + switch(t){ + case CHANNEL_TYPE_SOCKET: + pChannel=new SocketChannel(connStr); + break; +#if defined(IS_SSL_ENABLED) + case CHANNEL_TYPE_SSLSTREAM: + pChannel=new SSLStreamChannel(connStr); + break; +#endif + default: + DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; + break; + } + return pChannel; +} + +Channel* ChannelFactory::getChannel(channelType_t t, const char* host, const char* port){ + Channel* pChannel=NULL; + switch(t){ + case CHANNEL_TYPE_SOCKET: + pChannel=new SocketChannel(host, port); + break; +#if defined(IS_SSL_ENABLED) + case CHANNEL_TYPE_SSLSTREAM: + pChannel=new SSLStreamChannel(host, port); + break; +#endif + default: + DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; + break; + } + return pChannel; +} + +Channel* ChannelFactory::getChannel(channelType_t t, boost::asio::io_service& ioService, const char* connStr){ + Channel* pChannel=NULL; + switch(t){ + case CHANNEL_TYPE_SOCKET: + pChannel=new SocketChannel(ioService, connStr); + break; +#if defined(IS_SSL_ENABLED) + case CHANNEL_TYPE_SSLSTREAM: + pChannel=new SSLStreamChannel(ioService, connStr); + break; +#endif + default: + DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; + break; + } + return pChannel; +} + +Channel* ChannelFactory::getChannel(channelType_t t, boost::asio::io_service& ioService, const char* host, const char* port){ + Channel* pChannel=NULL; + switch(t){ + case CHANNEL_TYPE_SOCKET: + pChannel=new SocketChannel(ioService, host, port); + break; +#if defined(IS_SSL_ENABLED) + case CHANNEL_TYPE_SSLSTREAM: + pChannel=new SSLStreamChannel(ioService, host, port); + break; +#endif + default: + DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; + break; + } + return pChannel; +} + +/******************* + * Channel + * *****************/ + +Channel::Channel(const char* connStr) : m_ioService(m_ioServiceFallback){ + m_pEndpoint=new ConnectionEndpoint(connStr); + m_ownIoService = true; + m_pSocket=NULL; + m_state=CHANNEL_UNINITIALIZED; + m_pError=NULL; +} + +Channel::Channel(const char* host, const char* port) : m_ioService(m_ioServiceFallback){ + m_pEndpoint=new ConnectionEndpoint(host, port); + m_ownIoService = true; + m_pSocket=NULL; + m_state=CHANNEL_UNINITIALIZED; + m_pError=NULL; +} + +Channel::Channel(boost::asio::io_service& ioService, const char* connStr):m_ioService(ioService){ + m_pEndpoint=new ConnectionEndpoint(connStr); + m_ownIoService = false; + m_pSocket=NULL; + m_state=CHANNEL_UNINITIALIZED; + m_pError=NULL; +} + +Channel::Channel(boost::asio::io_service& ioService, const char* host, const char* port) : m_ioService(ioService){ + m_pEndpoint=new ConnectionEndpoint(host, port); + m_ownIoService = true; + m_pSocket=NULL; + m_state=CHANNEL_UNINITIALIZED; + m_pError=NULL; +} + +Channel::~Channel(){ + if(m_pEndpoint!=NULL){ + delete m_pEndpoint; m_pEndpoint=NULL; + } + if(m_pSocket!=NULL){ + delete m_pSocket; m_pSocket=NULL; + } + if(m_pError!=NULL){ + delete m_pError; m_pError=NULL; + } +} + +template <typename SettableSocketOption> void Channel::setOption(SettableSocketOption& option){ + //May be useful some day. + //At the moment, we only need to set some well known options after we connect. + assert(0); +} + +connectionStatus_t Channel::init(ChannelContext_t* pContext){ + connectionStatus_t ret=CONN_SUCCESS; + this->m_state=CHANNEL_INITIALIZED; + this->m_pContext = pContext; + return ret; +} + +connectionStatus_t Channel::connect(){ + connectionStatus_t ret=CONN_FAILURE; + if(this->m_state==CHANNEL_INITIALIZED){ + ret=m_pEndpoint->getDrillbitEndpoint(); + if(ret==CONN_SUCCESS){ + DRILL_LOG(LOG_TRACE) << "Connecting to drillbit: " + << m_pEndpoint->getHost() + << ":" << m_pEndpoint->getPort() + << "." << std::endl; + ret=this->connectInternal(); + }else{ + handleError(ret, m_pEndpoint->getError()->msg); + } + } + this->m_state=(ret==CONN_SUCCESS)?CHANNEL_CONNECTED:this->m_state; + return ret; +} + +connectionStatus_t Channel::handleError(connectionStatus_t status, std::string msg){ + DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg); + if(m_pError!=NULL){ delete m_pError; m_pError=NULL;} + m_pError=pErr; + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Error connecting:" << pErr->msg << std::endl;) + return status; +} + +connectionStatus_t Channel::connectInternal() { + using boost::asio::ip::tcp; + tcp::endpoint endpoint; + const char *host = m_pEndpoint->getHost().c_str(); + const char *port = m_pEndpoint->getPort().c_str(); + try { + tcp::resolver resolver(m_ioService); + tcp::resolver::query query(tcp::v4(), host, port); + tcp::resolver::iterator iter = resolver.resolve(query); + tcp::resolver::iterator end; + while (iter != end) { + endpoint = *iter++; + DRILL_LOG(LOG_TRACE) << endpoint << std::endl; + } + boost::system::error_code ec; + m_pSocket->getInnerSocket().connect(endpoint, ec); + if (ec) { + return handleError(CONN_FAILURE, getMessage(ERR_CONN_FAILURE, host, port, ec.message().c_str())); + } + } catch (boost::system::system_error e) { + // Handle case when the hostname cannot be resolved. "resolve" is hard-coded in boost asio resolver.resolve + if (!strncmp(e.what(), "resolve", 7)) { + return handleError(CONN_HOSTNAME_RESOLUTION_ERROR, getMessage(ERR_CONN_EXCEPT, e.what())); + } + } catch (std::exception e) { + return handleError(CONN_FAILURE, getMessage(ERR_CONN_EXCEPT, e.what())); + } + + // set socket keep alive + boost::asio::socket_base::keep_alive keepAlive(true); + m_pSocket->getInnerSocket().set_option(keepAlive); + // set no_delay + boost::asio::ip::tcp::no_delay noDelay(true); + m_pSocket->getInnerSocket().set_option(noDelay); + // set reuse addr + boost::asio::socket_base::reuse_address reuseAddr(true); + m_pSocket->getInnerSocket().set_option(reuseAddr); + + std::string useSystemTrustStore; + m_pContext->getUserProperties()->getProp(USERPROP_USESYSTEMTRUSTSTORE, useSystemTrustStore); + DRILL_LOG(LOG_TRACE) << "Connected" << std::endl; + return this->protocolHandshake(useSystemTrustStore=="true"); + +} + +connectionStatus_t SocketChannel::init(ChannelContext_t* pContext){ + connectionStatus_t ret=CONN_SUCCESS; + m_pSocket=new Socket(m_ioService); + if(m_pSocket!=NULL){ + ret=Channel::init(pContext); + }else{ + DRILL_LOG(LOG_ERROR) << "Channel initialization failure. " << std::endl; + handleError(CONN_NOSOCKET, getMessage(ERR_CONN_NOSOCKET)); + ret=CONN_FAILURE; + } + return ret; +} + +#if defined(IS_SSL_ENABLED) +connectionStatus_t SSLStreamChannel::init(ChannelContext_t* pContext){ + connectionStatus_t ret=CONN_SUCCESS; + + const DrillUserProperties* props = pContext->getUserProperties(); + std::string useSystemTrustStore; + props->getProp(USERPROP_USESYSTEMTRUSTSTORE, useSystemTrustStore); + if (useSystemTrustStore != "true"){ + std::string certFile; + props->getProp(USERPROP_CERTFILEPATH, certFile); + try{ + ((SSLChannelContext_t*)pContext)->getSslContext().load_verify_file(certFile); + } + catch (boost::system::system_error e){ + DRILL_LOG(LOG_ERROR) << "Channel initialization failure. Certificate file " + << certFile + << " could not be loaded." + << std::endl; + handleError(CONN_SSLERROR, getMessage(ERR_CONN_SSLCERTFAIL, certFile.c_str(), e.what())); + ret = CONN_FAILURE; + } + } + + std::string disableHostVerification; + props->getProp(USERPROP_DISABLE_HOSTVERIFICATION, disableHostVerification); + if (disableHostVerification != "true") { + std::string hostPortStr = m_pEndpoint->getHost() + ":" + m_pEndpoint->getPort(); + ((SSLChannelContext_t *) pContext)->getSslContext().set_verify_callback( + boost::asio::ssl::rfc2818_verification(hostPortStr.c_str())); + } + + m_pSocket=new SslSocket(m_ioService, ((SSLChannelContext_t*)pContext)->getSslContext() ); + if(m_pSocket!=NULL){ + ret=Channel::init(pContext); + }else{ + DRILL_LOG(LOG_ERROR) << "Channel initialization failure. " << std::endl; + handleError(CONN_NOSOCKET, getMessage(ERR_CONN_NOSOCKET)); + ret=CONN_FAILURE; + } + return ret; +} +#endif + +} // namespace Drill diff --git a/contrib/native/client/src/clientlib/channel.hpp b/contrib/native/client/src/clientlib/channel.hpp new file mode 100644 index 000000000..7f310e899 --- /dev/null +++ b/contrib/native/client/src/clientlib/channel.hpp @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CHANNEL_HPP +#define CHANNEL_HPP + +#include "drill/common.hpp" +#include "drill/drillClient.hpp" +#include "streamSocket.hpp" + +namespace Drill { + +class UserProperties; + + class ConnectionEndpoint{ + public: + ConnectionEndpoint(const char* connStr); + ConnectionEndpoint(const char* host, const char* port); + ~ConnectionEndpoint(); + + //parse the connection string and set up the host and port to connect to + connectionStatus_t getDrillbitEndpoint(); + + std::string& getProtocol(){return m_protocol;} + std::string& getHost(){return m_host;} + std::string& getPort(){return m_port;} + DrillClientError* getError(){ return m_pError;}; + + private: + void parseConnectString(); + bool isDirectConnection(); + bool isZookeeperConnection(); + connectionStatus_t getDrillbitEndpointFromZk(); + connectionStatus_t handleError(connectionStatus_t status, std::string msg); + + std::string m_connectString; + std::string m_pathToDrill; + std::string m_protocol; + std::string m_hostPortStr; + std::string m_host; + std::string m_port; + + DrillClientError* m_pError; + + }; + + class ChannelContext{ + public: + ChannelContext(DrillUserProperties* props):m_properties(props){}; + virtual ~ChannelContext(){}; + const DrillUserProperties* getUserProperties() const { return m_properties;} + protected: + DrillUserProperties* m_properties; + }; + + class SSLChannelContext: public ChannelContext{ + public: + static boost::asio::ssl::context::method getTlsVersion(std::string version){ + if(version.empty()){ + return boost::asio::ssl::context::tlsv12; + } else if (version == "tlsv12") { + return boost::asio::ssl::context::tlsv12; + } else if (version == "tlsv11") { + return boost::asio::ssl::context::tlsv11; + } else if (version == "sslv23") { + return boost::asio::ssl::context::sslv23; + } else if (version == "tlsv1") { + return boost::asio::ssl::context::tlsv1; + } else if (version == "sslv3") { + return boost::asio::ssl::context::sslv3; + } else { + return boost::asio::ssl::context::tlsv12; + } + } + + SSLChannelContext(DrillUserProperties *props, boost::asio::ssl::context::method tlsVersion, boost::asio::ssl::verify_mode verifyMode) : + ChannelContext(props), + m_SSLContext(tlsVersion) { + m_SSLContext.set_default_verify_paths(); + m_SSLContext.set_options( + boost::asio::ssl::context::default_workarounds + | boost::asio::ssl::context::no_sslv2 + | boost::asio::ssl::context::single_dh_use + ); + m_SSLContext.set_verify_mode(verifyMode); + }; + ~SSLChannelContext(){}; + boost::asio::ssl::context& getSslContext(){ return m_SSLContext;} + private: + boost::asio::ssl::context m_SSLContext; + }; + + typedef ChannelContext ChannelContext_t; + typedef SSLChannelContext SSLChannelContext_t; + + class ChannelContextFactory{ + public: + static ChannelContext_t* getChannelContext(channelType_t t, DrillUserProperties* props); + }; + + /*** + * The Channel class encapsulates a connection to a drillbit. Based on + * the connection string and the options, the connection will be either + * a simple socket or a socket using an ssl stream. The class also encapsulates + * connecting to a drillbit directly or thru zookeeper. + * The channel class owns the socket and the io_service that the applications + * will use to communicate with the server. + ***/ + class Channel{ + public: + Channel(const char* connStr); + Channel(const char* host, const char* port); + Channel(boost::asio::io_service& ioService, const char* connStr); + Channel(boost::asio::io_service& ioService, const char* host, const char* port); + virtual ~Channel(); + virtual connectionStatus_t init(ChannelContext_t* context)=0; + connectionStatus_t connect(); + bool isConnected(){ return m_state == CHANNEL_CONNECTED;} + template <typename SettableSocketOption> void setOption(SettableSocketOption& option); + DrillClientError* getError(){ return m_pError;} + void close(){ + if(m_state==CHANNEL_INITIALIZED||m_state==CHANNEL_CONNECTED){ + m_pSocket->protocolClose(); + m_state=CHANNEL_CLOSED; + } + } // Not OK to use the channel after this call. + + boost::asio::io_service& getIOService(){ + return m_ioService; + } + + // returns a reference to the underlying socket + // This access should really be removed and encapsulated in calls that + // manage async_send and async_recv + // Until then we will let DrillClientImpl have direct access + streamSocket_t& getInnerSocket(){ + return m_pSocket->getInnerSocket(); + } + + AsioStreamSocket& getSocketStream(){ + return *m_pSocket; + } + + ConnectionEndpoint* getEndpoint(){return m_pEndpoint;} + + protected: + connectionStatus_t handleError(connectionStatus_t status, std::string msg); + + boost::asio::io_service& m_ioService; + boost::asio::io_service m_ioServiceFallback; // used if m_ioService is not provided + AsioStreamSocket* m_pSocket; + ConnectionEndpoint *m_pEndpoint; + ChannelContext_t *m_pContext; + + private: + + typedef enum channelState{ + CHANNEL_UNINITIALIZED=1, + CHANNEL_INITIALIZED, + CHANNEL_CONNECTED, + CHANNEL_CLOSED + } channelState_t; + + connectionStatus_t connectInternal(); + connectionStatus_t protocolHandshake(bool useSystemConfig){ + connectionStatus_t status = CONN_SUCCESS; + try{ + m_pSocket->protocolHandshake(useSystemConfig); + } catch (boost::system::system_error e) { + status = handleError(CONN_HANDSHAKE_FAILED, e.what()); + } + return status; + } + + channelState_t m_state; + DrillClientError* m_pError; + bool m_ownIoService; + }; + + class SocketChannel: public Channel{ + public: + SocketChannel(const char* connStr):Channel(connStr){ + } + SocketChannel(const char* host, const char* port):Channel(host, port){ + } + SocketChannel(boost::asio::io_service& ioService, const char* connStr) + :Channel(ioService, connStr){ + } + SocketChannel(boost::asio::io_service& ioService, const char* host, const char* port) + :Channel(ioService, host, port){ + } + connectionStatus_t init(ChannelContext_t* context=NULL); + }; + + class SSLStreamChannel: public Channel{ + public: + SSLStreamChannel(const char* connStr):Channel(connStr){ + } + SSLStreamChannel(const char* host, const char* port):Channel(host, port){ + } + SSLStreamChannel(boost::asio::io_service& ioService, const char* connStr) + :Channel(ioService, connStr){ + } + SSLStreamChannel(boost::asio::io_service& ioService, const char* host, const char* port) + :Channel(ioService, host, port){ + } + connectionStatus_t init(ChannelContext_t* context); + }; + + class ChannelFactory{ + public: + static Channel* getChannel(channelType_t t, const char* connStr); + static Channel* getChannel(channelType_t t, const char* host, const char* port); + static Channel* getChannel(channelType_t t, boost::asio::io_service& ioService, const char* connStr); + static Channel* getChannel(channelType_t t, boost::asio::io_service& ioService, const char* host, const char* port); + }; + + +} // namespace Drill + +#endif // CHANNEL_HPP + diff --git a/contrib/native/client/src/clientlib/errmsgs.cpp b/contrib/native/client/src/clientlib/errmsgs.cpp index 15d2256fb..d2d8c1806 100644 --- a/contrib/native/client/src/clientlib/errmsgs.cpp +++ b/contrib/native/client/src/clientlib/errmsgs.cpp @@ -49,6 +49,9 @@ static Drill::ErrorMessages errorMessages[]={ {ERR_CONN_UNKNOWN_ERR, ERR_CATEGORY_CONN, 0, "Handshake Failed due to an error on the server. [Server message was: (%s) %s]"}, {ERR_CONN_NOCONN, ERR_CATEGORY_CONN, 0, "There is no connection to the server."}, {ERR_CONN_ALREADYCONN, ERR_CATEGORY_CONN, 0, "This client is already connected to a server."}, + {ERR_CONN_NOCONNSTR, ERR_CATEGORY_CONN, 0, "Cannot connect if either host name or port number are empty."}, + {ERR_CONN_SSLCERTFAIL, ERR_CATEGORY_CONN, 0, "SSL certificate file %s could not be loaded (exception message: %s)."}, + {ERR_CONN_NOSOCKET, ERR_CATEGORY_CONN, 0, "Failed to open socket connection."}, {ERR_QRY_OUTOFMEM, ERR_CATEGORY_QRY, 0, "Out of memory."}, {ERR_QRY_COMMERR, ERR_CATEGORY_QRY, 0, "Communication error. %s"}, {ERR_QRY_INVREADLEN, ERR_CATEGORY_QRY, 0, "Internal Error: Received a message with an invalid read length."}, diff --git a/contrib/native/client/src/clientlib/errmsgs.hpp b/contrib/native/client/src/clientlib/errmsgs.hpp index cfb56a6b0..246e4bbf2 100644 --- a/contrib/native/client/src/clientlib/errmsgs.hpp +++ b/contrib/native/client/src/clientlib/errmsgs.hpp @@ -51,7 +51,10 @@ namespace Drill{ #define ERR_CONN_UNKNOWN_ERR DRILL_ERR_START+18 #define ERR_CONN_NOCONN DRILL_ERR_START+19 #define ERR_CONN_ALREADYCONN DRILL_ERR_START+20 -#define ERR_CONN_MAX DRILL_ERR_START+20 +#define ERR_CONN_NOCONNSTR DRILL_ERR_START+21 +#define ERR_CONN_SSLCERTFAIL DRILL_ERR_START+22 +#define ERR_CONN_NOSOCKET DRILL_ERR_START+23 +#define ERR_CONN_MAX DRILL_ERR_START+23 #define ERR_QRY_OUTOFMEM ERR_CONN_MAX+1 #define ERR_QRY_COMMERR ERR_CONN_MAX+2 diff --git a/contrib/native/client/src/clientlib/logger.hpp b/contrib/native/client/src/clientlib/logger.hpp index 7baf50c41..966e3a1d3 100644 --- a/contrib/native/client/src/clientlib/logger.hpp +++ b/contrib/native/client/src/clientlib/logger.hpp @@ -21,6 +21,7 @@ #include <sstream> #include <ostream> +#include <iostream> #include <fstream> #include <string> #include <stdio.h> diff --git a/contrib/native/client/src/clientlib/streamSocket.hpp b/contrib/native/client/src/clientlib/streamSocket.hpp new file mode 100644 index 000000000..5db4dcaa9 --- /dev/null +++ b/contrib/native/client/src/clientlib/streamSocket.hpp @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef STREAMSOCKET_HPP +#define STREAMSOCKET_HPP + +#include "logger.hpp" +#include "wincert.ipp" + +#include <boost/asio.hpp> +#include <boost/asio/ssl.hpp> + +namespace Drill { + +typedef boost::asio::ip::tcp::socket::lowest_layer_type streamSocket_t; +typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket> sslTCPSocket_t; +typedef boost::asio::ip::tcp::socket basicTCPSocket_t; + + +// Some helper typedefs to define the highly templatized boost::asio methods +typedef boost::asio::const_buffers_1 ConstBufferSequence; +typedef boost::asio::mutable_buffers_1 MutableBufferSequence; + +// ReadHandlers have different possible signatures. +// +// As a standard C-type callback +// typedef void (*ReadHandler)(const boost::system::error_code& ec, std::size_t bytes_transferred); +// +// Or as a C++ functor +// struct ReadHandler { +// virtual void operator()( +// const boost::system::error_code& ec, +// std::size_t bytes_transferred) = 0; +//}; +// +// We need a different signature though, since we need to pass in a member function of a drill client +// class (which is C++), as a functor generated by boost::bind as a ReadHandler +// +typedef boost::function<void (const boost::system::error_code& ec, std::size_t bytes_transferred) > ReadHandler; + +class AsioStreamSocket{ + public: + virtual ~AsioStreamSocket(){}; + virtual streamSocket_t& getInnerSocket() = 0; + + virtual std::size_t writeSome( + const ConstBufferSequence& buffers, + boost::system::error_code & ec) = 0; + + virtual std::size_t readSome( + const MutableBufferSequence& buffers, + boost::system::error_code & ec) = 0; + + // + // boost::asio::async_read has the signature + // template< + // typename AsyncReadStream, + // typename MutableBufferSequence, + // typename ReadHandler> + // void-or-deduced async_read( + // AsyncReadStream & s, + // const MutableBufferSequence & buffers, + // ReadHandler handler); + // + // For our use case, the derived class will have an instance of a concrete type for AsyncReadStream which + // will implement the requirements for the AsyncReadStream type. We need not pass that in as a parameter + // since the class already has the value + // The method is templatized since the ReadHandler type is dependent on the class implementing the read + // handler (basically the class using the asio stream) + // + virtual void asyncRead( const MutableBufferSequence & buffers, ReadHandler handler) = 0; + + // call the underlying protocol's handshake method. + // if the useSystemConfig flag is true, then use properties read + // from the underlying operating system + virtual void protocolHandshake(bool useSystemConfig) = 0; + virtual void protocolClose() = 0; +}; + +class Socket: + public AsioStreamSocket, + public basicTCPSocket_t{ + + public: + Socket(boost::asio::io_service& ioService) : basicTCPSocket_t(ioService) { + } + + ~Socket(){ + boost::system::error_code ignorederr; + this->shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignorederr); + this->close(); + }; + + basicTCPSocket_t& getSocketStream(){ return *this;} + + streamSocket_t& getInnerSocket(){ return this->lowest_layer();} + + std::size_t writeSome( + const ConstBufferSequence& buffers, + boost::system::error_code & ec){ + return this->write_some(buffers, ec); + } + + std::size_t readSome( + const MutableBufferSequence& buffers, + boost::system::error_code & ec){ + return this->read_some(buffers, ec); + } + + void asyncRead( const MutableBufferSequence & buffers, ReadHandler handler){ + return async_read(*this, buffers, handler); + } + + void protocolHandshake(bool useSystemConfig){}; //nothing to do + void protocolClose(){ + // shuts down the socket! + boost::system::error_code ignorederr; + ((basicTCPSocket_t*)this)->shutdown(boost::asio::ip::tcp::socket::shutdown_both, + ignorederr + ); + } +}; + + +#if defined(IS_SSL_ENABLED) + +class SslSocket: + public AsioStreamSocket, + public sslTCPSocket_t{ + + public: + SslSocket(boost::asio::io_service& ioService, boost::asio::ssl::context &sslContext) : + sslTCPSocket_t(ioService, sslContext) { + } + + ~SslSocket(){ + this->lowest_layer().close(); + }; + + sslTCPSocket_t& getSocketStream(){ return *this;} + + streamSocket_t& getInnerSocket(){ return this->lowest_layer();} + + std::size_t writeSome( + const ConstBufferSequence& buffers, + boost::system::error_code & ec){ + return this->write_some(buffers, ec); + } + + std::size_t readSome( + const MutableBufferSequence& buffers, + boost::system::error_code & ec){ + return this->read_some(buffers, ec); + } + + void asyncRead( const MutableBufferSequence & buffers, ReadHandler handler){ + return async_read(*this, buffers, handler); + } + + // + // public method that can be invoked by callers to invoke the ssl handshake + // throws: boost::system::system_error + void protocolHandshake(bool useSystemConfig){ + if(useSystemConfig){ + std::string msg = ""; + int ret = loadSystemTrustStore(this->native_handle(), msg); + if(!msg.empty()){ + DRILL_LOG(LOG_WARNING) << msg.c_str() << std::endl; + } + if(ret){ + boost::system::error_code ec(EPROTO, boost::system::system_category()); + boost::asio::detail::throw_error(ec, msg.c_str()); + } + } + this->handshake(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>::client); + return; + }; + // + // public method that can be invoked by callers to invoke a clean ssl shutdown + // throws: boost::system::system_error + void protocolClose(){ + try{ + this->shutdown(); + }catch(boost::system::system_error e){ + //swallow the exception. The channel is unusable anyway + } + // shuts down the socket! + boost::system::error_code ignorederr; + this->lowest_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_both, + ignorederr + ); + return; + }; + +}; +#endif + + +} // namespace Drill + +#endif //STREAMSOCKET_HPP + diff --git a/contrib/native/client/src/clientlib/userProperties.cpp b/contrib/native/client/src/clientlib/userProperties.cpp index 07497ef55..666f587c5 100644 --- a/contrib/native/client/src/clientlib/userProperties.cpp +++ b/contrib/native/client/src/clientlib/userProperties.cpp @@ -31,7 +31,11 @@ const std::map<std::string, uint32_t> DrillUserProperties::USER_PROPERTIES=boos ( USERPROP_SERVICE_NAME, USERPROP_FLAGS_STRING) ( USERPROP_SERVICE_HOST, USERPROP_FLAGS_STRING) ( USERPROP_USESSL, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP) + ( USERPROP_TLSPROTOCOL, USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP) ( USERPROP_CERTFILEPATH, USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP|USERPROP_FLAGS_FILEPATH) + ( USERPROP_DISABLE_HOSTVERIFICATION, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP) + ( USERPROP_DISABLE_CERTVERIFICATION, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP) + ( USERPROP_USESYSTEMTRUSTSTORE, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP) ( USERPROP_SASL_ENCRYPT, USERPROP_FLAGS_STRING) ; diff --git a/contrib/native/client/src/clientlib/wincert.ipp b/contrib/native/client/src/clientlib/wincert.ipp new file mode 100644 index 000000000..c1af70a36 --- /dev/null +++ b/contrib/native/client/src/clientlib/wincert.ipp @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(IS_SSL_ENABLED) + +#include <openssl/x509.h> +#include <openssl/ssl.h> + +#if defined _WIN32 || defined _WIN64 + +#include <stdio.h> +#include <windows.h> +#include <wincrypt.h> +#include <cryptuiapi.h> +#include <iostream> +#include <tchar.h> + + +#pragma comment (lib, "crypt32.lib") +#pragma comment (lib, "cryptui.lib") + +#define MY_ENCODING_TYPE (PKCS_7_ASN_ENCODING | X509_ASN_ENCODING) + +inline +int loadSystemTrustStore(const SSL *ssl, std::string& msg) { + HCERTSTORE hStore; + PCCERT_CONTEXT pContext = NULL; + X509 *x509; + char* stores[] = { + "CA", + "MY", + "ROOT", + "SPC" + }; + int certCount=0; + + SSL_CTX * ctx = SSL_get_SSL_CTX(ssl); + X509_STORE *store = SSL_CTX_get_cert_store(ctx); + + for(int i=0; i<4; i++){ + hStore = CertOpenSystemStore(NULL, stores[i]); + + if (!hStore){ + msg.append("Failed to load store: ").append(stores[i]).append("\n"); + continue; + } + + while (pContext = CertEnumCertificatesInStore(hStore, pContext)) { + //uncomment the line below if you want to see the certificates as pop ups + //CryptUIDlgViewContext(CERT_STORE_CERTIFICATE_CONTEXT, pContext, NULL, NULL, 0, NULL); + + x509 = NULL; + x509 = d2i_X509(NULL, (const unsigned char **)&pContext->pbCertEncoded, pContext->cbCertEncoded); + if (x509) { + int ret = X509_STORE_add_cert(store, x509); + + //if (ret == 1) + // std::cout << "Added certificate " << x509->name << " from " << stores[i] << std::endl; + + X509_free(x509); + certCount++; + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + } + if(certCount==0){ + msg.append("No certificates found."); + return -1; + } + return 0; +} + +#else // notwindows +inline +int loadSystemTrustStore(const SSL *ssl, std::string& msg) { + return 0; +} + +#endif // WIN32 or WIN64 + +#endif // SSL_ENABLED diff --git a/contrib/native/client/src/include/drill/common.hpp b/contrib/native/client/src/include/drill/common.hpp index 5401c75a9..012bd1974 100644 --- a/contrib/native/client/src/include/drill/common.hpp +++ b/contrib/native/client/src/include/drill/common.hpp @@ -106,6 +106,11 @@ class AllocatedBuffer; typedef AllocatedBuffer* AllocatedBufferPtr; typedef enum{ + CHANNEL_TYPE_SOCKET=1, + CHANNEL_TYPE_SSLSTREAM=2 +} channelType_t; + +typedef enum{ QRY_SUCCESS=0, QRY_FAILURE=1, QRY_SUCCESS_WITH_INFO=2, @@ -136,7 +141,9 @@ typedef enum{ CONN_BAD_RPC_VER=8, CONN_DEAD=9, CONN_NOTCONNECTED=10, - CONN_ALREADYCONNECTED=11 + CONN_ALREADYCONNECTED=11, + CONN_SSLERROR=12, + CONN_NOSOCKET=13 } connectionStatus_t; typedef enum{ @@ -163,9 +170,13 @@ typedef enum{ #define USERPROP_USERNAME "userName" #define USERPROP_PASSWORD "password" #define USERPROP_SCHEMA "schema" -#define USERPROP_USESSL "useSSL" // Not implemented yet -#define USERPROP_FILEPATH "pemLocation" // Not implemented yet -#define USERPROP_FILENAME "pemFile" // Not implemented yet +#define USERPROP_USESSL "enableTLS" +#define USERPROP_TLSPROTOCOL "TLSProtocol" //TLS version +#define USERPROP_CERTFILEPATH "certFilePath" // pem file path and name +#define USERPROP_CERTPASSWORD "certPassword" // Password for certificate file +#define USERPROP_DISABLE_HOSTVERIFICATION "disableHostVerification" +#define USERPROP_DISABLE_CERTVERIFICATION "disableCertVerification" +#define USERPROP_USESYSTEMTRUSTSTORE "useSystemTrustStore" //Windows only, use the system trust store #define USERPROP_IMPERSONATION_TARGET "impersonation_target" #define USERPROP_AUTH_MECHANISM "auth" #define USERPROP_SERVICE_NAME "service_name" |