aboutsummaryrefslogtreecommitdiff
path: root/contrib/native/client/test/ssl/testSSL.cpp
diff options
context:
space:
mode:
authorParth Chandra <parthc@apache.org>2017-07-25 09:22:23 -0700
committerParth Chandra <parthc@apache.org>2017-10-11 19:27:48 -0700
commitfacbb92ba319373dd8b8baa171ac1d7978c926c5 (patch)
tree95b2f1b102047d4244d9f9a302779e1d11a81a9a /contrib/native/client/test/ssl/testSSL.cpp
parent414c94ee7e088c89b0c7c8b9d9dde9335cfcbe6d (diff)
DRILL-5431: SSL Support (C++) - Add (Netty like) socket abstraction that encapsulates a TCP socket or a SSL Stream on TCP.
The testSSL program tests the client connection against a drillbit by sending a drill handshake.
Diffstat (limited to 'contrib/native/client/test/ssl/testSSL.cpp')
-rw-r--r--contrib/native/client/test/ssl/testSSL.cpp384
1 files changed, 384 insertions, 0 deletions
diff --git a/contrib/native/client/test/ssl/testSSL.cpp b/contrib/native/client/test/ssl/testSSL.cpp
new file mode 100644
index 000000000..3eaac4876
--- /dev/null
+++ b/contrib/native/client/test/ssl/testSSL.cpp
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <stdio.h>
+#include <stdlib.h>
+#include <boost/algorithm/string.hpp>
+#include <boost/asio.hpp>
+#include <boost/assign.hpp>
+#include <boost/bind.hpp>
+#include "drill/drillc.hpp"
+#include "drill/drillError.hpp"
+#include "clientlib/channel.hpp"
+#include "clientlib/drillClientImpl.hpp"
+#include "clientlib/errmsgs.hpp"
+#include "clientlib/logger.hpp"
+#include "clientlib/rpcMessage.hpp"
+#include "clientlib/utils.hpp"
+#include "protobuf/GeneralRPC.pb.h"
+#include "protobuf/UserBitShared.pb.h"
+
+namespace Drill {
+
+class DrillTestClient {
+
+ public:
+
+ DrillTestClient(Channel* pChannel):
+ m_handshakeStatus(exec::user::SUCCESS),
+ m_wbuf(MAX_SOCK_RD_BUFSIZE),
+ m_rbuf(0){
+ m_pChannel=pChannel;
+ m_pError=NULL;
+ m_coordinationId=Utils::s_randomNumber()%1729+1;
+ }
+
+ connectionStatus_t recvHandshake(){
+ if(m_rbuf==NULL){
+ m_rbuf = Utils::allocateBuffer(MAX_SOCK_RD_BUFSIZE);
+ }
+
+ m_pChannel->getIOService().reset();
+
+ m_pChannel->getSocketStream().asyncRead(
+ boost::asio::buffer(m_rbuf, LEN_PREFIX_BUFLEN),
+ boost::bind(
+ &DrillTestClient::handleHandshake,
+ this,
+ m_rbuf,
+ boost::asio::placeholders::error,
+ boost::asio::placeholders::bytes_transferred)
+ );
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::recvHandshake: async read waiting for server handshake response.\n";)
+ m_pChannel->getIOService().run();
+ if(m_rbuf!=NULL){
+ Utils::freeBuffer(m_rbuf, MAX_SOCK_RD_BUFSIZE); m_rbuf=NULL;
+ }
+
+ if (m_pError != NULL) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_ERROR) << "DrillClientImpl::recvHandshake: failed to complete handshake with server."
+ << m_pError->msg << "\n";)
+ return static_cast<connectionStatus_t>(m_pError->status);
+ }
+
+ return CONN_SUCCESS;
+ }
+
+ void doReadFromSocket(ByteBuf_t inBuf, size_t bytesToRead, boost::system::error_code& errorCode) {
+ // Check if bytesToRead is zero
+ if(0 == bytesToRead) {
+ return;
+ }
+
+ // Read all the bytes. In case when all the bytes were not read the proper
+ // errorCode will be set.
+ while(1){
+ size_t dataBytesRead = m_pChannel->getSocketStream().readSome(boost::asio::buffer(inBuf, bytesToRead), errorCode);
+ // Update the state
+ bytesToRead -= dataBytesRead;
+ inBuf += dataBytesRead;
+
+ // Check if errorCode is EINTR then just retry otherwise break from loop
+ if(EINTR != errorCode.value()) break;
+
+ // Check if all the data is read then break from loop
+ if(0 == bytesToRead) break;
+ }
+ }
+
+ void handleHandshake(ByteBuf_t inBuf,
+ const boost::system::error_code& err,
+ size_t bytes_transferred) {
+ boost::system::error_code error=err;
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Deadline timer cancelled." << std::endl;)
+ if(!error){
+ rpc::InBoundRpcMessage msg;
+ uint32_t length = 0;
+ std::size_t bytes_read = rpc::lengthDecode(m_rbuf, length);
+ if(length>0){
+ const size_t leftover = LEN_PREFIX_BUFLEN - bytes_read;
+ const ByteBuf_t b = m_rbuf + LEN_PREFIX_BUFLEN;
+ const size_t bytesToRead=length - leftover;
+ doReadFromSocket(b, bytesToRead, error);
+
+ // Check if any error happen while reading the message bytes. If yes then return before decoding the Msg
+ if(error) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. "
+ << " Failed to read entire handshake message. with error: "
+ << error.message().c_str() << "\n";)
+ handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Failed to read entire handshake message"));
+ return;
+ }
+
+ // Decode the bytes into a valid RPC Message
+ if (!decode(m_rbuf+bytes_read, length, msg)) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. Cannot decode handshake.\n";)
+ handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Cannot decode handshake"));
+ return;
+ }
+ }else{
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. No handshake.\n";)
+ handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "No handshake"));
+ return;
+ }
+ exec::user::BitToUserHandshake b2u;
+ b2u.ParseFromArray(msg.m_pbody.data(), msg.m_pbody.size());
+ this->m_handshakeErrorId=b2u.errorid();
+ this->m_handshakeErrorMsg=b2u.errormessage();
+ }else{
+ // boost error
+ if(error==boost::asio::error::eof){ // Server broke off the connection
+ handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_NOHSHAKE, DRILL_RPC_VERSION));
+ }else{
+ handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, error.message().c_str()));
+ }
+ return;
+ }
+ return;
+ }
+
+ connectionStatus_t handleConnError(connectionStatus_t status, const std::string& msg){
+ DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg);
+ if(m_pError!=NULL){ delete m_pError; m_pError=NULL;}
+ m_pError=pErr;
+ return status;
+ }
+
+ connectionStatus_t sendSyncCommon(rpc::OutBoundRpcMessage& msg) {
+ encode(m_wbuf, msg);
+ boost::system::error_code ec;
+ doWriteToSocket(reinterpret_cast<char*>(m_wbuf.data()), m_wbuf.size(), ec);
+
+ if(!ec) {
+ return CONN_SUCCESS;
+ } else {
+ return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_WFAIL, ec.message().c_str()));
+ }
+ }
+
+ void doWriteToSocket(const char* dataPtr, size_t bytesToWrite,
+ boost::system::error_code& errorCode) {
+ if(0 == bytesToWrite) {
+ return;
+ }
+
+ // Write all the bytes to socket. In case of error when all bytes are not successfully written
+ // proper errorCode will be set.
+ while(1) {
+ size_t bytesWritten = m_pChannel->getSocketStream().writeSome(boost::asio::buffer(dataPtr, bytesToWrite), errorCode);
+ // Update the state
+ bytesToWrite -= bytesWritten;
+ dataPtr += bytesWritten;
+
+ if(EINTR != errorCode.value()) break;
+
+ // Check if all the data is written then break from loop
+ if(0 == bytesToWrite) break;
+ }
+ }
+
+ connectionStatus_t validateHandshake(DrillUserProperties* properties){
+
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "validateHandShake\n";)
+
+ exec::user::UserToBitHandshake u2b;
+ u2b.set_channel(exec::shared::USER);
+ u2b.set_rpc_version(DRILL_RPC_VERSION);
+ u2b.set_support_listening(true);
+ u2b.set_support_timeout(DrillClientConfig::getHeartbeatFrequency() > 0);
+ u2b.set_sasl_support(exec::user::SASL_PRIVACY);
+
+ // Adding version info
+ exec::user::RpcEndpointInfos* infos = u2b.mutable_client_infos();
+ infos->set_name(DrillClientConfig::getClientName());
+ infos->set_application(DrillClientConfig::getApplicationName());
+ infos->set_version(DRILL_VERSION_STRING);
+ infos->set_majorversion(DRILL_VERSION_MAJOR);
+ infos->set_minorversion(DRILL_VERSION_MINOR);
+ infos->set_patchversion(DRILL_VERSION_PATCH);
+
+ if(properties != NULL && properties->size()>0){
+ std::string username;
+ std::string err;
+ if(!properties->validate(err)){
+ DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Invalid user input:" << err << std::endl;)
+ }
+ exec::user::UserProperties* userProperties = u2b.mutable_properties();
+
+ std::map<char,int>::iterator it;
+ for (std::map<std::string,std::string>::const_iterator propIter=properties->begin(); propIter!=properties->end(); ++propIter){
+ std::string currKey=propIter->first;
+ std::string currVal=propIter->second;
+ std::map<std::string,uint32_t>::const_iterator it=DrillUserProperties::USER_PROPERTIES.find(currKey);
+ if(it==DrillUserProperties::USER_PROPERTIES.end()){
+ DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Connection property ("<< currKey
+ << ") is unknown" << std::endl;)
+ exec::user::Property* connProp = userProperties->add_properties();
+ connProp->set_key(currKey);
+ connProp->set_value(currVal);
+ continue;
+ }
+ if(IS_BITSET((*it).second,USERPROP_FLAGS_SERVERPROP)){
+ exec::user::Property* connProp = userProperties->add_properties();
+ connProp->set_key(currKey);
+ connProp->set_value(currVal);
+ //Username(but not the password) also needs to be set in UserCredentials
+ if(IS_BITSET((*it).second,USERPROP_FLAGS_USERNAME)){
+ exec::shared::UserCredentials* creds = u2b.mutable_credentials();
+ username=currVal;
+ creds->set_user_name(username);
+ //u2b.set_credentials(&creds);
+ }
+ if(IS_BITSET((*it).second,USERPROP_FLAGS_PASSWORD)){
+ DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << currKey << ": ********** " << std::endl;)
+ }else{
+ DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << currKey << ":" << currVal << std::endl;)
+ }
+ }// Server properties
+ }
+ }
+
+ {
+ boost::lock_guard<boost::mutex> lock(this->m_dcMutex);
+ uint64_t coordId = ++m_coordinationId;
+
+ rpc::OutBoundRpcMessage out_msg(exec::rpc::REQUEST, exec::user::HANDSHAKE, coordId, &u2b);
+ sendSyncCommon(out_msg);
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Sent handshake request message. Coordination id: " << coordId << "\n";)
+ }
+
+ connectionStatus_t ret = recvHandshake();
+ if(ret!=CONN_SUCCESS){
+ return ret;
+ }
+
+ switch(this->m_handshakeStatus) {
+ case exec::user::SUCCESS:
+ // reset io_service after handshake is validated before running queries
+ m_pChannel->getIOService().reset();
+ return CONN_SUCCESS;
+ case exec::user::RPC_VERSION_MISMATCH:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected "
+ << DRILL_RPC_VERSION << ", actual "<< 0 << "." << std::endl;)
+ return handleConnError(CONN_BAD_RPC_VER, getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION,
+ 0,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::AUTH_FAILED:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;)
+ return handleConnError(CONN_AUTH_FAILED, getMessage(ERR_CONN_AUTHFAIL,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::UNKNOWN_FAILURE:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;)
+ return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::AUTH_REQUIRED:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Server requires SASL authentication." << std::endl;)
+ return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ default:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown return status." << std::endl;)
+ return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ }
+ }
+
+ DrillClientError* m_pError;
+ private:
+ Channel* m_pChannel;
+ int32_t m_coordinationId;
+ std::string m_handshakeErrorId;
+ std::string m_handshakeErrorMsg;
+ exec::user::HandshakeStatus m_handshakeStatus;
+ DataBuf m_wbuf;
+ ByteBuf_t m_rbuf;
+ boost::mutex m_dcMutex;
+
+
+
+
+};
+
+} // namespace Drill
+
+using namespace Drill;
+
+int main(int argc, char* argv[]){
+ Channel *pChannel = NULL;
+ ChannelContext *pChannelContext = NULL;
+ std::string connectStr = "zk=localhost:2181/drill/drillbits1";
+ //std::string connectStr = "drillbit=localhost:31090";
+ channelType_t type;
+
+ bool isSSL = argc==2 && !(strcmp(argv[1], "ssl"));
+ type = CHANNEL_TYPE_SOCKET;
+ if(isSSL){
+ type = CHANNEL_TYPE_SSLSTREAM;
+ }
+ Drill::DrillUserProperties props;
+ props.setProperty(USERPROP_USERNAME, "admin");
+ props.setProperty(USERPROP_PASSWORD, "admin");
+ props.setProperty(USERPROP_CERTFILEPATH, "../../../test/ssl/drillTestCert.pem");
+
+ pChannelContext = ChannelContextFactory::getChannelContext(type, &props);
+
+ pChannel = ChannelFactory::getChannel(type, connectStr.c_str());
+ if(pChannel != NULL){
+ connectionStatus_t connStat;
+ connStat = pChannel->init(pChannelContext);
+ if(connStat != CONN_SUCCESS){
+ std::cout << "Init Failed." << std::endl;
+ return -1;
+ }
+ connStat = pChannel->connect();
+ if(connStat != CONN_SUCCESS){
+ std::cout << "Connect Failed." << std::endl;
+ std::cout << pChannel->getError()->msg << std::endl;
+ return -1;
+ }
+ } else{
+ std::cout << "Channel creation failed." << std::endl;
+ return -1;
+ }
+ std::cout << "Connected." << std::endl;
+ std::cout << "Starting Drill handshake" << std::endl;
+
+
+ DrillTestClient client(pChannel);
+
+ connectionStatus_t stat = client.validateHandshake(&props);
+ if(stat == CONN_SUCCESS){
+ std::cout << "Handshake validated." << std::endl;
+ } else{
+ if(client.m_pError != NULL){
+ std::cout << "Handshake failed: " << client.m_pError->msg << ". " << std::endl;
+ } else{
+ std::cout << "Handshake failed with unknown error" << ". " << std::endl;
+ }
+ }
+
+ return 0;
+
+}
+