diff --git a/include/oxen/quic/connection.hpp b/include/oxen/quic/connection.hpp index 2a76adf5..e23256dc 100644 --- a/include/oxen/quic/connection.hpp +++ b/include/oxen/quic/connection.hpp @@ -129,9 +129,9 @@ namespace oxen::quic template void send_datagram(std::vector&& buf) { - send_datagram( - std::basic_string_view{buf.data(), buf.size()}, - std::make_shared>(std::move(buf))); + auto keep_alive = std::make_shared>(std::move(buf)); + std::basic_string_view view{keep_alive->data(), keep_alive->size()}; + send_datagram(view, std::move(keep_alive)); } template diff --git a/src/stream.cpp b/src/stream.cpp index 4a6f8e30..9f736129 100644 --- a/src/stream.cpp +++ b/src/stream.cpp @@ -214,7 +214,7 @@ namespace oxen::quic { return endpoint.call_get([this]() { return _is_writing; }); } - + std::shared_ptr Stream::get_conn_interface() { return std::static_pointer_cast(_conn->shared_from_this()); diff --git a/tests/001-handshake.cpp b/tests/001-handshake.cpp index c103114e..23e1f2d0 100644 --- a/tests/001-handshake.cpp +++ b/tests/001-handshake.cpp @@ -3,6 +3,7 @@ #include #include +#include "tcp.hpp" #include "utils.hpp" namespace oxen::quic::test diff --git a/tests/tcp.hpp b/tests/tcp.hpp index 1491d8fc..77b1327d 100644 --- a/tests/tcp.hpp +++ b/tests/tcp.hpp @@ -32,14 +32,35 @@ namespace oxen::quic class TCPHandle; inline const auto LOCALHOST = "127.0.0.1"s; - inline const auto TUNNEL_SEED = oxenc::from_hex("0000000000000000000000000000000000000000000000000000000000000000"); - inline const auto TUNNEL_PUBKEY = oxenc::from_hex("3b6a27bcceb6a42d62a3a8d02a6f0d73653215771de243a63ac048a18b59da29"); + inline constexpr auto TUNNEL_SEED = "0000000000000000000000000000000000000000000000000000000000000000"_hex; + inline constexpr auto TUNNEL_PUBKEY = "3b6a27bcceb6a42d62a3a8d02a6f0d73653215771de243a63ac048a18b59da29"_hex; + + inline constexpr size_t HIGH_WATERMARK{4_Mi}; + inline constexpr size_t LOW_WATERMARK{HIGH_WATERMARK / 4}; + + inline std::vector serialize_payload(bstring_view data, uint16_t port = 0) + { + std::vector ret(data.size() + sizeof(port)); + oxenc::write_host_as_big(port, ret.data()); + std::memcpy(&ret[2], data.data(), data.size()); + return ret; + } + + inline std::tuple deserialize_payload(bstring data) + { + uint16_t p = oxenc::load_big_to_host(data.data()); + + return {p, data.substr(2)}; + } struct TCPQUIC { std::shared_ptr _ci; + std::unordered_set> t; + // keyed against backend tcp address + std::unordered_map>> _tcp_conns2; std::unordered_map> _tcp_conns; }; @@ -58,17 +79,62 @@ namespace oxen::quic evconnlistener_free(e); }; + void tcp_drained_write_cb(struct bufferevent* bev, void* user_arg); + void tcp_read_cb(struct bufferevent* bev, void* user_arg); + void tcp_event_cb(struct bufferevent* bev, short what, void* user_arg); + void tcp_listen_cb( struct evconnlistener* listener, evutil_socket_t fd, struct sockaddr* src, int socklen, void* user_arg); + void tcp_err_cb(struct evconnlistener* listener, void* user_arg); struct TCPConnection { TCPConnection(struct bufferevent* _bev, evutil_socket_t _fd, std::shared_ptr _s) : bev{_bev}, fd{_fd}, stream{std::move(_s)} - {} + { + stream->set_stream_data_cb([this](oxen::quic::Stream& s, bstring_view data) { + auto rv = bev ? bufferevent_write(bev, data.data(), data.size()) : -1; + log::info( + test_cat, + "Stream (id: {}) {} {}B to TCP buffer", + s.stream_id(), + rv < 0 ? "failed to write" : "successfully wrote", + data.size()); + + // we get the output buffer (it sounds backwards but it isn't) + if (evbuffer_get_length(bufferevent_get_output(bev)) >= HIGH_WATERMARK) + { + log::info( + test_cat, "TCP input buffer over high-water threshold ({}); pausing stream...", HIGH_WATERMARK); + s.pause(); + + bufferevent_setcb(bev, tcp_read_cb, tcp_drained_write_cb, tcp_event_cb, this); + bufferevent_setwatermark(bev, EV_WRITE, LOW_WATERMARK, HIGH_WATERMARK); + } + }); + + stream->set_stream_close_cb([this](Stream&, uint64_t) { + log::info( + test_cat, + "Stream closed cb fired, {}...", + bev ? "freeing bufferevent" : "bufferevent already freed"); + if (bev) + bufferevent_free(bev); + }); + + stream->set_remote_reset_hooks(opt::remote_stream_reset{ + [](Stream& s, uint64_t) { + log::info(test_cat, "Remote stream signalled reading termination; halting local stream write!"); + s.stop_writing(); + }, + [](Stream& s, uint64_t) { + log::info(test_cat, "Remote stream signalled writing termination; halting local stream read!"); + s.stop_reading(); + }}); + } TCPConnection() = delete; @@ -163,37 +229,18 @@ namespace oxen::quic // returns the socket address of the TCP connection std::optional
connect() const { return _connect; } - std::shared_ptr connect_to_backend(std::shared_ptr s, Address addr) + std::shared_ptr connect_to_backend(std::shared_ptr stream, Address addr) { if (addr.port() == 0) throw std::runtime_error{"TCP backend must have valid port on localhost!"}; - log::critical(test_cat, "Attempting TCP connection to backend at: {}", addr); + log::info(test_cat, "Attempting TCP connection to backend at: {}", addr); sockaddr_in _addr = addr.in4(); struct bufferevent* _bev = bufferevent_socket_new(_ev->loop().get(), -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); - s->set_stream_data_cb([_bev](oxen::quic::Stream& s, bstring_view data) { - auto rv = _bev ? bufferevent_write(_bev, data.data(), data.size()) : -1; - log::info( - test_cat, - "Stream (id: {}) {} {}B to TCP buffer", - s.stream_id(), - rv < 0 ? "failed to write" : "successfully wrote", - data.size()); - }); - - s->set_stream_close_cb([_bev](Stream&, uint64_t) { - log::critical( - test_cat, - "Stream closed cb fired, {}...", - _bev ? "freeing bufferevent" : "bufferevent already freed"); - if (_bev) - bufferevent_free(_bev); - }); - - auto tcp_conn = std::make_shared(_bev, -1, std::move(s)); + auto tcp_conn = std::make_shared(_bev, -1, std::move(stream)); bufferevent_setcb(_bev, tcp_read_cb, nullptr, tcp_event_cb, tcp_conn.get()); bufferevent_enable(_bev, EV_READ | EV_WRITE); @@ -260,13 +307,25 @@ namespace oxen::quic evconnlistener_set_error_cb(_tcp_listener.get(), tcp_err_cb); - log::critical(test_cat, "TCPHandle set up listener on: {}", *_bound); + log::info(test_cat, "TCPHandle set up listener on: {}", *_bound); } }; + inline void tcp_drained_write_cb(struct bufferevent* bev, void* user_arg) + { + bufferevent_setcb(bev, tcp_read_cb, nullptr, tcp_event_cb, user_arg); + bufferevent_setwatermark(bev, EV_WRITE, 0, 0); + + auto* conn = reinterpret_cast(user_arg); + assert(conn); + + log::info(test_cat, "TCP input buffer below low-water threshold ({}); resuming stream!", LOW_WATERMARK); + conn->stream->resume(); + } + inline void tcp_read_cb(struct bufferevent* bev, void* user_arg) { - std::array buf{}; + std::array buf{}; // Load data from input buffer to local buffer auto nwrite = bufferevent_read(bev, buf.data(), buf.size()); @@ -277,22 +336,38 @@ namespace oxen::quic { auto* conn = reinterpret_cast(user_arg); assert(conn); + auto& stream = conn->stream; + assert(stream); + + stream->send(ustring{buf.data(), nwrite}); - conn->stream->send(ustring{(buf.data()), nwrite}); + if (stream->unsent() >= HIGH_WATERMARK) + { + stream->set_watermark( + LOW_WATERMARK, + HIGH_WATERMARK, + opt::watermark{ + [bev](Stream&) { + log::info(test_cat, "Stream buffer below low-water threshold; enabling TCP read!"); + bufferevent_enable(bev, EV_READ); + }, + false}, + opt::watermark{ + [bev](Stream&) { + log::info(test_cat, "Stream buffer above high-water threshold; disabling TCP read!"); + bufferevent_disable(bev, EV_READ); + }, + false}); + } } } inline void tcp_event_cb([[maybe_unused]] struct bufferevent* bev, short what, void* user_arg) { - // this is where the InboundSession confirms it established a TCP connection to the backend app if (what & BEV_EVENT_CONNECTED) { log::info(test_cat, "TCP connect operation succeeded!"); } - if (what & BEV_EVENT_EOF) - { - log::critical(test_cat, "TCP Connection EOF!"); - } if (what & BEV_EVENT_ERROR) { log::critical( @@ -300,25 +375,35 @@ namespace oxen::quic "TCP Connection encountered bufferevent error (msg: {})!", evutil_socket_error_to_string(EVUTIL_SOCKET_ERROR())); } - if (what & (BEV_EVENT_ERROR | BEV_EVENT_EOF)) - { - // if (bev) - // { - // log::critical(test_cat, "Freeing bufferevent socket..."); - // bufferevent_free(bev); - // } - auto* conn = reinterpret_cast(user_arg); - assert(conn); + auto* conn = reinterpret_cast(user_arg); + assert(conn); + auto& stream = conn->stream; + if (what & BEV_EVENT_EOF) + { + if (what & BEV_EVENT_WRITING) + { + // remote shut down reading + log::info(test_cat, "Remote TCP stopped reading! Halting stream write..."); + stream->stop_writing(); + } + else if (what & BEV_EVENT_READING) + { + // remote shut down writing + log::info(test_cat, "Error encountered while reading! Halting stream read..."); + stream->stop_reading(); + } + else + { + // remote closed connection + log::info(test_cat, "TCP Connection EOF!"); + } + } + if (what & (BEV_EVENT_ERROR | BEV_EVENT_EOF) and not(what & BEV_EVENT_READING) and not(what & BEV_EVENT_WRITING)) + { log::critical(test_cat, "Closing stream..."); - conn->stream->close(); - // auto& str = conn->stream; - // if (str and not str->is_closing()) - // { - // return str->close(); - // } - // log::critical(test_cat, "Stream for tcp connection already destroyed..."); + stream->close(); } } @@ -326,43 +411,17 @@ namespace oxen::quic struct evconnlistener* listener, evutil_socket_t fd, struct sockaddr* src, int socklen, void* user_arg) { oxen::quic::Address source{src, static_cast(socklen)}; - log::critical(test_cat, "TCP CONNECTION ESTABLISHED -- SRC: {}", source); + log::info(test_cat, "TCP CONNECTION ESTABLISHED -- SRC: {}", source); auto* b = evconnlistener_get_base(listener); auto* _bev = bufferevent_socket_new(b, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); - // int yes{1}; - // if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &yes, sizeof(int)) < 0) - // { - // log::critical( - // test_cat, - // "Failed to set keepalive on accepted TCP connection socket: {}", - // evutil_socket_error_to_string(EVUTIL_SOCKET_ERROR())); - // return bufferevent_free(bevent); - // } - auto* handle = reinterpret_cast(user_arg); assert(handle); // make TCPConnection here! auto* conn = handle->_conn_maker(_bev, fd, std::move(source)); - - conn->stream->set_stream_data_cb([_bev](Stream& s, bstring_view data) { - auto rv = _bev ? bufferevent_write(_bev, data.data(), data.size()) : -1; - log::info( - test_cat, - "Stream (id: {}) {} {}B to TCP buffer", - s.stream_id(), - rv < 0 ? "failed to write" : "successfully wrote", - data.size()); - }); - - conn->stream->set_stream_close_cb([_bev](Stream&, uint64_t) { - log::critical( - test_cat, "Stream closed cb fired, {}...", _bev ? "freeing bufferevent" : "bufferevent already freed"); - if (_bev) - bufferevent_free(_bev); - }); + auto stream = conn->stream; bufferevent_setcb(_bev, tcp_read_cb, nullptr, tcp_event_cb, conn); bufferevent_enable(_bev, EV_READ | EV_WRITE); diff --git a/tests/tunnel-client.cpp b/tests/tunnel-client.cpp index 27d60abf..0f4aa3c3 100644 --- a/tests/tunnel-client.cpp +++ b/tests/tunnel-client.cpp @@ -36,7 +36,7 @@ int main(int argc, char* argv[]) std::vector
remote_addrs{{LOCALHOST, 4444}, {LOCALHOST, 4455}, {LOCALHOST, 4466}}; // TODO: make this a CLI arg and generate all the addresses? - const int num_conns{3}; + const int num_conns{1}; std::atomic current_conn{0}; @@ -55,23 +55,19 @@ int main(int argc, char* argv[]) for (auto& r : remote_addrs) connect_addrs.push_back(RemoteAddress{TUNNEL_PUBKEY, r}); - // Paths from manual client to remote manual server - std::vector paths{}; + // Paths from manual client to remote manual server keyed to remote port + std::unordered_map paths; - // for (auto& r : remote_addrs) - // paths.push_back(Path{manual_client_local, r}); + for (auto& r : remote_addrs) + paths.emplace(r.port(), Path{localhost_blank, r}); /** key: remote address to which we are connecting value: tunneled quic connection */ std::unordered_map _tunnels; - // callback_waiter initial_tunnel_established{ - // [](connection_interface&) { log::info(test_cat, "Initial tunnel established"); }}; - auto manual_client_established = [&](connection_interface& ci) { auto path = ci.path(); - paths.push_back(ci.path()); // make a copy for the list auto& remote = path.remote; auto _handle = TCPHandle::make_server( @@ -94,9 +90,10 @@ int main(int argc, char* argv[]) log::info(test_cat, "Opening TCPConnection..."); auto tcp_conn = std::make_shared(_bev, _fd, std::move(s)); - auto [it, _] = tcp_quic._tcp_conns.insert_or_assign(src, std::move(tcp_conn)); + auto [it, _] = tcp_quic._tcp_conns2[src].insert(std::move(tcp_conn)); + // auto [it, _] = tcp_quic._tcp_conns.insert_or_assign(src, std::move(tcp_conn)); - return it->second.get(); + return it->get(); } throw std::runtime_error{"Could not find paired TCP-QUIC for remote port:{}"_format(remote.port())}; } @@ -137,13 +134,19 @@ int main(int argc, char* argv[]) { std::shared_ptr tunnel_ci; - auto manual_client = - client_net.endpoint(manual_client_local, opt::manual_routing{[&](const Path& p, bstring_view data) { - tunnel_ci->send_datagram(Packet(p, bstring{data}).bt_encode()); - }}); + auto manual_client = client_net.endpoint(localhost_blank, opt::manual_routing{[&](const Path& p, bstring_view data) { + log::debug(log_cat, "client manual send path: {}", p); + log::debug(log_cat, "client manual sending {}B", data.size()); + tunnel_ci->send_datagram(serialize_payload(data, p.remote.port())); + }}); - dgram_data_callback recv_dgram_cb = [&](dgram_interface&, bstring data) { - manual_client->manually_receive_packet(*Packet::bt_decode(std::move(data))); + dgram_data_callback recv_dgram_cb = [&](dgram_interface&, bstring buf) { + auto [p, data] = deserialize_payload(std::move(buf)); + + if (auto it = paths.find(p); it != paths.end()) + manual_client->manually_receive_packet(Packet{it->second, data}); + else + throw std::runtime_error{"Could not find path for route to remote port:{}"_format(p)}; }; auto tunnel_client_established = callback_waiter{ @@ -159,9 +162,6 @@ int main(int argc, char* argv[]) tunnel_ci = tunnel_client->connect(tunnel_server_addr, client_tls, opt::keep_alive{10s}, tunnel_client_established); tunnel_client_established.wait(); - // manual_client->connect(RemoteAddress{TUNNEL_PUBKEY, localhost_blank}, client_tls, initial_tunnel_established); - // initial_tunnel_established.wait(); - for (int i = 0; i < num_conns; ++i) { manual_client->connect(connect_addrs[i], client_tls, manual_client_established, opt::keep_alive{10s}); diff --git a/tests/tunnel-server.cpp b/tests/tunnel-server.cpp index d3673a88..d6a2b400 100644 --- a/tests/tunnel-server.cpp +++ b/tests/tunnel-server.cpp @@ -30,16 +30,12 @@ int main(int argc, char* argv[]) Address localhost_blank{LOCALHOST, 0}, manual_server_local1{LOCALHOST, 4444}, manual_server_local2{LOCALHOST, 4455}, manual_server_local3{LOCALHOST, 4466}; - std::unordered_map localport_to_backendtcp{ - {manual_server_local1.port(), backend_tcp1}, - {manual_server_local2.port(), backend_tcp2}, - {manual_server_local3.port(), backend_tcp3}}; + std::unordered_map> localport_to_backendpair{ + {manual_server_local1.port(), {manual_server_local1, backend_tcp1}}, + {manual_server_local2.port(), {manual_server_local2, backend_tcp2}}, + {manual_server_local3.port(), {manual_server_local3, backend_tcp3}}}; - std::unordered_map backendport_to_localaddr{ - // - }; - - // std::atomic initial_tunnel{false}; + std::unordered_map localport_to_route{}; auto server_tls = GNUTLSCreds::make_from_ed_keys(TUNNEL_SEED, TUNNEL_PUBKEY); @@ -55,14 +51,15 @@ int main(int argc, char* argv[]) auto path = s.path().invert(); auto& remote = path.remote; auto& local = path.local; + auto localport = local.port(); Address backend_addr; - if (auto it = localport_to_backendtcp.find(local.port()); it != localport_to_backendtcp.end()) - backend_addr = it->second; + if (auto it = localport_to_backendpair.find(local.port()); it != localport_to_backendpair.end()) + backend_addr = std::get<1>(it->second); else throw std::runtime_error{"You fucked your mapping up dan (local:{}, remote:{})"_format(local, remote)}; - log::info(test_cat, "Inbound new stream to local port {} routing to backend at: {}", local.port(), backend_addr); + log::info(test_cat, "Inbound new stream to local port {} routing to backend at: {}", localport, backend_addr); if (auto it_a = _tunnels.find(local); it_a != _tunnels.end()) { @@ -73,12 +70,13 @@ int main(int argc, char* argv[]) auto tcp_conn = _handle->connect_to_backend(s.shared_from_this(), backend_addr); // search for local manual_server port extracted from the path - if (auto it_b = conns.find(local.port()); it_b != conns.end()) + if (auto it_b = conns.find(localport); it_b != conns.end()) { - it_b->second._tcp_conns.insert_or_assign(backend_addr, std::move(tcp_conn)); + it_b->second._tcp_conns2[backend_addr].insert(std::move(tcp_conn)); + // it_b->second._tcp_conns.insert_or_assign(backend_addr, std::move(tcp_conn)); } else - throw std::runtime_error{"Could not find paired TCP-QUIC for local port:{}"_format(local.port())}; + throw std::runtime_error{"Could not find paired TCP-QUIC for local port:{}"_format(localport)}; } else throw std::runtime_error{"Could not find tunnel to local:{}!"_format(local)}; @@ -87,20 +85,18 @@ int main(int argc, char* argv[]) }; auto manual_server_established = [&](connection_interface& ci) { + // set up the routing lookup; this expects the inverted path as set by the client + auto route_path = ci.path(); + // Path needs inversion because it is set by the client in manual routing - auto path = ci.path().invert(); - auto& remote = path.remote; - auto& local = path.local; + auto& remote = route_path.local; + auto& local = route_path.remote; + auto localport = local.port(); - // if (not initial_tunnel) - // { - // log::info(test_cat, "Server established initial connection with remote endpoint!"); - // initial_tunnel = true; - // return; - // } + // store in lookup + localport_to_route.emplace(localport, route_path); - log::debug(test_cat, "Server: (local:{}, remote:{})", local, remote); - log::critical(test_cat, "Manual server established connection (path: {})...", path); + log::critical(test_cat, "Manual server established connection (local:{} -> remote:{})...", local, remote); auto _handle = TCPHandle::make_client(server_net.loop()); @@ -114,8 +110,7 @@ int main(int argc, char* argv[]) tcp_quic._ci = ci.shared_from_this(); // map against local manual server port - // tunneled_conn.conns[local.port()] = std::move(tcp_quic); - if (auto [_, b] = tunneled_conn.conns.emplace(local.port(), std::move(tcp_quic)); not b) + if (auto [_, b] = tunneled_conn.conns.emplace(localport, std::move(tcp_quic)); not b) throw std::runtime_error{"Failed to emplace tunneled_connection!"}; _tunnels.emplace(local, std::move(tunneled_conn)); @@ -129,11 +124,29 @@ int main(int argc, char* argv[]) std::shared_ptr tunnel_ci; auto manual_server = server_net.endpoint(localhost_blank, opt::manual_routing{[&](const Path& p, bstring_view data) { - tunnel_ci->send_datagram(Packet(p, bstring{data}).bt_encode()); + log::debug(log_cat, "server manual send path: {}", p); + tunnel_ci->send_datagram(serialize_payload(data, p.remote.port())); }}); - dgram_data_callback recv_dgram_cb = [&](dgram_interface&, bstring data) { - manual_server->manually_receive_packet(*Packet::bt_decode(std::move(data))); + dgram_data_callback recv_dgram_cb = [&](dgram_interface&, bstring buf) { + auto [p, data] = deserialize_payload(buf); + Path path; + + if (auto it = localport_to_route.find(p); it != localport_to_route.end()) + path = it->second; + else + { + if (auto it = localport_to_backendpair.find(p); it != localport_to_backendpair.end()) + { + path = Path{localhost_blank, std::get<0>(it->second)}; + auto [itr, _] = localport_to_route.emplace(p, path); + log::info(log_cat, "server manual mapping port:{} to route:{}", p, itr->second); + } + else + throw std::runtime_error{"Could not find backend pair for port:{}"_format(p)}; + } + + manual_server->manually_receive_packet(Packet{path, std::move(data)}); }; std::promise tunnel_prom; diff --git a/utils/tcp_cannon.py b/utils/tcp_cannon.py index b53dff99..8fff34a5 100755 --- a/utils/tcp_cannon.py +++ b/utils/tcp_cannon.py @@ -11,7 +11,7 @@ parser = argparse.ArgumentParser("Simple TCP Cannon") parser.add_argument( "--size", - default=400, + default=40000, help="The number of bytes to send", type=int, ) @@ -100,12 +100,13 @@ connected = True while not message_sent: - print("Constructing msg of size {}B".format(SENDSIZE)) msg = b"" for i in range(SENDSIZE): msg += random.randint(0, 9).to_bytes() + print("Constructing msg of size {}B".format(SENDSIZE)) + if len(msg) > 0: print("\nSending message...") clientsocket.sendall(msg) @@ -121,7 +122,9 @@ print("EOF reached!") if len(buf) > 0: - print("\nReceived {}B in response!".format(len(buf))) + # print("\nReceived {}B in response!".format(len(buf))) + print("Response received:\n") + print(buf.decode()) buf = b"" received = True diff --git a/utils/tcp_client.py b/utils/tcp_client.py index 76cf6309..1c19b6ab 100755 --- a/utils/tcp_client.py +++ b/utils/tcp_client.py @@ -113,7 +113,7 @@ # explicitly conditional on this so connection failures will not enter this and loop around/restart while not received: print("Awaiting response...") - buf = clientsocket.recv(2048).strip() + buf = clientsocket.recv(4096).strip() if len(buf) == 0: print("EOF reached!") @@ -142,44 +142,6 @@ message_sent = False received = False - # while received: - # msg = input("Enter message to tunnel to remote...\n") - - # if msg.lower() == "q": - # print("Closing client and exiting...") - # if connected: - # clientsocket.shutdown(socket.SHUT_RDWR) - # clientsocket.close() - # sys.exit() - - # if len(msg) > 0: - # print("\nSending message...") - # clientsocket.sendall(bytes(msg, encoding="utf8")) - # msg = "" - # received = False - - # # explicitly conditional on this so connection failures will not enter this and loop around/restart - # while not received: - # print("Awaiting response...") - # buf = clientsocket.recv(2048).strip() - - # if len(buf) == 0: - # print("EOF reached!") - - # if len(buf) > 0: - # print("Response received:\n") - # print(buf.decode()) - - # break - - # clientsocket.shutdown(socket.SHUT_RDWR) - # clientsocket.close() - # print("\nClient connection closed\n") - - # awaiting_input = True - # connected = False - # received = True - except KeyboardInterrupt or ConnectionError or ConnectionResetError: print("Shutting down TCP client...") diff --git a/utils/tcp_server.py b/utils/tcp_server.py index cb2eece6..76cec007 100755 --- a/utils/tcp_server.py +++ b/utils/tcp_server.py @@ -2,36 +2,9 @@ import argparse import socket -import socketserver import sys - -class MyTCPHandler(socketserver.BaseRequestHandler): - # def setup(self): - # print("Configuring socket as non-blocking...") - # self.request.setblocking(0) - - def handle(self): - # buf = b"" - - # while True: - # read = self.request.recv(4096) - # if read == b"": - # break - # buf += read - # read = b"" - - self.data = self.request.recv(4096).strip() - - print( - "Received {}B received from {}:{}".format( - len(self.data), self.client_address[0], self.client_address[1] - ) - ) - - self.request.sendall(self.data) - self.request.close() - +READSIZE = 4096 parser = argparse.ArgumentParser("Simple TCP Server") parser.add_argument( @@ -47,7 +20,6 @@ def handle(self): type=int, ) - if __name__ == "__main__": argvars = vars(parser.parse_args()) @@ -56,12 +28,41 @@ def handle(self): print("Starting TCP server at {}:{}...".format(LOCALIP, LOCALPORT)) - try: - with socketserver.TCPServer((LOCALIP, LOCALPORT), MyTCPHandler) as server: - server.serve_forever() + serversock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + serversock.bind((LOCALIP, LOCALPORT)) + serversock.listen(1) + + while True: + try: + conn, addr = serversock.accept() + remote = "{}:{}".format(addr[0], addr[1]) + + print("Accepted connection from {}".format(remote)) + + buf = b"" + + while True: + b = conn.recv(READSIZE).strip() + buf += b + if not b: + break + + size = len(buf) + print("Received {}B from {}".format(size, remote)) + + conn.sendall( + bytes("{}B successfully received!".format(size), encoding="utf8") + ) + conn.close() + + print("Connection to {} closed!".format(remote)) + + except socket.error: + print("Remote {} disconnected! Continuing...".format(remote)) + pass - except KeyboardInterrupt: - print("Shutting down TCP server...") - server.shutdown() - server.socket.close() - sys.exit() + except KeyboardInterrupt: + print("Shutting down TCP server...") + serversock.shutdown(socket.SHUT_RDWR) + serversock.close() + sys.exit() diff --git a/utils/tcp_speedclient.py b/utils/tcp_speedclient.py new file mode 100755 index 00000000..4cd381d5 --- /dev/null +++ b/utils/tcp_speedclient.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 + +import argparse +import random +import socket +import sys +import os +import time + +LOCALHOST = "127.0.0.1" + +_kibibytes = 1024 +_mibibytes = 1024 * _kibibytes +_gibibytes = 1024 * _mibibytes + +DEFAULT_SENDSIZE = 40 * _mibibytes + +parser = argparse.ArgumentParser("Simple TCP Cannon") +parser.add_argument( + "--size", + default=DEFAULT_SENDSIZE, + help="The number of bytes to send", + type=int, +) +parser.add_argument( + "--remoteip", + default=LOCALHOST, + help="The remote IP address to which the TCP client should connect to", + type=str, +) +parser.add_argument( + "--remoteport", + required=True, + help="The remote port to which the TCP client should connect to", + type=int, +) + +if __name__ == "__main__": + argvars = vars(parser.parse_args()) + + SENDSIZE = argvars["size"] + REMOTEIP = argvars["remoteip"] + REMOTEPORT = argvars["remoteport"] + + connected = False + + # outer try/except to catch SIGINT, connection errors + try: + if SENDSIZE <= 0: + raise RuntimeError("SENDSIZE must be greater than 0!") + + clientsocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + print("Pregenerating msg of size {}B...".format(SENDSIZE)) + + msg = bytearray(os.urandom(SENDSIZE)) + + # msg = b"" + + # for i in range(SENDSIZE): + # msg += random.randint(0, 9).to_bytes() + + print("\nTCP Client connecting to {}:{}...".format(REMOTEIP, REMOTEPORT)) + + t1 = time.time() + + clientsocket.connect((REMOTEIP, int(REMOTEPORT))) + + t2 = time.time() + connected = True + + print("Sending payload...") + clientsocket.sendall(msg) + + clientsocket.shutdown(socket.SHUT_WR) + t3 = time.time() + + print("Payload away...") + + buf = clientsocket.recv(4096).strip() + t4 = time.time() + + ping = ((t2 - t1) + (t4 - t3)) / 2 + time = t3 - t2 + bandwidth = (SENDSIZE / time) * 2e-6 + + print("\nPayload Transmitted:") + print("Ping: {}".format(ping)) + print("Time: {}".format(time)) + print("Bandwidth (MB/s): {}".format(bandwidth)) + + print(buf.decode()) + + except KeyboardInterrupt or ConnectionError or ConnectionResetError or RuntimeError: + print("Shutting down TCP client...") + + if connected: + clientsocket.shutdown(socket.SHUT_RDWR) + + clientsocket.close() + sys.exit()