From 631dae1b53d8de33975729de841e632b004ff01f Mon Sep 17 00:00:00 2001 From: Hayden McAfee Date: Tue, 16 Feb 2021 00:47:16 -0800 Subject: [PATCH] =?UTF-8?q?=E2=8F=B0=20Time=20out=20connections=20that=20d?= =?UTF-8?q?on't=20auth=20(#27)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #26 This change will time out connections that don't finish their TLS handshake within a set time, preventing the server from getting hung up trying to negotiate dead connections. Eventually we should negotiate these connections without blocking the `accept` thread, but there would be a non-trivial amount of work to update the tests to the new behavior. --- inc/TlsConnectionTransport.h | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/inc/TlsConnectionTransport.h b/inc/TlsConnectionTransport.h index 4626298..2285c36 100644 --- a/inc/TlsConnectionTransport.h +++ b/inc/TlsConnectionTransport.h @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -180,6 +181,8 @@ class TlsConnectionTransport : public IConnectionTransport private: /* Static members */ static constexpr int BUFFER_SIZE = 512; + static constexpr std::chrono::milliseconds CONNECT_TIMEOUT = + std::chrono::milliseconds(2500); /* Private members */ const bool isServer; const int socketHandle; @@ -241,10 +244,25 @@ class TlsConnectionTransport : public IConnectionTransport // Indicate when we've exited this thread connectionThreadEndedPromise.set_value_at_thread_exit(); + // Keep track of how long it takes to connect, so we can time out + std::chrono::time_point connectStartTime = + std::chrono::steady_clock::now(); + // First, we need to connect. int connectResult = isServer ? SSL_accept(ssl.get()) : SSL_connect(ssl.get()); while (connectResult == -1) { + // Have we taken too long? + auto elapsedTime = (std::chrono::steady_clock::now() - connectStartTime); + if (elapsedTime > CONNECT_TIMEOUT) + { + // Whoops, took too long to connect. + spdlog::debug("{} SSL negotiation timed out"); + sslConnectedPromise.set_value(false); + closeConnection(); + return; + } + // We're not done connecting yet - figure out what we're waiting on int connectError = SSL_get_error(ssl.get(), connectResult); if (connectError == SSL_ERROR_WANT_READ)