aboutsummaryrefslogtreecommitdiff
path: root/contrib
diff options
context:
space:
mode:
authorSudheesh Katkam <sudheesh@apache.org>2017-02-23 18:47:04 -0800
committerSudheesh Katkam <sudheesh@apache.org>2017-02-24 20:01:18 -0800
commit3c3b08c5abbd4e9e11dadb8e97367db0c13b3243 (patch)
treed07cd92c86e351f0d8ddde71a3de9d8cdbf6b9b4 /contrib
parentf9f99e08084c0016f5a3edc74b7b86b304436284 (diff)
DRILL-4280: CORE (user to bit authentication, C++)
closes #578
Diffstat (limited to 'contrib')
-rw-r--r--contrib/native/client/CMakeLists.txt4
-rw-r--r--contrib/native/client/cmakeModules/FindSASL.cmake49
-rw-r--r--contrib/native/client/example/querySubmitter.cpp9
-rw-r--r--contrib/native/client/src/clientlib/CMakeLists.txt7
-rw-r--r--contrib/native/client/src/clientlib/drillClient.cpp16
-rw-r--r--contrib/native/client/src/clientlib/drillClientImpl.cpp201
-rw-r--r--contrib/native/client/src/clientlib/drillClientImpl.hpp27
-rw-r--r--contrib/native/client/src/clientlib/saslAuthenticatorImpl.cpp207
-rw-r--r--contrib/native/client/src/clientlib/saslAuthenticatorImpl.hpp65
-rw-r--r--contrib/native/client/src/include/drill/common.hpp3
-rw-r--r--contrib/native/client/src/include/drill/drillClient.hpp4
11 files changed, 552 insertions, 40 deletions
diff --git a/contrib/native/client/CMakeLists.txt b/contrib/native/client/CMakeLists.txt
index 65e3b857b..7b54b00fd 100644
--- a/contrib/native/client/CMakeLists.txt
+++ b/contrib/native/client/CMakeLists.txt
@@ -125,6 +125,8 @@ include_directories(${PROTOBUF_INCLUDE_DIR})
#Find Zookeeper
find_package(Zookeeper REQUIRED )
+# Find Cyrus SASL
+find_package(SASL REQUIRED)
# Generated sources
configure_file(
@@ -152,6 +154,8 @@ add_subdirectory("${CMAKE_SOURCE_DIR}/src/clientlib/y2038")
add_subdirectory("${CMAKE_SOURCE_DIR}/src/clientlib")
include_directories(${CMAKE_SOURCE_DIR}/src/include ${Zookeeper_INCLUDE_DIRS})
+include_directories(${SASL_INCLUDE_DIRS})
+
add_subdirectory("${CMAKE_SOURCE_DIR}/src/test")
# add a DEBUG preprocessor macro
diff --git a/contrib/native/client/cmakeModules/FindSASL.cmake b/contrib/native/client/cmakeModules/FindSASL.cmake
new file mode 100644
index 000000000..35d91c7f5
--- /dev/null
+++ b/contrib/native/client/cmakeModules/FindSASL.cmake
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# - Try to find Cyrus SASL
+
+if (MSVC)
+ if("${SASL_HOME}_" MATCHES "^_$")
+ message(" ")
+ message("- Please set the cache variable SASL_HOME to point to the directory with the Cyrus SASL source.")
+ message("- CMAKE will look for Cyrus SASL include files in $SASL_HOME/include or $SASL_HOME/win32/include.")
+ message("- CMAKE will look for Cyrus SASL library files in $SASL_HOME/lib.")
+ else()
+ FILE(TO_CMAKE_PATH ${SASL_HOME} SASL_HomePath)
+ set(SASL_LIB_PATHS ${SASL_HomePath}/lib)
+
+ find_path(SASL_INCLUDE_DIR sasl.h ${SASL_HomePath}/include ${SASL_HomePath}/win32/include)
+ find_library(SASL_LIBRARY NAMES "libsasl2${CMAKE_SHARED_LIBRARY_SUFFIX}" PATHS ${SASL_LIB_PATHS})
+ endif()
+else()
+ set(SASL_LIB_PATHS /usr/local/lib /opt/local/lib)
+ find_path(SASL_INCLUDE_DIR sasl/sasl.h /usr/local/include /opt/local/include)
+ find_library(SASL_LIBRARY NAMES "libsasl2${CMAKE_SHARED_LIBRARY_SUFFIX}" PATHS ${SASL_LIB_PATHS})
+endif()
+
+
+set(SASL_LIBRARIES ${SASL_LIBRARY})
+set(SASL_INCLUDE_DIRS ${SASL_INCLUDE_DIR})
+
+include(FindPackageHandleStandardArgs)
+# handle the QUIETLY and REQUIRED arguments and set SASL_FOUND to TRUE if all listed variables are valid
+find_package_handle_standard_args(SASL DEFAULT_MSG
+ SASL_LIBRARY SASL_INCLUDE_DIR)
+
+mark_as_advanced(SASL_INCLUDE_DIR SASL_LIBRARY)
diff --git a/contrib/native/client/example/querySubmitter.cpp b/contrib/native/client/example/querySubmitter.cpp
index 60f7b8a95..5b85a3e5d 100644
--- a/contrib/native/client/example/querySubmitter.cpp
+++ b/contrib/native/client/example/querySubmitter.cpp
@@ -23,7 +23,7 @@
#include <boost/thread.hpp>
#include "drill/drillc.hpp"
-int nOptions=13;
+int nOptions=15;
struct Option{
char name[32];
@@ -43,7 +43,8 @@ struct Option{
{"queryTimeout", "Query timeout (second).", false},
{"heartbeatFrequency", "Heartbeat frequency (second). Disabled if set to 0.", false},
{"user", "Username", false},
- {"password", "Password", false}
+ {"password", "Password", false},
+ {"saslPluginPath", "Path to where SASL plugins are installed", false}
};
std::map<std::string, std::string> qsOptionValues;
@@ -286,6 +287,7 @@ int main(int argc, char* argv[]) {
std::string heartbeatFrequency=qsOptionValues["heartbeatFrequency"];
std::string user=qsOptionValues["user"];
std::string password=qsOptionValues["password"];
+ std::string saslPluginPath=qsOptionValues["saslPluginPath"];
Drill::QueryType type;
@@ -348,6 +350,9 @@ int main(int argc, char* argv[]) {
if(!heartbeatFrequency.empty()) {
Drill::DrillClientConfig::setHeartbeatFrequency(atoi(heartbeatFrequency.c_str()));
}
+ if (!saslPluginPath.empty()){
+ Drill::DrillClientConfig::setSaslPluginPath(saslPluginPath.c_str());
+ }
Drill::DrillUserProperties props;
if(schema.length()>0){
diff --git a/contrib/native/client/src/clientlib/CMakeLists.txt b/contrib/native/client/src/clientlib/CMakeLists.txt
index 68326e2fc..343bb4d8c 100644
--- a/contrib/native/client/src/clientlib/CMakeLists.txt
+++ b/contrib/native/client/src/clientlib/CMakeLists.txt
@@ -29,12 +29,13 @@ set (CLIENTLIB_SRC_FILES
${CMAKE_CURRENT_SOURCE_DIR}/errmsgs.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logger.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/zookeeperClient.cpp
- )
+ ${CMAKE_CURRENT_SOURCE_DIR}/saslAuthenticatorImpl.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/zookeeperClient.cpp)
include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../include )
include_directories(${PROTOBUF_INCLUDE_DIR})
include_directories(${Zookeeper_INCLUDE_DIRS})
+include_directories(${SASL_INCLUDE_DIRS})
link_directories(/usr/local/lib)
@@ -49,4 +50,4 @@ if(MSVC)
endif()
add_library(drillClient SHARED ${CLIENTLIB_SRC_FILES} )
-target_link_libraries(drillClient ${Boost_LIBRARIES} ${PROTOBUF_LIBRARY} ${Zookeeper_LIBRARIES} protomsgs y2038)
+target_link_libraries(drillClient ${Boost_LIBRARIES} ${PROTOBUF_LIBRARY} ${Zookeeper_LIBRARIES} ${SASL_LIBRARIES} protomsgs y2038)
diff --git a/contrib/native/client/src/clientlib/drillClient.cpp b/contrib/native/client/src/clientlib/drillClient.cpp
index f97a25c5f..70002721c 100644
--- a/contrib/native/client/src/clientlib/drillClient.cpp
+++ b/contrib/native/client/src/clientlib/drillClient.cpp
@@ -47,6 +47,7 @@ DrillClientInitializer::~DrillClientInitializer(){
// Initialize static member of DrillClientConfig
logLevel_t DrillClientConfig::s_logLevel=LOG_ERROR;
+const char* DrillClientConfig::s_saslPluginPath = NULL;
uint64_t DrillClientConfig::s_bufferLimit=MAX_MEM_ALLOC_SIZE;
int32_t DrillClientConfig::s_socketTimeout=0;
int32_t DrillClientConfig::s_handshakeTimeout=5;
@@ -77,6 +78,16 @@ void DrillClientConfig::setLogLevel(logLevel_t l){
//boost::log::core::get()->set_filter(boost::log::trivial::severity >= s_logLevel);
}
+void DrillClientConfig::setSaslPluginPath(const char *path){
+ boost::lock_guard<boost::mutex> configLock(DrillClientConfig::s_mutex);
+ s_saslPluginPath = path;
+}
+
+const char* DrillClientConfig::getSaslPluginPath(){
+ boost::lock_guard<boost::mutex> configLock(DrillClientConfig::s_mutex);
+ return s_saslPluginPath;
+}
+
void DrillClientConfig::setBufferLimit(uint64_t l){
boost::lock_guard<boost::mutex> configLock(DrillClientConfig::s_mutex);
s_bufferLimit=l;
@@ -164,6 +175,9 @@ const std::map<std::string, uint32_t> DrillUserProperties::USER_PROPERTIES=boos
( USERPROP_PASSWORD, USERPROP_FLAGS_SERVERPROP|USERPROP_FLAGS_PASSWORD)
( USERPROP_SCHEMA, USERPROP_FLAGS_SERVERPROP|USERPROP_FLAGS_STRING)
( USERPROP_IMPERSONATION_TARGET, USERPROP_FLAGS_SERVERPROP|USERPROP_FLAGS_STRING)
+ ( USERPROP_AUTH_MECHANISM, USERPROP_FLAGS_STRING)
+ ( USERPROP_SERVICE_NAME, USERPROP_FLAGS_STRING)
+ ( USERPROP_SERVICE_HOST, USERPROP_FLAGS_STRING)
( USERPROP_USESSL, USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP)
( USERPROP_FILEPATH, USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP|USERPROP_FLAGS_FILEPATH)
( USERPROP_FILENAME, USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP|USERPROP_FLAGS_FILENAME)
@@ -365,7 +379,7 @@ connectionStatus_t DrillClient::connect(const char* connectStr, const char* defa
connectionStatus_t DrillClient::connect(const char* connectStr, DrillUserProperties* properties){
connectionStatus_t ret=CONN_SUCCESS;
- ret=this->m_pImpl->connect(connectStr);
+ ret=this->m_pImpl->connect(connectStr, properties);
if(ret==CONN_SUCCESS){
ret=this->m_pImpl->validateHandshake(properties);
}
diff --git a/contrib/native/client/src/clientlib/drillClientImpl.cpp b/contrib/native/client/src/clientlib/drillClientImpl.cpp
index 05171e5c8..4486068e8 100644
--- a/contrib/native/client/src/clientlib/drillClientImpl.cpp
+++ b/contrib/native/client/src/clientlib/drillClientImpl.cpp
@@ -20,6 +20,7 @@
#include "drill/common.hpp"
#include <queue>
#include <string>
+#include <boost/algorithm/string.hpp>
#include <boost/asio.hpp>
#include <boost/assign.hpp>
#include <boost/bind.hpp>
@@ -43,6 +44,7 @@
#include "GeneralRPC.pb.h"
#include "UserBitShared.pb.h"
#include "zookeeperClient.hpp"
+#include "saslAuthenticatorImpl.hpp"
namespace Drill{
@@ -58,7 +60,7 @@ static std::string debugPrintQid(const exec::shared::QueryId& qid){
return std::string("[")+boost::lexical_cast<std::string>(qid.part1()) +std::string(":") + boost::lexical_cast<std::string>(qid.part2())+std::string("] ");
}
-connectionStatus_t DrillClientImpl::connect(const char* connStr){
+connectionStatus_t DrillClientImpl::connect(const char* connStr, DrillUserProperties* props){
std::string pathToDrill, protocol, hostPortStr;
std::string host;
std::string port;
@@ -103,6 +105,15 @@ connectionStatus_t DrillClientImpl::connect(const char* connStr){
return handleConnError(CONN_INVALID_INPUT, getMessage(ERR_CONN_UNKPROTO, protocol.c_str()));
}
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Connecting to endpoint: " << host << ":" << port << std::endl;)
+ std::string serviceHost;
+ for (size_t i = 0; i < props->size(); i++) {
+ if (props->keyAt(i) == USERPROP_SERVICE_HOST) {
+ serviceHost = props->valueAt(i);
+ }
+ }
+ if (serviceHost.empty()) {
+ props->setProperty(USERPROP_SERVICE_HOST, host);
+ }
connectionStatus_t ret = this->connect(host.c_str(), port.c_str());
return ret;
}
@@ -308,6 +319,11 @@ void DrillClientImpl::handleHandshake(ByteBuf_t _buf,
this->m_handshakeErrorId=b2u.errorid();
this->m_handshakeErrorMsg=b2u.errormessage();
this->m_serverInfos = b2u.server_infos();
+ for (int i=0; i<b2u.authenticationmechanisms_size(); i++) {
+ std::string mechanism = b2u.authenticationmechanisms(i);
+ boost::algorithm::to_lower(mechanism);
+ this->m_serverAuthMechanisms.push_back(mechanism);
+ }
}else{
// boost error
@@ -348,6 +364,7 @@ connectionStatus_t DrillClientImpl::validateHandshake(DrillUserProperties* prope
u2b.set_rpc_version(DRILL_RPC_VERSION);
u2b.set_support_listening(true);
u2b.set_support_timeout(DrillClientConfig::getHeartbeatFrequency() > 0);
+ u2b.set_sasl_support(exec::user::SASL_AUTH);
// Adding version info
exec::user::RpcEndpointInfos* infos = u2b.mutable_client_infos();
@@ -412,37 +429,155 @@ connectionStatus_t DrillClientImpl::validateHandshake(DrillUserProperties* prope
if(ret!=CONN_SUCCESS){
return ret;
}
- if(this->m_handshakeStatus != exec::user::SUCCESS){
- switch(this->m_handshakeStatus){
- case exec::user::RPC_VERSION_MISMATCH:
- DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected "
- << DRILL_RPC_VERSION << ", actual "<< m_handshakeVersion << "." << std::endl;)
- return handleConnError(CONN_BAD_RPC_VER,
- getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION,
- m_handshakeVersion,
- this->m_handshakeErrorId.c_str(),
- this->m_handshakeErrorMsg.c_str()));
- case exec::user::AUTH_FAILED:
- DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;)
- return handleConnError(CONN_AUTH_FAILED,
- getMessage(ERR_CONN_AUTHFAIL,
- this->m_handshakeErrorId.c_str(),
- this->m_handshakeErrorMsg.c_str()));
- case exec::user::UNKNOWN_FAILURE:
- DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;)
- return handleConnError(CONN_HANDSHAKE_FAILED,
- getMessage(ERR_CONN_UNKNOWN_ERR,
- this->m_handshakeErrorId.c_str(),
- this->m_handshakeErrorMsg.c_str()));
- default:
- break;
+
+ switch(this->m_handshakeStatus) {
+ case exec::user::SUCCESS:
+ // reset io_service after handshake is validated before running queries
+ m_io_service.reset();
+ return CONN_SUCCESS;
+ case exec::user::RPC_VERSION_MISMATCH:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected "
+ << DRILL_RPC_VERSION << ", actual "<< m_handshakeVersion << "." << std::endl;)
+ return handleConnError(CONN_BAD_RPC_VER, getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION,
+ m_handshakeVersion,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::AUTH_FAILED:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;)
+ return handleConnError(CONN_AUTH_FAILED, getMessage(ERR_CONN_AUTHFAIL,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::UNKNOWN_FAILURE:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;)
+ return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ case exec::user::AUTH_REQUIRED:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Server requires SASL authentication." << std::endl;)
+ return handleAuthentication(properties);
+ default:
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown return status." << std::endl;)
+ return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
+ this->m_handshakeErrorId.c_str(),
+ this->m_handshakeErrorMsg.c_str()));
+ }
+}
+
+connectionStatus_t DrillClientImpl::handleAuthentication(const DrillUserProperties *userProperties) {
+ try {
+ m_saslAuthenticator = new SaslAuthenticatorImpl(userProperties);
+ } catch (std::runtime_error& e) {
+ return handleConnError(CONN_AUTH_FAILED, e.what());
+ }
+
+ startMessageListener();
+ initiateAuthentication();
+
+ { // block until SASL exchange is complete
+ boost::mutex::scoped_lock lock(m_saslMutex);
+ while (!m_saslDone) {
+ m_saslCv.wait(lock);
}
}
- // reset io_service after handshake is validated before running queries
- m_io_service.reset();
- return CONN_SUCCESS;
+
+ if (SASL_OK == m_saslResultCode) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::handleAuthentication: Successfully authenticated!"
+ << std::endl;)
+
+ // in future, negotiated security layers are known here..
+
+ m_io_service.reset();
+ return CONN_SUCCESS;
+ } else {
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::handleAuthentication: Authentication failed: "
+ << m_saslResultCode << std::endl;)
+ // shuts down socket as well
+ return handleConnError(CONN_AUTH_FAILED, "Authentication failed. Check connection parameters?");
+ }
}
+void DrillClientImpl::initiateAuthentication() {
+ exec::shared::SaslMessage response;
+ m_saslResultCode = m_saslAuthenticator->init(m_serverAuthMechanisms, response);
+
+
+ switch (m_saslResultCode) {
+ case SASL_CONTINUE:
+ case SASL_OK: {
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::initiateAuthentication: initiated. " << std::endl;)
+ boost::lock_guard<boost::mutex> prLock(m_prMutex);
+ sendSaslResponse(response); // the challenge returned by server is handled by processSaslChallenge
+ break;
+ }
+ case SASL_NOMECH:
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::initiateAuthentication: "
+ << "Mechanism is not supported (by server/client)." << std::endl;)
+ default:
+ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::initiateAuthentication: "
+ << "Failed to initiate authentication." << std::endl;)
+ finishAuthentication();
+ break;
+ }
+}
+
+void DrillClientImpl::sendSaslResponse(const exec::shared::SaslMessage& response) {
+ boost::lock_guard<boost::mutex> lock(m_dcMutex);
+ const int32_t coordId = getNextCoordinationId();
+ rpc::OutBoundRpcMessage msg(exec::rpc::REQUEST, exec::user::SASL_MESSAGE, coordId, &response);
+ sendSync(msg);
+ if (m_pendingRequests++ == 0) {
+ getNextResult();
+ }
+}
+
+void DrillClientImpl::processSaslChallenge(AllocatedBufferPtr allocatedBuffer, const rpc::InBoundRpcMessage& msg) {
+ boost::shared_ptr<AllocatedBuffer> deallocationGuard(allocatedBuffer);
+ assert(m_saslAuthenticator != NULL);
+
+ // parse challenge
+ exec::shared::SaslMessage challenge;
+ const bool parseStatus = challenge.ParseFromArray(msg.m_pbody.data(), msg.m_pbody.size());
+ if (!parseStatus) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Failed to parse challenge." << std::endl;)
+ m_saslResultCode = SASL_FAIL;
+ finishAuthentication();
+ m_pendingRequests--;
+ return;
+ }
+
+ // respond accordingly
+ exec::shared::SaslMessage response;
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::processSaslChallenge: status: "
+ << exec::shared::SaslStatus_Name(challenge.status()) << std::endl;)
+ switch (challenge.status()) {
+ case exec::shared::SASL_IN_PROGRESS:
+ m_saslResultCode = m_saslAuthenticator->step(challenge, response);
+ if (m_saslResultCode == SASL_CONTINUE || m_saslResultCode == SASL_OK) {
+ sendSaslResponse(response);
+ } else { // failure
+ finishAuthentication();
+ }
+ break;
+ case exec::shared::SASL_SUCCESS:
+ if (SASL_CONTINUE == m_saslResultCode) { // client may need to evaluate once more
+ m_saslResultCode = m_saslAuthenticator->step(challenge, response);
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "SASL succeeded on client? " << m_saslResultCode << std::endl;)
+ }
+ finishAuthentication();
+ break;
+ default:
+ m_saslResultCode = SASL_FAIL;
+ finishAuthentication();
+ break;
+ }
+ m_pendingRequests--;
+}
+
+void DrillClientImpl::finishAuthentication() {
+ boost::mutex::scoped_lock lock(m_saslMutex);
+ m_saslDone = true;
+ m_saslCv.notify_one();
+}
FieldDefPtr DrillClientQueryResult::s_emptyColDefs( new (std::vector<Drill::FieldMetadata*>));
@@ -1369,6 +1504,10 @@ void DrillClientImpl::handleRead(ByteBuf_t _buf,
delete allocatedBuffer;
break;
+ case exec::user::SASL_MESSAGE:
+ processSaslChallenge(allocatedBuffer, msg);
+ break;
+
case exec::user::ACK:
// Cancel requests will result in an ACK sent back.
// Consume silently
@@ -1859,7 +1998,7 @@ void DrillClientPrepareHandle::clearAndDestroy(){
}
}
-connectionStatus_t PooledDrillClientImpl::connect(const char* connStr){
+connectionStatus_t PooledDrillClientImpl::connect(const char* connStr, DrillUserProperties* props){
connectionStatus_t stat = CONN_SUCCESS;
std::string pathToDrill, protocol, hostPortStr;
std::string host;
@@ -2062,7 +2201,7 @@ DrillClientImpl* PooledDrillClientImpl::getOneConnection(){
DrillClientImpl* pDrillClientImpl = NULL;
while(pDrillClientImpl==NULL){
if(m_queriesExecuted == 0){
- // First query ever sent can use the connection already established to authenticate the user
+ // First query ever sent can use the connection already established to handleAuthentication the user
boost::lock_guard<boost::mutex> lock(m_poolMutex);
pDrillClientImpl=m_clientConnections[0];// There should be one connection in the list when the first query is executed
}else if(m_clientConnections.size() == m_maxConcurrentConnections){
@@ -2077,7 +2216,7 @@ DrillClientImpl* PooledDrillClientImpl::getOneConnection(){
int tries=0;
connectionStatus_t ret=CONN_SUCCESS;
while(pDrillClientImpl==NULL && tries++ < 3){
- if((ret=connect(m_connectStr.c_str()))==CONN_SUCCESS){
+ if((ret=connect(m_connectStr.c_str(), m_pUserProperties.get()))==CONN_SUCCESS){
boost::lock_guard<boost::mutex> lock(m_poolMutex);
pDrillClientImpl=m_clientConnections.back();
ret=pDrillClientImpl->validateHandshake(m_pUserProperties.get());
diff --git a/contrib/native/client/src/clientlib/drillClientImpl.hpp b/contrib/native/client/src/clientlib/drillClientImpl.hpp
index 22e34af7d..262edc9ea 100644
--- a/contrib/native/client/src/clientlib/drillClientImpl.hpp
+++ b/contrib/native/client/src/clientlib/drillClientImpl.hpp
@@ -50,6 +50,7 @@
#include "utils.hpp"
#include "User.pb.h"
#include "UserBitShared.pb.h"
+#include "saslAuthenticatorImpl.hpp"
namespace Drill {
@@ -73,7 +74,7 @@ class DrillClientImplBase{
//Connect via Zookeeper or directly.
//Makes an initial connection to a drillbit. successful connect adds the first drillbit to the pool.
- virtual connectionStatus_t connect(const char* connStr)=0;
+ virtual connectionStatus_t connect(const char* connStr, DrillUserProperties* props)=0;
// Test whether the client is active. Returns true if any one of the underlying connections is active
virtual bool Active()=0;
@@ -362,6 +363,8 @@ class DrillClientImpl : public DrillClientImplBase{
m_handshakeVersion(0),
m_handshakeStatus(exec::user::SUCCESS),
m_bIsConnected(false),
+ m_saslAuthenticator(NULL),
+ m_saslDone(false),
m_pendingRequests(0),
m_pError(NULL),
m_pListenerThread(NULL),
@@ -385,6 +388,10 @@ class DrillClientImpl : public DrillClientImplBase{
delete this->m_pWork;
this->m_pWork = NULL;
}
+ if(this->m_saslAuthenticator!=NULL){
+ delete this->m_saslAuthenticator;
+ this->m_saslAuthenticator = NULL;
+ }
m_heartbeatTimer.cancel();
m_deadlineTimer.cancel();
@@ -415,7 +422,7 @@ class DrillClientImpl : public DrillClientImplBase{
};
//Connect via Zookeeper or directly
- connectionStatus_t connect(const char* connStr);
+ connectionStatus_t connect(const char* connStr, DrillUserProperties* props);
// test whether the client is active
bool Active();
void Close() ;
@@ -511,6 +518,13 @@ class DrillClientImpl : public DrillClientImplBase{
DrillClientTableResult* getTables(const std::string& catalogPattern, const std::string& schemaPattern, const std::string& tablePattern, const std::vector<std::string>* tableTypes, Metadata::pfnTableMetadataListener listener, void* listenerCtx);
DrillClientColumnResult* getColumns(const std::string& catalogPattern, const std::string& schemaPattern, const std::string& tablePattern, const std::string& columnPattern, Metadata::pfnColumnMetadataListener listener, void* listenerCtx);
+ // SASL exchange
+ connectionStatus_t handleAuthentication(const DrillUserProperties *userProperties);
+ void initiateAuthentication();
+ void sendSaslResponse(const exec::shared::SaslMessage& response);
+ void processSaslChallenge(AllocatedBufferPtr allocatedBuffer, const rpc::InBoundRpcMessage& msg);
+ void finishAuthentication();
+
void shutdownSocket();
int32_t m_coordinationId;
@@ -521,6 +535,13 @@ class DrillClientImpl : public DrillClientImplBase{
exec::user::RpcEndpointInfos m_serverInfos;
bool m_bIsConnected;
+ std::vector<std::string> m_serverAuthMechanisms;
+ SaslAuthenticatorImpl* m_saslAuthenticator;
+ int m_saslResultCode;
+ bool m_saslDone;
+ boost::mutex m_saslMutex; // mutex to protect m_saslDone
+ boost::condition_variable m_saslCv; // to signal completion of SASL exchange
+
std::string m_connectStr;
//
@@ -605,7 +626,7 @@ class PooledDrillClientImpl : public DrillClientImplBase{
//Connect via Zookeeper or directly.
//Makes an initial connection to a drillbit. successful connect adds the first drillbit to the pool.
- connectionStatus_t connect(const char* connStr);
+ connectionStatus_t connect(const char* connStr, DrillUserProperties* props);
// Test whether the client is active. Returns true if any one of the underlying connections is active
bool Active();
diff --git a/contrib/native/client/src/clientlib/saslAuthenticatorImpl.cpp b/contrib/native/client/src/clientlib/saslAuthenticatorImpl.cpp
new file mode 100644
index 000000000..e7e2ba594
--- /dev/null
+++ b/contrib/native/client/src/clientlib/saslAuthenticatorImpl.cpp
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <vector>
+#include <algorithm>
+#include <boost/algorithm/string.hpp>
+#include <boost/assign.hpp>
+#include "saslAuthenticatorImpl.hpp"
+
+#include "drillClientImpl.hpp"
+#include "logger.hpp"
+
+namespace Drill {
+
+static const std::string DEFAULT_SERVICE_NAME = "drill";
+
+static const std::string KERBEROS_SIMPLE_NAME = "kerberos";
+static const std::string KERBEROS_SASL_NAME = "gssapi";
+static const std::string PLAIN_NAME = "plain";
+
+const std::map<std::string, std::string> SaslAuthenticatorImpl::MECHANISM_MAPPING = boost::assign::map_list_of
+ (KERBEROS_SIMPLE_NAME, KERBEROS_SASL_NAME)
+ (PLAIN_NAME, PLAIN_NAME)
+;
+
+boost::mutex SaslAuthenticatorImpl::s_mutex;
+bool SaslAuthenticatorImpl::s_initialized = false;
+
+SaslAuthenticatorImpl::SaslAuthenticatorImpl(const DrillUserProperties* const properties) :
+ m_pUserProperties(properties), m_pConnection(NULL), m_ppwdSecret(NULL) {
+
+ if (!s_initialized) {
+ boost::lock_guard<boost::mutex> lock(SaslAuthenticatorImpl::s_mutex);
+ if (!s_initialized) {
+ // set plugin path if provided
+ if (DrillClientConfig::getSaslPluginPath()) {
+ char *saslPluginPath = const_cast<char *>(DrillClientConfig::getSaslPluginPath());
+ sasl_set_path(0, saslPluginPath);
+ }
+
+ // loads all the available mechanism and factories in the sasl_lib referenced by the path
+ const int err = sasl_client_init(NULL);
+ if (0 != err) {
+ std::stringstream errMsg;
+ errMsg << "Failed to load authentication libraries. code: " << err;
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << errMsg.str() << std::endl;)
+ throw std::runtime_error(errMsg.str().c_str());
+ }
+ { // for debugging purposes
+ const char **mechanisms = sasl_global_listmech();
+ int i = 0;
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "SASL mechanisms available on client: " << std::endl;)
+ while (mechanisms[i] != NULL) {
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << i << " : " << mechanisms[i] << std::endl;)
+ i++;
+ }
+ }
+ s_initialized = true;
+ }
+ }
+}
+
+SaslAuthenticatorImpl::~SaslAuthenticatorImpl() {
+ if (m_ppwdSecret) {
+ free(m_ppwdSecret);
+ }
+ m_ppwdSecret = NULL;
+ // may be used to negotiated security layers before disposing in the future
+ if (m_pConnection) {
+ sasl_dispose(&m_pConnection);
+ }
+ m_pConnection = NULL;
+}
+
+typedef int (*sasl_callback_proc_t)(void); // see sasl_callback_ft
+
+int SaslAuthenticatorImpl::userNameCallback(void *context, int id, const char **result, unsigned *len) {
+ const std::string* const username = static_cast<const std::string* const>(context);
+
+ if ((SASL_CB_USER == id || SASL_CB_AUTHNAME == id)
+ && username != NULL) {
+ *result = username->c_str();
+ // *len = (unsigned int) username->length();
+ }
+ return SASL_OK;
+}
+
+int SaslAuthenticatorImpl::passwordCallback(sasl_conn_t *conn, void *context, int id, sasl_secret_t **psecret) {
+ const SaslAuthenticatorImpl* const authenticator = static_cast<const SaslAuthenticatorImpl* const>(context);
+
+ if (SASL_CB_PASS == id) {
+ *psecret = authenticator->m_ppwdSecret;
+ }
+ return SASL_OK;
+}
+
+int SaslAuthenticatorImpl::init(const std::vector<std::string>& mechanisms, exec::shared::SaslMessage& response) {
+ // find and set parameters
+ std::string authMechanismToUse;
+ std::string serviceName;
+ std::string serviceHost;
+ for (size_t i = 0; i < m_pUserProperties->size(); i++) {
+ const std::string key = m_pUserProperties->keyAt(i);
+ const std::string value = m_pUserProperties->valueAt(i);
+
+ if (USERPROP_SERVICE_HOST == key) {
+ serviceHost = value;
+ } else if (USERPROP_SERVICE_NAME == key) {
+ serviceName = value;
+ } else if (USERPROP_PASSWORD == key) {
+ const size_t length = value.length();
+ m_ppwdSecret = (sasl_secret_t *) malloc(sizeof(sasl_secret_t) + length);
+ std::memcpy(m_ppwdSecret->data, value.c_str(), length);
+ m_ppwdSecret->len = length;
+ authMechanismToUse = PLAIN_NAME;
+ } else if (USERPROP_USERNAME == key) {
+ m_username = value;
+ } else if (USERPROP_AUTH_MECHANISM == key) {
+ authMechanismToUse = value;
+ }
+ }
+ if (authMechanismToUse.empty()) return SASL_NOMECH;
+
+ // check if requested mechanism is supported by server
+ boost::algorithm::to_lower(authMechanismToUse);
+ if (std::find(mechanisms.begin(), mechanisms.end(), authMechanismToUse) == mechanisms.end()) return SASL_NOMECH;
+
+ // find the SASL name
+ const std::map<std::string, std::string>::const_iterator it =
+ SaslAuthenticatorImpl::MECHANISM_MAPPING.find(authMechanismToUse);
+ if (it == SaslAuthenticatorImpl::MECHANISM_MAPPING.end()) return SASL_NOMECH;
+
+ const std::string saslMechanismToUse = it->second;
+
+ // setup callbacks and parameters
+ const sasl_callback_t callbacks[] = {
+ { SASL_CB_USER, (sasl_callback_proc_t) &userNameCallback, static_cast<void *>(&m_username) },
+ { SASL_CB_AUTHNAME, (sasl_callback_proc_t) &userNameCallback, static_cast<void *>(&m_username) },
+ { SASL_CB_PASS, (sasl_callback_proc_t) &passwordCallback, static_cast<void *>(this) },
+ { SASL_CB_LIST_END, NULL, NULL }
+ };
+ if (serviceName.empty()) serviceName = DEFAULT_SERVICE_NAME;
+
+ // create SASL client
+ int saslResult = sasl_client_new(serviceName.c_str(), serviceHost.c_str(), NULL /** iplocalport */,
+ NULL /** ipremoteport */, callbacks, 0 /** sec flags */, &m_pConnection);
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "SaslAuthenticatorImpl::init: sasl_client_new code: "
+ << saslResult << std::endl;)
+ if (saslResult != SASL_OK) return saslResult;
+
+ // initiate; for now, pass in only one mechanism
+ const char *out;
+ unsigned outlen;
+ const char *mech;
+ saslResult = sasl_client_start(m_pConnection, saslMechanismToUse.c_str(), NULL /** no prompt */, &out, &outlen,
+ &mech);
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "SaslAuthenticatorImpl::init: sasl_client_start code: "
+ << saslResult << std::endl;)
+ if (saslResult != SASL_OK && saslResult != SASL_CONTINUE) return saslResult;
+
+ // prepare response
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "SaslAuthenticatorImpl::init: chosen: " << authMechanismToUse << std::endl;)
+ response.set_mechanism(authMechanismToUse);
+ response.set_data(NULL == out ? "" : out, outlen);
+ response.set_status(exec::shared::SASL_START);
+ return saslResult;
+}
+
+int SaslAuthenticatorImpl::step(const exec::shared::SaslMessage& challenge, exec::shared::SaslMessage& response) const {
+ const char *in = challenge.data().c_str();
+ const unsigned inlen = challenge.data().length();
+ const char *out;
+ unsigned outlen;
+ const int saslResult = sasl_client_step(m_pConnection, in, inlen, NULL /** no prompt */, &out, &outlen);
+ switch (saslResult) {
+ case SASL_CONTINUE:
+ response.set_data(out, outlen);
+ response.set_status(exec::shared::SASL_IN_PROGRESS);
+ break;
+ case SASL_OK:
+ response.set_data(out, outlen);
+ response.set_status(exec::shared::SASL_SUCCESS);
+ break;
+ default:
+ response.set_status(exec::shared::SASL_FAILED);
+ break;
+ }
+ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "SaslAuthenticatorImpl::step: result: " << saslResult << std::endl;)
+ return saslResult;
+}
+
+} /* namespace Drill */
diff --git a/contrib/native/client/src/clientlib/saslAuthenticatorImpl.hpp b/contrib/native/client/src/clientlib/saslAuthenticatorImpl.hpp
new file mode 100644
index 000000000..5e36ee123
--- /dev/null
+++ b/contrib/native/client/src/clientlib/saslAuthenticatorImpl.hpp
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef DRILLCLIENT_SASLAUTHENTICATORIMPL_HPP
+#define DRILLCLIENT_SASLAUTHENTICATORIMPL_HPP
+
+#include <string>
+#include <map>
+#include <vector>
+#include "drill/drillClient.hpp"
+#include "UserBitShared.pb.h"
+
+#include "sasl/sasl.h"
+#include "sasl/saslplug.h"
+
+namespace Drill {
+
+class SaslAuthenticatorImpl {
+
+public:
+
+ SaslAuthenticatorImpl(const DrillUserProperties *const properties);
+
+ ~SaslAuthenticatorImpl();
+
+ int init(const std::vector<std::string>& mechanisms, exec::shared::SaslMessage& response);
+
+ int step(const exec::shared::SaslMessage& challenge, exec::shared::SaslMessage& response) const;
+
+private:
+
+ static const std::map<std::string, std::string> MECHANISM_MAPPING;
+
+ static boost::mutex s_mutex;
+ static bool s_initialized;
+
+ const DrillUserProperties *const m_pUserProperties;
+ sasl_conn_t *m_pConnection;
+ std::string m_username;
+ sasl_secret_t *m_ppwdSecret;
+
+ static int passwordCallback(sasl_conn_t *conn, void *context, int id, sasl_secret_t **psecret);
+
+ static int userNameCallback(void *context, int id, const char **result, unsigned int *len);
+
+};
+
+} /* namespace Drill */
+
+#endif //DRILLCLIENT_SASLAUTHENTICATORIMPL_HPP
diff --git a/contrib/native/client/src/include/drill/common.hpp b/contrib/native/client/src/include/drill/common.hpp
index 6d3816e7c..ed0a1ed32 100644
--- a/contrib/native/client/src/include/drill/common.hpp
+++ b/contrib/native/client/src/include/drill/common.hpp
@@ -166,6 +166,9 @@ typedef enum{
#define USERPROP_FILEPATH "pemLocation" // Not implemented yet
#define USERPROP_FILENAME "pemFile" // Not implemented yet
#define USERPROP_IMPERSONATION_TARGET "impersonation_target"
+#define USERPROP_AUTH_MECHANISM "auth"
+#define USERPROP_SERVICE_NAME "service_name"
+#define USERPROP_SERVICE_HOST "service_host"
// Bitflags to describe user properties
// Used in DrillUserProperties::USER_PROPERTIES
diff --git a/contrib/native/client/src/include/drill/drillClient.hpp b/contrib/native/client/src/include/drill/drillClient.hpp
index 00ff72344..01c9f676f 100644
--- a/contrib/native/client/src/include/drill/drillClient.hpp
+++ b/contrib/native/client/src/include/drill/drillClient.hpp
@@ -80,6 +80,8 @@ class DECLSPEC_DRILL_CLIENT DrillClientConfig{
~DrillClientConfig();
static void initLogging(const char* path);
static void setLogLevel(logLevel_t l);
+ static void setSaslPluginPath(const char* path);
+ static const char* getSaslPluginPath();
static void setBufferLimit(uint64_t l);
static uint64_t getBufferLimit();
static void setSocketTimeout(int32_t l);
@@ -135,6 +137,8 @@ class DECLSPEC_DRILL_CLIENT DrillClientConfig{
// For future use. Currently, not enforced.
static uint64_t s_bufferLimit;
+ static const char* s_saslPluginPath;
+
/**
* DrillClient configures timeout (in seconds) in a fine granularity.
* Disabled by setting the value to zero.