diff options
author | Parth Chandra <parthc@apache.org> | 2017-07-25 09:22:23 -0700 |
---|---|---|
committer | Parth Chandra <parthc@apache.org> | 2017-10-11 19:27:48 -0700 |
commit | facbb92ba319373dd8b8baa171ac1d7978c926c5 (patch) | |
tree | 95b2f1b102047d4244d9f9a302779e1d11a81a9a /contrib/native/client/test/ssl/testSSL.cpp | |
parent | 414c94ee7e088c89b0c7c8b9d9dde9335cfcbe6d (diff) |
DRILL-5431: SSL Support (C++) - Add (Netty like) socket abstraction that encapsulates a TCP socket or a SSL Stream on TCP.
The testSSL program tests the client connection against a drillbit by sending a drill handshake.
Diffstat (limited to 'contrib/native/client/test/ssl/testSSL.cpp')
-rw-r--r-- | contrib/native/client/test/ssl/testSSL.cpp | 384 |
1 files changed, 384 insertions, 0 deletions
diff --git a/contrib/native/client/test/ssl/testSSL.cpp b/contrib/native/client/test/ssl/testSSL.cpp new file mode 100644 index 000000000..3eaac4876 --- /dev/null +++ b/contrib/native/client/test/ssl/testSSL.cpp @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <fstream> +#include <iostream> +#include <stdio.h> +#include <stdlib.h> +#include <boost/algorithm/string.hpp> +#include <boost/asio.hpp> +#include <boost/assign.hpp> +#include <boost/bind.hpp> +#include "drill/drillc.hpp" +#include "drill/drillError.hpp" +#include "clientlib/channel.hpp" +#include "clientlib/drillClientImpl.hpp" +#include "clientlib/errmsgs.hpp" +#include "clientlib/logger.hpp" +#include "clientlib/rpcMessage.hpp" +#include "clientlib/utils.hpp" +#include "protobuf/GeneralRPC.pb.h" +#include "protobuf/UserBitShared.pb.h" + +namespace Drill { + +class DrillTestClient { + + public: + + DrillTestClient(Channel* pChannel): + m_handshakeStatus(exec::user::SUCCESS), + m_wbuf(MAX_SOCK_RD_BUFSIZE), + m_rbuf(0){ + m_pChannel=pChannel; + m_pError=NULL; + m_coordinationId=Utils::s_randomNumber()%1729+1; + } + + connectionStatus_t recvHandshake(){ + if(m_rbuf==NULL){ + m_rbuf = Utils::allocateBuffer(MAX_SOCK_RD_BUFSIZE); + } + + m_pChannel->getIOService().reset(); + + m_pChannel->getSocketStream().asyncRead( + boost::asio::buffer(m_rbuf, LEN_PREFIX_BUFLEN), + boost::bind( + &DrillTestClient::handleHandshake, + this, + m_rbuf, + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred) + ); + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::recvHandshake: async read waiting for server handshake response.\n";) + m_pChannel->getIOService().run(); + if(m_rbuf!=NULL){ + Utils::freeBuffer(m_rbuf, MAX_SOCK_RD_BUFSIZE); m_rbuf=NULL; + } + + if (m_pError != NULL) { + DRILL_MT_LOG(DRILL_LOG(LOG_ERROR) << "DrillClientImpl::recvHandshake: failed to complete handshake with server." + << m_pError->msg << "\n";) + return static_cast<connectionStatus_t>(m_pError->status); + } + + return CONN_SUCCESS; + } + + void doReadFromSocket(ByteBuf_t inBuf, size_t bytesToRead, boost::system::error_code& errorCode) { + // Check if bytesToRead is zero + if(0 == bytesToRead) { + return; + } + + // Read all the bytes. In case when all the bytes were not read the proper + // errorCode will be set. + while(1){ + size_t dataBytesRead = m_pChannel->getSocketStream().readSome(boost::asio::buffer(inBuf, bytesToRead), errorCode); + // Update the state + bytesToRead -= dataBytesRead; + inBuf += dataBytesRead; + + // Check if errorCode is EINTR then just retry otherwise break from loop + if(EINTR != errorCode.value()) break; + + // Check if all the data is read then break from loop + if(0 == bytesToRead) break; + } + } + + void handleHandshake(ByteBuf_t inBuf, + const boost::system::error_code& err, + size_t bytes_transferred) { + boost::system::error_code error=err; + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Deadline timer cancelled." << std::endl;) + if(!error){ + rpc::InBoundRpcMessage msg; + uint32_t length = 0; + std::size_t bytes_read = rpc::lengthDecode(m_rbuf, length); + if(length>0){ + const size_t leftover = LEN_PREFIX_BUFLEN - bytes_read; + const ByteBuf_t b = m_rbuf + LEN_PREFIX_BUFLEN; + const size_t bytesToRead=length - leftover; + doReadFromSocket(b, bytesToRead, error); + + // Check if any error happen while reading the message bytes. If yes then return before decoding the Msg + if(error) { + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. " + << " Failed to read entire handshake message. with error: " + << error.message().c_str() << "\n";) + handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Failed to read entire handshake message")); + return; + } + + // Decode the bytes into a valid RPC Message + if (!decode(m_rbuf+bytes_read, length, msg)) { + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. Cannot decode handshake.\n";) + handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Cannot decode handshake")); + return; + } + }else{ + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. No handshake.\n";) + handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "No handshake")); + return; + } + exec::user::BitToUserHandshake b2u; + b2u.ParseFromArray(msg.m_pbody.data(), msg.m_pbody.size()); + this->m_handshakeErrorId=b2u.errorid(); + this->m_handshakeErrorMsg=b2u.errormessage(); + }else{ + // boost error + if(error==boost::asio::error::eof){ // Server broke off the connection + handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_NOHSHAKE, DRILL_RPC_VERSION)); + }else{ + handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, error.message().c_str())); + } + return; + } + return; + } + + connectionStatus_t handleConnError(connectionStatus_t status, const std::string& msg){ + DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg); + if(m_pError!=NULL){ delete m_pError; m_pError=NULL;} + m_pError=pErr; + return status; + } + + connectionStatus_t sendSyncCommon(rpc::OutBoundRpcMessage& msg) { + encode(m_wbuf, msg); + boost::system::error_code ec; + doWriteToSocket(reinterpret_cast<char*>(m_wbuf.data()), m_wbuf.size(), ec); + + if(!ec) { + return CONN_SUCCESS; + } else { + return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_WFAIL, ec.message().c_str())); + } + } + + void doWriteToSocket(const char* dataPtr, size_t bytesToWrite, + boost::system::error_code& errorCode) { + if(0 == bytesToWrite) { + return; + } + + // Write all the bytes to socket. In case of error when all bytes are not successfully written + // proper errorCode will be set. + while(1) { + size_t bytesWritten = m_pChannel->getSocketStream().writeSome(boost::asio::buffer(dataPtr, bytesToWrite), errorCode); + // Update the state + bytesToWrite -= bytesWritten; + dataPtr += bytesWritten; + + if(EINTR != errorCode.value()) break; + + // Check if all the data is written then break from loop + if(0 == bytesToWrite) break; + } + } + + connectionStatus_t validateHandshake(DrillUserProperties* properties){ + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "validateHandShake\n";) + + exec::user::UserToBitHandshake u2b; + u2b.set_channel(exec::shared::USER); + u2b.set_rpc_version(DRILL_RPC_VERSION); + u2b.set_support_listening(true); + u2b.set_support_timeout(DrillClientConfig::getHeartbeatFrequency() > 0); + u2b.set_sasl_support(exec::user::SASL_PRIVACY); + + // Adding version info + exec::user::RpcEndpointInfos* infos = u2b.mutable_client_infos(); + infos->set_name(DrillClientConfig::getClientName()); + infos->set_application(DrillClientConfig::getApplicationName()); + infos->set_version(DRILL_VERSION_STRING); + infos->set_majorversion(DRILL_VERSION_MAJOR); + infos->set_minorversion(DRILL_VERSION_MINOR); + infos->set_patchversion(DRILL_VERSION_PATCH); + + if(properties != NULL && properties->size()>0){ + std::string username; + std::string err; + if(!properties->validate(err)){ + DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Invalid user input:" << err << std::endl;) + } + exec::user::UserProperties* userProperties = u2b.mutable_properties(); + + std::map<char,int>::iterator it; + for (std::map<std::string,std::string>::const_iterator propIter=properties->begin(); propIter!=properties->end(); ++propIter){ + std::string currKey=propIter->first; + std::string currVal=propIter->second; + std::map<std::string,uint32_t>::const_iterator it=DrillUserProperties::USER_PROPERTIES.find(currKey); + if(it==DrillUserProperties::USER_PROPERTIES.end()){ + DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Connection property ("<< currKey + << ") is unknown" << std::endl;) + exec::user::Property* connProp = userProperties->add_properties(); + connProp->set_key(currKey); + connProp->set_value(currVal); + continue; + } + if(IS_BITSET((*it).second,USERPROP_FLAGS_SERVERPROP)){ + exec::user::Property* connProp = userProperties->add_properties(); + connProp->set_key(currKey); + connProp->set_value(currVal); + //Username(but not the password) also needs to be set in UserCredentials + if(IS_BITSET((*it).second,USERPROP_FLAGS_USERNAME)){ + exec::shared::UserCredentials* creds = u2b.mutable_credentials(); + username=currVal; + creds->set_user_name(username); + //u2b.set_credentials(&creds); + } + if(IS_BITSET((*it).second,USERPROP_FLAGS_PASSWORD)){ + DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << currKey << ": ********** " << std::endl;) + }else{ + DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << currKey << ":" << currVal << std::endl;) + } + }// Server properties + } + } + + { + boost::lock_guard<boost::mutex> lock(this->m_dcMutex); + uint64_t coordId = ++m_coordinationId; + + rpc::OutBoundRpcMessage out_msg(exec::rpc::REQUEST, exec::user::HANDSHAKE, coordId, &u2b); + sendSyncCommon(out_msg); + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Sent handshake request message. Coordination id: " << coordId << "\n";) + } + + connectionStatus_t ret = recvHandshake(); + if(ret!=CONN_SUCCESS){ + return ret; + } + + switch(this->m_handshakeStatus) { + case exec::user::SUCCESS: + // reset io_service after handshake is validated before running queries + m_pChannel->getIOService().reset(); + return CONN_SUCCESS; + case exec::user::RPC_VERSION_MISMATCH: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected " + << DRILL_RPC_VERSION << ", actual "<< 0 << "." << std::endl;) + return handleConnError(CONN_BAD_RPC_VER, getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION, + 0, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + case exec::user::AUTH_FAILED: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;) + return handleConnError(CONN_AUTH_FAILED, getMessage(ERR_CONN_AUTHFAIL, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + case exec::user::UNKNOWN_FAILURE: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;) + return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + case exec::user::AUTH_REQUIRED: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Server requires SASL authentication." << std::endl;) + return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + default: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown return status." << std::endl;) + return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + } + } + + DrillClientError* m_pError; + private: + Channel* m_pChannel; + int32_t m_coordinationId; + std::string m_handshakeErrorId; + std::string m_handshakeErrorMsg; + exec::user::HandshakeStatus m_handshakeStatus; + DataBuf m_wbuf; + ByteBuf_t m_rbuf; + boost::mutex m_dcMutex; + + + + +}; + +} // namespace Drill + +using namespace Drill; + +int main(int argc, char* argv[]){ + Channel *pChannel = NULL; + ChannelContext *pChannelContext = NULL; + std::string connectStr = "zk=localhost:2181/drill/drillbits1"; + //std::string connectStr = "drillbit=localhost:31090"; + channelType_t type; + + bool isSSL = argc==2 && !(strcmp(argv[1], "ssl")); + type = CHANNEL_TYPE_SOCKET; + if(isSSL){ + type = CHANNEL_TYPE_SSLSTREAM; + } + Drill::DrillUserProperties props; + props.setProperty(USERPROP_USERNAME, "admin"); + props.setProperty(USERPROP_PASSWORD, "admin"); + props.setProperty(USERPROP_CERTFILEPATH, "../../../test/ssl/drillTestCert.pem"); + + pChannelContext = ChannelContextFactory::getChannelContext(type, &props); + + pChannel = ChannelFactory::getChannel(type, connectStr.c_str()); + if(pChannel != NULL){ + connectionStatus_t connStat; + connStat = pChannel->init(pChannelContext); + if(connStat != CONN_SUCCESS){ + std::cout << "Init Failed." << std::endl; + return -1; + } + connStat = pChannel->connect(); + if(connStat != CONN_SUCCESS){ + std::cout << "Connect Failed." << std::endl; + std::cout << pChannel->getError()->msg << std::endl; + return -1; + } + } else{ + std::cout << "Channel creation failed." << std::endl; + return -1; + } + std::cout << "Connected." << std::endl; + std::cout << "Starting Drill handshake" << std::endl; + + + DrillTestClient client(pChannel); + + connectionStatus_t stat = client.validateHandshake(&props); + if(stat == CONN_SUCCESS){ + std::cout << "Handshake validated." << std::endl; + } else{ + if(client.m_pError != NULL){ + std::cout << "Handshake failed: " << client.m_pError->msg << ". " << std::endl; + } else{ + std::cout << "Handshake failed with unknown error" << ". " << std::endl; + } + } + + return 0; + +} + |