diff options
Diffstat (limited to 'contrib/native/client/src/clientlib/drillClientImpl.cpp')
-rw-r--r-- | contrib/native/client/src/clientlib/drillClientImpl.cpp | 694 |
1 files changed, 582 insertions, 112 deletions
diff --git a/contrib/native/client/src/clientlib/drillClientImpl.cpp b/contrib/native/client/src/clientlib/drillClientImpl.cpp index 30a354e47..0dee309a6 100644 --- a/contrib/native/client/src/clientlib/drillClientImpl.cpp +++ b/contrib/native/client/src/clientlib/drillClientImpl.cpp @@ -30,7 +30,6 @@ #include <boost/lexical_cast.hpp> #include <boost/thread.hpp> - #include "drill/drillClient.hpp" #include "drill/fieldmeta.hpp" #include "drill/recordBatch.hpp" @@ -193,7 +192,7 @@ connectionStatus_t DrillClientImpl::sendHeartbeat(){ boost::lock_guard<boost::mutex> prLock(this->m_prMutex); boost::lock_guard<boost::mutex> lock(m_dcMutex); DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Heartbeat sent." << std::endl;) - status=sendSync(heartbeatMsg); + status=sendSyncCommon(heartbeatMsg); status=status==CONN_SUCCESS?status:CONN_DEAD; //If the server sends responses to a heartbeat, we need to increment the pending requests counter. if(m_pendingRequests++==0){ @@ -233,18 +232,125 @@ void DrillClientImpl::Close() { shutdownSocket(); } +/* + * Write bytesToWrite length data bytes pointed by dataPtr. It handles EINTR error + * occurred during write_some sys call and does a retry on that. + * + * Parameters: + * dataPtr - in param - Pointer to data bytes to write on socket. + * bytesToWrite - in param - Length of data bytes to write from dataPtr. + * errorCode - out param - Error code set by boost. + */ +void DrillClientImpl::doWriteToSocket(const char* dataPtr, size_t bytesToWrite, + boost::system::error_code& errorCode) { + if(0 == bytesToWrite) { + return; + } + + // Write all the bytes to socket. In case of error when all bytes are not successfully written + // proper errorCode will be set. + while(1) { + size_t bytesWritten = m_socket.write_some(boost::asio::buffer(dataPtr, bytesToWrite), errorCode); + + // Update the state + bytesToWrite -= bytesWritten; + dataPtr += bytesWritten; + + if(EINTR != errorCode.value()) break; + + // Check if all the data is written then break from loop + if(0 == bytesToWrite) break; + } +} -connectionStatus_t DrillClientImpl::sendSync(rpc::OutBoundRpcMessage& msg){ +/* + * Common wrapper to take care of sending both plain or encrypted message. It creates a send buffer from an + * OutboundRPCMessage and then call the send handler pointing to either sendSyncPlain or sendSyncEncrypted + * + * Return: + * connectionStatus_t - CONN_SUCCESS - In case of successful send + * - CONN_FAILURE - In case of failure to send + */ +connectionStatus_t DrillClientImpl::sendSyncCommon(rpc::OutBoundRpcMessage& msg) { encode(m_wbuf, msg); + return (this->*m_fpCurrentSendHandler)(); +} + +/* + * Send handler for sending plain messages over wire + * + * Return: + * connectionStatus_t - CONN_SUCCESS - In case of successful send + * - CONN_FAILURE - In case of failure to send + */ +connectionStatus_t DrillClientImpl::sendSyncPlain(){ + boost::system::error_code ec; - size_t s=m_socket.write_some(boost::asio::buffer(m_wbuf), ec); - if(!ec && s!=0){ + doWriteToSocket(reinterpret_cast<char*>(m_wbuf.data()), m_wbuf.size(), ec); + + if(!ec) { return CONN_SUCCESS; - }else{ + } else { return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_WFAIL, ec.message().c_str())); } } +/* + * Send handler for sending encrypted messages over wire. It encrypts the send buffer using wrap api provided by + * saslAuthenticatorImpl and then transmit the encrypted bytes over wire. + * + * Return: + * connectionStatus_t - CONN_SUCCESS - In case of successful send + * - CONN_FAILURE - In case of failure to send + */ +connectionStatus_t DrillClientImpl::sendSyncEncrypted() { + + boost::system::error_code ec; + + // Encoded message is encrypted into chunks of size <= WrapSizeLimit. Each encrypted chunk along with + // its encrypted length in network order (added by Cyrus-SASL plugin) is sent over wire. + const int wrapChunkSize = m_encryptionCtxt.getWrapSizeLimit(); + int lengthToEncrypt = m_wbuf.size(); + + int currentChunkLen = std::min(wrapChunkSize, lengthToEncrypt); + uint32_t currentChunkOffset = 0; + std::stringstream errorMsg; + + // Encrypt and send each chunk + while(lengthToEncrypt != 0) { + const char* wrappedChunk = NULL; + uint32_t wrappedLen = 0; + const int wrapResult = m_saslAuthenticator->wrap(reinterpret_cast<const char*>(m_wbuf.data() + currentChunkOffset), + currentChunkLen, &wrappedChunk, wrappedLen); + if(SASL_OK != wrapResult) { + errorMsg << "Sasl wrap failed while encrypting chunk of length: " << currentChunkLen << " , EncodeError: " + << wrapResult; + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::sendSyncEncrypted - " << errorMsg.str() + << " ,ChunkOffset: " << currentChunkOffset << ", Message Len: " << m_wbuf.size() + << ", Closing connection.";) + return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_WFAIL, errorMsg.str().c_str())); + } + + // Send the encrypted chunk. + doWriteToSocket(wrappedChunk, wrappedLen, ec); + + if(ec) { + errorMsg << "Failure while sending encrypted chunk. Error: " << ec.message().c_str(); + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::sendSyncEncrypted - " << errorMsg.str() + << ", Chunk Length: " << currentChunkLen << ", ChunkOffset:" << currentChunkOffset + << ", Message Len: " << m_wbuf.size() << ", Closing connection.";) + return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_WFAIL, errorMsg.str().c_str())); + } + + // Update variables after sending each encrypted chunk + lengthToEncrypt -= currentChunkLen; + currentChunkOffset += currentChunkLen; + currentChunkLen = std::min(wrapChunkSize, lengthToEncrypt); + } + + return CONN_SUCCESS; +} + connectionStatus_t DrillClientImpl::recvHandshake(){ if(m_rbuf==NULL){ m_rbuf = Utils::allocateBuffer(MAX_SOCK_RD_BUFSIZE); @@ -289,7 +395,41 @@ connectionStatus_t DrillClientImpl::recvHandshake(){ return CONN_SUCCESS; } -void DrillClientImpl::handleHandshake(ByteBuf_t _buf, +/* + * Read bytesToRead length data bytes from socket into inBuf. It handles EINTR error + * occurred during read_some sys call and does a retry on that. + * + * Parameters: + * inBuf - out param - Pointer to buffer to read data into from socket. + * bytesToRead - in param - Length of data bytes to read from socket. + * errorCode - out param - Error code set by boost. + */ +void DrillClientImpl::doReadFromSocket(ByteBuf_t inBuf, size_t bytesToRead, + boost::system::error_code& errorCode) { + + // Check if bytesToRead is zero + if(0 == bytesToRead) { + return; + } + + // Read all the bytes. In case when all the bytes were not read the proper + // errorCode will be set. + while(1){ + size_t dataBytesRead = m_socket.read_some(boost::asio::buffer(inBuf, bytesToRead), + errorCode); + // Update the state + bytesToRead -= dataBytesRead; + inBuf += dataBytesRead; + + // Check if errorCode is EINTR then just retry otherwise break from loop + if(EINTR != errorCode.value()) break; + + // Check if all the data is read then break from loop + if(0 == bytesToRead) break; + } +} + +void DrillClientImpl::handleHandshake(ByteBuf_t inBuf, const boost::system::error_code& err, size_t bytes_transferred) { boost::system::error_code error=err; @@ -299,21 +439,23 @@ void DrillClientImpl::handleHandshake(ByteBuf_t _buf, if(!error){ rpc::InBoundRpcMessage msg; uint32_t length = 0; - std::size_t bytes_read = rpc::lengthDecode(m_rbuf, length); + std::size_t bytes_read = rpcLengthDecode(m_rbuf, length); if(length>0){ - size_t leftover = LEN_PREFIX_BUFLEN - bytes_read; - ByteBuf_t b=m_rbuf + LEN_PREFIX_BUFLEN; - size_t bytesToRead=length - leftover; - while(1){ - size_t dataBytesRead=m_socket.read_some( - boost::asio::buffer(b, bytesToRead), - error); - if(err) break; - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Handshake Message: actual bytes read = " << dataBytesRead << std::endl;) - if(dataBytesRead==bytesToRead) break; - bytesToRead-=dataBytesRead; - b+=dataBytesRead; + const size_t leftover = LEN_PREFIX_BUFLEN - bytes_read; + const ByteBuf_t b = m_rbuf + LEN_PREFIX_BUFLEN; + const size_t bytesToRead=length - leftover; + doReadFromSocket(b, bytesToRead, error); + + // Check if any error happen while reading the message bytes. If yes then return before decoding the Msg + if(error) { + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. " + << " Failed to read entire handshake message. with error: " + << error.message().c_str() << "\n";) + handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Failed to read entire handshake message")); + return; } + + // Decode the bytes into a valid RPC Message if (!decode(m_rbuf+bytes_read, length, msg)) { DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. Cannot decode handshake.\n";) handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Cannot decode handshake")); @@ -340,6 +482,11 @@ void DrillClientImpl::handleHandshake(ByteBuf_t _buf, this->m_serverAuthMechanisms.push_back(mechanism); } + // Updated encryption context based on server response + this->m_encryptionCtxt.setEncryptionReqd(b2u.has_encrypted() && b2u.encrypted()); + if(b2u.has_maxwrappedsize()) { + this->m_encryptionCtxt.setMaxWrappedSize(b2u.maxwrappedsize()); + } }else{ // boost error if(error==boost::asio::error::eof){ // Server broke off the connection @@ -360,7 +507,8 @@ void DrillClientImpl::handleHShakeReadTimeout(const boost::system::error_code & if (m_deadlineTimer.expires_at() <= boost::asio::deadline_timer::traits_type::now()){ // The deadline has passed. m_deadlineTimer.expires_at(boost::posix_time::pos_infin); - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::HandleHShakeReadTimeout: Deadline timer expired; ERR_CONN_HSHAKETIMOUT.\n";) + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::HandleHShakeReadTimeout: " + << "Deadline timer expired; ERR_CONN_HSHAKETIMOUT.\n";) handleConnError(CONN_HANDSHAKE_TIMEOUT, getMessage(ERR_CONN_HSHAKETIMOUT)); m_io_service.stop(); boost::system::error_code ignorederr; @@ -370,6 +518,33 @@ void DrillClientImpl::handleHShakeReadTimeout(const boost::system::error_code & return; } +/* + * Check's if client has explicitly expressed interest in encrypted connections only. It looks for USERPROP_SASL_ENCRYPT + * connection string property. If set to true then returns true else returns false + */ +bool DrillClientImpl::clientNeedsEncryption(const DrillUserProperties* userProperties) { + bool needsEncryption = false; + // check if userProperties is null + if(!userProperties) { + return needsEncryption; + } + + // Loop through the property to find USERPROP_SASL_ENCRYPT and it's value + for (size_t i = 0; i < userProperties->size(); i++) { + const std::string key = userProperties->keyAt(i); + std::string value = userProperties->valueAt(i); + + if(USERPROP_SASL_ENCRYPT == key) { + boost::algorithm::to_lower(value); + + if(0 == value.compare("true")) { + needsEncryption = true; + } + } + } + return needsEncryption; +} + connectionStatus_t DrillClientImpl::validateHandshake(DrillUserProperties* properties){ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "validateHandShake\n";) @@ -379,7 +554,7 @@ connectionStatus_t DrillClientImpl::validateHandshake(DrillUserProperties* prope u2b.set_rpc_version(DRILL_RPC_VERSION); u2b.set_support_listening(true); u2b.set_support_timeout(DrillClientConfig::getHeartbeatFrequency() > 0); - u2b.set_sasl_support(exec::user::SASL_AUTH); + u2b.set_sasl_support(exec::user::SASL_PRIVACY); // Adding version info exec::user::RpcEndpointInfos* infos = u2b.mutable_client_infos(); @@ -436,7 +611,7 @@ connectionStatus_t DrillClientImpl::validateHandshake(DrillUserProperties* prope uint64_t coordId = this->getNextCoordinationId(); rpc::OutBoundRpcMessage out_msg(exec::rpc::REQUEST, exec::user::HANDSHAKE, coordId, &u2b); - sendSync(out_msg); + sendSyncCommon(out_msg); DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Sent handshake request message. Coordination id: " << coordId << "\n";) } @@ -479,6 +654,13 @@ connectionStatus_t DrillClientImpl::validateHandshake(DrillUserProperties* prope } connectionStatus_t DrillClientImpl::handleAuthentication(const DrillUserProperties *userProperties) { + + // Check if client needs encryption and server is configured for encryption or not before starting handshake + if(clientNeedsEncryption(userProperties) && !m_encryptionCtxt.isEncryptionReqd()) { + return handleConnError(CONN_AUTH_FAILED, "Client needs encryption but on server side encryption is disabled." + " Please check connection parameters or contact administrator?"); + } + try { m_saslAuthenticator = new SaslAuthenticatorImpl(userProperties); } catch (std::runtime_error& e) { @@ -495,26 +677,46 @@ connectionStatus_t DrillClientImpl::handleAuthentication(const DrillUserProperti } } + std::stringstream logMsg; + logMsg << "DrillClientImpl::handleAuthentication: Authentication failed. [Details: "; + if (SASL_OK == m_saslResultCode) { - DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::handleAuthentication: Successfully authenticated!" - << std::endl;) + // Check the negotiated SSF value and change the handlers. + if(m_encryptionCtxt.isEncryptionReqd()) { + if(SASL_OK != m_saslAuthenticator->verifyAndUpdateSaslProps()) { + logMsg << m_encryptionCtxt << "]. Negotiated Parameter is invalid." + << " Error: " << m_saslResultCode; + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << logMsg.str() << std::endl;) + return handleConnError(CONN_AUTH_FAILED, logMsg.str().c_str()); + } + + // Successfully negotiated for encryption related security parameters. + // Start using Encrypt and Decrypt handlers. + m_fpCurrentSendHandler = &DrillClientImpl::sendSyncEncrypted; + m_fpCurrentReadMsgHandler = &DrillClientImpl::readAndDecryptMsg; + } - // in future, negotiated security layers are known here.. + // Reset the errorMsg stream since this is success case. + logMsg.str(std::string()); + logMsg << "DrillClientImpl::handleAuthentication: Successfully authenticated! [Details: " + << m_encryptionCtxt << " ]"; + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << logMsg.str() << std::endl;) m_io_service.reset(); return CONN_SUCCESS; } else { - DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::handleAuthentication: Authentication failed: " - << m_saslResultCode << std::endl;) + logMsg << m_encryptionCtxt << ", Error: " << m_saslResultCode; + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << logMsg.str() << std::endl;) + // shuts down socket as well - return handleConnError(CONN_AUTH_FAILED, "Authentication failed. Check connection parameters?"); + logMsg << "]. Check connection parameters?"; + return handleConnError(CONN_AUTH_FAILED, logMsg.str().c_str()); } } void DrillClientImpl::initiateAuthentication() { exec::shared::SaslMessage response; - m_saslResultCode = m_saslAuthenticator->init(m_serverAuthMechanisms, response); - + m_saslResultCode = m_saslAuthenticator->init(m_serverAuthMechanisms, response, &m_encryptionCtxt); switch (m_saslResultCode) { case SASL_CONTINUE: @@ -539,7 +741,7 @@ void DrillClientImpl::sendSaslResponse(const exec::shared::SaslMessage& response boost::lock_guard<boost::mutex> lock(m_dcMutex); const int32_t coordId = getNextCoordinationId(); rpc::OutBoundRpcMessage msg(exec::rpc::REQUEST, exec::user::SASL_MESSAGE, coordId, &response); - sendSync(msg); + sendSyncCommon(msg); if (m_pendingRequests++ == 0) { getNextResult(); } @@ -768,23 +970,23 @@ Handle* DrillClientImpl::sendMsg(boost::function<Handle*(int32_t)> handleFactory phandle = handleFactory(coordId); this->m_queryHandles[coordId]=phandle; - connectionStatus_t cStatus=sendSync(out_msg); + connectionStatus_t cStatus = sendSyncCommon(out_msg); if(cStatus == CONN_SUCCESS){ bool sendRequest=false; DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Sent " << ::exec::user::RpcType_Name(type) << " request. " << "[" << m_connectedHost << "]" << "Coordination id = " << coordId << std::endl;) - DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Sent " << ::exec::user::RpcType_Name(type) << " Coordination id = " << coordId << " query: " << phandle->getQuery() << std::endl;) + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Sent " << ::exec::user::RpcType_Name(type) << " Coordination id = " << coordId << " query: " << phandle->getQuery() << std::endl;) - if(m_pendingRequests++==0){ - sendRequest=true; - }else{ - DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Queuing " << ::exec::user::RpcType_Name(type) << " request to server" << std::endl;) - DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Number of pending requests = " << m_pendingRequests << std::endl;) - } + if(m_pendingRequests++==0){ + sendRequest=true; + }else{ + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Queuing " << ::exec::user::RpcType_Name(type) << " request to server" << std::endl;) + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Number of pending requests = " << m_pendingRequests << std::endl;) + } if(sendRequest){ DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Sending " << ::exec::user::RpcType_Name(type) << " request. Number of pending requests = " << m_pendingRequests << std::endl;) - getNextResult(); // async wait for results + getNextResult(); // async wait for results } } @@ -854,76 +1056,319 @@ void DrillClientImpl::waitForResults(){ } } -status_t DrillClientImpl::readMsg(ByteBuf_t _buf, - AllocatedBufferPtr* allocatedBuffer, - rpc::InBoundRpcMessage& msg){ +/* + * Decode the length of the message from bufWithLen and then read entire message from the socket. + * Parameters: + * bufWithLenField - in param - buffer containing the length of the RPC message/encrypted chunk + * bufferWithDataAndLenBytes - out param - buffer pointer which points to memory allocated in this function and has the + * entire one RPC message / encrypted chunk along with the length of the message. + * Memory for this buffer is released by caller. + * lengthFieldLength - out param - bytes of bufWithLen which contains the length of the entire RPC message or + * encrypted chunk + * lengthDecodeHandler - in param - function pointer with length decoder to use. For encrypted chunk we use + * lengthDecode and for plain RPC message we use rpcLengthDecode. + * Return: + * status_t - QRY_SUCCESS - In case of success. + * - QRY_COMM_ERROR/QRY_INTERNAL_ERROR/QRY_CLIENT_OUTOFMEM - In cases of error. + */ +status_t DrillClientImpl::readLenBytesFromSocket(const ByteBuf_t bufWithLenField, AllocatedBufferPtr* bufferWithDataAndLenBytes, + uint32_t& lengthFieldLength, lengthDecoder lengthDecodeHandler) { + + uint32_t rmsgLen = 0; + boost::system::error_code error; + *bufferWithDataAndLenBytes = NULL; + + // Decode the length field + lengthFieldLength = (this->*lengthDecodeHandler)(bufWithLenField, rmsgLen); + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Length bytes = " << lengthFieldLength << std::endl;) + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Msg Length = " << rmsgLen << std::endl;) + + if(rmsgLen>0) { + const size_t leftover = LEN_PREFIX_BUFLEN - lengthFieldLength; + + // Allocate a buffer for reading all the bytes in bufWithLen and length number of bytes. + const size_t bufferSizeWithLenBytes = rmsgLen + lengthFieldLength; + *bufferWithDataAndLenBytes = new AllocatedBuffer(bufferSizeWithLenBytes); + + if(*bufferWithDataAndLenBytes == NULL) { + return handleQryError(QRY_CLIENT_OUTOFMEM, getMessage(ERR_QRY_OUTOFMEM), NULL); + } + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readLenBytesFromSocket: Allocated and locked buffer: [ " + << *bufferWithDataAndLenBytes << ", size = " << bufferSizeWithLenBytes << " ]\n";) + + // Copy the memory of bufWithLen into bufferWithLenBytesSize + memcpy((*bufferWithDataAndLenBytes)->m_pBuffer, bufWithLenField, LEN_PREFIX_BUFLEN); + const size_t bytesToRead = rmsgLen - leftover; + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Copied bufWithLen into bufferWithLenBytes. " + << "Now reading data (rmsgLen - leftover) : " << bytesToRead + << std::endl;) + + // Read the entire data left from socket and copy to currentBuffer. + const ByteBuf_t b = (*bufferWithDataAndLenBytes)->m_pBuffer + LEN_PREFIX_BUFLEN; + doReadFromSocket(b, bytesToRead, error); + } else { + return handleQryError(QRY_INTERNAL_ERROR, getMessage(ERR_QRY_INVREADLEN), NULL); + } + + return error ? handleQryError(QRY_COMM_ERROR, getMessage(ERR_QRY_COMMERR, error.message().c_str()), NULL) + : QRY_SUCCESS; +} + + +/* + * Function to read entire RPC message from socket and decode it to InboundRpcMessage + * Parameters: + * inBuf - in param - Buffer containing the length bytes. + * allocatedBuffer - out param - Buffer containing the length bytes and entire RPC message bytes. + * msg - out param - Decoded InBoundRpcMessage from the bytes in allocatedBuffer + * Return: + * status_t - QRY_SUCCESS - In case of success. + * - QRY_COMM_ERROR/QRY_INTERNAL_ERROR/QRY_CLIENT_OUTOFMEM - In cases of error. + */ +status_t DrillClientImpl::readMsg(const ByteBuf_t inBuf, AllocatedBufferPtr* allocatedBuffer, + rpc::InBoundRpcMessage& msg){ DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readMsg: Read message from buffer " - << reinterpret_cast<int*>(_buf) << std::endl;) - size_t leftover=0; - uint32_t rmsgLen; - AllocatedBufferPtr currentBuffer; - *allocatedBuffer=NULL; + << reinterpret_cast<int*>(inBuf) << std::endl;) + *allocatedBuffer = NULL; + { + // We need to protect the readLength and read buffer, and the pending requests counter, + // but we don't have to keep the lock while we decode the rest of the buffer. + boost::lock_guard<boost::mutex> lock(this->m_dcMutex); + uint32_t lengthFieldSize = 0; + + // Read the message length and extract length size bytes to form InBoundRpcMessage + const status_t statusCode = readLenBytesFromSocket(inBuf, allocatedBuffer, lengthFieldSize, + &DrillClientImpl::rpcLengthDecode); + + // Check for error conditions + if(QRY_SUCCESS != statusCode) { + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); + return statusCode; + } + + // Get the message size + size_t msgLen = (*allocatedBuffer)->m_bufSize; + + // Read data successfully, now let's try to decode the buffer and form a valid RPC message. + // allocatedBuffer also contains the length bytes which is not needed by decodes so skip that part of buffer. + // We have it since in case of encryption the unwrap function expects it + if (!decode((*allocatedBuffer)->m_pBuffer + lengthFieldSize, msgLen - lengthFieldSize, msg)) { + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); + return handleQryError(QRY_COMM_ERROR, getMessage(ERR_QRY_COMMERR, "Cannot decode server message"), NULL); + } + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Successfully created a RPC message with Coordination id: " + << msg.m_coord_id << std::endl;) + } + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readMsg: Free buffer " + << reinterpret_cast<int*>(inBuf) << std::endl;) + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); + return QRY_SUCCESS; +} + + +/* + * Read ENCRYPT_LEN_PREFIX_BUFLEN bytes to decode length of one complete encrypted chunk. The length bytes are expected + * to be in network order. It is converted to host order and the value is stored in rmsgLen parameter. + * Parameters: + * inBuf - in param - ByteBuf_t containing atleast the length bytes. + * rmsgLen - out param - Contain the decoded value of length. + * Return: + * size_t - length bytes read to decode + */ +size_t DrillClientImpl::lengthDecode(const ByteBuf_t inBuf, uint32_t& rmsgLen) { + memcpy(&rmsgLen, inBuf, ENCRYPT_LEN_PREFIX_BUFLEN); + rmsgLen = ntohl(rmsgLen); + return ENCRYPT_LEN_PREFIX_BUFLEN; +} + +/* + * Wrapper which uses RPC message length decoder to get length of one complete RPC message from _buf. + * Parameters: + * inBuf - in param - ByteBuf_t containing atleast the length bytes. + * rmsgLen - out param - Contain the decoded value of length. + * Return: + * size_t - length bytes read to decode + */ +size_t DrillClientImpl::rpcLengthDecode(const ByteBuf_t inBuf, uint32_t& rmsgLen) { + return rpc::lengthDecode(inBuf, rmsgLen); +} + + +/* + * Read all the encrypted chunk needed to form a complete RPC message. Read an entire chunk from network, decrypt it + * and put in a buffer. The same process is repeated until the entire buffer to form a completed RPC message is read. + * Parameters: + * inBuf - in param - ByteBuf_t containing atleast the length bytes. + * allocatedBuffer - out param - Buffer containing the entire RPC message bytes which is formed by reading all the + * required encrypted chunk from network and decrypting each individual chunk. The + * buffer memory is released by caller. +.* msg - out param - InBoundRpcMessage formed from bytes in allocatedBuffer + * Return: + * status_t - QRY_SUCCESS - In case of success. + * - QRY_COMM_ERROR/QRY_INTERNAL_ERROR/QRY_CLIENT_OUTOFMEM - In cases of error. + */ +status_t DrillClientImpl::readAndDecryptMsg(const ByteBuf_t inBuf, AllocatedBufferPtr* allocatedBuffer, + rpc::InBoundRpcMessage& msg) { + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readAndDecryptMsg: Read message from buffer " + << reinterpret_cast<int*>(inBuf) << std::endl;) + + size_t leftover = 0; + uint32_t rpcMsgLen = 0; + size_t bytes_read = 0; + uint32_t writeIndex = 0; + size_t bytesToRead = 0; + + *allocatedBuffer = NULL; + boost::system::error_code error; + std::stringstream errorMsg; + { // We need to protect the readLength and read buffer, and the pending requests counter, // but we don't have to keep the lock while we decode the rest of the buffer. boost::lock_guard<boost::mutex> lock(this->m_dcMutex); - std::size_t bytes_read = rpc::lengthDecode(_buf, rmsgLen); - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "len bytes = " << bytes_read << std::endl;) - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "rmsgLen = " << rmsgLen << std::endl;) - - if(rmsgLen>0){ - leftover = LEN_PREFIX_BUFLEN - bytes_read; - // Allocate a buffer - currentBuffer=new AllocatedBuffer(rmsgLen); - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readMsg: Allocated and locked buffer: [ " - << currentBuffer << ", size = " << rmsgLen << " ]\n";) - if(currentBuffer==NULL){ - Utils::freeBuffer(_buf, LEN_PREFIX_BUFLEN); - return handleQryError(QRY_CLIENT_OUTOFMEM, getMessage(ERR_QRY_OUTOFMEM), NULL); + + do{ + AllocatedBufferPtr currentBuffer = NULL; + uint32_t lengthFieldSize = 0; + const status_t statusCode = readLenBytesFromSocket(inBuf, ¤tBuffer, lengthFieldSize, + &DrillClientImpl::lengthDecode); + + if(QRY_SUCCESS != statusCode) { + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); + + // Release the buffer allocated to hold chunk + if(currentBuffer != NULL) { + Utils::freeBuffer(currentBuffer->m_pBuffer, currentBuffer->m_bufSize); + currentBuffer = NULL; + } + return statusCode; } - *allocatedBuffer=currentBuffer; - if(leftover){ - memcpy(currentBuffer->m_pBuffer, _buf + bytes_read, leftover); + + // read one chunk successfully. Let's try to decrypt the message + const char* unWrappedData = NULL; + uint32_t unWrappedLen = 0; + const int decryptResult = m_saslAuthenticator->unwrap(reinterpret_cast<const char*>(currentBuffer->m_pBuffer), + currentBuffer->m_bufSize, &unWrappedData, unWrappedLen); + + if(SASL_OK != decryptResult) { + + errorMsg << "Sasl unwrap failed for the buffer of size:" << currentBuffer->m_bufSize << " , Error: " + << decryptResult; + + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::readAndDecryptMsg: " + << errorMsg.str() << std::endl;) + + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); + + // Release the buffer allocated to hold chunk + Utils::freeBuffer(currentBuffer->m_pBuffer, currentBuffer->m_bufSize); + currentBuffer = NULL; + return handleQryError(QRY_COMM_ERROR, + getMessage(ERR_QRY_COMMERR, errorMsg.str().c_str()), NULL); } - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "reading data (rmsgLen - leftover) : " - << (rmsgLen - leftover) << std::endl;) - ByteBuf_t b=currentBuffer->m_pBuffer + leftover; - size_t bytesToRead=rmsgLen - leftover; - boost::system::error_code error; - while(1){ - size_t dataBytesRead=this->m_socket.read_some( - boost::asio::buffer(b, bytesToRead), - error); - if(error) break; - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Data Message: actual bytes read = " << dataBytesRead << std::endl;) - if(dataBytesRead==bytesToRead) break; - bytesToRead-=dataBytesRead; - b+=dataBytesRead; + + // Check for case if the unWrappedLen is 0, since Cyrus SASL plugin verifies if the length of wrapped data + // is less than the length specified by prepended 4 octets as per RFC 4422/2222. If so it just returns + // and waits for more data + if(unWrappedLen == 0 || (unWrappedData == NULL)) { + errorMsg << "Sasl unwrap failed with mismatch in length of wrapped data and the prepended length value"; + DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::readAndDecryptMsg: " << errorMsg.str() + << std::endl;) + + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); + + // Release the buffer allocated to hold chunk + Utils::freeBuffer(currentBuffer->m_pBuffer, currentBuffer->m_bufSize); + currentBuffer = NULL; + return handleQryError(QRY_COMM_ERROR, + getMessage(ERR_QRY_COMMERR, errorMsg.str().c_str()), NULL); } - if(!error){ - // read data successfully - if (!decode(currentBuffer->m_pBuffer, rmsgLen, msg)) { - Utils::freeBuffer(_buf, LEN_PREFIX_BUFLEN); - return handleQryError(QRY_COMM_ERROR, - getMessage(ERR_QRY_COMMERR, "Cannot decode server message"), NULL);; + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readAndDecryptMsg: Successfully decrypted the buffer" + << " Sizes - Before Decryption = " << currentBuffer->m_bufSize + << " and After Decryption = " << unWrappedLen << std::endl;) + + // Release the buffer allocated to hold chunk + Utils::freeBuffer(currentBuffer->m_pBuffer, currentBuffer->m_bufSize); + currentBuffer = NULL; + + bytes_read = 0; + if(*allocatedBuffer == NULL) { + // This is the first chunk of the RPC message. We will decode the RPC message full length + bytes_read = rpcLengthDecode(reinterpret_cast<ByteBuf_t>(const_cast<char*>(unWrappedData)), rpcMsgLen); + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readAndDecryptMsg: Rpc Message Length bytes = " + << bytes_read << std::endl;) + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readAndDecryptMsg: Rpc Message Length = " + << rpcMsgLen << std::endl;) + + if(rpcMsgLen == 0) { + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); + return handleQryError(QRY_INTERNAL_ERROR, getMessage(ERR_QRY_INVREADLEN), NULL); } - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Done decoding chunk. Coordination id: " <<msg.m_coord_id<< std::endl;) - }else{ - Utils::freeBuffer(_buf, LEN_PREFIX_BUFLEN); - return handleQryError(QRY_COMM_ERROR, - getMessage(ERR_QRY_COMMERR, error.message().c_str()), NULL); + // Allocate a buffer for storing full RPC message. This is released by the caller + *allocatedBuffer = new AllocatedBuffer(rpcMsgLen); + + if(*allocatedBuffer == NULL){ + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); + return handleQryError(QRY_CLIENT_OUTOFMEM, getMessage(ERR_QRY_OUTOFMEM), NULL); + } + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readAndDecryptMsg: Allocated and locked buffer:" + << "[ " << *allocatedBuffer << ", size = " << rpcMsgLen << " ]\n";) + + bytesToRead = rpcMsgLen; } - }else{ - // got a message with an invalid read length. - Utils::freeBuffer(_buf, LEN_PREFIX_BUFLEN); - return handleQryError(QRY_INTERNAL_ERROR, getMessage(ERR_QRY_INVREADLEN), NULL); + + // Update the leftover bytes that is not copied yet + leftover = unWrappedLen - bytes_read; + + // Copy rest of decrypted message to the buffer. We can do this since it is assured that one + // entire decrypted chunk is part of the same RPC message. + if(leftover) { + memcpy((*allocatedBuffer)->m_pBuffer + writeIndex, unWrappedData + bytes_read, leftover); + } + + // Update bytes left to read to form full RPC message. + bytesToRead -= leftover; + writeIndex += leftover; + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readAndDecryptMsg: Left to read unencrypted data" + << " of length (bytesToRead) : " << bytesToRead << std::endl;) + + if(bytesToRead > 0) { + // Read synchronously buffer of size LEN_PREFIX_BUFLEN to get length of next chunk + doReadFromSocket(inBuf, LEN_PREFIX_BUFLEN, error); + + if(error) { + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); + return handleQryError(QRY_COMM_ERROR, getMessage(ERR_QRY_COMMERR, error.message().c_str()), NULL); + } + } + }while(bytesToRead > 0); // more chunks to read for entire RPC message + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readAndDecryptMsg: Done decrypting entire RPC message " + << " of length: " << rpcMsgLen << ". Now starting decode:" << std::endl;) + + // Decode the buffer and form a RPC message + if (!decode((*allocatedBuffer)->m_pBuffer, rpcMsgLen, msg)) { + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); + return handleQryError(QRY_COMM_ERROR, getMessage(ERR_QRY_COMMERR, + "Cannot decode server message into valid RPC message"), NULL); } + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Successfully created a RPC message with Coordination id: " + << msg.m_coord_id << std::endl;) } - DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readMsg: Free buffer " - << reinterpret_cast<int*>(_buf) << std::endl;) - Utils::freeBuffer(_buf, LEN_PREFIX_BUFLEN); + + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::readAndDecryptMsg: Free buffer " + << reinterpret_cast<int*>(inBuf) << std::endl;) + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); return QRY_SUCCESS; } @@ -1364,15 +1809,15 @@ status_t DrillClientImpl::processServerMetaResult(AllocatedBufferPtr allocatedBu std::map<int,DrillClientQueryHandle*>::const_iterator it=this->m_queryHandles.find(msg.m_coord_id); if(it!=this->m_queryHandles.end()){ DrillClientServerMetaHandle* pHandle=static_cast<DrillClientServerMetaHandle*>((*it).second); + exec::user::GetServerMetaResp* resp = new exec::user::GetServerMetaResp(); DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Received GetServerMetaResp result Handle " << msg.m_pbody.size() << std::endl;) - exec::user::GetServerMetaResp resp; - if (!(resp.ParseFromArray(msg.m_pbody.data(), msg.m_pbody.size()))) { + if (!(resp->ParseFromArray(msg.m_pbody.data(), msg.m_pbody.size()))) { return handleQryError(QRY_COMM_ERROR, "Cannot decode GetServerMetaResp results", pHandle); } - if (resp.status() != exec::user::OK) { - return handleQryError(QRY_FAILED, resp.error(), pHandle); + if (resp->status() != exec::user::OK) { + return handleQryError(QRY_FAILED, resp->error(), pHandle); } - pHandle->notifyListener(&(resp.server_meta()), NULL); + pHandle->notifyListener(&(resp->server_meta()), NULL); DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "GetServerMetaResp result " << std::endl;) }else{ return handleQryError(QRY_INTERNAL_ERROR, getMessage(ERR_QRY_INVQUERYID), NULL); @@ -1484,11 +1929,11 @@ void DrillClientImpl::handleReadTimeout(const boost::system::error_code & err){ return; } -void DrillClientImpl::handleRead(ByteBuf_t _buf, +void DrillClientImpl::handleRead(ByteBuf_t inBuf, const boost::system::error_code& error, size_t bytes_transferred) { DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleRead: Handle Read from buffer " - << reinterpret_cast<int*>(_buf) << std::endl;) + << reinterpret_cast<int*>(inBuf) << std::endl;) if(DrillClientConfig::getQueryTimeout() > 0){ // Cancel the timeout if handleRead is called DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleRead: Cancel deadline timer.\n";) @@ -1496,7 +1941,7 @@ void DrillClientImpl::handleRead(ByteBuf_t _buf, } if (error) { // boost error - Utils::freeBuffer(_buf, LEN_PREFIX_BUFLEN); + Utils::freeBuffer(inBuf, LEN_PREFIX_BUFLEN); boost::lock_guard<boost::mutex> lock(this->m_dcMutex); DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleRead: ERR_QRY_COMMERR. " "Boost Communication Error: " << error.message() << std::endl;) @@ -1510,7 +1955,7 @@ void DrillClientImpl::handleRead(ByteBuf_t _buf, DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Getting new message" << std::endl;) AllocatedBufferPtr allocatedBuffer=NULL; - if(readMsg(_buf, &allocatedBuffer, msg)!=QRY_SUCCESS){ + if((this->*m_fpCurrentReadMsgHandler)(inBuf, &allocatedBuffer, msg)!=QRY_SUCCESS){ delete allocatedBuffer; if(m_pendingRequests!=0){ boost::lock_guard<boost::mutex> lock(this->m_dcMutex); @@ -1655,6 +2100,9 @@ status_t DrillClientImpl::validateResultMessage(const rpc::InBoundRpcMessage& ms return QRY_SUCCESS; } +/* + * Called when there is failure in connect/send. + */ connectionStatus_t DrillClientImpl::handleConnError(connectionStatus_t status, const std::string& msg){ DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg); m_pendingRequests=0; @@ -1669,19 +2117,28 @@ connectionStatus_t DrillClientImpl::handleConnError(connectionStatus_t status, c return status; } +/* + * Always called with NULL QueryHandle when there is any error while reading data from socket. Once enough data is read + * and a valid RPC message is formed then it can get called with NULL/valid QueryHandle depending on if QueryHandle is found + * for the created RPC message. + */ status_t DrillClientImpl::handleQryError(status_t status, const std::string& msg, DrillClientQueryHandle* pQueryHandle){ DrillClientError* pErr = new DrillClientError(status, DrillClientError::QRY_ERROR_START+status, msg); - // set query error only if queries are running + // Set query error only if queries are running. If valid QueryHandle that means the bytes to form a valid + // RPC message was read successfully from socket. So there is no socket/connection issues. if(pQueryHandle!=NULL){ m_pendingRequests--; pQueryHandle->signalError(pErr); - }else{ + }else{ // This means error was while reading from socket, hence call broadcastError which eventually closes socket. m_pendingRequests=0; broadcastError(pErr); } return status; } +/* + * Always called with valid QueryHandle when there is any error processing Query related data. + */ status_t DrillClientImpl::handleQryError(status_t status, const exec::shared::DrillPBError& e, DrillClientQueryHandle* pQueryHandle){ @@ -1766,7 +2223,7 @@ void DrillClientImpl::sendAck(const rpc::InBoundRpcMessage& msg, bool isOk){ ack.set_ok(isOk); rpc::OutBoundRpcMessage ack_msg(exec::rpc::RESPONSE, exec::user::ACK, msg.m_coord_id, &ack); boost::lock_guard<boost::mutex> lock(m_dcMutex); - sendSync(ack_msg); + sendSyncCommon(ack_msg); DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "ACK sent" << std::endl;) } @@ -1774,7 +2231,7 @@ void DrillClientImpl::sendCancel(const exec::shared::QueryId* pQueryId){ boost::lock_guard<boost::mutex> lock(m_dcMutex); uint64_t coordId = this->getNextCoordinationId(); rpc::OutBoundRpcMessage cancel_msg(exec::rpc::REQUEST, exec::user::CANCEL_QUERY, coordId, pQueryId); - sendSync(cancel_msg); + sendSyncCommon(cancel_msg); DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "CANCEL sent" << std::endl;) } @@ -1783,6 +2240,21 @@ void DrillClientImpl::shutdownSocket(){ boost::system::error_code ignorederr; m_socket.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignorederr); m_bIsConnected=false; + + // Delete the saslAuthenticatorImpl instance since connection is broken. It will recreated on next + // call to connect. + if(m_saslAuthenticator != NULL) { + delete m_saslAuthenticator; + m_saslAuthenticator = NULL; + } + + // Reset the SASL states. + m_saslDone = false; + m_saslResultCode = SASL_OK; + + // Reset the encryption context since connection is invalid + m_encryptionCtxt.reset(); + DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Socket shutdown" << std::endl;) } @@ -1799,8 +2271,6 @@ struct ServerMetaContext { boost::mutex m_mutex; boost::condition_variable m_cv; - ServerMetaContext(): m_done(false), m_status(QRY_SUCCESS), m_serverMeta(), m_mutex(), m_cv() {}; - static status_t listener(void* ctx, const exec::user::ServerMeta* serverMeta, DrillClientError* err) { ServerMetaContext* context = static_cast<ServerMetaContext*>(ctx); if (err) { |