aboutsummaryrefslogtreecommitdiff
path: root/contrib/native/client
diff options
context:
space:
mode:
authorParth Chandra <parthc@apache.org>2017-07-25 09:22:23 -0700
committerParth Chandra <parthc@apache.org>2017-10-11 19:27:48 -0700
commitfacbb92ba319373dd8b8baa171ac1d7978c926c5 (patch)
tree95b2f1b102047d4244d9f9a302779e1d11a81a9a /contrib/native/client
parent414c94ee7e088c89b0c7c8b9d9dde9335cfcbe6d (diff)
DRILL-5431: SSL Support (C++) - Add (Netty like) socket abstraction that encapsulates a TCP socket or a SSL Stream on TCP.
The testSSL program tests the client connection against a drillbit by sending a drill handshake.
Diffstat (limited to 'contrib/native/client')
-rw-r--r--contrib/native/client/CMakeLists.txt13
-rw-r--r--contrib/native/client/src/clientlib/CMakeLists.txt4
-rw-r--r--contrib/native/client/src/clientlib/channel.cpp448
-rw-r--r--contrib/native/client/src/clientlib/channel.hpp237
-rw-r--r--contrib/native/client/src/clientlib/errmsgs.cpp3
-rw-r--r--contrib/native/client/src/clientlib/errmsgs.hpp5
-rw-r--r--contrib/native/client/src/clientlib/logger.hpp1
-rw-r--r--contrib/native/client/src/clientlib/streamSocket.hpp218
-rw-r--r--contrib/native/client/src/clientlib/userProperties.cpp4
-rw-r--r--contrib/native/client/src/clientlib/wincert.ipp98
-rw-r--r--contrib/native/client/src/include/drill/common.hpp19
-rw-r--r--contrib/native/client/test/ssl/testSSL.cpp384
12 files changed, 1428 insertions, 6 deletions
diff --git a/contrib/native/client/CMakeLists.txt b/contrib/native/client/CMakeLists.txt
index ddb151917..0c104abc0 100644
--- a/contrib/native/client/CMakeLists.txt
+++ b/contrib/native/client/CMakeLists.txt
@@ -17,6 +17,7 @@
#
cmake_minimum_required(VERSION 2.6)
+project(drillclient)
cmake_policy(SET CMP0043 NEW)
cmake_policy(SET CMP0048 NEW)
enable_testing()
@@ -125,6 +126,12 @@ endif()
find_package(Protobuf REQUIRED )
include_directories(${PROTOBUF_INCLUDE_DIR})
+#Find SSL
+find_package(OpenSSL REQUIRED )
+if(OPENSSL_FOUND)
+ add_definitions("-DIS_SSL_ENABLED=1")
+endif()
+
#Find Zookeeper
find_package(Zookeeper REQUIRED )
@@ -170,6 +177,12 @@ set_property(
# Link directory
link_directories(/usr/local/lib)
+#test programs
+add_subdirectory("${CMAKE_SOURCE_DIR}/test")
+message("Open SSL Include = ${OPENSSL_INCLUDE_DIR}")
+message("Open SSL Libraries = ${OPENSSL_LIBRARIES}")
+message("Open SSL = ${OPENSSL_ROOT_DIR}")
+
add_executable(querySubmitter example/querySubmitter.cpp )
target_link_libraries(querySubmitter ${Boost_LIBRARIES} ${PROTOBUF_LIBRARY} drillClient protomsgs y2038)
diff --git a/contrib/native/client/src/clientlib/CMakeLists.txt b/contrib/native/client/src/clientlib/CMakeLists.txt
index 6124fc898..2270c91ee 100644
--- a/contrib/native/client/src/clientlib/CMakeLists.txt
+++ b/contrib/native/client/src/clientlib/CMakeLists.txt
@@ -19,6 +19,7 @@
# Drill Client library
set (CLIENTLIB_SRC_FILES
+ ${CMAKE_CURRENT_SOURCE_DIR}/channel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/decimalUtils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/drillConfig.cpp
${CMAKE_CURRENT_SOURCE_DIR}/drillClient.cpp
@@ -39,6 +40,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../i
include_directories(${PROTOBUF_INCLUDE_DIR})
include_directories(${Zookeeper_INCLUDE_DIRS})
include_directories(${SASL_INCLUDE_DIRS})
+include_directories("${OPENSSL_INCLUDE_DIR}")
link_directories(/usr/local/lib)
@@ -53,4 +55,4 @@ if(MSVC)
endif()
add_library(drillClient SHARED ${CLIENTLIB_SRC_FILES} )
-target_link_libraries(drillClient ${Boost_LIBRARIES} ${PROTOBUF_LIBRARY} ${Zookeeper_LIBRARIES} ${SASL_LIBRARIES} protomsgs y2038)
+target_link_libraries(drillClient ${Boost_LIBRARIES} ${PROTOBUF_LIBRARY} ${Zookeeper_LIBRARIES} ${SASL_LIBRARIES} ${OPENSSL_LIBRARIES} protomsgs y2038)
diff --git a/contrib/native/client/src/clientlib/channel.cpp b/contrib/native/client/src/clientlib/channel.cpp
new file mode 100644
index 000000000..62ce976b1
--- /dev/null
+++ b/contrib/native/client/src/clientlib/channel.cpp
@@ -0,0 +1,448 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <boost/lexical_cast.hpp>
+#include <boost/regex.hpp>
+#include "drill/drillConfig.hpp"
+#include "drill/drillError.hpp"
+#include "drill/userProperties.hpp"
+#include "channel.hpp"
+#include "errmsgs.hpp"
+#include "logger.hpp"
+#include "utils.hpp"
+#include "zookeeperClient.hpp"
+
+#include "GeneralRPC.pb.h"
+
+namespace Drill{
+
+ConnectionEndpoint::ConnectionEndpoint(const char* connStr){
+ m_connectString=connStr;
+ m_pError=NULL;
+}
+
+ConnectionEndpoint::ConnectionEndpoint(const char* host, const char* port){
+ m_host=host;
+ m_port=port;
+ m_protocol="drillbit"; // direct connection
+ m_pError=NULL;
+}
+
+ConnectionEndpoint::~ConnectionEndpoint(){
+ if(m_pError!=NULL){
+ delete m_pError; m_pError=NULL;
+ }
+}
+
+connectionStatus_t ConnectionEndpoint::getDrillbitEndpoint(){
+ connectionStatus_t ret=CONN_SUCCESS;
+ if(!m_connectString.empty()){
+ parseConnectString();
+ if(m_protocol.empty()){
+ return handleError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, "<invalid_string>"));
+ }
+ if(isZookeeperConnection()){
+ if((ret=getDrillbitEndpointFromZk())!=CONN_SUCCESS){
+ DRILL_LOG(LOG_INFO) << "Failed to get endpoint from zk" << std::endl;
+ return ret;
+ }
+ }else if(!this->isDirectConnection()){
+ return handleError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, this->getProtocol().c_str()));
+ }
+ }else{
+ if(m_host.empty() || m_port.empty()){
+ return handleError(CONN_INVALID_INPUT, getMessage(ERR_CONN_NOCONNSTR));
+ }
+ }
+ return ret;
+}
+
+void ConnectionEndpoint::parseConnectString(){
+ boost::regex connStrExpr("(.*)=(((.*):([0-9]+),?)+)(/.+)?");
+ boost::cmatch matched;
+
+ if(boost::regex_match(m_connectString.c_str(), matched, connStrExpr)){
+ m_protocol.assign(matched[1].first, matched[1].second);
+ if(isDirectConnection()){
+ m_host.assign(matched[4].first, matched[4].second);
+ m_port.assign(matched[5].first, matched[5].second);
+ }else {
+ // if the connection is to a zookeeper,
+ // we will get the host and the port only after connecting to the Zookeeper
+ m_host = "";
+ m_port = "";
+ }
+ m_hostPortStr.assign(matched[2].first, matched[2].second);
+ if(matched[6].matched) {
+ m_pathToDrill.assign(matched[6].first, matched[6].second);
+ }
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG)
+ << "Conn str: "<< m_connectString
+ << "; protocol: " << m_protocol
+ << "; host: " << m_host
+ << "; port: " << m_port
+ << "; path to drill: " << m_pathToDrill
+ << std::endl;)
+ } else {
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Invalid connect string. Regexp did not match" << std::endl;)
+ }
+ return;
+}
+
+bool ConnectionEndpoint::isDirectConnection(){
+ assert(!m_protocol.empty());
+ return (!strcmp(m_protocol.c_str(), "local") || !strcmp(m_protocol.c_str(), "drillbit"));
+}
+
+bool ConnectionEndpoint::isZookeeperConnection(){
+ assert(!m_protocol.empty());
+ return (!strcmp(m_protocol.c_str(), "zk"));
+}
+
+connectionStatus_t ConnectionEndpoint::getDrillbitEndpointFromZk(){
+ ZookeeperClient zook(m_pathToDrill);
+ assert(!m_hostPortStr.empty());
+ std::vector<std::string> drillbits;
+ if(zook.getAllDrillbits(m_hostPortStr.c_str(), drillbits)!=0){
+ return handleError(CONN_ZOOKEEPER_ERROR, getMessage(ERR_CONN_ZOOKEEPER, zook.getError().c_str()));
+ }
+ if (drillbits.empty()){
+ return handleError(CONN_FAILURE, getMessage(ERR_CONN_ZKNODBIT));
+ }
+ Utils::shuffle(drillbits);
+ exec::DrillbitEndpoint endpoint;
+ int err = zook.getEndPoint(drillbits[drillbits.size() -1], endpoint);// get the last one in the list
+ if(!err){
+ m_host=boost::lexical_cast<std::string>(endpoint.address());
+ m_port=boost::lexical_cast<std::string>(endpoint.user_port());
+ }
+ if(err){
+ return handleError(CONN_ZOOKEEPER_ERROR, getMessage(ERR_CONN_ZOOKEEPER, zook.getError().c_str()));
+ }
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Choosing drillbit <" << (drillbits.size() - 1)
+ << ">. Selected " << endpoint.DebugString() << std::endl;)
+ zook.close();
+ return CONN_SUCCESS;
+}
+
+connectionStatus_t ConnectionEndpoint::handleError(connectionStatus_t status, std::string msg){
+ DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg);
+ if(m_pError!=NULL){ delete m_pError; m_pError=NULL;}
+ m_pError=pErr;
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Error getting drillbit endpoint:" << pErr->msg << std::endl;)
+ return status;
+}
+
+/****************************
+ * Channel Context Factory
+ ****************************/
+ChannelContext* ChannelContextFactory::getChannelContext(channelType_t t, DrillUserProperties* props){
+ ChannelContext* pChannelContext=NULL;
+ switch(t){
+ case CHANNEL_TYPE_SOCKET:
+ pChannelContext=new ChannelContext(props);
+ break;
+#if defined(IS_SSL_ENABLED)
+ case CHANNEL_TYPE_SSLSTREAM: {
+
+ std::string protocol;
+ props->getProp(USERPROP_TLSPROTOCOL, protocol);
+ boost::asio::ssl::context::method tlsVersion = SSLChannelContext::getTlsVersion(protocol);
+
+ std::string noVerifyCert;
+ props->getProp(USERPROP_DISABLE_CERTVERIFICATION, noVerifyCert);
+ boost::asio::ssl::context::verify_mode verifyMode = boost::asio::ssl::context::verify_peer;
+ if (noVerifyCert == "true") {
+ verifyMode = boost::asio::ssl::context::verify_none;
+ }
+
+ pChannelContext = new SSLChannelContext(props, tlsVersion, verifyMode);
+ }
+ break;
+#endif
+ default:
+ DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl;
+ break;
+ }
+ return pChannelContext;
+}
+
+/*******************
+ * ChannelFactory
+ * *****************/
+Channel* ChannelFactory::getChannel(channelType_t t, const char* connStr){
+ Channel* pChannel=NULL;
+ switch(t){
+ case CHANNEL_TYPE_SOCKET:
+ pChannel=new SocketChannel(connStr);
+ break;
+#if defined(IS_SSL_ENABLED)
+ case CHANNEL_TYPE_SSLSTREAM:
+ pChannel=new SSLStreamChannel(connStr);
+ break;
+#endif
+ default:
+ DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl;
+ break;
+ }
+ return pChannel;
+}
+
+Channel* ChannelFactory::getChannel(channelType_t t, const char* host, const char* port){
+ Channel* pChannel=NULL;
+ switch(t){
+ case CHANNEL_TYPE_SOCKET:
+ pChannel=new SocketChannel(host, port);
+ break;
+#if defined(IS_SSL_ENABLED)
+ case CHANNEL_TYPE_SSLSTREAM:
+ pChannel=new SSLStreamChannel(host, port);
+ break;
+#endif
+ default:
+ DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl;
+ break;
+ }
+ return pChannel;
+}
+
+Channel* ChannelFactory::getChannel(channelType_t t, boost::asio::io_service& ioService, const char* connStr){
+ Channel* pChannel=NULL;
+ switch(t){
+ case CHANNEL_TYPE_SOCKET:
+ pChannel=new SocketChannel(ioService, connStr);
+ break;
+#if defined(IS_SSL_ENABLED)
+ case CHANNEL_TYPE_SSLSTREAM:
+ pChannel=new SSLStreamChannel(ioService, connStr);
+ break;
+#endif
+ default:
+ DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl;
+ break;
+ }
+ return pChannel;
+}
+
+Channel* ChannelFactory::getChannel(channelType_t t, boost::asio::io_service& ioService, const char* host, const char* port){
+ Channel* pChannel=NULL;
+ switch(t){
+ case CHANNEL_TYPE_SOCKET:
+ pChannel=new SocketChannel(ioService, host, port);
+ break;
+#if defined(IS_SSL_ENABLED)
+ case CHANNEL_TYPE_SSLSTREAM:
+ pChannel=new SSLStreamChannel(ioService, host, port);
+ break;
+#endif
+ default:
+ DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not supported." << std::endl;
+ break;
+ }
+ return pChannel;
+}
+
+/*******************
+ * Channel
+ * *****************/
+
+Channel::Channel(const char* connStr) : m_ioService(m_ioServiceFallback){
+ m_pEndpoint=new ConnectionEndpoint(connStr);
+ m_ownIoService = true;
+ m_pSocket=NULL;
+ m_state=CHANNEL_UNINITIALIZED;
+ m_pError=NULL;
+}
+
+Channel::Channel(const char* host, const char* port) : m_ioService(m_ioServiceFallback){
+ m_pEndpoint=new ConnectionEndpoint(host, port);
+ m_ownIoService = true;
+ m_pSocket=NULL;
+ m_state=CHANNEL_UNINITIALIZED;
+ m_pError=NULL;
+}
+
+Channel::Channel(boost::asio::io_service& ioService, const char* connStr):m_ioService(ioService){
+ m_pEndpoint=new ConnectionEndpoint(connStr);
+ m_ownIoService = false;
+ m_pSocket=NULL;
+ m_state=CHANNEL_UNINITIALIZED;
+ m_pError=NULL;
+}
+
+Channel::Channel(boost::asio::io_service& ioService, const char* host, const char* port) : m_ioService(ioService){
+ m_pEndpoint=new ConnectionEndpoint(host, port);
+ m_ownIoService = true;
+ m_pSocket=NULL;
+ m_state=CHANNEL_UNINITIALIZED;
+ m_pError=NULL;
+}
+
+Channel::~Channel(){
+ if(m_pEndpoint!=NULL){
+ delete m_pEndpoint; m_pEndpoint=NULL;
+ }
+ if(m_pSocket!=NULL){
+ delete m_pSocket; m_pSocket=NULL;
+ }
+ if(m_pError!=NULL){
+ delete m_pError; m_pError=NULL;
+ }
+}
+
+template <typename SettableSocketOption> void Channel::setOption(SettableSocketOption& option){
+ //May be useful some day.
+ //At the moment, we only need to set some well known options after we connect.
+ assert(0);
+}
+
+connectionStatus_t Channel::init(ChannelContext_t* pContext){
+ connectionStatus_t ret=CONN_SUCCESS;
+ this->m_state=CHANNEL_INITIALIZED;
+ this->m_pContext = pContext;
+ return ret;
+}
+
+connectionStatus_t Channel::connect(){
+ connectionStatus_t ret=CONN_FAILURE;
+ if(this->m_state==CHANNEL_INITIALIZED){
+ ret=m_pEndpoint->getDrillbitEndpoint();
+ if(ret==CONN_SUCCESS){
+ DRILL_LOG(LOG_TRACE) << "Connecting to drillbit: "
+ << m_pEndpoint->getHost()
+ << ":" << m_pEndpoint->getPort()
+ << "." << std::endl;
+ ret=this->connectInternal();
+ }else{
+ handleError(ret, m_pEndpoint->getError()->msg);
+ }
+ }
+ this->m_state=(ret==CONN_SUCCESS)?CHANNEL_CONNECTED:this->m_state;
+ return ret;
+}
+
+connectionStatus_t Channel::handleError(connectionStatus_t status, std::string msg){
+ DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg);
+ if(m_pError!=NULL){ delete m_pError; m_pError=NULL;}
+ m_pError=pErr;
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Error connecting:" << pErr->msg << std::endl;)
+ return status;
+}
+
+connectionStatus_t Channel::connectInternal() {
+ using boost::asio::ip::tcp;
+ tcp::endpoint endpoint;
+ const char *host = m_pEndpoint->getHost().c_str();
+ const char *port = m_pEndpoint->getPort().c_str();
+ try {
+ tcp::resolver resolver(m_ioService);
+ tcp::resolver::query query(tcp::v4(), host, port);
+ tcp::resolver::iterator iter = resolver.resolve(query);
+ tcp::resolver::iterator end;
+ while (iter != end) {
+ endpoint = *iter++;
+ DRILL_LOG(LOG_TRACE) << endpoint << std::endl;
+ }
+ boost::system::error_code ec;
+ m_pSocket->getInnerSocket().connect(endpoint, ec);
+ if (ec) {
+ return handleError(CONN_FAILURE, getMessage(ERR_CONN_FAILURE, host, port, ec.message().c_str()));
+ }
+ } catch (boost::system::system_error e) {
+ // Handle case when the hostname cannot be resolved. "resolve" is hard-coded in boost asio resolver.resolve
+ if (!strncmp(e.what(), "resolve", 7)) {
+ return handleError(CONN_HOSTNAME_RESOLUTION_ERROR, getMessage(ERR_CONN_EXCEPT, e.what()));
+ }
+ } catch (std::exception e) {
+ return handleError(CONN_FAILURE, getMessage(ERR_CONN_EXCEPT, e.what()));
+ }
+
+ // set socket keep alive
+ boost::asio::socket_base::keep_alive keepAlive(true);
+ m_pSocket->getInnerSocket().set_option(keepAlive);
+ // set no_delay
+ boost::asio::ip::tcp::no_delay noDelay(true);
+ m_pSocket->getInnerSocket().set_option(noDelay);
+ // set reuse addr
+ boost::asio::socket_base::reuse_address reuseAddr(true);
+ m_pSocket->getInnerSocket().set_option(reuseAddr);
+
+ std::string useSystemTrustStore;
+ m_pContext->getUserProperties()->getProp(USERPROP_USESYSTEMTRUSTSTORE, useSystemTrustStore);
+ DRILL_LOG(LOG_TRACE) << "Connected" << std::endl;
+ return this->protocolHandshake(useSystemTrustStore=="true");
+
+}
+
+connectionStatus_t SocketChannel::init(ChannelContext_t* pContext){
+ connectionStatus_t ret=CONN_SUCCESS;
+ m_pSocket=new Socket(m_ioService);
+ if(m_pSocket!=NULL){
+ ret=Channel::init(pContext);
+ }else{
+ DRILL_LOG(LOG_ERROR) << "Channel initialization failure. " << std::endl;
+ handleError(CONN_NOSOCKET, getMessage(ERR_CONN_NOSOCKET));
+ ret=CONN_FAILURE;
+ }
+ return ret;
+}
+
+#if defined(IS_SSL_ENABLED)
+connectionStatus_t SSLStreamChannel::init(ChannelContext_t* pContext){
+ connectionStatus_t ret=CONN_SUCCESS;
+
+ const DrillUserProperties* props = pContext->getUserProperties();
+ std::string useSystemTrustStore;
+ props->getProp(USERPROP_USESYSTEMTRUSTSTORE, useSystemTrustStore);
+ if (useSystemTrustStore != "true"){
+ std::string certFile;
+ props->getProp(USERPROP_CERTFILEPATH, certFile);
+ try{
+ ((SSLChannelContext_t*)pContext)->getSslContext().load_verify_file(certFile);
+ }
+ catch (boost::system::system_error e){
+ DRILL_LOG(LOG_ERROR) << "Channel initialization failure. Certificate file "
+ << certFile
+ << " could not be loaded."
+ << std::endl;
+ handleError(CONN_SSLERROR, getMessage(ERR_CONN_SSLCERTFAIL, certFile.c_str(), e.what()));
+ ret = CONN_FAILURE;
+ }
+ }
+
+ std::string disableHostVerification;
+ props->getProp(USERPROP_DISABLE_HOSTVERIFICATION, disableHostVerification);
+ if (disableHostVerification != "true") {
+ std::string hostPortStr = m_pEndpoint->getHost() + ":" + m_pEndpoint->getPort();
+ ((SSLChannelContext_t *) pContext)->getSslContext().set_verify_callback(
+ boost::asio::ssl::rfc2818_verification(hostPortStr.c_str()));
+ }
+
+ m_pSocket=new SslSocket(m_ioService, ((SSLChannelContext_t*)pContext)->getSslContext() );
+ if(m_pSocket!=NULL){
+ ret=Channel::init(pContext);
+ }else{
+ DRILL_LOG(LOG_ERROR) << "Channel initialization failure. " << std::endl;
+ handleError(CONN_NOSOCKET, getMessage(ERR_CONN_NOSOCKET));
+ ret=CONN_FAILURE;
+ }
+ return ret;
+}
+#endif
+
+} // namespace Drill
diff --git a/contrib/native/client/src/clientlib/channel.hpp b/contrib/native/client/src/clientlib/channel.hpp
new file mode 100644
index 000000000..7f310e899
--- /dev/null
+++ b/contrib/native/client/src/clientlib/channel.hpp
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CHANNEL_HPP
+#define CHANNEL_HPP
+
+#include "drill/common.hpp"
+#include "drill/drillClient.hpp"
+#include "streamSocket.hpp"
+
+namespace Drill {
+
+class UserProperties;
+
+ class ConnectionEndpoint{
+ public:
+ ConnectionEndpoint(const char* connStr);
+ ConnectionEndpoint(const char* host, const char* port);
+ ~ConnectionEndpoint();
+
+ //parse the connection string and set up the host and port to connect to
+ connectionStatus_t getDrillbitEndpoint();
+
+ std::string& getProtocol(){return m_protocol;}
+ std::string& getHost(){return m_host;}
+ std::string& getPort(){return m_port;}
+ DrillClientError* getError(){ return m_pError;};
+
+ private:
+ void parseConnectString();
+ bool isDirectConnection();
+ bool isZookeeperConnection();
+ connectionStatus_t getDrillbitEndpointFromZk();
+ connectionStatus_t handleError(connectionStatus_t status, std::string msg);
+
+ std::string m_connectString;
+ std::string m_pathToDrill;
+ std::string m_protocol;
+ std::string m_hostPortStr;
+ std::string m_host;
+ std::string m_port;
+
+ DrillClientError* m_pError;
+
+ };
+
+ class ChannelContext{
+ public:
+ ChannelContext(DrillUserProperties* props):m_properties(props){};
+ virtual ~ChannelContext(){};
+ const DrillUserProperties* getUserProperties() const { return m_properties;}
+ protected:
+ DrillUserProperties* m_properties;
+ };
+
+ class SSLChannelContext: public ChannelContext{
+ public:
+ static boost::asio::ssl::context::method getTlsVersion(std::string version){
+ if(version.empty()){
+ return boost::asio::ssl::context::tlsv12;
+ } else if (version == "tlsv12") {
+ return boost::asio::ssl::context::tlsv12;
+ } else if (version == "tlsv11") {
+ return boost::asio::ssl::context::tlsv11;
+ } else if (version == "sslv23") {
+ return boost::asio::ssl::context::sslv23;
+ } else if (version == "tlsv1") {
+ return boost::asio::ssl::context::tlsv1;
+ } else if (version == "sslv3") {
+ return boost::asio::ssl::context::sslv3;
+ } else {
+ return boost::asio::ssl::context::tlsv12;
+ }
+ }
+
+ SSLChannelContext(DrillUserProperties *props, boost::asio::ssl::context::method tlsVersion, boost::asio::ssl::verify_mode verifyMode) :
+ ChannelContext(props),
+ m_SSLContext(tlsVersion) {
+ m_SSLContext.set_default_verify_paths();
+ m_SSLContext.set_options(
+ boost::asio::ssl::context::default_workarounds
+ | boost::asio::ssl::context::no_sslv2
+ | boost::asio::ssl::context::single_dh_use
+ );
+ m_SSLContext.set_verify_mode(verifyMode);
+ };
+ ~SSLChannelContext(){};
+ boost::asio::ssl::context& getSslContext(){ return m_SSLContext;}
+ private:
+ boost::asio::ssl::context m_SSLContext;
+ };
+
+ typedef ChannelContext ChannelContext_t;
+ typedef SSLChannelContext SSLChannelContext_t;
+
+ class ChannelContextFactory{
+ public:
+ static ChannelContext_t* getChannelContext(channelType_t t, DrillUserProperties* props);
+ };
+
+ /***
+ * The Channel class encapsulates a connection to a drillbit. Based on
+ * the connection string and the options, the connection will be either
+ * a simple socket or a socket using an ssl stream. The class also encapsulates
+ * connecting to a drillbit directly or thru zookeeper.
+ * The channel class owns the socket and the io_service that the applications
+ * will use to communicate with the server.
+ ***/
+ class Channel{
+ public:
+ Channel(const char* connStr);
+ Channel(const char* host, const char* port);
+ Channel(boost::asio::io_service& ioService, const char* connStr);
+ Channel(boost::asio::io_service& ioService, const char* host, const char* port);
+ virtual ~Channel();
+ virtual connectionStatus_t init(ChannelContext_t* context)=0;
+ connectionStatus_t connect();
+ bool isConnected(){ return m_state == CHANNEL_CONNECTED;}
+ template <typename SettableSocketOption> void setOption(SettableSocketOption& option);
+ DrillClientError* getError(){ return m_pError;}
+ void close(){
+ if(m_state==CHANNEL_INITIALIZED||m_state==CHANNEL_CONNECTED){
+ m_pSocket->protocolClose();
+ m_state=CHANNEL_CLOSED;
+ }
+ } // Not OK to use the channel after this call.
+
+ boost::asio::io_service& getIOService(){
+ return m_ioService;
+ }
+
+ // returns a reference to the underlying socket
+ // This access should really be removed and encapsulated in calls that
+ // manage async_send and async_recv
+ // Until then we will let DrillClientImpl have direct access
+ streamSocket_t& getInnerSocket(){
+ return m_pSocket->getInnerSocket();
+ }
+
+ AsioStreamSocket& getSocketStream(){
+ return *m_pSocket;
+ }
+
+ ConnectionEndpoint* getEndpoint(){return m_pEndpoint;}
+
+ protected:
+ connectionStatus_t handleError(connectionStatus_t status, std::string msg);
+
+ boost::asio::io_service& m_ioService;
+ boost::asio::io_service m_ioServiceFallback; // used if m_ioService is not provided
+ AsioStreamSocket* m_pSocket;
+ ConnectionEndpoint *m_pEndpoint;
+ ChannelContext_t *m_pContext;
+
+ private:
+
+ typedef enum channelState{
+ CHANNEL_UNINITIALIZED=1,
+ CHANNEL_INITIALIZED,
+ CHANNEL_CONNECTED,
+ CHANNEL_CLOSED
+ } channelState_t;
+
+ connectionStatus_t connectInternal();
+ connectionStatus_t protocolHandshake(bool useSystemConfig){
+ connectionStatus_t status = CONN_SUCCESS;
+ try{
+ m_pSocket->protocolHandshake(useSystemConfig);
+ } catch (boost::system::system_error e) {
+ status = handleError(CONN_HANDSHAKE_FAILED, e.what());
+ }
+ return status;
+ }
+
+ channelState_t m_state;
+ DrillClientError* m_pError;
+ bool m_ownIoService;
+ };
+
+ class SocketChannel: public Channel{
+ public:
+ SocketChannel(const char* connStr):Channel(connStr){
+ }
+ SocketChannel(const char* host, const char* port):Channel(host, port){
+ }
+ SocketChannel(boost::asio::io_service& ioService, const char* connStr)
+ :Channel(ioService, connStr){
+ }
+ SocketChannel(boost::asio::io_service& ioService, const char* host, const char* port)
+ :Channel(ioService, host, port){
+ }
+ connectionStatus_t init(ChannelContext_t* context=NULL);
+ };
+
+ class SSLStreamChannel: public Channel{
+ public:
+ SSLStreamChannel(const char* connStr):Channel(connStr){
+ }
+ SSLStreamChannel(const char* host, const char* port):Channel(host, port){
+ }
+ SSLStreamChannel(boost::asio::io_service& ioService, const char* connStr)
+ :Channel(ioService, connStr){
+ }
+ SSLStreamChannel(boost::asio::io_service& ioService, const char* host, const char* port)
+ :Channel(ioService, host, port){
+ }
+ connectionStatus_t init(ChannelContext_t* context);
+ };
+
+ class ChannelFactory{
+ public:
+ static Channel* getChannel(channelType_t t, const char* connStr);
+ static Channel* getChannel(channelType_t t, const char* host, const char* port);
+ static Channel* getChannel(channelType_t t, boost::asio::io_service& ioService, const char* connStr);
+ static Channel* getChannel(channelType_t t, boost::asio::io_service& ioService, const char* host, const char* port);
+ };
+
+
+} // namespace Drill
+
+#endif // CHANNEL_HPP
+
diff --git a/contrib/native/client/src/clientlib/errmsgs.cpp b/contrib/native/client/src/clientlib/errmsgs.cpp
index 15d2256fb..d2d8c1806 100644
--- a/contrib/native/client/src/clientlib/errmsgs.cpp
+++ b/contrib/native/client/src/clientlib/errmsgs.cpp
@@ -49,6 +49,9 @@ static Drill::ErrorMessages errorMessages[]={
{ERR_CONN_UNKNOWN_ERR, ERR_CATEGORY_CONN, 0, "Handshake Failed due to an error on the server. [Server message was: (%s) %s]"},
{ERR_CONN_NOCONN, ERR_CATEGORY_CONN, 0, "There is no connection to the server."},
{ERR_CONN_ALREADYCONN, ERR_CATEGORY_CONN, 0, "This client is already connected to a server."},
+ {ERR_CONN_NOCONNSTR, ERR_CATEGORY_CONN, 0, "Cannot connect if either host name or port number are empty."},
+ {ERR_CONN_SSLCERTFAIL, ERR_CATEGORY_CONN, 0, "SSL certificate file %s could not be loaded (exception message: %s)."},
+ {ERR_CONN_NOSOCKET, ERR_CATEGORY_CONN, 0, "Failed to open socket connection."},
{ERR_QRY_OUTOFMEM, ERR_CATEGORY_QRY, 0, "Out of memory."},
{ERR_QRY_COMMERR, ERR_CATEGORY_QRY, 0, "Communication error. %s"},
{ERR_QRY_INVREADLEN, ERR_CATEGORY_QRY, 0, "Internal Error: Received a message with an invalid read length."},
diff --git a/contrib/native/client/src/clientlib/errmsgs.hpp b/contrib/native/client/src/clientlib/errmsgs.hpp
index cfb56a6b0..246e4bbf2 100644
--- a/contrib/native/client/src/clientlib/errmsgs.hpp
+++ b/contrib/native/client/src/clientlib/errmsgs.hpp
@@ -51,7 +51,10 @@ namespace Drill{
#define ERR_CONN_UNKNOWN_ERR DRILL_ERR_START+18
#define ERR_CONN_NOCONN DRILL_ERR_START+19
#define ERR_CONN_ALREADYCONN DRILL_ERR_START+20
-#define ERR_CONN_MAX DRILL_ERR_START+20
+#define ERR_CONN_NOCONNSTR DRILL_ERR_START+21
+#define ERR_CONN_SSLCERTFAIL DRILL_ERR_START+22
+#define ERR_CONN_NOSOCKET DRILL_ERR_START+23
+#define ERR_CONN_MAX DRILL_ERR_START+23
#define ERR_QRY_OUTOFMEM ERR_CONN_MAX+1
#define ERR_QRY_COMMERR ERR_CONN_MAX+2
diff --git a/contrib/native/client/src/clientlib/logger.hpp b/contrib/native/client/src/clientlib/logger.hpp
index 7baf50c41..966e3a1d3 100644
--- a/contrib/native/client/src/clientlib/logger.hpp
+++ b/contrib/native/client/src/clientlib/logger.hpp
@@ -21,6 +21,7 @@
#include <sstream>
#include <ostream>
+#include <iostream>
#include <fstream>
#include <string>
#include <stdio.h>
diff --git a/contrib/native/client/src/clientlib/streamSocket.hpp b/contrib/native/client/src/clientlib/streamSocket.hpp
new file mode 100644
index 000000000..5db4dcaa9
--- /dev/null
+++ b/contrib/native/client/src/clientlib/streamSocket.hpp
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+#ifndef STREAMSOCKET_HPP
+#define STREAMSOCKET_HPP
+
+#include "logger.hpp"
+#include "wincert.ipp"
+
+#include <boost/asio.hpp>
+#include <boost/asio/ssl.hpp>
+
+namespace Drill {
+
+typedef boost::asio::ip::tcp::socket::lowest_layer_type streamSocket_t;
+typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket> sslTCPSocket_t;
+typedef boost::asio::ip::tcp::socket basicTCPSocket_t;
+
+
+// Some helper typedefs to define the highly templatized boost::asio methods
+typedef boost::asio::const_buffers_1 ConstBufferSequence;
+typedef boost::asio::mutable_buffers_1 MutableBufferSequence;
+
+// ReadHandlers have different possible signatures.
+//
+// As a standard C-type callback
+// typedef void (*ReadHandler)(const boost::system::error_code& ec, std::size_t bytes_transferred);
+//
+// Or as a C++ functor
+// struct ReadHandler {
+// virtual void operator()(
+// const boost::system::error_code& ec,
+// std::size_t bytes_transferred) = 0;
+//};
+//
+// We need a different signature though, since we need to pass in a member function of a drill client
+// class (which is C++), as a functor generated by boost::bind as a ReadHandler
+//
+typedef boost::function<void (const boost::system::error_code& ec, std::size_t bytes_transferred) > ReadHandler;
+
+class AsioStreamSocket{
+ public:
+ virtual ~AsioStreamSocket(){};
+ virtual streamSocket_t& getInnerSocket() = 0;
+
+ virtual std::size_t writeSome(
+ const ConstBufferSequence& buffers,
+ boost::system::error_code & ec) = 0;
+
+ virtual std::size_t readSome(
+ const MutableBufferSequence& buffers,
+ boost::system::error_code & ec) = 0;
+
+ //
+ // boost::asio::async_read has the signature
+ // template<
+ // typename AsyncReadStream,
+ // typename MutableBufferSequence,
+ // typename ReadHandler>
+ // void-or-deduced async_read(
+ // AsyncReadStream & s,
+ // const MutableBufferSequence & buffers,
+ // ReadHandler handler);
+ //
+ // For our use case, the derived class will have an instance of a concrete type for AsyncReadStream which
+ // will implement the requirements for the AsyncReadStream type. We need not pass that in as a parameter
+ // since the class already has the value
+ // The method is templatized since the ReadHandler type is dependent on the class implementing the read
+ // handler (basically the class using the asio stream)
+ //
+ virtual void asyncRead( const MutableBufferSequence & buffers, ReadHandler handler) = 0;
+
+ // call the underlying protocol's handshake method.
+ // if the useSystemConfig flag is true, then use properties read
+ // from the underlying operating system
+ virtual void protocolHandshake(bool useSystemConfig) = 0;
+ virtual void protocolClose() = 0;
+};
+
+class Socket:
+ public AsioStreamSocket,
+ public basicTCPSocket_t{
+
+ public:
+ Socket(boost::asio::io_service& ioService) : basicTCPSocket_t(ioService) {
+ }
+
+ ~Socket(){
+ boost::system::error_code ignorederr;
+ this->shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignorederr);
+ this->close();
+ };
+
+ basicTCPSocket_t& getSocketStream(){ return *this;}
+
+ streamSocket_t& getInnerSocket(){ return this->lowest_layer();}
+
+ std::size_t writeSome(
+ const ConstBufferSequence& buffers,
+ boost::system::error_code & ec){
+ return this->write_some(buffers, ec);
+ }
+
+ std::size_t readSome(
+ const MutableBufferSequence& buffers,
+ boost::system::error_code & ec){
+ return this->read_some(buffers, ec);
+ }
+
+ void asyncRead( const MutableBufferSequence & buffers, ReadHandler handler){
+ return async_read(*this, buffers, handler);
+ }
+
+ void protocolHandshake(bool useSystemConfig){}; //nothing to do
+ void protocolClose(){
+ // shuts down the socket!
+ boost::system::error_code ignorederr;
+ ((basicTCPSocket_t*)this)->shutdown(boost::asio::ip::tcp::socket::shutdown_both,
+ ignorederr
+ );
+ }
+};
+
+
+#if defined(IS_SSL_ENABLED)
+
+class SslSocket:
+ public AsioStreamSocket,
+ public sslTCPSocket_t{
+
+ public:
+ SslSocket(boost::asio::io_service& ioService, boost::asio::ssl::context &sslContext) :
+ sslTCPSocket_t(ioService, sslContext) {
+ }
+
+ ~SslSocket(){
+ this->lowest_layer().close();
+ };
+
+ sslTCPSocket_t& getSocketStream(){ return *this;}
+
+ streamSocket_t& getInnerSocket(){ return this->lowest_layer();}
+
+ std::size_t writeSome(
+ const ConstBufferSequence& buffers,
+ boost::system::error_code & ec){
+ return this->write_some(buffers, ec);
+ }
+
+ std::size_t readSome(
+ const MutableBufferSequence& buffers,
+ boost::system::error_code & ec){
+ return this->read_some(buffers, ec);
+ }
+
+ void asyncRead( const MutableBufferSequence & buffers, ReadHandler handler){
+ return async_read(*this, buffers, handler);
+ }
+
+ //
+ // public method that can be invoked by callers to invoke the ssl handshake
+ // throws: boost::system::system_error
+ void protocolHandshake(bool useSystemConfig){
+ if(useSystemConfig){
+ std::string msg = "";
+ int ret = loadSystemTrustStore(this->native_handle(), msg);
+ if(!msg.empty()){
+ DRILL_LOG(LOG_WARNING) << msg.c_str() << std::endl;
+ }
+ if(ret){
+ boost::system::error_code ec(EPROTO, boost::system::system_category());
+ boost::asio::detail::throw_error(ec, msg.c_str());
+ }
+ }
+ this->handshake(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>::client);
+ return;
+ };
+ //
+ // public method that can be invoked by callers to invoke a clean ssl shutdown
+ // throws: boost::system::system_error
+ void protocolClose(){
+ try{
+ this->shutdown();
+ }catch(boost::system::system_error e){
+ //swallow the exception. The channel is unusable anyway
+ }
+ // shuts down the socket!
+ boost::system::error_code ignorederr;
+ this->lowest_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_both,
+ ignorederr
+ );
+ return;
+ };
+
+};
+#endif
+
+
+} // namespace Drill
+
+#endif //STREAMSOCKET_HPP
+
diff --git a/contrib/native/client/src/clientlib/userProperties.cpp b/contrib/native/client/src/clientlib/userProperties.cpp
index 07497ef55..666f587c5 100644
--- a/contrib/native/client/src/clientlib/userProperties.cpp
+++ b/contrib/native/client/src/clientlib/userProperties.cpp
@@ -31,7 +31,11 @@ const std::map<std::string, uint32_t> DrillUserProperties::USER_PROPERTIES=boos
( USERPROP_SERVICE_NAME, USERPROP_FLAGS_STRING)
( USERPROP_SERVICE_HOST, USERPROP_FLAGS_STRING)
( USERPROP_USESSL, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP)
+ ( USERPROP_TLSPROTOCOL, USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP)
( USERPROP_CERTFILEPATH, USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP|USERPROP_FLAGS_FILEPATH)
+ ( USERPROP_DISABLE_HOSTVERIFICATION, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP)
+ ( USERPROP_DISABLE_CERTVERIFICATION, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP)
+ ( USERPROP_USESYSTEMTRUSTSTORE, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP)
( USERPROP_SASL_ENCRYPT, USERPROP_FLAGS_STRING)
;
diff --git a/contrib/native/client/src/clientlib/wincert.ipp b/contrib/native/client/src/clientlib/wincert.ipp
new file mode 100644
index 000000000..c1af70a36
--- /dev/null
+++ b/contrib/native/client/src/clientlib/wincert.ipp
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#if defined(IS_SSL_ENABLED)
+
+#include <openssl/x509.h>
+#include <openssl/ssl.h>
+
+#if defined _WIN32 || defined _WIN64
+
+#include <stdio.h>
+#include <windows.h>
+#include <wincrypt.h>
+#include <cryptuiapi.h>
+#include <iostream>
+#include <tchar.h>
+
+
+#pragma comment (lib, "crypt32.lib")
+#pragma comment (lib, "cryptui.lib")
+
+#define MY_ENCODING_TYPE (PKCS_7_ASN_ENCODING | X509_ASN_ENCODING)
+
+inline
+int loadSystemTrustStore(const SSL *ssl, std::string& msg) {
+ HCERTSTORE hStore;
+ PCCERT_CONTEXT pContext = NULL;
+ X509 *x509;
+ char* stores[] = {
+ "CA",
+ "MY",
+ "ROOT",
+ "SPC"
+ };
+ int certCount=0;
+
+ SSL_CTX * ctx = SSL_get_SSL_CTX(ssl);
+ X509_STORE *store = SSL_CTX_get_cert_store(ctx);
+
+ for(int i=0; i<4; i++){
+ hStore = CertOpenSystemStore(NULL, stores[i]);
+
+ if (!hStore){
+ msg.append("Failed to load store: ").append(stores[i]).append("\n");
+ continue;
+ }
+
+ while (pContext = CertEnumCertificatesInStore(hStore, pContext)) {
+ //uncomment the line below if you want to see the certificates as pop ups
+ //CryptUIDlgViewContext(CERT_STORE_CERTIFICATE_CONTEXT, pContext, NULL, NULL, 0, NULL);
+
+ x509 = NULL;
+ x509 = d2i_X509(NULL, (const unsigned char **)&pContext->pbCertEncoded, pContext->cbCertEncoded);
+ if (x509) {
+ int ret = X509_STORE_add_cert(store, x509);
+
+ //if (ret == 1)
+ // std::cout << "Added certificate " << x509->name << " from " << stores[i] << std::endl;
+
+ X509_free(x509);
+ certCount++;
+ }
+ }
+
+ CertFreeCertificateContext(pContext);
+ CertCloseStore(hStore, 0);
+ }
+ if(certCount==0){
+ msg.append("No certificates found.");
+ return -1;
+ }
+ return 0;
+}
+
+#else // notwindows
+inline
+int loadSystemTrustStore(const SSL *ssl, std::string& msg) {
+ return 0;
+}
+
+#endif // WIN32 or WIN64
+
+#endif // SSL_ENABLED
diff --git a/contrib/native/client/src/include/drill/common.hpp b/contrib/native/client/src/include/drill/common.hpp
index 5401c75a9..012bd1974 100644
--- a/contrib/native/client/src/include/drill/common.hpp
+++ b/contrib/native/client/src/include/drill/common.hpp
@@ -106,6 +106,11 @@ class AllocatedBuffer;
typedef AllocatedBuffer* AllocatedBufferPtr;
typedef enum{
+ CHANNEL_TYPE_SOCKET=1,
+ CHANNEL_TYPE_SSLSTREAM=2
+} channelType_t;
+
+typedef enum{
QRY_SUCCESS=0,
QRY_FAILURE=1,
QRY_SUCCESS_WITH_INFO=2,
@@ -136,7 +141,9 @@ typedef enum{
CONN_BAD_RPC_VER=8,
CONN_DEAD=9,
CONN_NOTCONNECTED=10,
- CONN_ALREADYCONNECTED=11
+ CONN_ALREADYCONNECTED=11,
+ CONN_SSLERROR=12,
+ CONN_NOSOCKET=13
} connectionStatus_t;
typedef enum{
@@ -163,9 +170,13 @@ typedef enum{
#define USERPROP_USERNAME "userName"
#define USERPROP_PASSWORD "password"
#define USERPROP_SCHEMA "schema"
-#define USERPROP_USESSL "useSSL" // Not implemented yet
-#define USERPROP_FILEPATH "pemLocation" // Not implemented yet
-#define USERPROP_FILENAME "pemFile" // Not implemented yet
+#define USERPROP_USESSL "enableTLS"
+#define USERPROP_TLSPROTOCOL "TLSProtocol" //TLS version
+#define USERPROP_CERTFILEPATH "certFilePath" // pem file path and name
+#define USERPROP_CERTPASSWORD "certPassword" // Password for certificate file
+#define USERPROP_DISABLE_HOSTVERIFICATION "disableHostVerification"
+#define USERPROP_DISABLE_CERTVERIFICATION "disableCertVerification"
+#define USERPROP_USESYSTEMTRUSTSTORE "useSystemTrustStore" //Windows only, use the system trust store
#define USERPROP_IMPERSONATION_TARGET "impersonation_target"
#define USERPROP_AUTH_MECHANISM "auth"
#define USERPROP_SERVICE_NAME "service_name"
diff --git a/contrib/native/client/test/ssl/testSSL.cpp b/contrib/native/client/test/ssl/testSSL.cpp
new file mode 100644
index 000000000..3eaac4876
--- /dev/null
+++ b/contrib/native/client/test/ssl/testSSL.cpp
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <stdio.h>
+#include <stdlib.h>
+#include <boost/algorithm/string.hpp>
+#include <boost/asio.hpp>
+#include <boost/assign.hpp>
+#include <boost/bind.hpp>
+#include "drill/drillc.hpp"
+#include "drill/drillError.hpp"
+#include "clientlib/channel.hpp"
+#include "clientlib/drillClientImpl.hpp"
+#include "clientlib/errmsgs.hpp"
+#include "clientlib/logger.hpp"
+#include "clientlib/rpcMessage.hpp"
+#include "clientlib/utils.hpp"
+#include "protobuf/GeneralRPC.pb.h"
+#include "protobuf/UserBitShared.pb.h"
+
+namespace Drill {
+
+class DrillTestClient {
+
+ public:
+
+ DrillTestClient(Channel* pChannel):
+ m_handshakeStatus(exec::user::SUCCESS),
+ m_wbuf(MAX_SOCK_RD_BUFSIZE),
+ m_rbuf(0){
+ m_pChannel=pChannel;
+ m_pError=NULL;
+ m_coordinationId=Utils::s_randomNumber()%1729+1;
+ }
+
+ connectionStatus_t recvHandshake(){
+ if(m_rbuf==NULL){
+ m_rbuf = Utils::allocateBuffer(MAX_SOCK_RD_BUFSIZE);
+ }
+
+ m_pChannel->getIOService().reset();
+
+ m_pChannel->getSocketStream().asyncRead(
+ boost::asio::buffer(m_rbuf, LEN_PREFIX_BUFLEN),
+ boost::bind(
+ &DrillTestClient::handleHandshake,
+ this,
+ m_rbuf,
+ boost::asio::placeholders::error,
+ boost::asio::placeholders::bytes_transferred)
+ );
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::recvHandshake: async read waiting for server handshake response.\n";)
+ m_pChannel->getIOService().run();
+ if(m_rbuf!=NULL){
+ Utils::freeBuffer(m_rbuf, MAX_SOCK_RD_BUFSIZE); m_rbuf=NULL;
+ }
+
+ if (m_pError != NULL) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_ERROR) << "DrillClientImpl::recvHandshake: failed to complete handshake with server."
+ << m_pError->msg << "\n";)
+ return static_cast<connectionStatus_t>(m_pError->status);
+ }
+
+ return CONN_SUCCESS;
+ }
+
+ void doReadFromSocket(ByteBuf_t inBuf, size_t bytesToRead, boost::system::error_code& errorCode) {
+ // Check if bytesToRead is zero
+ if(0 == bytesToRead) {
+ return;
+ }
+
+ // Read all the bytes. In case when all the bytes were not read the proper
+ // errorCode will be set.
+ while(1){
+ size_t dataBytesRead = m_pChannel->getSocketStream().readSome(boost::asio::buffer(inBuf, bytesToRead), errorCode);
+ // Update the state
+ bytesToRead -= dataBytesRead;
+ inBuf += dataBytesRead;
+
+ // Check if errorCode is EINTR then just retry otherwise break from loop
+ if(EINTR != errorCode.value()) break;
+
+ // Check if all the data is read then break from loop
+ if(0 == bytesToRead) break;
+ }
+ }
+
+ void handleHandshake(ByteBuf_t inBuf,
+ const boost::system::error_code& err,
+ size_t bytes_transferred) {
+ boost::system::error_code error=err;
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Deadline timer cancelled." << std::endl;)
+ if(!error){
+ rpc::InBoundRpcMessage msg;
+ uint32_t length = 0;
+ std::size_t bytes_read = rpc::lengthDecode(m_rbuf, length);
+ if(length>0){
+ const size_t leftover = LEN_PREFIX_BUFLEN - bytes_read;
+ const ByteBuf_t b = m_rbuf + LEN_PREFIX_BUFLEN;
+ const size_t bytesToRead=length - leftover;
+ doReadFromSocket(b, bytesToRead, error);
+
+ // Check if any error happen while reading the message bytes. If yes then return before decoding the Msg
+ if(error) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. "
+ << " Failed to read entire handshake message. with error: "
+ << error.message().c_str() << "\n";)
+ handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Failed to read entire handshake message"));
+ return;
+ }
+
+ // Decode the bytes into a valid RPC Message
+ if (!decode(m_rbuf+bytes_read, length, msg)) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. Cannot decode handshake.\n";)
+ handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Cannot decode handshake"));
+ return;
+ }
+ }else{
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. No handshake.\n";)
+ handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "No handshake"));
+ return;
+ }
+ exec::user::BitToUserHandshake b2u;
+ b2u.ParseFromArray(msg.m_pbody.data(), msg.m_pbody.size());
+ this->m_handshakeErrorId=b2u.errorid();
+ this->m_handshakeErrorMsg=b2u.errormessage();
+ }else{
+ // boost error
+ if(error==boost::asio::error::eof){ // Server broke off the connection
+ handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_NOHSHAKE, DRILL_RPC_VERSION));
+ }else{
+ handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, error.message().c_str()));
+ }
+ return;
+ }
+ return;
+ }
+
+ connectionStatus_t handleConnError(connectionStatus_t status, const std::string& msg){
+ DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg);
+ if(m_pError!=NULL){ delete m_pError; m_pError=NULL;}
+ m_pError=pErr;
+ return status;
+ }
+
+ connectionStatus_t sendSyncCommon(rpc::OutBoundRpcMessage& msg) {
+ encode(m_wbuf, msg);
+ boost::system::error_code ec;
+ doWriteToSocket(reinterpret_cast<char*>(m_wbuf.data()), m_wbuf.size(), ec);
+
+ if(!ec) {
+ return CONN_SUCCESS;
+ } else {
+ return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_WFAIL, ec.message().c_str()));
+ }
+ }
+
+ void doWriteToSocket(const char* dataPtr, size_t bytesToWrite,
+ boost::system::error_code& errorCode) {
+ if(0 == bytesToWrite) {
+ return;
+ }
+
+ // Write all the bytes to socket. In case of error when all bytes are not successfully written
+ // proper errorCode will be set.
+ while(1) {
+ size_t bytesWritten = m_pChannel->getSocketStream().writeSome(boost::asio::buffer(dataPtr, bytesToWrite), errorCode);
+ // Update the state
+ bytesToWrite -= bytesWritten;
+ dataPtr += bytesWritten;
+
+ if(EINTR != errorCode.value()) break;
+
+ // Check if all the data is written then break from loop
+ if(0 == bytesToWrite) break;
+ }
+ }
+
+ connectionStatus_t validateHandshake(DrillUserProperties* properties){
+
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "validateHandShake\n";)
+
+ exec::user::UserToBitHandshake u2b;
+ u2b.set_channel(exec::shared::USER);
+ u2b.set_rpc_version(DRILL_RPC_VERSION);
+ u2b.set_support_listening(true);
+ u2b.set_support_timeout(DrillClientConfig::getHeartbeatFrequency() > 0);
+ u2b.set_sasl_support(exec::user::SASL_PRIVACY);
+
+ // Adding version info
+ exec::user::RpcEndpointInfos* infos = u2b.mutable_client_infos();
+ infos->set_name(DrillClientConfig::getClientName());
+ infos->set_application(DrillClientConfig::getApplicationName());
+ infos->set_version(DRILL_VERSION_STRING);
+ infos->set_majorversion(DRILL_VERSION_MAJOR);
+ infos->set_minorversion(DRILL_VERSION_MINOR);
+ infos->set_patchversion(DRILL_VERSION_PATCH);
+
+ if(properties != NULL && properties->size()>0){
+ std::string username;
+ std::string err;
+ if(!properties->validate(err)){
+ DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Invalid user input:" << err << std::endl;)
+ }
+ exec::user::UserProperties* userProperties = u2b.mutable_properties();
+
+ std::map<char,int>::iterator it;
+ for (std::map<std::string,std::string>::const_iterator propIter=properties->begin(); propIter!=properties->end(); ++propIter){
+ std::string currKey=propIter->first;
+ std::string currVal=propIter->second;
+ std::map<std::string,uint32_t>::const_iterator it=DrillUserProperties::USER_PROPERTIES.find(currKey);
+ if(it==DrillUserProperties::USER_PROPERTIES.end()){
+ DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Connection property ("<< currKey
+ << ") is unknown" << std::endl;)
+ exec::user::Property* connProp = userProperties->add_properties();
+ connProp->set_key(currKey);
+ connProp->set_value(currVal);
+ continue;
+ }
+ if(IS_BITSET((*it).second,USERPROP_FLAGS_SERVERPROP)){
+ exec::user::Property* connProp = userProperties->add_properties();
+ connProp->set_key(currKey);
+ connProp->set_value(currVal);
+ //Username(but not the password) also needs to be set in UserCredentials
+ if(IS_BITSET((*it).second,USERPROP_FLAGS_USERNAME)){
+ exec::shared::UserCredentials* creds = u2b.mutable_credentials();
+ username=currVal;
+ creds->set_user_name(username);
+ //u2b.set_credentials(&creds);
+ }
+ if(IS_BITSET((*it).second,USERPROP_FLAGS_PASSWORD)){
+ DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << currKey << ": ********** " << std::endl;)
+ }else{
+ DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << currKey << ":" << currVal << std::endl;)
+ }
+ }// Server properties
+ }
+ }
+
+ {
+ boost::lock_guard<boost::mutex> lock(this->m_dcMutex);
+ uint64_t coordId = ++m_coordinationId;
+
+ rpc::OutBoundRpcMessage out_msg(exec::rpc::REQUEST, exec::user::HANDSHAKE, coordId, &u2b);
+ sendSyncCommon(out_msg);
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Sent handshake request message. Coordination id: " << coordId << "\n";)
+ }
+
+ connectionStatus_t ret = recvHandshake();
+ if(ret!=CONN_SUCCESS){
+ return ret;
+ }
+
+ switch(this->m_handshakeStatus) {
+ case exec::user::SUCCESS:
+ // reset io_service after handshake is validated before running queries
+ m_pChannel->getIOService().reset();
+ return CONN_SUCCESS;
+ case exec::user::RPC_VERSION_MISMATCH:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected "
+ << DRILL_RPC_VERSION << ", actual "<< 0 << "." << std::endl;)
+ return handleConnError(CONN_BAD_RPC_VER, getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION,
+ 0,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::AUTH_FAILED:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;)
+ return handleConnError(CONN_AUTH_FAILED, getMessage(ERR_CONN_AUTHFAIL,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::UNKNOWN_FAILURE:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;)
+ return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::AUTH_REQUIRED:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Server requires SASL authentication." << std::endl;)
+ return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ default:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown return status." << std::endl;)
+ return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ }
+ }
+
+ DrillClientError* m_pError;
+ private:
+ Channel* m_pChannel;
+ int32_t m_coordinationId;
+ std::string m_handshakeErrorId;
+ std::string m_handshakeErrorMsg;
+ exec::user::HandshakeStatus m_handshakeStatus;
+ DataBuf m_wbuf;
+ ByteBuf_t m_rbuf;
+ boost::mutex m_dcMutex;
+
+
+
+
+};
+
+} // namespace Drill
+
+using namespace Drill;
+
+int main(int argc, char* argv[]){
+ Channel *pChannel = NULL;
+ ChannelContext *pChannelContext = NULL;
+ std::string connectStr = "zk=localhost:2181/drill/drillbits1";
+ //std::string connectStr = "drillbit=localhost:31090";
+ channelType_t type;
+
+ bool isSSL = argc==2 && !(strcmp(argv[1], "ssl"));
+ type = CHANNEL_TYPE_SOCKET;
+ if(isSSL){
+ type = CHANNEL_TYPE_SSLSTREAM;
+ }
+ Drill::DrillUserProperties props;
+ props.setProperty(USERPROP_USERNAME, "admin");
+ props.setProperty(USERPROP_PASSWORD, "admin");
+ props.setProperty(USERPROP_CERTFILEPATH, "../../../test/ssl/drillTestCert.pem");
+
+ pChannelContext = ChannelContextFactory::getChannelContext(type, &props);
+
+ pChannel = ChannelFactory::getChannel(type, connectStr.c_str());
+ if(pChannel != NULL){
+ connectionStatus_t connStat;
+ connStat = pChannel->init(pChannelContext);
+ if(connStat != CONN_SUCCESS){
+ std::cout << "Init Failed." << std::endl;
+ return -1;
+ }
+ connStat = pChannel->connect();
+ if(connStat != CONN_SUCCESS){
+ std::cout << "Connect Failed." << std::endl;
+ std::cout << pChannel->getError()->msg << std::endl;
+ return -1;
+ }
+ } else{
+ std::cout << "Channel creation failed." << std::endl;
+ return -1;
+ }
+ std::cout << "Connected." << std::endl;
+ std::cout << "Starting Drill handshake" << std::endl;
+
+
+ DrillTestClient client(pChannel);
+
+ connectionStatus_t stat = client.validateHandshake(&props);
+ if(stat == CONN_SUCCESS){
+ std::cout << "Handshake validated." << std::endl;
+ } else{
+ if(client.m_pError != NULL){
+ std::cout << "Handshake failed: " << client.m_pError->msg << ". " << std::endl;
+ } else{
+ std::cout << "Handshake failed with unknown error" << ". " << std::endl;
+ }
+ }
+
+ return 0;
+
+}
+