aboutsummaryrefslogtreecommitdiff
path: root/contrib/native/client/src/clientlib/drillClientImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/native/client/src/clientlib/drillClientImpl.cpp')
-rw-r--r--contrib/native/client/src/clientlib/drillClientImpl.cpp201
1 files changed, 170 insertions, 31 deletions
diff --git a/contrib/native/client/src/clientlib/drillClientImpl.cpp b/contrib/native/client/src/clientlib/drillClientImpl.cpp
index 05171e5c8..4486068e8 100644
--- a/contrib/native/client/src/clientlib/drillClientImpl.cpp
+++ b/contrib/native/client/src/clientlib/drillClientImpl.cpp
@@ -20,6 +20,7 @@
#include "drill/common.hpp"
#include <queue>
#include <string>
+#include <boost/algorithm/string.hpp>
#include <boost/asio.hpp>
#include <boost/assign.hpp>
#include <boost/bind.hpp>
@@ -43,6 +44,7 @@
#include "GeneralRPC.pb.h"
#include "UserBitShared.pb.h"
#include "zookeeperClient.hpp"
+#include "saslAuthenticatorImpl.hpp"
namespace Drill{
@@ -58,7 +60,7 @@ static std::string debugPrintQid(const exec::shared::QueryId& qid){
return std::string("[")+boost::lexical_cast<std::string>(qid.part1()) +std::string(":") + boost::lexical_cast<std::string>(qid.part2())+std::string("] ");
}
-connectionStatus_t DrillClientImpl::connect(const char* connStr){
+connectionStatus_t DrillClientImpl::connect(const char* connStr, DrillUserProperties* props){
std::string pathToDrill, protocol, hostPortStr;
std::string host;
std::string port;
@@ -103,6 +105,15 @@ connectionStatus_t DrillClientImpl::connect(const char* connStr){
return handleConnError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, protocol.c_str()));
}
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Connecting to endpoint: " << host << ":" << port << std::endl;)
+ std::string serviceHost;
+ for (size_t i = 0; i < props->size(); i++) {
+ if (props->keyAt(i) == USERPROP_SERVICE_HOST) {
+ serviceHost = props->valueAt(i);
+ }
+ }
+ if (serviceHost.empty()) {
+ props->setProperty(USERPROP_SERVICE_HOST, host);
+ }
connectionStatus_t ret = this->connect(host.c_str(), port.c_str());
return ret;
}
@@ -308,6 +319,11 @@ void DrillClientImpl::handleHandshake(ByteBuf_t _buf,
this->m_handshakeErrorId=b2u.errorid();
this->m_handshakeErrorMsg=b2u.errormessage();
this->m_serverInfos = b2u.server_infos();
+ for (int i=0; i<b2u.authenticationmechanisms_size(); i++) {
+ std::string mechanism = b2u.authenticationmechanisms(i);
+ boost::algorithm::to_lower(mechanism);
+ this->m_serverAuthMechanisms.push_back(mechanism);
+ }
}else{
// boost error
@@ -348,6 +364,7 @@ connectionStatus_t DrillClientImpl::validateHandshake(DrillUserProperties* prope
u2b.set_rpc_version(DRILL_RPC_VERSION);
u2b.set_support_listening(true);
u2b.set_support_timeout(DrillClientConfig::getHeartbeatFrequency() > 0);
+ u2b.set_sasl_support(exec::user::SASL_AUTH);
// Adding version info
exec::user::RpcEndpointInfos* infos = u2b.mutable_client_infos();
@@ -412,37 +429,155 @@ connectionStatus_t DrillClientImpl::validateHandshake(DrillUserProperties* prope
if(ret!=CONN_SUCCESS){
return ret;
}
- if(this->m_handshakeStatus != exec::user::SUCCESS){
- switch(this->m_handshakeStatus){
- case exec::user::RPC_VERSION_MISMATCH:
- DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected "
- << DRILL_RPC_VERSION << ", actual "<< m_handshakeVersion << "." << std::endl;)
- return handleConnError(CONN_BAD_RPC_VER,
- getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION,
- m_handshakeVersion,
- this->m_handshakeErrorId.c_str(),
- this->m_handshakeErrorMsg.c_str()));
- case exec::user::AUTH_FAILED:
- DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;)
- return handleConnError(CONN_AUTH_FAILED,
- getMessage(ERR_CONN_AUTHFAIL,
- this->m_handshakeErrorId.c_str(),
- this->m_handshakeErrorMsg.c_str()));
- case exec::user::UNKNOWN_FAILURE:
- DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;)
- return handleConnError(CONN_HANDSHAKE_FAILED,
- getMessage(ERR_CONN_UNKNOWN_ERR,
- this->m_handshakeErrorId.c_str(),
- this->m_handshakeErrorMsg.c_str()));
- default:
- break;
+
+ switch(this->m_handshakeStatus) {
+ case exec::user::SUCCESS:
+ // reset io_service after handshake is validated before running queries
+ m_io_service.reset();
+ return CONN_SUCCESS;
+ case exec::user::RPC_VERSION_MISMATCH:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected "
+ << DRILL_RPC_VERSION << ", actual "<< m_handshakeVersion << "." << std::endl;)
+ return handleConnError(CONN_BAD_RPC_VER, getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION,
+ m_handshakeVersion,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::AUTH_FAILED:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;)
+ return handleConnError(CONN_AUTH_FAILED, getMessage(ERR_CONN_AUTHFAIL,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::UNKNOWN_FAILURE:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;)
+ return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::AUTH_REQUIRED:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Server requires SASL authentication." << std::endl;)
+ return handleAuthentication(properties);
+ default:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown return status." << std::endl;)
+ return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ }
+}
+
+connectionStatus_t DrillClientImpl::handleAuthentication(const DrillUserProperties *userProperties) {
+ try {
+ m_saslAuthenticator = new SaslAuthenticatorImpl(userProperties);
+ } catch (std::runtime_error& e) {
+ return handleConnError(CONN_AUTH_FAILED, e.what());
+ }
+
+ startMessageListener();
+ initiateAuthentication();
+
+ { // block until SASL exchange is complete
+ boost::mutex::scoped_lock lock(m_saslMutex);
+ while (!m_saslDone) {
+ m_saslCv.wait(lock);
}
}
- // reset io_service after handshake is validated before running queries
- m_io_service.reset();
- return CONN_SUCCESS;
+
+ if (SASL_OK == m_saslResultCode) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::handleAuthentication: Successfully authenticated!"
+ << std::endl;)
+
+ // in future, negotiated security layers are known here..
+
+ m_io_service.reset();
+ return CONN_SUCCESS;
+ } else {
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::handleAuthentication: Authentication failed: "
+ << m_saslResultCode << std::endl;)
+ // shuts down socket as well
+ return handleConnError(CONN_AUTH_FAILED, "Authentication failed. Check connection parameters?");
+ }
}
+void DrillClientImpl::initiateAuthentication() {
+ exec::shared::SaslMessage response;
+ m_saslResultCode = m_saslAuthenticator->init(m_serverAuthMechanisms, response);
+
+
+ switch (m_saslResultCode) {
+ case SASL_CONTINUE:
+ case SASL_OK: {
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::initiateAuthentication: initiated. " << std::endl;)
+ boost::lock_guard<boost::mutex> prLock(m_prMutex);
+ sendSaslResponse(response); // the challenge returned by server is handled by processSaslChallenge
+ break;
+ }
+ case SASL_NOMECH:
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::initiateAuthentication: "
+ << "Mechanism is not supported (by server/client)." << std::endl;)
+ default:
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::initiateAuthentication: "
+ << "Failed to initiate authentication." << std::endl;)
+ finishAuthentication();
+ break;
+ }
+}
+
+void DrillClientImpl::sendSaslResponse(const exec::shared::SaslMessage& response) {
+ boost::lock_guard<boost::mutex> lock(m_dcMutex);
+ const int32_t coordId = getNextCoordinationId();
+ rpc::OutBoundRpcMessage msg(exec::rpc::REQUEST, exec::user::SASL_MESSAGE, coordId, &response);
+ sendSync(msg);
+ if (m_pendingRequests++ == 0) {
+ getNextResult();
+ }
+}
+
+void DrillClientImpl::processSaslChallenge(AllocatedBufferPtr allocatedBuffer, const rpc::InBoundRpcMessage& msg) {
+ boost::shared_ptr<AllocatedBuffer> deallocationGuard(allocatedBuffer);
+ assert(m_saslAuthenticator != NULL);
+
+ // parse challenge
+ exec::shared::SaslMessage challenge;
+ const bool parseStatus = challenge.ParseFromArray(msg.m_pbody.data(), msg.m_pbody.size());
+ if (!parseStatus) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Failed to parse challenge." << std::endl;)
+ m_saslResultCode = SASL_FAIL;
+ finishAuthentication();
+ m_pendingRequests--;
+ return;
+ }
+
+ // respond accordingly
+ exec::shared::SaslMessage response;
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::processSaslChallenge: status: "
+ << exec::shared::SaslStatus_Name(challenge.status()) << std::endl;)
+ switch (challenge.status()) {
+ case exec::shared::SASL_IN_PROGRESS:
+ m_saslResultCode = m_saslAuthenticator->step(challenge, response);
+ if (m_saslResultCode == SASL_CONTINUE || m_saslResultCode == SASL_OK) {
+ sendSaslResponse(response);
+ } else { // failure
+ finishAuthentication();
+ }
+ break;
+ case exec::shared::SASL_SUCCESS:
+ if (SASL_CONTINUE == m_saslResultCode) { // client may need to evaluate once more
+ m_saslResultCode = m_saslAuthenticator->step(challenge, response);
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "SASL succeeded on client? " << m_saslResultCode << std::endl;)
+ }
+ finishAuthentication();
+ break;
+ default:
+ m_saslResultCode = SASL_FAIL;
+ finishAuthentication();
+ break;
+ }
+ m_pendingRequests--;
+}
+
+void DrillClientImpl::finishAuthentication() {
+ boost::mutex::scoped_lock lock(m_saslMutex);
+ m_saslDone = true;
+ m_saslCv.notify_one();
+}
FieldDefPtr DrillClientQueryResult::s_emptyColDefs( new (std::vector<Drill::FieldMetadata*>));
@@ -1369,6 +1504,10 @@ void DrillClientImpl::handleRead(ByteBuf_t _buf,
delete allocatedBuffer;
break;
+ case exec::user::SASL_MESSAGE:
+ processSaslChallenge(allocatedBuffer, msg);
+ break;
+
case exec::user::ACK:
// Cancel requests will result in an ACK sent back.
// Consume silently
@@ -1859,7 +1998,7 @@ void DrillClientPrepareHandle::clearAndDestroy(){
}
}
-connectionStatus_t PooledDrillClientImpl::connect(const char* connStr){
+connectionStatus_t PooledDrillClientImpl::connect(const char* connStr, DrillUserProperties* props){
connectionStatus_t stat = CONN_SUCCESS;
std::string pathToDrill, protocol, hostPortStr;
std::string host;
@@ -2062,7 +2201,7 @@ DrillClientImpl* PooledDrillClientImpl::getOneConnection(){
DrillClientImpl* pDrillClientImpl = NULL;
while(pDrillClientImpl==NULL){
if(m_queriesExecuted == 0){
- // First query ever sent can use the connection already established to authenticate the user
+ // First query ever sent can use the connection already established to handleAuthentication the user
boost::lock_guard<boost::mutex> lock(m_poolMutex);
pDrillClientImpl=m_clientConnections[0];// There should be one connection in the list when the first query is executed
}else if(m_clientConnections.size() == m_maxConcurrentConnections){
@@ -2077,7 +2216,7 @@ DrillClientImpl* PooledDrillClientImpl::getOneConnection(){
int tries=0;
connectionStatus_t ret=CONN_SUCCESS;
while(pDrillClientImpl==NULL && tries++ < 3){
- if((ret=connect(m_connectStr.c_str()))==CONN_SUCCESS){
+ if((ret=connect(m_connectStr.c_str(), m_pUserProperties.get()))==CONN_SUCCESS){
boost::lock_guard<boost::mutex> lock(m_poolMutex);
pDrillClientImpl=m_clientConnections.back();
ret=pDrillClientImpl->validateHandshake(m_pUserProperties.get());