Skip to content

Commit

Permalink
multi-client tests (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
MisterTea authored Oct 5, 2018
1 parent 16fa301 commit d862d1a
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 33 deletions.
3 changes: 2 additions & 1 deletion src/base/ServerConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ bool ServerConnection::acceptNewConnection(int fd) {
return false;
}
VLOG(1) << "SERVER: got client socket fd: " << clientSocketFd;
clientHandlerThreadPool.push([&, this](int id) { clientHandler(clientSocketFd); });
clientHandlerThreadPool.push(
[&, this](int id) { clientHandler(clientSocketFd); });
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions src/base/SocketEndpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
namespace et {
class SocketEndpoint {
public:
SocketEndpoint() : name(""), port(-1) {}

explicit SocketEndpoint(const string &_name) : name(_name), port(-1) {}

explicit SocketEndpoint(int _port) : name(""), port(_port) {}
Expand Down
97 changes: 65 additions & 32 deletions test/ConnectionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class Collector {
}

void finish() {
connection->shutdown();
lock_guard<std::mutex> guard(collectorMutex);
done = true;
collectorThread->join();
connection->shutdown();
}

bool hasData() {
Expand Down Expand Up @@ -110,12 +110,17 @@ void listenFn(bool* stopListening, int serverFd,
}
}

shared_ptr<ServerClientConnection> serverClientConnection;
map<string, shared_ptr<ServerClientConnection>> serverClientConnections;
class NewConnectionHandler : public ServerConnectionHandler {
public:
virtual bool newClient(
shared_ptr<ServerClientConnection> _serverClientState) {
serverClientConnection = _serverClientState;
string clientId = _serverClientState->getId();
if (serverClientConnections.find(clientId) !=
serverClientConnections.end()) {
LOG(FATAL) << "TRIED TO CREATE DUPLICATE CLIENT ID";
}
serverClientConnections[clientId] = _serverClientState;
return true;
}
};
Expand All @@ -124,56 +129,60 @@ class ConnectionTest : public testing::Test {
protected:
void SetUp() override {
el::Helpers::setThreadName("Main");
const string CRYPTO_KEY = "12345678901234567890123456789012";
const string CLIENT_ID = "1234567890123456";

string tmpPath = string("/tmp/et_test_XXXXXXXX");
pipeDirectory = string(mkdtemp(&tmpPath[0]));
pipePath = string(pipeDirectory) + "/pipe";
SocketEndpoint endpoint(pipePath);
endpoint = SocketEndpoint(pipePath);

serverConnection.reset(new ServerConnection(
serverSocketHandler, endpoint,
shared_ptr<ServerConnectionHandler>(new NewConnectionHandler())));
serverConnection->addClientKey(CLIENT_ID, CRYPTO_KEY);

int serverFd = *(serverSocketHandler->getEndpointFds(endpoint).begin());
stopListening = false;
serverListenThread.reset(
new std::thread(listenFn, &stopListening, serverFd, serverConnection));
}

void TearDown() override {
stopListening = true;
serverListenThread->join();
serverClientConnections.clear();
FATAL_FAIL(::remove(pipePath.c_str()));
FATAL_FAIL(::remove(pipeDirectory.c_str()));
}

void readWriteTest(const string& clientId) {
serverConnection->addClientKey(clientId, CRYPTO_KEY);
// Wait for server to spin up
::usleep(1000 * 1000);
clientConnection.reset(new ClientConnection(clientSocketHandler, endpoint,
CLIENT_ID, CRYPTO_KEY));

shared_ptr<ClientConnection> clientConnection(new ClientConnection(
clientSocketHandler, endpoint, clientId, CRYPTO_KEY));
while (true) {
try {
clientConnection->connect();
break;
} catch (const std::runtime_error& ex) {
LOG(INFO) << "Connection failed, retrying...";
::usleep(10 * 1000);
::usleep(1000 * 1000);
}
}

while(serverClientConnection.get() == NULL) {
while (serverClientConnections.find(clientId) ==
serverClientConnections.end()) {
::usleep(1000 * 1000);
}
serverCollector.reset(new Collector(
std::static_pointer_cast<Connection>(serverClientConnection), "Server"));
serverClientConnection.reset();
shared_ptr<Collector> serverCollector(
new Collector(std::static_pointer_cast<Connection>(
serverClientConnections.find(clientId)->second),
"Server"));
serverCollector->start();
clientCollector.reset(
new Collector(std::static_pointer_cast<Connection>(clientConnection), "Client"));
shared_ptr<Collector> clientCollector(new Collector(
std::static_pointer_cast<Connection>(clientConnection), "Client"));
clientCollector->start();
}

void TearDown() override {
FATAL_FAIL(::remove(pipePath.c_str()));
FATAL_FAIL(::remove(pipeDirectory.c_str()));
}

void readWriteTest() {
const int NUM_MESSAGES = 32;
string s(NUM_MESSAGES * 1024, '\0');
for (int a = 0; a < NUM_MESSAGES * 1024; a++) {
Expand All @@ -196,10 +205,6 @@ class ConnectionTest : public testing::Test {
result = clientCollector->read();
EXPECT_EQ(result, "DONE");

clientConnection->shutdown();
serverConnection->shutdown();
stopListening = true;
serverListenThread->join();
serverCollector->finish();
clientCollector->finish();

Expand All @@ -209,13 +214,12 @@ class ConnectionTest : public testing::Test {
shared_ptr<SocketHandler> serverSocketHandler;
shared_ptr<SocketHandler> clientSocketHandler;
shared_ptr<ServerConnection> serverConnection;
shared_ptr<ClientConnection> clientConnection;
shared_ptr<Collector> serverCollector;
shared_ptr<Collector> clientCollector;
shared_ptr<std::thread> serverListenThread;
string pipeDirectory;
string pipePath;
SocketEndpoint endpoint;
bool stopListening;
const string CRYPTO_KEY = "12345678901234567890123456789012";
};

class ReliableConnectionTest : public ConnectionTest {
Expand All @@ -230,7 +234,20 @@ class ReliableConnectionTest : public ConnectionTest {
}
};

TEST_F(ReliableConnectionTest, ReadWrite) { readWriteTest(); }
TEST_F(ReliableConnectionTest, ReadWrite) { readWriteTest("1234567890123456"); }

TEST_F(ReliableConnectionTest, MultiReadWrite) {
thread_pool pool(16);
string base_id = "1234567890123456";
for (int a = 0; a < 16; a++) {
string new_id = base_id;
new_id[0] = 'A' + a;
pool.push([&, this](int id, string clientId) { readWriteTest(clientId); },
new_id);
::usleep((500 + rand() % 1000) * 1000);
}
pool.stop(true);
}

class FlakyConnectionTest : public ConnectionTest {
protected:
Expand All @@ -252,4 +269,20 @@ class FlakyConnectionTest : public ConnectionTest {
}
};

TEST_F(FlakyConnectionTest, ReadWrite) { readWriteTest(); }
TEST_F(FlakyConnectionTest, ReadWrite) {
const string clientId = "1234567890123456";
readWriteTest(clientId);
}

TEST_F(FlakyConnectionTest, MultiReadWrite) {
thread_pool pool(16);
string base_id = "1234567890123456";
for (int a = 0; a < 16; a++) {
string new_id = base_id;
new_id[0] = 'A' + a;
pool.push([&, this](int id, string clientId) { readWriteTest(clientId); },
new_id);
::usleep((500 + rand() % 1000) * 1000);
}
pool.stop(true);
}

0 comments on commit d862d1a

Please sign in to comment.