diff options
Diffstat (limited to 'contrib/native/client/src/clientlib/drillClientImpl.cpp')
-rw-r--r-- | contrib/native/client/src/clientlib/drillClientImpl.cpp | 201 |
1 files changed, 170 insertions, 31 deletions
diff --git a/contrib/native/client/src/clientlib/drillClientImpl.cpp b/contrib/native/client/src/clientlib/drillClientImpl.cpp index 05171e5c8..4486068e8 100644 --- a/contrib/native/client/src/clientlib/drillClientImpl.cpp +++ b/contrib/native/client/src/clientlib/drillClientImpl.cpp @@ -20,6 +20,7 @@ #include "drill/common.hpp" #include <queue> #include <string> +#include <boost/algorithm/string.hpp> #include <boost/asio.hpp> #include <boost/assign.hpp> #include <boost/bind.hpp> @@ -43,6 +44,7 @@ #include "GeneralRPC.pb.h" #include "UserBitShared.pb.h" #include "zookeeperClient.hpp" +#include "saslAuthenticatorImpl.hpp" namespace Drill{ @@ -58,7 +60,7 @@ static std::string debugPrintQid(const exec::shared::QueryId& qid){ return std::string("[")+boost::lexical_cast<std::string>(qid.part1()) +std::string(":") + boost::lexical_cast<std::string>(qid.part2())+std::string("] "); } -connectionStatus_t DrillClientImpl::connect(const char* connStr){ +connectionStatus_t DrillClientImpl::connect(const char* connStr, DrillUserProperties* props){ std::string pathToDrill, protocol, hostPortStr; std::string host; std::string port; @@ -103,6 +105,15 @@ connectionStatus_t DrillClientImpl::connect(const char* connStr){ return handleConnError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, protocol.c_str())); } DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Connecting to endpoint: " << host << ":" << port << std::endl;) + std::string serviceHost; + for (size_t i = 0; i < props->size(); i++) { + if (props->keyAt(i) == USERPROP_SERVICE_HOST) { + serviceHost = props->valueAt(i); + } + } + if (serviceHost.empty()) { + props->setProperty(USERPROP_SERVICE_HOST, host); + } connectionStatus_t ret = this->connect(host.c_str(), port.c_str()); return ret; } @@ -308,6 +319,11 @@ void DrillClientImpl::handleHandshake(ByteBuf_t _buf, this->m_handshakeErrorId=b2u.errorid(); this->m_handshakeErrorMsg=b2u.errormessage(); this->m_serverInfos = b2u.server_infos(); + for (int i=0; i<b2u.authenticationmechanisms_size(); i++) { + std::string mechanism = b2u.authenticationmechanisms(i); + boost::algorithm::to_lower(mechanism); + this->m_serverAuthMechanisms.push_back(mechanism); + } }else{ // boost error @@ -348,6 +364,7 @@ connectionStatus_t DrillClientImpl::validateHandshake(DrillUserProperties* prope u2b.set_rpc_version(DRILL_RPC_VERSION); u2b.set_support_listening(true); u2b.set_support_timeout(DrillClientConfig::getHeartbeatFrequency() > 0); + u2b.set_sasl_support(exec::user::SASL_AUTH); // Adding version info exec::user::RpcEndpointInfos* infos = u2b.mutable_client_infos(); @@ -412,37 +429,155 @@ connectionStatus_t DrillClientImpl::validateHandshake(DrillUserProperties* prope if(ret!=CONN_SUCCESS){ return ret; } - if(this->m_handshakeStatus != exec::user::SUCCESS){ - switch(this->m_handshakeStatus){ - case exec::user::RPC_VERSION_MISMATCH: - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected " - << DRILL_RPC_VERSION << ", actual "<< m_handshakeVersion << "." << std::endl;) - return handleConnError(CONN_BAD_RPC_VER, - getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION, - m_handshakeVersion, - this->m_handshakeErrorId.c_str(), - this->m_handshakeErrorMsg.c_str())); - case exec::user::AUTH_FAILED: - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;) - return handleConnError(CONN_AUTH_FAILED, - getMessage(ERR_CONN_AUTHFAIL, - this->m_handshakeErrorId.c_str(), - this->m_handshakeErrorMsg.c_str())); - case exec::user::UNKNOWN_FAILURE: - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;) - return handleConnError(CONN_HANDSHAKE_FAILED, - getMessage(ERR_CONN_UNKNOWN_ERR, - this->m_handshakeErrorId.c_str(), - this->m_handshakeErrorMsg.c_str())); - default: - break; + + switch(this->m_handshakeStatus) { + case exec::user::SUCCESS: + // reset io_service after handshake is validated before running queries + m_io_service.reset(); + return CONN_SUCCESS; + case exec::user::RPC_VERSION_MISMATCH: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected " + << DRILL_RPC_VERSION << ", actual "<< m_handshakeVersion << "." << std::endl;) + return handleConnError(CONN_BAD_RPC_VER, getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION, + m_handshakeVersion, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + case exec::user::AUTH_FAILED: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;) + return handleConnError(CONN_AUTH_FAILED, getMessage(ERR_CONN_AUTHFAIL, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + case exec::user::UNKNOWN_FAILURE: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;) + return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + case exec::user::AUTH_REQUIRED: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Server requires SASL authentication." << std::endl;) + return handleAuthentication(properties); + default: + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown return status." << std::endl;) + return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR, + this->m_handshakeErrorId.c_str(), + this->m_handshakeErrorMsg.c_str())); + } +} + +connectionStatus_t DrillClientImpl::handleAuthentication(const DrillUserProperties *userProperties) { + try { + m_saslAuthenticator = new SaslAuthenticatorImpl(userProperties); + } catch (std::runtime_error& e) { + return handleConnError(CONN_AUTH_FAILED, e.what()); + } + + startMessageListener(); + initiateAuthentication(); + + { // block until SASL exchange is complete + boost::mutex::scoped_lock lock(m_saslMutex); + while (!m_saslDone) { + m_saslCv.wait(lock); } } - // reset io_service after handshake is validated before running queries - m_io_service.reset(); - return CONN_SUCCESS; + + if (SASL_OK == m_saslResultCode) { + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::handleAuthentication: Successfully authenticated!" + << std::endl;) + + // in future, negotiated security layers are known here.. + + m_io_service.reset(); + return CONN_SUCCESS; + } else { + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::handleAuthentication: Authentication failed: " + << m_saslResultCode << std::endl;) + // shuts down socket as well + return handleConnError(CONN_AUTH_FAILED, "Authentication failed. Check connection parameters?"); + } } +void DrillClientImpl::initiateAuthentication() { + exec::shared::SaslMessage response; + m_saslResultCode = m_saslAuthenticator->init(m_serverAuthMechanisms, response); + + + switch (m_saslResultCode) { + case SASL_CONTINUE: + case SASL_OK: { + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::initiateAuthentication: initiated. " << std::endl;) + boost::lock_guard<boost::mutex> prLock(m_prMutex); + sendSaslResponse(response); // the challenge returned by server is handled by processSaslChallenge + break; + } + case SASL_NOMECH: + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::initiateAuthentication: " + << "Mechanism is not supported (by server/client)." << std::endl;) + default: + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::initiateAuthentication: " + << "Failed to initiate authentication." << std::endl;) + finishAuthentication(); + break; + } +} + +void DrillClientImpl::sendSaslResponse(const exec::shared::SaslMessage& response) { + boost::lock_guard<boost::mutex> lock(m_dcMutex); + const int32_t coordId = getNextCoordinationId(); + rpc::OutBoundRpcMessage msg(exec::rpc::REQUEST, exec::user::SASL_MESSAGE, coordId, &response); + sendSync(msg); + if (m_pendingRequests++ == 0) { + getNextResult(); + } +} + +void DrillClientImpl::processSaslChallenge(AllocatedBufferPtr allocatedBuffer, const rpc::InBoundRpcMessage& msg) { + boost::shared_ptr<AllocatedBuffer> deallocationGuard(allocatedBuffer); + assert(m_saslAuthenticator != NULL); + + // parse challenge + exec::shared::SaslMessage challenge; + const bool parseStatus = challenge.ParseFromArray(msg.m_pbody.data(), msg.m_pbody.size()); + if (!parseStatus) { + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Failed to parse challenge." << std::endl;) + m_saslResultCode = SASL_FAIL; + finishAuthentication(); + m_pendingRequests--; + return; + } + + // respond accordingly + exec::shared::SaslMessage response; + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::processSaslChallenge: status: " + << exec::shared::SaslStatus_Name(challenge.status()) << std::endl;) + switch (challenge.status()) { + case exec::shared::SASL_IN_PROGRESS: + m_saslResultCode = m_saslAuthenticator->step(challenge, response); + if (m_saslResultCode == SASL_CONTINUE || m_saslResultCode == SASL_OK) { + sendSaslResponse(response); + } else { // failure + finishAuthentication(); + } + break; + case exec::shared::SASL_SUCCESS: + if (SASL_CONTINUE == m_saslResultCode) { // client may need to evaluate once more + m_saslResultCode = m_saslAuthenticator->step(challenge, response); + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "SASL succeeded on client? " << m_saslResultCode << std::endl;) + } + finishAuthentication(); + break; + default: + m_saslResultCode = SASL_FAIL; + finishAuthentication(); + break; + } + m_pendingRequests--; +} + +void DrillClientImpl::finishAuthentication() { + boost::mutex::scoped_lock lock(m_saslMutex); + m_saslDone = true; + m_saslCv.notify_one(); +} FieldDefPtr DrillClientQueryResult::s_emptyColDefs( new (std::vector<Drill::FieldMetadata*>)); @@ -1369,6 +1504,10 @@ void DrillClientImpl::handleRead(ByteBuf_t _buf, delete allocatedBuffer; break; + case exec::user::SASL_MESSAGE: + processSaslChallenge(allocatedBuffer, msg); + break; + case exec::user::ACK: // Cancel requests will result in an ACK sent back. // Consume silently @@ -1859,7 +1998,7 @@ void DrillClientPrepareHandle::clearAndDestroy(){ } } -connectionStatus_t PooledDrillClientImpl::connect(const char* connStr){ +connectionStatus_t PooledDrillClientImpl::connect(const char* connStr, DrillUserProperties* props){ connectionStatus_t stat = CONN_SUCCESS; std::string pathToDrill, protocol, hostPortStr; std::string host; @@ -2062,7 +2201,7 @@ DrillClientImpl* PooledDrillClientImpl::getOneConnection(){ DrillClientImpl* pDrillClientImpl = NULL; while(pDrillClientImpl==NULL){ if(m_queriesExecuted == 0){ - // First query ever sent can use the connection already established to authenticate the user + // First query ever sent can use the connection already established to handleAuthentication the user boost::lock_guard<boost::mutex> lock(m_poolMutex); pDrillClientImpl=m_clientConnections[0];// There should be one connection in the list when the first query is executed }else if(m_clientConnections.size() == m_maxConcurrentConnections){ @@ -2077,7 +2216,7 @@ DrillClientImpl* PooledDrillClientImpl::getOneConnection(){ int tries=0; connectionStatus_t ret=CONN_SUCCESS; while(pDrillClientImpl==NULL && tries++ < 3){ - if((ret=connect(m_connectStr.c_str()))==CONN_SUCCESS){ + if((ret=connect(m_connectStr.c_str(), m_pUserProperties.get()))==CONN_SUCCESS){ boost::lock_guard<boost::mutex> lock(m_poolMutex); pDrillClientImpl=m_clientConnections.back(); ret=pDrillClientImpl->validateHandshake(m_pUserProperties.get()); |