Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom client SCID's #124

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/oxen/quic/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace oxen::quic
virtual std::shared_ptr<Stream> open_stream_impl(
std::function<std::shared_ptr<Stream>(Connection& c, Endpoint& e)> make_stream) = 0;
virtual std::shared_ptr<Stream> get_stream_impl(int64_t id) = 0;
virtual ustring initial_client_scid_impl() const = 0;

public:
virtual ustring_view selected_alpn() const = 0;
Expand Down Expand Up @@ -158,6 +159,7 @@ namespace oxen::quic
virtual bool is_inbound() const = 0;
virtual bool is_outbound() const = 0;
virtual std::string direction_str() = 0;
ustring initial_client_scid();

// Non-virtual base class wrappers for the virtual methods of the same name with _impl
// appended (e.g. path_impl); these versions in the base class wrap the _impl call in a
Expand Down Expand Up @@ -365,6 +367,10 @@ namespace oxen::quic
// alive after it gets dropped from libquic internal structures).
void drop_streams();

// Returns the initial scid chosen by the client, which will be the value of whatever was passed in
// opt::outbound_scid on the client's initiation of the connection
ustring initial_client_scid_impl() const override;

private:
// private Constructor (publicly construct via `make_conn` instead, so that we can properly
// set up the shared_from_this shenanigans).
Expand Down
1 change: 0 additions & 1 deletion include/oxen/quic/connection_ids.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ extern "C"

#include "formattable.hpp"
#include "types.hpp"
#include "utils.hpp"

namespace oxen::quic
{
Expand Down
3 changes: 3 additions & 0 deletions include/oxen/quic/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace oxen::quic
bool split_packet{false};
// splitting policy
Splitting policy{Splitting::NONE};
// outbound scid
std::optional<quic_cid> scid{std::nullopt};

user_config() = default;
};
Expand Down Expand Up @@ -65,6 +67,7 @@ namespace oxen::quic
void handle_ioctx_opt(opt::keep_alive ka);
void handle_ioctx_opt(opt::idle_timeout ito);
void handle_ioctx_opt(opt::handshake_timeout hto);
void handle_ioctx_opt(opt::outbound_scid scid);
void handle_ioctx_opt(stream_data_callback func);
void handle_ioctx_opt(stream_open_callback func);
void handle_ioctx_opt(stream_close_callback func);
Expand Down
6 changes: 4 additions & 2 deletions include/oxen/quic/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace oxen::quic
{
static_assert(
(0 + ... + std::is_convertible_v<std::remove_cvref_t<Opt>, std::shared_ptr<TLSCreds>>) == 1,
"Endpoint listen/connect require exactly one std::shared_ptr<TLSCreds> argument");
"Endpoint::{listen,connect}(...) require exactly one std::shared_ptr<TLSCreds> argument");
}

class Endpoint : public std::enable_shared_from_this<Endpoint>
Expand Down Expand Up @@ -108,8 +108,10 @@ namespace oxen::quic

for (;;)
{
auto scid = outbound_ctx->config.scid.value_or(quic_cid::random());

// emplace random CID into lookup keyed to unique reference ID
if (auto [it_a, res_a] = conn_lookup.emplace(quic_cid::random(), next_rid); res_a)
if (auto [it_a, res_a] = conn_lookup.emplace(scid, next_rid); res_a)
{
qcid = it_a->first;

Expand Down
1 change: 1 addition & 0 deletions include/oxen/quic/format.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <iostream>

#include "formattable.hpp"
#include "utils.hpp"

namespace oxen::quic
{
Expand Down
10 changes: 1 addition & 9 deletions include/oxen/quic/formattable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,11 @@

#include <string_view>

// GCC before 10 requires a "bool" keyword in concept; this CONCEPT_COMPAT is empty by default, but
// expands to bool if under such a GCC.
#if (!(defined(__clang__)) && defined(__GNUC__) && __GNUC__ < 10)
#define CONCEPT_COMPAT bool
#else
#define CONCEPT_COMPAT
#endif

namespace oxen::quic
{
// Types can opt-in to being fmt-formattable by ensuring they have a ::to_string() method defined
template <typename T>
concept CONCEPT_COMPAT ToStringFormattable = requires(T a) {
concept ToStringFormattable = requires(T a) {
{
a.to_string()
} -> std::convertible_to<std::string_view>;
Expand Down
51 changes: 50 additions & 1 deletion include/oxen/quic/opt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <stdexcept>

#include "address.hpp"
#include "connection_ids.hpp"
#include "crypto.hpp"
#include "types.hpp"

namespace oxen::quic
{
Expand Down Expand Up @@ -171,5 +171,54 @@ namespace oxen::quic

explicit operator bool() const { return send_hook != nullptr; }
};

// Used to allow the client to set its initial SCID. During ConnectionID association, the client's initial SCID will
// be the server's initial DCID. This can allow the client to coordinate information to the server prior to any
// stream data being sent down. For example, if a client is attempting to tunnel a connection to a remote port
// through streams, the server will be able to access this port value on it's created inbound connection in its
// connection_established_cb
struct outbound_scid
{
private:
size_t len;
std::array<uint8_t, NGTCP2_MAX_CIDLEN> scid;

template <oxenc::basic_char T>
constexpr outbound_scid(const T* data, size_t l) : len{l}
{
if (len > NGTCP2_MAX_CIDLEN)
throw std::runtime_error{"Max SCID length is 20B"};

std::memcpy(scid.data(), data, len);
}

public:
outbound_scid() = default;

size_t size() { return scid.size(); }
const uint8_t* data() { return scid.data(); }

template <oxenc::string_view_compatible T>
constexpr outbound_scid(T view) : outbound_scid{reinterpret_cast<const char*>(view.data()), view.size()}
{}

template <std::integral T>
constexpr outbound_scid(T val)
{
std::array<char, NGTCP2_MAX_CIDLEN> buf;

if (auto [res, ec] = std::to_chars(buf.data(), buf.data() + buf.size(), val); ec == std::errc())
{
len = res - buf.data();
std::memcpy(scid.data(), buf.data(), len);
}
else
throw std::runtime_error{"outbound_scid input must be string_view-like or int-like!"};
}

// This converting operator passes NGTCP2_MAX_CIDLEN as the length because ngtcp2 expects it
explicit operator oxen::quic::quic_cid() const { return oxen::quic::quic_cid{scid.data(), NGTCP2_MAX_CIDLEN}; }
};

} // namespace opt
} // namespace oxen::quic
12 changes: 12 additions & 0 deletions src/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,13 @@ namespace oxen::quic
return tls_session.get();
}

ustring Connection::initial_client_scid_impl() const
{
auto ret = _is_outbound ? _source_cid.to_string() : _dest_cid.to_string();
auto reusv = to_usv(ret);
return ustring{reusv.data(), reusv.find_first_of('\0')};
}

void Connection::halt_events()
{
log::trace(log_cat, "{} called", __PRETTY_FUNCTION__);
Expand Down Expand Up @@ -1758,6 +1765,11 @@ namespace oxen::quic
s->check_timeouts();
}

ustring connection_interface::initial_client_scid()
{
return endpoint().call_get([this]() { return initial_client_scid_impl(); });
}

size_t connection_interface::num_streams_active()
{
return endpoint().call_get([this] { return num_streams_active_impl(); });
Expand Down
7 changes: 6 additions & 1 deletion src/connection_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ namespace oxen::quic

std::string quic_cid::to_string() const
{
return oxenc::to_hex(data, data + datalen);
if (oxenc::is_hex(data, data + datalen))
{
return oxenc::to_hex(data, data + datalen);
}
else
return {data, data + datalen};
}

quic_cid quic_cid::random()
Expand Down
8 changes: 8 additions & 0 deletions src/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ namespace oxen::quic
log::trace(log_cat, "User passed connection handshake_timeout config value: {}", config.handshake_timeout->count());
}

void IOContext::handle_ioctx_opt(opt::outbound_scid scid)
{
if (dir == Direction::INBOUND)
throw std::runtime_error{"Inbound connection contexts cannot store an outbound scid!"};

config.scid.emplace(scid);
}

void IOContext::handle_ioctx_opt(stream_data_callback func)
{
log::trace(log_cat, "IO context stored stream close callback");
Expand Down
73 changes: 73 additions & 0 deletions tests/001-handshake.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,4 +795,77 @@ namespace oxen::quic::test
REQUIRE(stream_callback_called);
CHECK(*stream_callback_called);
}

TEST_CASE("001 - Custom Client Initial Scid", "[001][scid]")
{
Network test_net{};

auto [client_tls, server_tls] = defaults::tls_creds_from_ed_keys();

Address server_local{};
Address client_local1{};
Address client_local2{};

auto client_established1 = callback_waiter{[](connection_interface&) {}};
auto client_established2 = callback_waiter{[](connection_interface&) {}};

// Length must be less than NGTCP2_MAX_CIDLEN (20B)
CHECK_THROWS(opt::outbound_scid{"hello from your neighbor"_usv});

auto hello_sv = "goodmorning"_usv;
// The converting constructor to quic_cid will pad it out to NGTCP2_MAX_CIDLEN
opt::outbound_scid scid_str{hello_sv};

uint16_t int_like{5685};

opt::outbound_scid scid_int{int_like};

auto server_endpoint = test_net.endpoint(server_local);

// Server cannot accept opt::outbound_scid in call to ::listen(...)
CHECK_THROWS(server_endpoint->listen(server_tls, scid_str));

CHECK_NOTHROW(server_endpoint->listen(server_tls));

RemoteAddress client_remote{defaults::SERVER_PUBKEY, "127.0.0.1"s, server_endpoint->local().port()};

auto client_endpoint1 = test_net.endpoint(client_local1, client_established1);
auto client_endpoint2 = test_net.endpoint(client_local2, client_established2);

auto client_ci1 = client_endpoint1->connect(client_remote, client_tls, scid_str);
CHECK(client_established1.wait());

auto client_ci2 = client_endpoint2->connect(client_remote, client_tls, scid_int);
CHECK(client_established2.wait());

auto server_conns = server_endpoint->get_all_conns(Direction::INBOUND);
auto server_ci1 = server_conns.front();
auto server_ci2 = server_conns.back();

auto server_ci = server_endpoint->get_all_conns(Direction::INBOUND).front();

auto client_str_scid1 = client_ci1->initial_client_scid();
auto server_str_scid1 = server_ci1->initial_client_scid();

REQUIRE(client_str_scid1 == hello_sv);
REQUIRE(server_str_scid1 == hello_sv);

auto client_str_scid2 = client_ci2->initial_client_scid();
auto server_str_scid2 = server_ci2->initial_client_scid();

uint16_t client_parsed, server_parsed;
REQUIRE(std::from_chars(
reinterpret_cast<const char*>(client_str_scid2.data()),
reinterpret_cast<const char*>(client_str_scid2.data()) + client_str_scid2.size(),
client_parsed)
.ec == std::errc());
REQUIRE(std::from_chars(
reinterpret_cast<const char*>(server_str_scid2.data()),
reinterpret_cast<const char*>(server_str_scid2.data()) + server_str_scid2.size(),
server_parsed)
.ec == std::errc());

REQUIRE(client_parsed == int_like);
REQUIRE(server_parsed == int_like);
}
} // namespace oxen::quic::test