diff options
author | Parth Chandra <parthc@apache.org> | 2017-07-25 09:22:23 -0700 |
---|---|---|
committer | Parth Chandra <parthc@apache.org> | 2017-10-11 19:27:48 -0700 |
commit | facbb92ba319373dd8b8baa171ac1d7978c926c5 (patch) | |
tree | 95b2f1b102047d4244d9f9a302779e1d11a81a9a /contrib/native/client | |
parent | 414c94ee7e088c89b0c7c8b9d9dde9335cfcbe6d (diff) |
DRILL-5431: SSL Support (C++) - Add (Netty like) socket abstraction that encapsulates a TCP socket or a SSL Stream on TCP.
The testSSL program tests the client connection against a drillbit by sending a drill handshake.
Diffstat (limited to 'contrib/native/client')
-rw-r--r-- | contrib/native/client/CMakeLists.txt | 13 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/CMakeLists.txt | 4 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/channel.cpp | 448 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/channel.hpp | 237 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/errmsgs.cpp | 3 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/errmsgs.hpp | 5 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/logger.hpp | 1 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/streamSocket.hpp | 218 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/userProperties.cpp | 4 | ||||
-rw-r--r-- | contrib/native/client/src/clientlib/wincert.ipp | 98 | ||||
-rw-r--r-- | contrib/native/client/src/include/drill/common.hpp | 19 | ||||
-rw-r--r-- | contrib/native/client/test/ssl/testSSL.cpp | 384 |
12 files changed, 1428 insertions, 6 deletions
diff --git a/contrib/native/client/CMakeLists.txt b/contrib/native/client/CMakeLists.txt index ddb151917..0c104abc0 100644 --- a/contrib/native/client/CMakeLists.txt +++ b/contrib/native/client/CMakeLists.txt @@ -17,6 +17,7 @@ # cmake_minimum_required(VERSION 2.6) +project(drillclient) cmake_policy(SET CMP0043 NEW) cmake_policy(SET CMP0048 NEW) enable_testing() @@ -125,6 +126,12 @@ endif() find_package(Protobuf REQUIRED ) include_directories(${PROTOBUF_INCLUDE_DIR}) +#Find SSL +find_package(OpenSSL REQUIRED ) +if(OPENSSL_FOUND) + add_definitions("-DIS_SSL_ENABLED=1") +endif() + #Find Zookeeper find_package(Zookeeper REQUIRED ) @@ -170,6 +177,12 @@ set_property( # Link directory link_directories(/usr/local/lib) +#test programs +add_subdirectory("${CMAKE_SOURCE_DIR}/test") +message("Open SSL Include = ${OPENSSL_INCLUDE_DIR}") +message("Open SSL Libraries = ${OPENSSL_LIBRARIES}") +message("Open SSL = ${OPENSSL_ROOT_DIR}") + add_executable(querySubmitter example/querySubmitter.cpp ) target_link_libraries(querySubmitter ${Boost_LIBRARIES} ${PROTOBUF_LIBRARY} drillClient protomsgs y2038) diff --git a/contrib/native/client/src/clientlib/CMakeLists.txt b/contrib/native/client/src/clientlib/CMakeLists.txt index 6124fc898..2270c91ee 100644 --- a/contrib/native/client/src/clientlib/CMakeLists.txt +++ b/contrib/native/client/src/clientlib/CMakeLists.txt @@ -19,6 +19,7 @@ # Drill Client library set (CLIENTLIB_SRC_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/channel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/decimalUtils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/drillConfig.cpp ${CMAKE_CURRENT_SOURCE_DIR}/drillClient.cpp @@ -39,6 +40,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../i include_directories(${PROTOBUF_INCLUDE_DIR}) include_directories(${Zookeeper_INCLUDE_DIRS}) include_directories(${SASL_INCLUDE_DIRS}) +include_directories("${OPENSSL_INCLUDE_DIR}") link_directories(/usr/local/lib) @@ -53,4 +55,4 @@ if(MSVC) endif() add_library(drillClient SHARED ${CLIENTLIB_SRC_FILES} ) -target_link_libraries(drillClient ${Boost_LIBRARIES} ${PROTOBUF_LIBRARY} ${Zookeeper_LIBRARIES} ${SASL_LIBRARIES} protomsgs y2038) +target_link_libraries(drillClient ${Boost_LIBRARIES} ${PROTOBUF_LIBRARY} ${Zookeeper_LIBRARIES} ${SASL_LIBRARIES} ${OPENSSL_LIBRARIES} protomsgs y2038) diff --git a/contrib/native/client/src/clientlib/channel.cpp b/contrib/native/client/src/clientlib/channel.cpp new file mode 100644 index 000000000..62ce976b1 --- /dev/null +++ b/contrib/native/client/src/clientlib/channel.cpp @@ -0,0 +1,448 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <boost/lexical_cast.hpp> +#include <boost/regex.hpp> +#include "drill/drillConfig.hpp" +#include "drill/drillError.hpp" +#include "drill/userProperties.hpp" +#include "channel.hpp" +#include "errmsgs.hpp" +#include "logger.hpp" +#include "utils.hpp" +#include "zookeeperClient.hpp" + +#include "GeneralRPC.pb.h" + +namespace Drill{ + +ConnectionEndpoint::ConnectionEndpoint(const char* connStr){ + m_connectString=connStr; + m_pError=NULL; +} + +ConnectionEndpoint::ConnectionEndpoint(const char* host, const char* port){ + m_host=host; + m_port=port; + m_protocol="drillbit"; // direct connection + m_pError=NULL; +} + +ConnectionEndpoint::~ConnectionEndpoint(){ + if(m_pError!=NULL){ + delete m_pError; m_pError=NULL; + } +} + +connectionStatus_t ConnectionEndpoint::getDrillbitEndpoint(){ + connectionStatus_t ret=CONN_SUCCESS; + if(!m_connectString.empty()){ + parseConnectString(); + if(m_protocol.empty()){ + return handleError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, "<invalid_string>")); + } + if(isZookeeperConnection()){ + if((ret=getDrillbitEndpointFromZk())!=CONN_SUCCESS){ + DRILL_LOG(LOG_INFO) << "Failed to get endpoint from zk" << std::endl; + return ret; + } + }else if(!this->isDirectConnection()){ + return handleError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, this->getProtocol().c_str())); + } + }else{ + if(m_host.empty() || m_port.empty()){ + return handleError(CONN_INVALID_INPUT, getMessage(ERR_CONN_NOCONNSTR)); + } + } + return ret; +} + +void ConnectionEndpoint::parseConnectString(){ + boost::regex connStrExpr("(.*)=(((.*):([0-9]+),?)+)(/.+)?"); + boost::cmatch matched; + + if(boost::regex_match(m_connectString.c_str(), matched, connStrExpr)){ + m_protocol.assign(matched[1].first, matched[1].second); + if(isDirectConnection()){ + m_host.assign(matched[4].first, matched[4].second); + m_port.assign(matched[5].first, matched[5].second); + }else { + // if the connection is to a zookeeper, + // we will get the host and the port only after connecting to the Zookeeper + m_host = ""; + m_port = ""; + } + m_hostPortStr.assign(matched[2].first, matched[2].second); + if(matched[6].matched) { + m_pathToDrill.assign(matched[6].first, matched[6].second); + } + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) + << "Conn str: "<< m_connectString + << "; protocol: " << m_protocol + << "; host: " << m_host + << "; port: " << m_port + << "; path to drill: " << m_pathToDrill + << std::endl;) + } else { + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Invalid connect string. Regexp did not match" << std::endl;) + } + return; +} + +bool ConnectionEndpoint::isDirectConnection(){ + assert(!m_protocol.empty()); + return (!strcmp(m_protocol.c_str(), "local") || !strcmp(m_protocol.c_str(), "drillbit")); +} + +bool ConnectionEndpoint::isZookeeperConnection(){ + assert(!m_protocol.empty()); + return (!strcmp(m_protocol.c_str(), "zk")); +} + +connectionStatus_t ConnectionEndpoint::getDrillbitEndpointFromZk(){ + ZookeeperClient zook(m_pathToDrill); + assert(!m_hostPortStr.empty()); + std::vector<std::string> drillbits; + if(zook.getAllDrillbits(m_hostPortStr.c_str(), drillbits)!=0){ + return handleError(CONN_ZOOKEEPER_ERROR, getMessage(ERR_CONN_ZOOKEEPER, zook.getError().c_str())); + } + if (drillbits.empty()){ + return handleError(CONN_FAILURE, getMessage(ERR_CONN_ZKNODBIT)); + } + Utils::shuffle(drillbits); + exec::DrillbitEndpoint endpoint; + int err = zook.getEndPoint(drillbits[drillbits.size() -1], endpoint);// get the last one in the list + if(!err){ + m_host=boost::lexical_cast<std::string>(endpoint.address()); + m_port=boost::lexical_cast<std::string>(endpoint.user_port()); + } + if(err){ + return handleError(CONN_ZOOKEEPER_ERROR, getMessage(ERR_CONN_ZOOKEEPER, zook.getError().c_str())); + } + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Choosing drillbit <" << (drillbits.size() - 1) + << ">. Selected " << endpoint.DebugString() << std::endl;) + zook.close(); + return CONN_SUCCESS; +} + +connectionStatus_t ConnectionEndpoint::handleError(connectionStatus_t status, std::string msg){ + DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg); + if(m_pError!=NULL){ delete m_pError; m_pError=NULL;} + m_pError=pErr; + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Error getting drillbit endpoint:" << pErr->msg << std::endl;) + return status; +} + +/**************************** + * Channel Context Factory + ****************************/ +ChannelContext* ChannelContextFactory::getChannelContext(channelType_t t, DrillUserProperties* props){ + ChannelContext* pChannelContext=NULL; + switch(t){ + case CHANNEL_TYPE_SOCKET: + pChannelContext=new ChannelContext(props); + break; +#if defined(IS_SSL_ENABLED) + case CHANNEL_TYPE_SSLSTREAM: { + + std::string protocol; + props->getProp(USERPROP_TLSPROTOCOL, protocol); + boost::asio::ssl::context::method tlsVersion = SSLChannelContext::getTlsVersion(protocol); + + std::string noVerifyCert; + props->getProp(USERPROP_DISABLE_CERTVERIFICATION, noVerifyCert); + boost::asio::ssl::context::verify_mode verifyMode = boost::asio::ssl::context::verify_peer; + if (noVerifyCert == "true") { + verifyMode = boost::asio::ssl::context::verify_none; + } + + pChannelContext = new SSLChannelContext(props, tlsVersion, verifyMode); + } + break; +#endif + default: + DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; + break; + } + return pChannelContext; +} + +/******************* + * ChannelFactory + * *****************/ +Channel* ChannelFactory::getChannel(channelType_t t, const char* connStr){ + Channel* pChannel=NULL; + switch(t){ + case CHANNEL_TYPE_SOCKET: + pChannel=new SocketChannel(connStr); + break; +#if defined(IS_SSL_ENABLED) + case CHANNEL_TYPE_SSLSTREAM: + pChannel=new SSLStreamChannel(connStr); + break; +#endif + default: + DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; + break; + } + return pChannel; +} + +Channel* ChannelFactory::getChannel(channelType_t t, const char* host, const char* port){ + Channel* pChannel=NULL; + switch(t){ + case CHANNEL_TYPE_SOCKET: + pChannel=new SocketChannel(host, port); + break; +#if defined(IS_SSL_ENABLED) + case CHANNEL_TYPE_SSLSTREAM: + pChannel=new SSLStreamChannel(host, port); + break; +#endif + default: + DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; + break; + } + return pChannel; +} + +Channel* ChannelFactory::getChannel(channelType_t t, boost::asio::io_service& ioService, const char* connStr){ + Channel* pChannel=NULL; + switch(t){ + case CHANNEL_TYPE_SOCKET: + pChannel=new SocketChannel(ioService, connStr); + break; +#if defined(IS_SSL_ENABLED) + case CHANNEL_TYPE_SSLSTREAM: + pChannel=new SSLStreamChannel(ioService, connStr); + break; +#endif + default: + DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; + break; + } + return pChannel; +} + +Channel* ChannelFactory::getChannel(channelType_t t, boost::asio::io_service& ioService, const char* host, const char* port){ + Channel* pChannel=NULL; + switch(t){ + case CHANNEL_TYPE_SOCKET: + pChannel=new SocketChannel(ioService, host, port); + break; +#if defined(IS_SSL_ENABLED) + case CHANNEL_TYPE_SSLSTREAM: + pChannel=new SSLStreamChannel(ioService, host, port); + break; +#endif + default: + DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl; + break; + } + return pChannel; +} + +/******************* + * Channel + * *****************/ + +Channel::Channel(const char* connStr) : m_ioService(m_ioServiceFallback){ + m_pEndpoint=new ConnectionEndpoint(connStr); + m_ownIoService = true; + m_pSocket=NULL; + m_state=CHANNEL_UNINITIALIZED; + m_pError=NULL; +} + +Channel::Channel(const char* host, const char* port) : m_ioService(m_ioServiceFallback){ + m_pEndpoint=new ConnectionEndpoint(host, port); + m_ownIoService = true; + m_pSocket=NULL; + m_state=CHANNEL_UNINITIALIZED; + m_pError=NULL; +} + +Channel::Channel(boost::asio::io_service& ioService, const char* connStr):m_ioService(ioService){ + m_pEndpoint=new ConnectionEndpoint(connStr); + m_ownIoService = false; + m_pSocket=NULL; + m_state=CHANNEL_UNINITIALIZED; + m_pError=NULL; +} + +Channel::Channel(boost::asio::io_service& ioService, const char* host, const char* port) : m_ioService(ioService){ + m_pEndpoint=new ConnectionEndpoint(host, port); + m_ownIoService = true; + m_pSocket=NULL; + m_state=CHANNEL_UNINITIALIZED; + m_pError=NULL; +} + +Channel::~Channel(){ + if(m_pEndpoint!=NULL){ + delete m_pEndpoint; m_pEndpoint=NULL; + } + if(m_pSocket!=NULL){ + delete m_pSocket; m_pSocket=NULL; + } + if(m_pError!=NULL){ + delete m_pError; m_pError=NULL; + } +} + +template <typename SettableSocketOption> void Channel::setOption(SettableSocketOption& option){ + //May be useful some day. + //At the moment, we only need to set some well known options after we connect. + assert(0); +} + +connectionStatus_t Channel::init(ChannelContext_t* pContext){ + connectionStatus_t ret=CONN_SUCCESS; + this->m_state=CHANNEL_INITIALIZED; + this->m_pContext = pContext; + return ret; +} + +connectionStatus_t Channel::connect(){ + connectionStatus_t ret=CONN_FAILURE; + if(this->m_state==CHANNEL_INITIALIZED){ + ret=m_pEndpoint->getDrillbitEndpoint(); + if(ret==CONN_SUCCESS){ + DRILL_LOG(LOG_TRACE) << "Connecting to drillbit: " + << m_pEndpoint->getHost() + << ":" << m_pEndpoint->getPort() + << "." << std::endl; + ret=this->connectInternal(); + }else{ + handleError(ret, m_pEndpoint->getError()->msg); + } + } + this->m_state=(ret==CONN_SUCCESS)?CHANNEL_CONNECTED:this->m_state; + return ret; +} + +connectionStatus_t Channel::handleError(connectionStatus_t status, std::string msg){ + DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg); + if(m_pError!=NULL){ delete m_pError; m_pError=NULL;} + m_pError=pErr; + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Error connecting:" << pErr->msg << std::endl;) + return status; +} + +connectionStatus_t Channel::connectInternal() { + using boost::asio::ip::tcp; + tcp::endpoint endpoint; + const char *host = m_pEndpoint->getHost().c_str(); + const char *port = m_pEndpoint->getPort().c_str(); + try { + tcp::resolver resolver(m_ioService); + tcp::resolver::query query(tcp::v4(), host, port); + tcp::resolver::iterator iter = resolver.resolve(query); + tcp::resolver::iterator end; + while (iter != end) { + endpoint = *iter++; + DRILL_LOG(LOG_TRACE) << endpoint << std::endl; + } + boost::system::error_code ec; + m_pSocket->getInnerSocket().connect(endpoint, ec); + if (ec) { + return handleError(CONN_FAILURE, getMessage(ERR_CONN_FAILURE, host, port, ec.message().c_str())); + } + } catch (boost::system::system_error e) { + // Handle case when the hostname cannot be resolved. "resolve" is hard-coded in boost asio resolver.resolve + if (!strncmp(e.what(), "resolve", 7)) { + return handleError(CONN_HOSTNAME_RESOLUTION_ERROR, getMessage(ERR_CONN_EXCEPT, e.what())); + } + } catch (std::exception e) { + return handleError(CONN_FAILURE, getMessage(ERR_CONN_EXCEPT, e.what())); + } + + // set socket keep alive + boost::asio::socket_base::keep_alive keepAlive(true); + m_pSocket->getInnerSocket().set_option(keepAlive); + // set no_delay + boost::asio::ip::tcp::no_delay noDelay(true); + m_pSocket->getInnerSocket().set_option(noDelay); + // set reuse addr + boost::asio::socket_base::reuse_address reuseAddr(true); + m_pSocket->getInnerSocket().set_option(reuseAddr); + + std::string useSystemTrustStore; + m_pContext->getUserProperties()->getProp(USERPROP_USESYSTEMTRUSTSTORE, useSystemTrustStore); + DRILL_LOG(LOG_TRACE) << "Connected" << std::endl; + return this->protocolHandshake(useSystemTrustStore=="true"); + +} + +connectionStatus_t SocketChannel::init(ChannelContext_t* pContext){ + connectionStatus_t ret=CONN_SUCCESS; + m_pSocket=new Socket(m_ioService); + if(m_pSocket!=NULL){ + ret=Channel::init(pContext); + }else{ + DRILL_LOG(LOG_ERROR) << "Channel initialization failure. " << std::endl; + handleError(CONN_NOSOCKET, getMessage(ERR_CONN_NOSOCKET)); + ret=CONN_FAILURE; + } + return ret; +} + +#if defined(IS_SSL_ENABLED) +connectionStatus_t SSLStreamChannel::init(ChannelContext_t* pContext){ + connectionStatus_t ret=CONN_SUCCESS; + + const DrillUserProperties* props = pContext->getUserProperties(); + std::string useSystemTrustStore; + props->getProp(USERPROP_USESYSTEMTRUSTSTORE, useSystemTrustStore); + if (useSystemTrustStore != "true"){ + std::string certFile; + props->getProp(USERPROP_CERTFILEPATH, certFile); + try{ + ((SSLChannelContext_t*)pContext)->getSslContext().load_verify_file(certFile); + } + catch (boost::system::system_error e){ + DRILL_LOG(LOG_ERROR) << "Channel initialization failure. Certificate file " + << certFile + << " could not be loaded." + << std::endl; + handleError(CONN_SSLERROR, getMessage(ERR_CONN_SSLCERTFAIL, certFile.c_str(), e.what())); + ret = CONN_FAILURE; + } + } + + std::string disableHostVerification; + props->getProp(USERPROP_DISABLE_HOSTVERIFICATION, disableHostVerification); + if (disableHostVerification != "true") { + std::string hostPortStr = m_pEndpoint->getHost() + ":" + m_pEndpoint->getPort(); + ((SSLChannelContext_t *) pContext)->getSslContext().set_verify_callback( + boost::asio::ssl::rfc2818_verification(hostPortStr.c_str())); + } + + m_pSocket=new SslSocket(m_ioService, ((SSLChannelContext_t*)pContext)->getSslContext() ); + if(m_pSocket!=NULL){ + ret=Channel::init(pContext); + }else{ + DRILL_LOG(LOG_ERROR) << "Channel initialization failure. " << std::endl; + handleError(CONN_NOSOCKET, getMessage(ERR_CONN_NOSOCKET)); + ret=CONN_FAILURE; + } + return ret; +} +#endif + +} // namespace Drill diff --git a/contrib/native/client/src/clientlib/channel.hpp b/contrib/native/client/src/clientlib/channel.hpp new file mode 100644 index 000000000..7f310e899 --- /dev/null +++ b/contrib/native/client/src/clientlib/channel.hpp @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CHANNEL_HPP +#define CHANNEL_HPP + +#include "drill/common.hpp" +#include "drill/drillClient.hpp" +#include "streamSocket.hpp" + +namespace Drill { + +class UserProperties; + + class ConnectionEndpoint{ + public: + ConnectionEndpoint(const char* connStr); + ConnectionEndpoint(const char* host, const char* port); + ~ConnectionEndpoint(); + + //parse the connection string and set up the host and port to connect to + connectionStatus_t getDrillbitEndpoint(); + + std::string& getProtocol(){return m_protocol;} + std::string& getHost(){return m_host;} + std::string& getPort(){return m_port;} + DrillClientError* getError(){ return m_pError;}; + + private: + void parseConnectString(); + bool isDirectConnection(); + bool isZookeeperConnection(); + connectionStatus_t getDrillbitEndpointFromZk(); + connectionStatus_t handleError(connectionStatus_t status, std::string msg); + + std::string m_connectString; + std::string m_pathToDrill; + std::string m_protocol; + std::string m_hostPortStr; + std::string m_host; + std::string m_port; + + DrillClientError* m_pError; + + }; + + class ChannelContext{ + public: + ChannelContext(DrillUserProperties* props):m_properties(props){}; + virtual ~ChannelContext(){}; + const DrillUserProperties* getUserProperties() const { return m_properties;} + protected: + DrillUserProperties* m_properties; + }; + + class SSLChannelContext: public ChannelContext{ + public: + static boost::asio::ssl::context::method getTlsVersion(std::string version){ + if(version.empty()){ + return boost::asio::ssl::context::tlsv12; + } else if (version == "tlsv12") { + return boost::asio::ssl::context::tlsv12; + } else if (version == "tlsv11") { + return boost::asio::ssl::context::tlsv11; + } else if (version == "sslv23") { + return boost::asio::ssl::context::sslv23; + } else if (version == "tlsv1") { + return boost::asio::ssl::context::tlsv1; + } else if (version == "sslv3") { + return boost::asio::ssl::context::sslv3; + } else { + return boost::asio::ssl::context::tlsv12; + } + } + + SSLChannelContext(DrillUserProperties *props, boost::asio::ssl::context::method tlsVersion, boost::asio::ssl::verify_mode verifyMode) : + ChannelContext(props), + m_SSLContext(tlsVersion) { + m_SSLContext.set_default_verify_paths(); + m_SSLContext.set_options( + boost::asio::ssl::context::default_workarounds + | boost::asio::ssl::context::no_sslv2 + | boost::asio::ssl::context::single_dh_use + ); + m_SSLContext.set_verify_mode(verifyMode); + }; + ~SSLChannelContext(){}; + boost::asio::ssl::context& getSslContext(){ return m_SSLContext;} + private: + boost::asio::ssl::context m_SSLContext; + }; + + typedef ChannelContext ChannelContext_t; + typedef SSLChannelContext SSLChannelContext_t; + + class ChannelContextFactory{ + public: + static ChannelContext_t* getChannelContext(channelType_t t, DrillUserProperties* props); + }; + + /*** + * The Channel class encapsulates a connection to a drillbit. Based on + * the connection string and the options, the connection will be either + * a simple socket or a socket using an ssl stream. The class also encapsulates + * connecting to a drillbit directly or thru zookeeper. + * The channel class owns the socket and the io_service that the applications + * will use to communicate with the server. + ***/ + class Channel{ + public: + Channel(const char* connStr); + Channel(const char* host, const char* port); + Channel(boost::asio::io_service& ioService, const char* connStr); + Channel(boost::asio::io_service& ioService, const char* host, const char* port); + virtual ~Channel(); + virtual connectionStatus_t init(ChannelContext_t* context)=0; + connectionStatus_t connect(); + bool isConnected(){ return m_state == CHANNEL_CONNECTED;} + template <typename SettableSocketOption> void setOption(SettableSocketOption& option); + DrillClientError* getError(){ return m_pError;} + void close(){ + if(m_state==CHANNEL_INITIALIZED||m_state==CHANNEL_CONNECTED){ + m_pSocket->protocolClose(); + m_state=CHANNEL_CLOSED; + } + } // Not OK to use the channel after this call. + + boost::asio::io_service& getIOService(){ + return m_ioService; + } + + // returns a reference to the underlying socket + // This access should really be removed and encapsulated in calls that + // manage async_send and async_recv + // Until then we will let DrillClientImpl have direct access + streamSocket_t& getInnerSocket(){ + return m_pSocket->getInnerSocket(); + } + + AsioStreamSocket& getSocketStream(){ + return *m_pSocket; + } + + ConnectionEndpoint* getEndpoint(){return m_pEndpoint;} + + protected: + connectionStatus_t handleError(connectionStatus_t status, std::string msg); + + boost::asio::io_service& m_ioService; + boost::asio::io_service m_ioServiceFallback; // used if m_ioService is not provided + AsioStreamSocket* m_pSocket; + ConnectionEndpoint *m_pEndpoint; + ChannelContext_t *m_pContext; + + private: + + typedef enum channelState{ + CHANNEL_UNINITIALIZED=1, + CHANNEL_INITIALIZED, + CHANNEL_CONNECTED, + CHANNEL_CLOSED + } channelState_t; + + connectionStatus_t connectInternal(); + connectionStatus_t protocolHandshake(bool useSystemConfig){ + connectionStatus_t status = CONN_SUCCESS; + try{ + m_pSocket->protocolHandshake(useSystemConfig); + } catch (boost::system::system_error e) { + status = handleError(CONN_HANDSHAKE_FAILED, e.what()); + } + return status; + } + + channelState_t m_state; + DrillClientError* m_pError; + bool m_ownIoService; + }; + + class SocketChannel: public Channel{ + public: + SocketChannel(const char* connStr):Channel(connStr){ + } + SocketChannel(const char* host, const char* port):Channel(host, port){ + } + SocketChannel(boost::asio::io_service& ioService, const char* connStr) + :Channel(ioService, connStr){ + } + SocketChannel(boost::asio::io_service& ioService, const char* host, const char* port) + :Channel(ioService, host, port){ + } + connectionStatus_t init(ChannelContext_t* context=NULL); + }; + + class SSLStreamChannel: public Channel{ + public: + SSLStreamChannel(const char* connStr):Channel(connStr){ + } + SSLStreamChannel(const char* host, const char* port):Channel(host, port){ + } + SSLStreamChannel(boost::asio::io_service& ioService, const char* connStr) + :Channel(ioService, connStr){ + } + SSLStreamChannel(boost::asio::io_service& ioService, const char* host, const char* port) + :Channel(ioService, host, port){ + } + connectionStatus_t init(ChannelContext_t* context); + }; + + class ChannelFactory{ + public: + static Channel* getChannel(channelType_t t, const char* connStr); + static Channel* getChannel(channelType_t t, const char* host, const char* port); + static Channel* getChannel(channelType_t t, boost::asio::io_service& ioService, const char* connStr); + static Channel* getChannel(channelType_t t, boost::asio::io_service& ioService, const char* host, const char* port); + }; + + +} // namespace Drill + +#endif // CHANNEL_HPP + diff --git a/contrib/native/client/src/clientlib/errmsgs.cpp b/contrib/native/client/src/clientlib/errmsgs.cpp index 15d2256fb..d2d8c1806 100644 --- a/contrib/native/client/src/clientlib/errmsgs.cpp +++ b/contrib/native/client/src/clientlib/errmsgs.cpp @@ -49,6 +49,9 @@ static Drill::ErrorMessages errorMessages[]={ {ERR_CONN_UNKNOWN_ERR, ERR_CATEGORY_CONN, 0, "Handshake Failed due to an error on the server. [Server message was: (%s) %s]"}, {ERR_CONN_NOCONN, ERR_CATEGORY_CONN, 0, "There is no connection to the server."}, {ERR_CONN_ALREADYCONN, ERR_CATEGORY_CONN, 0, "This client is already connected to a server."}, + {ERR_CONN_NOCONNSTR, ERR_CATEGORY_CONN, 0, "Cannot connect if either host name or port number are empty."}, + {ERR_CONN_SSLCERTFAIL, ERR_CATEGORY_CONN, 0, "SSL certificate file %s could not be loaded (exception message: %s)."}, + {ERR_CONN_NOSOCKET, ERR_CATEGORY_CONN, 0, "Failed to open socket connection."}, {ERR_QRY_OUTOFMEM, ERR_CATEGORY_QRY, 0, "Out of memory."}, {ERR_QRY_COMMERR, ERR_CATEGORY_QRY, 0, "Communication error. %s"}, {ERR_QRY_INVREADLEN, ERR_CATEGORY_QRY, 0, "Internal Error: Received a message with an invalid read length."}, diff --git a/contrib/native/client/src/clientlib/errmsgs.hpp b/contrib/native/client/src/clientlib/errmsgs.hpp index cfb56a6b0..246e4bbf2 100644 --- a/contrib/native/client/src/clientlib/errmsgs.hpp +++ b/contrib/native/client/src/clientlib/errmsgs.hpp @@ -51,7 +51,10 @@ namespace Drill{ #define ERR_CONN_UNKNOWN_ERR DRILL_ERR_START+18 #define ERR_CONN_NOCONN DRILL_ERR_START+19 #define ERR_CONN_ALREADYCONN DRILL_ERR_START+20 -#define ERR_CONN_MAX DRILL_ERR_START+20 +#define ERR_CONN_NOCONNSTR DRILL_ERR_START+21 +#define ERR_CONN_SSLCERTFAIL DRILL_ERR_START+22 +#define ERR_CONN_NOSOCKET DRILL_ERR_START+23 +#define ERR_CONN_MAX DRILL_ERR_START+23 #define ERR_QRY_OUTOFMEM ERR_CONN_MAX+1 #define ERR_QRY_COMMERR ERR_CONN_MAX+2 diff --git a/contrib/native/client/src/clientlib/logger.hpp b/contrib/native/client/src/clientlib/logger.hpp index 7baf50c41..966e3a1d3 100644 --- a/contrib/native/client/src/clientlib/logger.hpp +++ b/contrib/native/client/src/clientlib/logger.hpp @@ -21,6 +21,7 @@ #include <sstream> #include <ostream> +#include <iostream> #include <fstream> #include <string> #include <stdio.h> diff --git a/contrib/native/client/src/clientlib/streamSocket.hpp b/contrib/native/client/src/clientlib/streamSocket.hpp new file mode 100644 index 000000000..5db4dcaa9 --- /dev/null +++ b/contrib/native/client/src/clientlib/streamSocket.hpp @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef STREAMSOCKET_HPP +#define STREAMSOCKET_HPP + +#include "logger.hpp" +#include "wincert.ipp" + +#include <boost/asio.hpp> +#include <boost/asio/ssl.hpp> + +namespace Drill { + +typedef boost::asio::ip::tcp::socket::lowest_layer_type streamSocket_t; +typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket> sslTCPSocket_t; +typedef boost::asio::ip::tcp::socket basicTCPSocket_t; + + +// Some helper typedefs to define the highly templatized boost::asio methods +typedef boost::asio::const_buffers_1 ConstBufferSequence; +typedef boost::asio::mutable_buffers_1 MutableBufferSequence; + +// ReadHandlers have different possible signatures. +// +// As a standard C-type callback +// typedef void (*ReadHandler)(const boost::system::error_code& ec, std::size_t bytes_transferred); +// +// Or as a C++ functor +// struct ReadHandler { +// virtual void operator()( +// const boost::system::error_code& ec, +// std::size_t bytes_transferred) = 0; +//}; +// +// We need a different signature though, since we need to pass in a member function of a drill client +// class (which is C++), as a functor generated by boost::bind as a ReadHandler +// +typedef boost::function<void (const boost::system::error_code& ec, std::size_t bytes_transferred) > ReadHandler; + +class AsioStreamSocket{ + public: + virtual ~AsioStreamSocket(){}; + virtual streamSocket_t& getInnerSocket() = 0; + + virtual std::size_t writeSome( + const ConstBufferSequence& buffers, + boost::system::error_code & ec) = 0; + + virtual std::size_t readSome( + const MutableBufferSequence& buffers, + boost::system::error_code & ec) = 0; + + // + // boost::asio::async_read has the signature + // template< + // typename AsyncReadStream, + // typename MutableBufferSequence, + // typename ReadHandler> + // void-or-deduced async_read( + // AsyncReadStream & s, + // const MutableBufferSequence & buffers, + // ReadHandler handler); + // + // For our use case, the derived class will have an instance of a concrete type for AsyncReadStream which + // will implement the requirements for the AsyncReadStream type. We need not pass that in as a parameter + // since the class already has the value + // The method is templatized since the ReadHandler type is dependent on the class implementing the read + // handler (basically the class using the asio stream) + // + virtual void asyncRead( const MutableBufferSequence & buffers, ReadHandler handler) = 0; + + // call the underlying protocol's handshake method. + // if the useSystemConfig flag is true, then use properties read + // from the underlying operating system + virtual void protocolHandshake(bool useSystemConfig) = 0; + virtual void protocolClose() = 0; +}; + +class Socket: + public AsioStreamSocket, + public basicTCPSocket_t{ + + public: + Socket(boost::asio::io_service& ioService) : basicTCPSocket_t(ioService) { + } + + ~Socket(){ + boost::system::error_code ignorederr; + this->shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignorederr); + this->close(); + }; + + basicTCPSocket_t& getSocketStream(){ return *this;} + + streamSocket_t& getInnerSocket(){ return this->lowest_layer();} + + std::size_t writeSome( + const ConstBufferSequence& buffers, + boost::system::error_code & ec){ + return this->write_some(buffers, ec); + } + + std::size_t readSome( + const MutableBufferSequence& buffers, + boost::system::error_code & ec){ + return this->read_some(buffers, ec); + } + + void asyncRead( const MutableBufferSequence & buffers, ReadHandler handler){ + return async_read(*this, buffers, handler); + } + + void protocolHandshake(bool useSystemConfig){}; //nothing to do + void protocolClose(){ + // shuts down the socket! + boost::system::error_code ignorederr; + ((basicTCPSocket_t*)this)->shutdown(boost::asio::ip::tcp::socket::shutdown_both, + ignorederr + ); + } +}; + + +#if defined(IS_SSL_ENABLED) + +class SslSocket: + public AsioStreamSocket, + public sslTCPSocket_t{ + + public: + SslSocket(boost::asio::io_service& ioService, boost::asio::ssl::context &sslContext) : + sslTCPSocket_t(ioService, sslContext) { + } + + ~SslSocket(){ + this->lowest_layer().close(); + }; + + sslTCPSocket_t& getSocketStream(){ return *this;} + + streamSocket_t& getInnerSocket(){ return this->lowest_layer();} + + std::size_t writeSome( + const ConstBufferSequence& buffers, + boost::system::error_code & ec){ + return this->write_some(buffers, ec); + } + + std::size_t readSome( + const MutableBufferSequence& buffers, + boost::system::error_code & ec){ + return this->read_some(buffers, ec); + } + + void asyncRead( const MutableBufferSequence & buffers, ReadHandler handler){ + return async_read(*this, buffers, handler); + } + + // + // public method that can be invoked by callers to invoke the ssl handshake + // throws: boost::system::system_error + void protocolHandshake(bool useSystemConfig){ + if(useSystemConfig){ + std::string msg = ""; + int ret = loadSystemTrustStore(this->native_handle(), msg); + if(!msg.empty()){ + DRILL_LOG(LOG_WARNING) << msg.c_str() << std::endl; + } + if(ret){ + boost::system::error_code ec(EPROTO, boost::system::system_category()); + boost::asio::detail::throw_error(ec, msg.c_str()); + } + } + this->handshake(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>::client); + return; + }; + // + // public method that can be invoked by callers to invoke a clean ssl shutdown + // throws: boost::system::system_error + void protocolClose(){ + try{ + this->shutdown(); + }catch(boost::system::system_error e){ + //swallow the exception. The channel is unusable anyway + } + // shuts down the socket! + boost::system::error_code ignorederr; + this->lowest_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_both, + ignorederr + ); + return; + }; + +}; +#endif + + +} // namespace Drill + +#endif //STREAMSOCKET_HPP + diff --git a/contrib/native/client/src/clientlib/userProperties.cpp b/contrib/native/client/src/clientlib/userProperties.cpp index 07497ef55..666f587c5 100644 --- a/contrib/native/client/src/clientlib/userProperties.cpp +++ b/contrib/native/client/src/clientlib/userProperties.cpp @@ -31,7 +31,11 @@ const std::map<std::string, uint32_t> DrillUserProperties::USER_PROPERTIES=boos ( USERPROP_SERVICE_NAME, USERPROP_FLAGS_STRING) ( USERPROP_SERVICE_HOST, USERPROP_FLAGS_STRING) ( USERPROP_USESSL, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP) + ( USERPROP_TLSPROTOCOL, USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP) ( USERPROP_CERTFILEPATH, USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP|USERPROP_FLAGS_FILEPATH) + ( USERPROP_DISABLE_HOSTVERIFICATION, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP) + ( USERPROP_DISABLE_CERTVERIFICATION, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP) + ( USERPROP_USESYSTEMTRUSTSTORE, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP) ( USERPROP_SASL_ENCRYPT, USERPROP_FLAGS_STRING) ; diff --git a/contrib/native/client/src/clientlib/wincert.ipp b/contrib/native/client/src/clientlib/wincert.ipp new file mode 100644 index 000000000..c1af70a36 --- /dev/null +++ b/contrib/native/client/src/clientlib/wincert.ipp @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(IS_SSL_ENABLED) + +#include <openssl/x509.h> +#include <openssl/ssl.h> + +#if defined _WIN32 || defined _WIN64 + +#include <stdio.h> +#include <windows.h> +#include <wincrypt.h> +#include <cryptuiapi.h> +#include <iostream> +#include <tchar.h> + + +#pragma comment (lib, "crypt32.lib") +#pragma comment (lib, "cryptui.lib") + +#define MY_ENCODING_TYPE (PKCS_7_ASN_ENCODING | X509_ASN_ENCODING) + +inline +int loadSystemTrustStore(const SSL *ssl, std::string& msg) { + HCERTSTORE hStore; + PCCERT_CONTEXT pContext = NULL; + X509 *x509; + char* stores[] = { + "CA", + "MY", + "ROOT", + "SPC" + }; + int certCount=0; + + SSL_CTX * ctx = SSL_get_SSL_CTX(ssl); + X509_STORE *store = SSL_CTX_get_cert_store(ctx); + + for(int i=0; i<4; i++){ + hStore = CertOpenSystemStore(NULL, stores[i]); + + if (!hStore){ + msg.append("Failed to load store: ").append(stores[i]).append("\n"); + continue; + } + + while (pContext = CertEnumCertificatesInStore(hStore, pContext)) { + //uncomment the line below if you want to see the certificates as pop ups + //CryptUIDlgViewContext(CERT_STORE_CERTIFICATE_CONTEXT, pContext, NULL, NULL, 0, NULL); + + x509 = NULL; + x509 = d2i_X509(NULL, (const unsigned char **)&pContext->pbCertEncoded, pContext->cbCertEncoded); + if (x509) { + int ret = X509_STORE_add_cert(store, x509); + + //if (ret == 1) + // std::cout << "Added certificate " << x509->name << " from " << stores[i] << std::endl; + + X509_free(x509); + certCount++; + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + } + if(certCount==0){ + msg.append("No certificates found."); + return -1; + } + return 0; +} + +#else // notwindows +inline +int loadSystemTrustStore(const SSL *ssl, std::string& msg) { + return 0; +} + +#endif // WIN32 or WIN64 + +#endif // SSL_ENABLED diff --git a/contrib/native/client/src/include/drill/common.hpp b/contrib/native/client/src/include/drill/common.hpp index 5401c75a9..012bd1974 100644 --- a/contrib/native/client/src/include/drill/common.hpp +++ b/contrib/native/client/src/include/drill/common.hpp @@ -106,6 +106,11 @@ class AllocatedBuffer; typedef AllocatedBuffer* AllocatedBufferPtr; typedef enum{ + CHANNEL_TYPE_SOCKET=1, + CHANNEL_TYPE_SSLSTREAM=2 +} channelType_t; + +typedef enum{ QRY_SUCCESS=0, QRY_FAILURE=1, QRY_SUCCESS_WITH_INFO=2, @@ -136,7 +141,9 @@ typedef enum{ CONN_BAD_RPC_VER=8, CONN_DEAD=9, CONN_NOTCONNECTED=10, - CONN_ALREADYCONNECTED=11 + CONN_ALREADYCONNECTED=11, + CONN_SSLERROR=12, + CONN_NOSOCKET=13 } connectionStatus_t; typedef enum{ @@ -163,9 +170,13 @@ typedef enum{ #define USERPROP_USERNAME "userName" #define USERPROP_PASSWORD "password" #define USERPROP_SCHEMA "schema" -#define USERPROP_USESSL "useSSL" // Not implemented yet -#define USERPROP_FILEPATH "pemLocation" // Not implemented yet -#define USERPROP_FILENAME "pemFile" // Not implemented yet +#define USERPROP_USESSL "enableTLS" +#define USERPROP_TLSPROTOCOL "TLSProtocol" //TLS version +#define USERPROP_CERTFILEPATH "certFilePath" // pem file path and name +#define USERPROP_CERTPASSWORD "certPassword" // Password for certificate file +#define USERPROP_DISABLE_HOSTVERIFICATION "disableHostVerification" +#define USERPROP_DISABLE_CERTVERIFICATION "disableCertVerification" +#define USERPROP_USESYSTEMTRUSTSTORE "useSystemTrustStore" //Windows only, use the system trust store #define USERPROP_IMPERSONATION_TARGET "impersonation_target" #define USERPROP_AUTH_MECHANISM "auth" #define USERPROP_SERVICE_NAME "service_name" diff --git a/contrib/native/client/test/ssl/testSSL.cpp b/contrib/native/client/test/ssl/testSSL.cpp new file mode 100644 index 000000000..3eaac4876 --- /dev/null +++ b/contrib/native/client/test/ssl/testSSL.cpp @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <fstream> +#include <iostream> +#include <stdio.h> +#include <stdlib.h> +#include <boost/algorithm/string.hpp> +#include <boost/asio.hpp> +#include <boost/assign.hpp> +#include <boost/bind.hpp> +#include "drill/drillc.hpp" +#include "drill/drillError.hpp" +#include "clientlib/channel.hpp" +#include "clientlib/drillClientImpl.hpp" +#include "clientlib/errmsgs.hpp" +#include "clientlib/logger.hpp" +#include "clientlib/rpcMessage.hpp" +#include "clientlib/utils.hpp" +#include "protobuf/GeneralRPC.pb.h" +#include "protobuf/UserBitShared.pb.h" + +namespace Drill { + +class DrillTestClient { + + public: + + DrillTestClient(Channel* pChannel): + m_handshakeStatus(exec::user::SUCCESS), + m_wbuf(MAX_SOCK_RD_BUFSIZE), + m_rbuf(0){ + m_pChannel=pChannel; + m_pError=NULL; + m_coordinationId=Utils::s_randomNumber()%1729+1; + } + + connectionStatus_t recvHandshake(){ + if(m_rbuf==NULL){ + m_rbuf = Utils::allocateBuffer(MAX_SOCK_RD_BUFSIZE); + } + + m_pChannel->getIOService().reset(); + + m_pChannel->getSocketStream().asyncRead( + boost::asio::buffer(m_rbuf, LEN_PREFIX_BUFLEN), + boost::bind( + &DrillTestClient::handleHandshake, + this, + m_rbuf, + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred) + ); + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::recvHandshake: async read waiting for server handshake response.\n";) + m_pChannel->getIOService().run(); + if(m_rbuf!=NULL){ + Utils::freeBuffer(m_rbuf, MAX_SOCK_RD_BUFSIZE); m_rbuf=NULL; + } + + if (m_pError != NULL) { + DRILL_MT_LOG(DRILL_LOG(LOG_ERROR) << "DrillClientImpl::recvHandshake: failed to complete handshake with server." + << m_pError->msg << "\n";) + return static_cast<connectionStatus_t>(m_pError->status); + } + + return CONN_SUCCESS; + } + + void doReadFromSocket(ByteBuf_t inBuf, size_t bytesToRead, boost::system::error_code& errorCode) { + // Check if bytesToRead is zero + if(0 == bytesToRead) { + return; + } + + // Read all the bytes. In case when all the bytes were not read the proper + // errorCode will be set. + while(1){ + size_t dataBytesRead = m_pChannel->getSocketStream().readSome(boost::asio::buffer(inBuf, bytesToRead), errorCode); + // Update the state + bytesToRead -= dataBytesRead; + inBuf += dataBytesRead; + + // Check if errorCode is EINTR then just retry otherwise break from loop + if(EINTR != errorCode.value()) break; + + // Check if all the data is read then break from loop + if(0 == bytesToRead) break; + } + } + + void handleHandshake(ByteBuf_t inBuf, + const boost::system::error_code& err, + size_t bytes_transferred) { + boost::system::error_code error=err; + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Deadline timer cancelled." << std::endl;) + if(!error){ + rpc::InBoundRpcMessage msg; + uint32_t length = 0; + std::size_t bytes_read = rpc::lengthDecode(m_rbuf, length); + if(length>0){ + const size_t leftover = LEN_PREFIX_BUFLEN - bytes_read; + const ByteBuf_t b = m_rbuf + LEN_PREFIX_BUFLEN; + const size_t bytesToRead=length - leftover; + doReadFromSocket(b, bytesToRead, error); + + // Check if any error happen while reading the message bytes. If yes then return before decoding the Msg + if(error) { + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. " + << " Failed to read entire handshake message. with error: " + << error.message().c_str() << "\n";) + handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Failed to read entire handshake message")); + return; + } + + // Decode the bytes into a valid RPC Message + if (!decode(m_rbuf+bytes_read, length, msg)) { + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. Cannot decode handshake.\n";) + handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Cannot decode handshake")); + return; + } + }else{ + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. No handshake.\n";) + handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "No handshake")); + return; + } + exec::user::BitToUserHandshake b2u; + b2u.ParseFromArray(msg.m_pbody.data(), msg.m_pbody.size()); + this->m_handshakeErrorId=b2u.errorid(); + this->m_handshakeErrorMsg=b2u.errormessage(); + }else{ + // boost error + if(error==boost::asio::error::eof){ // Server broke off the connection + handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_NOHSHAKE, DRILL_RPC_VERSION)); + }else{ + handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, error.message().c_str())); + } + return; + } + return; + } + + connectionStatus_t handleConnError(connectionStatus_t status, const std::string& msg){ + DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg); + if(m_pError!=NULL){ delete m_pError; m_pError=NULL;} + m_pError=pErr; + return status; + } + + connectionStatus_t sendSyncCommon(rpc::OutBoundRpcMessage& msg) { + encode(m_wbuf, msg); + boost::system::error_code ec; + doWriteToSocket(reinterpret_cast<char*>(m_wbuf.data()), m_wbuf.size(), ec); + + if(!ec) { + return CONN_SUCCESS; + } else { + return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_WFAIL, ec.message().c_str())); + } + } + + void doWriteToSocket(const char* dataPtr, size_t bytesToWrite, + boost::system::error_code& errorCode) { + if(0 == bytesToWrite) { + return; + } + + // Write all the bytes to socket. In case of error when all bytes are not successfully written + // proper errorCode will be set. + while(1) { + size_t bytesWritten = m_pChannel->getSocketStream().writeSome(boost::asio::buffer(dataPtr, bytesToWrite), errorCode); + // Update the state + bytesToWrite -= bytesWritten; + dataPtr += bytesWritten; + + if(EINTR != errorCode.value()) break; + + // Check if all the data is written then break from loop + if(0 == bytesToWrite) break; + } + } + + connectionStatus_t validateHandshake(DrillUserProperties* properties){ + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "validateHandShake\n";) + + exec::user::UserToBitHandshake u2b; + u2b.set_channel(exec::shared::USER); + u2b.set_rpc_version(DRILL_RPC_VERSION); + u2b.set_support_listening(true); + u2b.set_support_timeout(DrillClientConfig::getHeartbeatFrequency() > 0); + u2b.set_sasl_support(exec::user::SASL_PRIVACY); + + // Adding version info + exec::user::RpcEndpointInfos* infos = u2b.mutable_client_infos(); + infos->set_name(DrillClientConfig::getClientName()); + infos->set_application(DrillClientConfig::getApplicationName()); + infos->set_version(DRILL_VERSION_STRING); + infos->set_majorversion(DRILL_VERSION_MAJOR); + infos->set_minorversion(DRILL_VERSION_MINOR); + infos->set_patchversion(DRILL_VERSION_PATCH); + + if(properties != NULL && properties->size()>0){ + std::string username; + std::string err; + if(!properties->validate(err)){ + DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Invalid user input:" << err << std::endl;) + } + exec::user::UserProperties* userProperties = u2b.mutable_properties(); + + std::map<char,int>::iterator it; + for (std::map<std::string,std::string>::const_iterator propIter=properties->begin(); propIter!=properties->end(); ++propIter){ + std::string currKey=propIter->first; + std::string currVal=propIter->second; + std::map<std::string,uint32_t>::const_iterator it=DrillUserProperties::USER_PROPERTIES.find(currKey); + if(it==DrillUserProperties::USER_PROPERTIES.end()){ + DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Connection property ("<< currKey + << ") is unknown" << std::endl;) + exec::user::Property* connProp = userProperties->add_properties(); + connProp->set_key(currKey); + connProp->set_value(currVal); + continue; + } + if(IS_BITSET((*it).second,USERPROP_FLAGS_SERVERPROP)){ + exec::user::Property* connProp = userProperties->add_properties(); + connProp->set_key(currKey); + connProp->set_value(currVal); + //Username(but not the password) also needs to be set in UserCredentials + if(IS_BITSET((*it).second,USERPROP_FLAGS_USERNAME)){ + exec::shared::UserCredentials* creds = u2b.mutable_credentials(); + username=currVal; + creds->set_user_name(username); + //u2b.set_credentials(&creds); + } + if(IS_BITSET((*it).second,USERPROP_FLAGS_PASSWORD)){ + DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << currKey << ": ********** " << std::endl;) + }else{ + DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << currKey << ":" << currVal << std::endl;) + } + }// Server properties + } + } + + { + boost::lock_guard<boost::mutex> lock(this->m_dcMutex); + uint64_t coordId = ++m_coordinationId; + + rpc::OutBoundRpcMessage out_msg(exec::rpc::REQUEST, exec::user::HANDSHAKE, coordId, &u2b); + sendSyncCommon(out_msg); + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Sent handshake request message. Coordination id: " << coordId << "\n";) + } + + connectionStatus_t ret = recvHandshake(); + if(ret!=CONN_SUCCESS){ + return ret; + } + + switch(this->m_handshakeStatus) { + case exec::user::SUCCESS: + // reset io_service after handshake is validated before running queries + m_pChannel->getIOService().reset(); + return CONN_SUCCESS; + case exec::user::RPC_VERSION_MISMATCH: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected " + << DRILL_RPC_VERSION << ", actual "<< 0 << "." << std::endl;) + return handleConnError(CONN_BAD_RPC_VER, getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION, + 0, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + case exec::user::AUTH_FAILED: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;) + return handleConnError(CONN_AUTH_FAILED, getMessage(ERR_CONN_AUTHFAIL, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + case exec::user::UNKNOWN_FAILURE: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;) + return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + case exec::user::AUTH_REQUIRED: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Server requires SASL authentication." << std::endl;) + return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + default: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown return status." << std::endl;) + return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + } + } + + DrillClientError* m_pError; + private: + Channel* m_pChannel; + int32_t m_coordinationId; + std::string m_handshakeErrorId; + std::string m_handshakeErrorMsg; + exec::user::HandshakeStatus m_handshakeStatus; + DataBuf m_wbuf; + ByteBuf_t m_rbuf; + boost::mutex m_dcMutex; + + + + +}; + +} // namespace Drill + +using namespace Drill; + +int main(int argc, char* argv[]){ + Channel *pChannel = NULL; + ChannelContext *pChannelContext = NULL; + std::string connectStr = "zk=localhost:2181/drill/drillbits1"; + //std::string connectStr = "drillbit=localhost:31090"; + channelType_t type; + + bool isSSL = argc==2 && !(strcmp(argv[1], "ssl")); + type = CHANNEL_TYPE_SOCKET; + if(isSSL){ + type = CHANNEL_TYPE_SSLSTREAM; + } + Drill::DrillUserProperties props; + props.setProperty(USERPROP_USERNAME, "admin"); + props.setProperty(USERPROP_PASSWORD, "admin"); + props.setProperty(USERPROP_CERTFILEPATH, "../../../test/ssl/drillTestCert.pem"); + + pChannelContext = ChannelContextFactory::getChannelContext(type, &props); + + pChannel = ChannelFactory::getChannel(type, connectStr.c_str()); + if(pChannel != NULL){ + connectionStatus_t connStat; + connStat = pChannel->init(pChannelContext); + if(connStat != CONN_SUCCESS){ + std::cout << "Init Failed." << std::endl; + return -1; + } + connStat = pChannel->connect(); + if(connStat != CONN_SUCCESS){ + std::cout << "Connect Failed." << std::endl; + std::cout << pChannel->getError()->msg << std::endl; + return -1; + } + } else{ + std::cout << "Channel creation failed." << std::endl; + return -1; + } + std::cout << "Connected." << std::endl; + std::cout << "Starting Drill handshake" << std::endl; + + + DrillTestClient client(pChannel); + + connectionStatus_t stat = client.validateHandshake(&props); + if(stat == CONN_SUCCESS){ + std::cout << "Handshake validated." << std::endl; + } else{ + if(client.m_pError != NULL){ + std::cout << "Handshake failed: " << client.m_pError->msg << ". " << std::endl; + } else{ + std::cout << "Handshake failed with unknown error" << ". " << std::endl; + } + } + + return 0; + +} + |