diff --git a/src/base/ServerConnection.cpp b/src/base/ServerConnection.cpp index 6295eb197..27d1e5758 100644 --- a/src/base/ServerConnection.cpp +++ b/src/base/ServerConnection.cpp @@ -21,7 +21,8 @@ bool ServerConnection::acceptNewConnection(int fd) { return false; } VLOG(1) << "SERVER: got client socket fd: " << clientSocketFd; - clientHandlerThreadPool.push([&, this](int id) { clientHandler(clientSocketFd); }); + clientHandlerThreadPool.push( + [&, this](int id) { clientHandler(clientSocketFd); }); return true; } diff --git a/src/base/SocketEndpoint.hpp b/src/base/SocketEndpoint.hpp index 7cce98d5e..83abcf82a 100644 --- a/src/base/SocketEndpoint.hpp +++ b/src/base/SocketEndpoint.hpp @@ -6,6 +6,8 @@ namespace et { class SocketEndpoint { public: + SocketEndpoint() : name(""), port(-1) {} + explicit SocketEndpoint(const string &_name) : name(_name), port(-1) {} explicit SocketEndpoint(int _port) : name(""), port(_port) {} diff --git a/test/ConnectionTest.cpp b/test/ConnectionTest.cpp index b88392a9d..7640e4a6e 100644 --- a/test/ConnectionTest.cpp +++ b/test/ConnectionTest.cpp @@ -59,10 +59,10 @@ class Collector { } void finish() { + connection->shutdown(); lock_guard guard(collectorMutex); done = true; collectorThread->join(); - connection->shutdown(); } bool hasData() { @@ -110,12 +110,17 @@ void listenFn(bool* stopListening, int serverFd, } } -shared_ptr serverClientConnection; +map> serverClientConnections; class NewConnectionHandler : public ServerConnectionHandler { public: virtual bool newClient( shared_ptr _serverClientState) { - serverClientConnection = _serverClientState; + string clientId = _serverClientState->getId(); + if (serverClientConnections.find(clientId) != + serverClientConnections.end()) { + LOG(FATAL) << "TRIED TO CREATE DUPLICATE CLIENT ID"; + } + serverClientConnections[clientId] = _serverClientState; return true; } }; @@ -124,56 +129,60 @@ class ConnectionTest : public testing::Test { protected: void SetUp() override { el::Helpers::setThreadName("Main"); - const string CRYPTO_KEY = "12345678901234567890123456789012"; - const string CLIENT_ID = "1234567890123456"; string tmpPath = string("/tmp/et_test_XXXXXXXX"); pipeDirectory = string(mkdtemp(&tmpPath[0])); pipePath = string(pipeDirectory) + "/pipe"; - SocketEndpoint endpoint(pipePath); + endpoint = SocketEndpoint(pipePath); serverConnection.reset(new ServerConnection( serverSocketHandler, endpoint, shared_ptr(new NewConnectionHandler()))); - serverConnection->addClientKey(CLIENT_ID, CRYPTO_KEY); int serverFd = *(serverSocketHandler->getEndpointFds(endpoint).begin()); stopListening = false; serverListenThread.reset( new std::thread(listenFn, &stopListening, serverFd, serverConnection)); + } + void TearDown() override { + stopListening = true; + serverListenThread->join(); + serverClientConnections.clear(); + FATAL_FAIL(::remove(pipePath.c_str())); + FATAL_FAIL(::remove(pipeDirectory.c_str())); + } + + void readWriteTest(const string& clientId) { + serverConnection->addClientKey(clientId, CRYPTO_KEY); // Wait for server to spin up ::usleep(1000 * 1000); - clientConnection.reset(new ClientConnection(clientSocketHandler, endpoint, - CLIENT_ID, CRYPTO_KEY)); + + shared_ptr clientConnection(new ClientConnection( + clientSocketHandler, endpoint, clientId, CRYPTO_KEY)); while (true) { try { clientConnection->connect(); break; } catch (const std::runtime_error& ex) { LOG(INFO) << "Connection failed, retrying..."; - ::usleep(10 * 1000); + ::usleep(1000 * 1000); } } - while(serverClientConnection.get() == NULL) { + while (serverClientConnections.find(clientId) == + serverClientConnections.end()) { ::usleep(1000 * 1000); } - serverCollector.reset(new Collector( - std::static_pointer_cast(serverClientConnection), "Server")); - serverClientConnection.reset(); + shared_ptr serverCollector( + new Collector(std::static_pointer_cast( + serverClientConnections.find(clientId)->second), + "Server")); serverCollector->start(); - clientCollector.reset( - new Collector(std::static_pointer_cast(clientConnection), "Client")); + shared_ptr clientCollector(new Collector( + std::static_pointer_cast(clientConnection), "Client")); clientCollector->start(); - } - void TearDown() override { - FATAL_FAIL(::remove(pipePath.c_str())); - FATAL_FAIL(::remove(pipeDirectory.c_str())); - } - - void readWriteTest() { const int NUM_MESSAGES = 32; string s(NUM_MESSAGES * 1024, '\0'); for (int a = 0; a < NUM_MESSAGES * 1024; a++) { @@ -196,10 +205,6 @@ class ConnectionTest : public testing::Test { result = clientCollector->read(); EXPECT_EQ(result, "DONE"); - clientConnection->shutdown(); - serverConnection->shutdown(); - stopListening = true; - serverListenThread->join(); serverCollector->finish(); clientCollector->finish(); @@ -209,13 +214,12 @@ class ConnectionTest : public testing::Test { shared_ptr serverSocketHandler; shared_ptr clientSocketHandler; shared_ptr serverConnection; - shared_ptr clientConnection; - shared_ptr serverCollector; - shared_ptr clientCollector; shared_ptr serverListenThread; string pipeDirectory; string pipePath; + SocketEndpoint endpoint; bool stopListening; + const string CRYPTO_KEY = "12345678901234567890123456789012"; }; class ReliableConnectionTest : public ConnectionTest { @@ -230,7 +234,20 @@ class ReliableConnectionTest : public ConnectionTest { } }; -TEST_F(ReliableConnectionTest, ReadWrite) { readWriteTest(); } +TEST_F(ReliableConnectionTest, ReadWrite) { readWriteTest("1234567890123456"); } + +TEST_F(ReliableConnectionTest, MultiReadWrite) { + thread_pool pool(16); + string base_id = "1234567890123456"; + for (int a = 0; a < 16; a++) { + string new_id = base_id; + new_id[0] = 'A' + a; + pool.push([&, this](int id, string clientId) { readWriteTest(clientId); }, + new_id); + ::usleep((500 + rand() % 1000) * 1000); + } + pool.stop(true); +} class FlakyConnectionTest : public ConnectionTest { protected: @@ -252,4 +269,20 @@ class FlakyConnectionTest : public ConnectionTest { } }; -TEST_F(FlakyConnectionTest, ReadWrite) { readWriteTest(); } +TEST_F(FlakyConnectionTest, ReadWrite) { + const string clientId = "1234567890123456"; + readWriteTest(clientId); +} + +TEST_F(FlakyConnectionTest, MultiReadWrite) { + thread_pool pool(16); + string base_id = "1234567890123456"; + for (int a = 0; a < 16; a++) { + string new_id = base_id; + new_id[0] = 'A' + a; + pool.push([&, this](int id, string clientId) { readWriteTest(clientId); }, + new_id); + ::usleep((500 + rand() % 1000) * 1000); + } + pool.stop(true); +}