Skip to content

Commit

Permalink
Custom client SCID's
Browse files Browse the repository at this point in the history
- Endpoints can now specify a custom SCID to set for the initial connection attempt. This can be accessed by the receiving remote, and can be used to coordinate information in a reliable (not in the transport params) way
  • Loading branch information
dr7ana committed Apr 28, 2024
1 parent 826c6db commit dd1f605
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 14 deletions.
5 changes: 5 additions & 0 deletions include/oxen/quic/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ namespace oxen::quic
virtual bool is_inbound() const = 0;
virtual bool is_outbound() const = 0;
virtual std::string direction_str() = 0;
virtual ustring initial_client_scid() const = 0;

// Non-virtual base class wrappers for the virtual methods of the same name with _impl
// appended (e.g. path_impl); these versions in the base class wrap the _impl call in a
Expand Down Expand Up @@ -365,6 +366,10 @@ namespace oxen::quic
// alive after it gets dropped from libquic internal structures).
void drop_streams();

// Returns the initial scid chosen by the client, which will be the value of whatever was passed in
// opt::outbound_scid on the client's initiation of the connection
ustring initial_client_scid() const override;

private:
// private Constructor (publicly construct via `make_conn` instead, so that we can properly
// set up the shared_from_this shenanigans).
Expand Down
1 change: 0 additions & 1 deletion include/oxen/quic/connection_ids.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ extern "C"

#include "formattable.hpp"
#include "types.hpp"
#include "utils.hpp"

namespace oxen::quic
{
Expand Down
3 changes: 3 additions & 0 deletions include/oxen/quic/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace oxen::quic
bool split_packet{false};
// splitting policy
Splitting policy{Splitting::NONE};
// outbound scid
std::optional<quic_cid> scid{std::nullopt};

user_config() = default;
};
Expand Down Expand Up @@ -65,6 +67,7 @@ namespace oxen::quic
void handle_ioctx_opt(opt::keep_alive ka);
void handle_ioctx_opt(opt::idle_timeout ito);
void handle_ioctx_opt(opt::handshake_timeout hto);
void handle_ioctx_opt(opt::outbound_scid scid);
void handle_ioctx_opt(stream_data_callback func);
void handle_ioctx_opt(stream_open_callback func);
void handle_ioctx_opt(stream_close_callback func);
Expand Down
6 changes: 4 additions & 2 deletions include/oxen/quic/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace oxen::quic
{
static_assert(
(0 + ... + std::is_convertible_v<std::remove_cvref_t<Opt>, std::shared_ptr<TLSCreds>>) == 1,
"Endpoint listen/connect require exactly one std::shared_ptr<TLSCreds> argument");
"Endpoint::{listen,connect}(...) require exactly one std::shared_ptr<TLSCreds> argument");
}

class Endpoint : public std::enable_shared_from_this<Endpoint>
Expand Down Expand Up @@ -108,8 +108,10 @@ namespace oxen::quic

for (;;)
{
auto scid = outbound_ctx->config.scid.value_or(quic_cid::random());

// emplace random CID into lookup keyed to unique reference ID
if (auto [it_a, res_a] = conn_lookup.emplace(quic_cid::random(), next_rid); res_a)
if (auto [it_a, res_a] = conn_lookup.emplace(scid, next_rid); res_a)
{
qcid = it_a->first;

Expand Down
1 change: 1 addition & 0 deletions include/oxen/quic/format.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <iostream>

#include "formattable.hpp"
#include "utils.hpp"

namespace oxen::quic
{
Expand Down
10 changes: 1 addition & 9 deletions include/oxen/quic/formattable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,11 @@

#include <string_view>

// GCC before 10 requires a "bool" keyword in concept; this CONCEPT_COMPAT is empty by default, but
// expands to bool if under such a GCC.
#if (!(defined(__clang__)) && defined(__GNUC__) && __GNUC__ < 10)
#define CONCEPT_COMPAT bool
#else
#define CONCEPT_COMPAT
#endif

namespace oxen::quic
{
// Types can opt-in to being fmt-formattable by ensuring they have a ::to_string() method defined
template <typename T>
concept CONCEPT_COMPAT ToStringFormattable = requires(T a) {
concept ToStringFormattable = requires(T a) {
{
a.to_string()
} -> std::convertible_to<std::string_view>;
Expand Down
51 changes: 50 additions & 1 deletion include/oxen/quic/opt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <stdexcept>

#include "address.hpp"
#include "connection_ids.hpp"
#include "crypto.hpp"
#include "types.hpp"

namespace oxen::quic
{
Expand Down Expand Up @@ -171,5 +171,54 @@ namespace oxen::quic

explicit operator bool() const { return send_hook != nullptr; }
};

// Used to allow the client to set its initial SCID. During ConnectionID association, the client's initial SCID will
// be the server's initial DCID. This can allow the client to coordinate information to the server prior to any
// stream data being sent down. For example, if a client is attempting to tunnel a connection to a remote port
// through streams, the server will be able to access this port value on it's created inbound connection in its
// connection_established_cb
struct outbound_scid
{
private:
size_t len;
std::array<uint8_t, NGTCP2_MAX_CIDLEN> scid;

template <oxenc::basic_char T>
constexpr outbound_scid(const T* data, size_t l) : len{l}
{
if (len > NGTCP2_MAX_CIDLEN)
throw std::runtime_error{"Max SCID length is 20B"};

std::memcpy(scid.data(), data, len);
}

public:
outbound_scid() = default;

size_t size() { return scid.size(); }
const uint8_t* data() { return scid.data(); }

template <oxenc::string_view_compatible T>
constexpr outbound_scid(T view) : outbound_scid{reinterpret_cast<const char*>(view.data()), view.size()}
{}

template <std::integral T>
constexpr outbound_scid(T val)
{
std::array<char, NGTCP2_MAX_CIDLEN> buf;

if (auto [res, ec] = std::to_chars(buf.data(), buf.data() + buf.size(), val); ec == std::errc())
{
len = res - buf.data();
std::memcpy(scid.data(), buf.data(), len);
}
else
throw std::runtime_error{"outbound_scid input must be string_view-like or int-like!"};
}

// This converting operator passes NGTCP2_MAX_CIDLEN as the length because ngtcp2 expects it
explicit operator oxen::quic::quic_cid() const { return oxen::quic::quic_cid{scid.data(), NGTCP2_MAX_CIDLEN}; }
};

} // namespace opt
} // namespace oxen::quic
7 changes: 7 additions & 0 deletions src/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,13 @@ namespace oxen::quic
return tls_session.get();
}

ustring Connection::initial_client_scid() const
{
auto ret = _is_outbound ? _source_cid.to_string() : _dest_cid.to_string();
auto reusv = to_usv(ret);
return ustring{reusv.data(), reusv.find_first_of('\0')};
}

void Connection::halt_events()
{
log::trace(log_cat, "{} called", __PRETTY_FUNCTION__);
Expand Down
7 changes: 6 additions & 1 deletion src/connection_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ namespace oxen::quic

std::string quic_cid::to_string() const
{
return oxenc::to_hex(data, data + datalen);
if (oxenc::is_hex(data, data + datalen))
{
return oxenc::to_hex(data, data + datalen);
}
else
return {data, data + datalen};
}

quic_cid quic_cid::random()
Expand Down
8 changes: 8 additions & 0 deletions src/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ namespace oxen::quic
log::trace(log_cat, "User passed connection handshake_timeout config value: {}", config.handshake_timeout->count());
}

void IOContext::handle_ioctx_opt(opt::outbound_scid scid)
{
if (dir == Direction::INBOUND)
throw std::runtime_error{"Inbound connection contexts cannot store an outbound scid!"};

config.scid.emplace(scid);
}

void IOContext::handle_ioctx_opt(stream_data_callback func)
{
log::trace(log_cat, "IO context stored stream close callback");
Expand Down
73 changes: 73 additions & 0 deletions tests/001-handshake.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,4 +795,77 @@ namespace oxen::quic::test
REQUIRE(stream_callback_called);
CHECK(*stream_callback_called);
}

TEST_CASE("001 - Custom Client Initial Scid", "[001][scid]")
{
Network test_net{};

auto [client_tls, server_tls] = defaults::tls_creds_from_ed_keys();

Address server_local{};
Address client_local1{};
Address client_local2{};

auto client_established1 = callback_waiter{[](connection_interface&) {}};
auto client_established2 = callback_waiter{[](connection_interface&) {}};

// Length must be less than NGTCP2_MAX_CIDLEN (20B)
CHECK_THROWS(opt::outbound_scid{"hello from your neighbor"_usv});

auto hello_sv = "goodmorning"_usv;
// The converting constructor to quic_cid will pad it out to NGTCP2_MAX_CIDLEN
opt::outbound_scid scid_str{hello_sv};

uint16_t int_like{5685};

opt::outbound_scid scid_int{int_like};

auto server_endpoint = test_net.endpoint(server_local);

// Server cannot accept opt::outbound_scid in call to ::listen(...)
CHECK_THROWS(server_endpoint->listen(server_tls, scid_str));

CHECK_NOTHROW(server_endpoint->listen(server_tls));

RemoteAddress client_remote{defaults::SERVER_PUBKEY, "127.0.0.1"s, server_endpoint->local().port()};

auto client_endpoint1 = test_net.endpoint(client_local1, client_established1);
auto client_endpoint2 = test_net.endpoint(client_local2, client_established2);

auto client_ci1 = client_endpoint1->connect(client_remote, client_tls, scid_str);
CHECK(client_established1.wait());

auto client_ci2 = client_endpoint2->connect(client_remote, client_tls, scid_int);
CHECK(client_established2.wait());

auto server_conns = server_endpoint->get_all_conns(Direction::INBOUND);
auto server_ci1 = server_conns.front();
auto server_ci2 = server_conns.back();

auto server_ci = server_endpoint->get_all_conns(Direction::INBOUND).front();

auto client_str_scid1 = client_ci1->initial_client_scid();
auto server_str_scid1 = server_ci1->initial_client_scid();

REQUIRE(client_str_scid1 == hello_sv);
REQUIRE(server_str_scid1 == hello_sv);

auto client_str_scid2 = client_ci2->initial_client_scid();
auto server_str_scid2 = server_ci2->initial_client_scid();

uint16_t client_parsed, server_parsed;
REQUIRE(std::from_chars(
reinterpret_cast<const char*>(client_str_scid2.data()),
reinterpret_cast<const char*>(client_str_scid2.data()) + client_str_scid2.size(),
client_parsed)
.ec == std::errc());
REQUIRE(std::from_chars(
reinterpret_cast<const char*>(server_str_scid2.data()),
reinterpret_cast<const char*>(server_str_scid2.data()) + server_str_scid2.size(),
server_parsed)
.ec == std::errc());

REQUIRE(client_parsed == int_like);
REQUIRE(server_parsed == int_like);
}
} // namespace oxen::quic::test

0 comments on commit dd1f605

Please sign in to comment.