From dd1f6059cf8928aa6b146094e38e4514a6d57b1e Mon Sep 17 00:00:00 2001 From: dr7ana Date: Sat, 27 Apr 2024 18:39:28 -0700 Subject: [PATCH] Custom client SCID's - Endpoints can now specify a custom SCID to set for the initial connection attempt. This can be accessed by the receiving remote, and can be used to coordinate information in a reliable (not in the transport params) way --- include/oxen/quic/connection.hpp | 5 ++ include/oxen/quic/connection_ids.hpp | 1 - include/oxen/quic/context.hpp | 3 ++ include/oxen/quic/endpoint.hpp | 6 ++- include/oxen/quic/format.hpp | 1 + include/oxen/quic/formattable.hpp | 10 +--- include/oxen/quic/opt.hpp | 51 ++++++++++++++++++- src/connection.cpp | 7 +++ src/connection_ids.cpp | 7 ++- src/context.cpp | 8 +++ tests/001-handshake.cpp | 73 ++++++++++++++++++++++++++++ 11 files changed, 158 insertions(+), 14 deletions(-) diff --git a/include/oxen/quic/connection.hpp b/include/oxen/quic/connection.hpp index a795da82..3d282265 100644 --- a/include/oxen/quic/connection.hpp +++ b/include/oxen/quic/connection.hpp @@ -158,6 +158,7 @@ namespace oxen::quic virtual bool is_inbound() const = 0; virtual bool is_outbound() const = 0; virtual std::string direction_str() = 0; + virtual ustring initial_client_scid() const = 0; // Non-virtual base class wrappers for the virtual methods of the same name with _impl // appended (e.g. path_impl); these versions in the base class wrap the _impl call in a @@ -365,6 +366,10 @@ namespace oxen::quic // alive after it gets dropped from libquic internal structures). void drop_streams(); + // Returns the initial scid chosen by the client, which will be the value of whatever was passed in + // opt::outbound_scid on the client's initiation of the connection + ustring initial_client_scid() const override; + private: // private Constructor (publicly construct via `make_conn` instead, so that we can properly // set up the shared_from_this shenanigans). diff --git a/include/oxen/quic/connection_ids.hpp b/include/oxen/quic/connection_ids.hpp index 512fa35c..52cdbd63 100644 --- a/include/oxen/quic/connection_ids.hpp +++ b/include/oxen/quic/connection_ids.hpp @@ -9,7 +9,6 @@ extern "C" #include "formattable.hpp" #include "types.hpp" -#include "utils.hpp" namespace oxen::quic { diff --git a/include/oxen/quic/context.hpp b/include/oxen/quic/context.hpp index b84a05ea..f95a34bc 100644 --- a/include/oxen/quic/context.hpp +++ b/include/oxen/quic/context.hpp @@ -29,6 +29,8 @@ namespace oxen::quic bool split_packet{false}; // splitting policy Splitting policy{Splitting::NONE}; + // outbound scid + std::optional scid{std::nullopt}; user_config() = default; }; @@ -65,6 +67,7 @@ namespace oxen::quic void handle_ioctx_opt(opt::keep_alive ka); void handle_ioctx_opt(opt::idle_timeout ito); void handle_ioctx_opt(opt::handshake_timeout hto); + void handle_ioctx_opt(opt::outbound_scid scid); void handle_ioctx_opt(stream_data_callback func); void handle_ioctx_opt(stream_open_callback func); void handle_ioctx_opt(stream_close_callback func); diff --git a/include/oxen/quic/endpoint.hpp b/include/oxen/quic/endpoint.hpp index eafb0518..17b7fe2c 100644 --- a/include/oxen/quic/endpoint.hpp +++ b/include/oxen/quic/endpoint.hpp @@ -39,7 +39,7 @@ namespace oxen::quic { static_assert( (0 + ... + std::is_convertible_v, std::shared_ptr>) == 1, - "Endpoint listen/connect require exactly one std::shared_ptr argument"); + "Endpoint::{listen,connect}(...) require exactly one std::shared_ptr argument"); } class Endpoint : public std::enable_shared_from_this @@ -108,8 +108,10 @@ namespace oxen::quic for (;;) { + auto scid = outbound_ctx->config.scid.value_or(quic_cid::random()); + // emplace random CID into lookup keyed to unique reference ID - if (auto [it_a, res_a] = conn_lookup.emplace(quic_cid::random(), next_rid); res_a) + if (auto [it_a, res_a] = conn_lookup.emplace(scid, next_rid); res_a) { qcid = it_a->first; diff --git a/include/oxen/quic/format.hpp b/include/oxen/quic/format.hpp index b21679ce..d25f7ade 100644 --- a/include/oxen/quic/format.hpp +++ b/include/oxen/quic/format.hpp @@ -11,6 +11,7 @@ #include #include "formattable.hpp" +#include "utils.hpp" namespace oxen::quic { diff --git a/include/oxen/quic/formattable.hpp b/include/oxen/quic/formattable.hpp index e67f3530..4154ad7c 100644 --- a/include/oxen/quic/formattable.hpp +++ b/include/oxen/quic/formattable.hpp @@ -2,19 +2,11 @@ #include -// GCC before 10 requires a "bool" keyword in concept; this CONCEPT_COMPAT is empty by default, but -// expands to bool if under such a GCC. -#if (!(defined(__clang__)) && defined(__GNUC__) && __GNUC__ < 10) -#define CONCEPT_COMPAT bool -#else -#define CONCEPT_COMPAT -#endif - namespace oxen::quic { // Types can opt-in to being fmt-formattable by ensuring they have a ::to_string() method defined template - concept CONCEPT_COMPAT ToStringFormattable = requires(T a) { + concept ToStringFormattable = requires(T a) { { a.to_string() } -> std::convertible_to; diff --git a/include/oxen/quic/opt.hpp b/include/oxen/quic/opt.hpp index 46ab92fe..d2f8c2d6 100644 --- a/include/oxen/quic/opt.hpp +++ b/include/oxen/quic/opt.hpp @@ -3,8 +3,8 @@ #include #include "address.hpp" +#include "connection_ids.hpp" #include "crypto.hpp" -#include "types.hpp" namespace oxen::quic { @@ -171,5 +171,54 @@ namespace oxen::quic explicit operator bool() const { return send_hook != nullptr; } }; + + // Used to allow the client to set its initial SCID. During ConnectionID association, the client's initial SCID will + // be the server's initial DCID. This can allow the client to coordinate information to the server prior to any + // stream data being sent down. For example, if a client is attempting to tunnel a connection to a remote port + // through streams, the server will be able to access this port value on it's created inbound connection in its + // connection_established_cb + struct outbound_scid + { + private: + size_t len; + std::array scid; + + template + constexpr outbound_scid(const T* data, size_t l) : len{l} + { + if (len > NGTCP2_MAX_CIDLEN) + throw std::runtime_error{"Max SCID length is 20B"}; + + std::memcpy(scid.data(), data, len); + } + + public: + outbound_scid() = default; + + size_t size() { return scid.size(); } + const uint8_t* data() { return scid.data(); } + + template + constexpr outbound_scid(T view) : outbound_scid{reinterpret_cast(view.data()), view.size()} + {} + + template + constexpr outbound_scid(T val) + { + std::array buf; + + if (auto [res, ec] = std::to_chars(buf.data(), buf.data() + buf.size(), val); ec == std::errc()) + { + len = res - buf.data(); + std::memcpy(scid.data(), buf.data(), len); + } + else + throw std::runtime_error{"outbound_scid input must be string_view-like or int-like!"}; + } + + // This converting operator passes NGTCP2_MAX_CIDLEN as the length because ngtcp2 expects it + explicit operator oxen::quic::quic_cid() const { return oxen::quic::quic_cid{scid.data(), NGTCP2_MAX_CIDLEN}; } + }; + } // namespace opt } // namespace oxen::quic diff --git a/src/connection.cpp b/src/connection.cpp index e5bb87ca..4c266830 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -474,6 +474,13 @@ namespace oxen::quic return tls_session.get(); } + ustring Connection::initial_client_scid() const + { + auto ret = _is_outbound ? _source_cid.to_string() : _dest_cid.to_string(); + auto reusv = to_usv(ret); + return ustring{reusv.data(), reusv.find_first_of('\0')}; + } + void Connection::halt_events() { log::trace(log_cat, "{} called", __PRETTY_FUNCTION__); diff --git a/src/connection_ids.cpp b/src/connection_ids.cpp index c8be2bb7..7a8b9f43 100644 --- a/src/connection_ids.cpp +++ b/src/connection_ids.cpp @@ -15,7 +15,12 @@ namespace oxen::quic std::string quic_cid::to_string() const { - return oxenc::to_hex(data, data + datalen); + if (oxenc::is_hex(data, data + datalen)) + { + return oxenc::to_hex(data, data + datalen); + } + else + return {data, data + datalen}; } quic_cid quic_cid::random() diff --git a/src/context.cpp b/src/context.cpp index d1e9ae74..373c89f5 100644 --- a/src/context.cpp +++ b/src/context.cpp @@ -42,6 +42,14 @@ namespace oxen::quic log::trace(log_cat, "User passed connection handshake_timeout config value: {}", config.handshake_timeout->count()); } + void IOContext::handle_ioctx_opt(opt::outbound_scid scid) + { + if (dir == Direction::INBOUND) + throw std::runtime_error{"Inbound connection contexts cannot store an outbound scid!"}; + + config.scid.emplace(scid); + } + void IOContext::handle_ioctx_opt(stream_data_callback func) { log::trace(log_cat, "IO context stored stream close callback"); diff --git a/tests/001-handshake.cpp b/tests/001-handshake.cpp index 094fdeed..2e06511c 100644 --- a/tests/001-handshake.cpp +++ b/tests/001-handshake.cpp @@ -795,4 +795,77 @@ namespace oxen::quic::test REQUIRE(stream_callback_called); CHECK(*stream_callback_called); } + + TEST_CASE("001 - Custom Client Initial Scid", "[001][scid]") + { + Network test_net{}; + + auto [client_tls, server_tls] = defaults::tls_creds_from_ed_keys(); + + Address server_local{}; + Address client_local1{}; + Address client_local2{}; + + auto client_established1 = callback_waiter{[](connection_interface&) {}}; + auto client_established2 = callback_waiter{[](connection_interface&) {}}; + + // Length must be less than NGTCP2_MAX_CIDLEN (20B) + CHECK_THROWS(opt::outbound_scid{"hello from your neighbor"_usv}); + + auto hello_sv = "goodmorning"_usv; + // The converting constructor to quic_cid will pad it out to NGTCP2_MAX_CIDLEN + opt::outbound_scid scid_str{hello_sv}; + + uint16_t int_like{5685}; + + opt::outbound_scid scid_int{int_like}; + + auto server_endpoint = test_net.endpoint(server_local); + + // Server cannot accept opt::outbound_scid in call to ::listen(...) + CHECK_THROWS(server_endpoint->listen(server_tls, scid_str)); + + CHECK_NOTHROW(server_endpoint->listen(server_tls)); + + RemoteAddress client_remote{defaults::SERVER_PUBKEY, "127.0.0.1"s, server_endpoint->local().port()}; + + auto client_endpoint1 = test_net.endpoint(client_local1, client_established1); + auto client_endpoint2 = test_net.endpoint(client_local2, client_established2); + + auto client_ci1 = client_endpoint1->connect(client_remote, client_tls, scid_str); + CHECK(client_established1.wait()); + + auto client_ci2 = client_endpoint2->connect(client_remote, client_tls, scid_int); + CHECK(client_established2.wait()); + + auto server_conns = server_endpoint->get_all_conns(Direction::INBOUND); + auto server_ci1 = server_conns.front(); + auto server_ci2 = server_conns.back(); + + auto server_ci = server_endpoint->get_all_conns(Direction::INBOUND).front(); + + auto client_str_scid1 = client_ci1->initial_client_scid(); + auto server_str_scid1 = server_ci1->initial_client_scid(); + + REQUIRE(client_str_scid1 == hello_sv); + REQUIRE(server_str_scid1 == hello_sv); + + auto client_str_scid2 = client_ci2->initial_client_scid(); + auto server_str_scid2 = server_ci2->initial_client_scid(); + + uint16_t client_parsed, server_parsed; + REQUIRE(std::from_chars( + reinterpret_cast(client_str_scid2.data()), + reinterpret_cast(client_str_scid2.data()) + client_str_scid2.size(), + client_parsed) + .ec == std::errc()); + REQUIRE(std::from_chars( + reinterpret_cast(server_str_scid2.data()), + reinterpret_cast(server_str_scid2.data()) + server_str_scid2.size(), + server_parsed) + .ec == std::errc()); + + REQUIRE(client_parsed == int_like); + REQUIRE(server_parsed == int_like); + } } // namespace oxen::quic::test