diff --git a/proxy/include/proxy/async_connect.hpp b/proxy/include/proxy/async_connect.hpp index 66ff2ef63..62bf74a7e 100644 --- a/proxy/include/proxy/async_connect.hpp +++ b/proxy/include/proxy/async_connect.hpp @@ -11,23 +11,21 @@ #ifndef INCLUDE__2023_10_18__ASYNC_CONNECT_HPP #define INCLUDE__2023_10_18__ASYNC_CONNECT_HPP - -#include #include -#include +#include #include -#include #include #include +#include +#include - -#include +#include #include +#include #include -#include -#include #include +#include #include #include @@ -35,453 +33,381 @@ #include - namespace asio_util { - namespace net = boost::asio; - - namespace detail { - template - struct connect_context - { - connect_context(Handler&& h) - : handler_(std::move(h)) - {} - - std::atomic_int flag_; - std::atomic_int num_; - Handler handler_; - std::vector> socket_; - }; - - template - void do_result(Handler&& handle, - const boost::system::error_code& error, ResultType&& result) - { - handle(error, result); - } - - template - void callback(Handler&& handler, Executor ex, - Iterator& begin, const boost::system::error_code& error) - { - net::post(ex, - [error, h = std::move(handler), begin]() mutable - { - if constexpr (std::same_as) - do_result(h, error, *begin); - - if constexpr (!std::same_as) - do_result(h, error, begin); - }); - } - - struct default_connect_condition - { - template - bool operator()(const boost::system::error_code&, - Stream&, const Endpoint&) - { - return true; - } - }; - - struct initiate_do_connect - { - bool use_happy_eyeball = false; - int reject = 0; - - template - bool check_condition(const boost::system::error_code& ec, - Stream& stream, Endpoint& endp, - ConnectCondition connect_condition) - { - bool ret = connect_condition(ec, stream, endp); - - if (!ret) - reject++; - - return ret; - } - - template - void cancellation_slot(boost::local_shared_ptr< - connect_context>& context) - { - auto cstate = net::get_associated_cancellation_slot( - context->handler_); - - if (!cstate.is_connected()) - return; - - boost::weak_ptr< - connect_context - > weak_ptr = context; - - cstate.assign([weak_ptr](net::cancellation_type_t) mutable - { - auto context = weak_ptr.lock(); - if (!context) - return; - - auto& sockets = context->socket_; - for (auto& stream : sockets) - { - if (!stream) - continue; - - boost::system::error_code ignore_ec; - stream->cancel(ignore_ec); - } - }); - } - - template - bool check_connect_iterator(boost::local_shared_ptr< - connect_context> &context, - Executor ex, Iterator begin, Iterator end) - { - context->flag_ = false; - context->num_ = std::distance(begin, end); - - if (context->num_ == 0) - { - boost::system::error_code error = net::error::not_found; - - callback( - std::move(context->handler_), ex, - begin, error); - - return false; - } - - return true; - } - - template - void happy_eyeballs_detection(Iterator begin, Iterator end) - { - bool has_a = false, has_aaaa = false; - - for (; begin != end && !(has_a && has_aaaa); begin++) - { - const auto& addr = begin->endpoint().address(); - - if (!has_aaaa) - has_aaaa = addr.is_v6(); - - if (!has_a) - has_a = addr.is_v4(); - } - - if (has_aaaa && has_a) - use_happy_eyeball = true; - } - - template - void do_connect(Iterator iter, Stream& stream, - boost::local_shared_ptr< - connect_context> &context, - Executor ex, - boost::local_shared_ptr sock, - ConnectCondition connect_condition) - { - if (!check_condition({}, *sock, *iter, connect_condition)) - { - if (reject == context->num_) - { - boost::system::error_code error = net::error::not_found; - - callback( - std::forward(context->handler_), - ex, - iter, - error); - } - - return; - } - - sock->async_connect(*iter, - [&stream, context, ex, iter, sock] - (const boost::system::error_code& error) mutable - { - if (!error) - { - if (context->flag_.exchange(true)) - return; - - stream = std::move(*sock); - } - - context->num_--; - bool is_last = context->num_ == 0; - - if (error) - { - if (context->flag_ || !is_last) - return; - } - - - auto& sockets = context->socket_; - for (auto& s : sockets) - { - if (!s) - continue; - - boost::system::error_code ignore_ec; - s->cancel(ignore_ec); - } - - callback( - std::forward(context->handler_), - ex, - iter, - error); - }); - } - - template - void do_async_connect(Handler handler, Stream& stream, - Iterator begin, Iterator end, - ConnectCondition connect_condition) - { - auto context = boost::make_local_shared< - connect_context>(std::move(handler)); - - // Process handler cancellation slot - cancellation_slot(context); - - // Get executor from handler or stream. - auto executor = net::get_associated_executor( - context->handler_, stream.get_executor()); - - // Check connect iterator valid - if (!check_connect_iterator< - Stream, Handler, decltype(executor), Iterator, ResultType>( - context, executor, begin, end)) - return; - - // happy eyeballs detection - happy_eyeballs_detection(begin, end); - - using connector_type = std::tuple, bool>; - std::vector connectors; - - for (; begin != end; begin++) - { - auto sock = boost::make_local_shared< - Stream>(stream.get_executor()); - - context->socket_.emplace_back(sock); - - auto conn_func = [this, - iter = begin, - &stream, - context, - executor, - sock, - connect_condition]() mutable - { - do_connect( - iter, stream, context, executor, - sock, connect_condition); - }; - - auto v4 = begin->endpoint().address().is_v4(); - - connectors.emplace_back(connector_type{ conn_func, v4 }); - } - - for (auto& [conn_func, v4] : connectors) - { - if (use_happy_eyeball && v4) - { - using namespace std::chrono_literals; - using net::steady_timer; - - // ipv4 delay 200ms. - auto timer = boost::make_local_shared< - steady_timer>(stream.get_executor()); - - const auto delay = 200ms; - - timer->expires_from_now(delay); - timer->async_wait([timer, - conn_func = std::move(conn_func), - context] - ([[maybe_unused]] auto error) - { - if (context->flag_) - return; - conn_func(); - }); - } - else - { - conn_func(); - } - } - } - - template - void operator()(Handler&& handler, Stream* s, - Iterator begin, Iterator end, - ConnectCondition connect_condition) - { - do_async_connect(std::move(handler), *s, - begin, end, connect_condition); - } - - template - void operator()(Handler&& handler, Stream* s, - const EndpointSequence& endpoints, - ConnectCondition connect_condition) - { - auto begin = endpoints.begin(); - auto end = endpoints.end(); - using Iterator = decltype(begin); - - do_async_connect( - std::move(handler), *s, - begin, end, connect_condition); - } - }; - } - - template - inline auto async_connect(net::basic_stream_socket& s, - Iterator begin, - ConnectHandler handler = net::default_completion_token_t(), - typename net::enable_if< - !net::is_endpoint_sequence::value>::type* = 0) - -> decltype(net::async_initiate - (detail::initiate_do_connect{}, handler, &s, - begin, Iterator(), - detail::default_connect_condition{})) - { - return net::async_initiate - (detail::initiate_do_connect{}, handler, &s, - begin, Iterator(), - detail::default_connect_condition{}); - } - - template > - auto async_connect( - net::basic_stream_socket& s, Iterator begin, - Iterator end, - ConnectHandler&& handler = net::default_completion_token_t()) - -> decltype(net::async_initiate( - detail::initiate_do_connect{}, handler, &s, begin, end, - detail::default_connect_condition{})) { - return net::async_initiate( - detail::initiate_do_connect{}, handler, &s, begin, end, - detail::default_connect_condition{}); - } - - template > - auto async_connect( - net::basic_stream_socket& s, - const EndpointSequence& endpoints, - ConnectHandler&& handler = net::default_completion_token_t(), - typename net::enable_if< - net::is_endpoint_sequence::value>::type* = 0) { - return net::async_initiate::endpoint_type)>( - detail::initiate_do_connect{}, handler, &s, endpoints, - detail::default_connect_condition{}); - } - - template > - auto async_connect( - net::basic_stream_socket& s, Iterator begin, - ConnectCondition connect_condition, - ConnectHandler&& handler = net::default_completion_token_t(), - typename net::enable_if< - !net::is_endpoint_sequence::value>::type* = 0) - -> decltype(net::async_initiate( - detail::initiate_do_connect{}, handler, &s, begin, Iterator(), - connect_condition)) { - return net::async_initiate( - detail::initiate_do_connect{}, handler, &s, begin, Iterator(), - connect_condition); - } - - template > - auto async_connect( - net::basic_stream_socket& s, Iterator begin, - Iterator end, ConnectCondition connect_condition, - ConnectHandler&& handler = net::default_completion_token_t()) - -> decltype(net::async_initiate( - detail::initiate_do_connect{}, handler, &s, begin, end, - connect_condition)) { - return net::async_initiate( - detail::initiate_do_connect{}, handler, &s, begin, end, - connect_condition); - } - - template > - auto async_connect( - net::basic_stream_socket& s, - const EndpointSequence& endpoints, ConnectCondition connect_condition, - ConnectHandler&& handler = net::default_completion_token_t(), - typename net::enable_if< - net::is_endpoint_sequence::value>::type* = 0) - -> decltype(net::async_initiate< - ConnectHandler, void(boost::system::error_code, - typename net::basic_stream_socket< - Protocol, Executor>::endpoint_type)>( - detail::initiate_do_connect{}, handler, & s, endpoints, - connect_condition)) { - return net::async_initiate::endpoint_type)>( - detail::initiate_do_connect{}, handler, &s, endpoints, - connect_condition); - } +namespace net = boost::asio; + +namespace detail { +template struct connect_context { + connect_context(Handler &&h) : handler_(std::move(h)) {} + + std::atomic_int flag_; + std::atomic_int num_; + Handler handler_; + std::vector> socket_; +}; + +template +void do_result(Handler &&handle, const boost::system::error_code &error, + ResultType &&result) { + handle(error, result); +} + +template +void callback(Handler &&handler, Executor ex, Iterator &begin, + const boost::system::error_code &error) { + net::post(ex, [error, h = std::move(handler), begin]() mutable { + if constexpr (std::same_as) + do_result(h, error, *begin); + + if constexpr (!std::same_as) + do_result(h, error, begin); + }); +} + +struct default_connect_condition { + template + bool operator()(const boost::system::error_code &, Stream &, + const Endpoint &) { + return true; + } +}; + +struct initiate_do_connect { + bool use_happy_eyeball = false; + int reject = 0; + + template + bool check_condition(const boost::system::error_code &ec, Stream &stream, + Endpoint &endp, ConnectCondition connect_condition) { + bool ret = connect_condition(ec, stream, endp); + + if (!ret) + reject++; + + return ret; + } + + template + void cancellation_slot( + boost::local_shared_ptr> &context) { + auto cstate = net::get_associated_cancellation_slot(context->handler_); + + if (!cstate.is_connected()) + return; + + boost::weak_ptr> weak_ptr = context; + + cstate.assign([weak_ptr](net::cancellation_type_t) mutable { + auto context = weak_ptr.lock(); + if (!context) + return; + + auto &sockets = context->socket_; + for (auto &stream : sockets) { + if (!stream) + continue; + + boost::system::error_code ignore_ec; + stream->cancel(ignore_ec); + } + }); + } + + template + bool check_connect_iterator( + boost::local_shared_ptr> &context, + Executor ex, Iterator begin, Iterator end) { + context->flag_ = false; + context->num_ = std::distance(begin, end); + + if (context->num_ == 0) { + boost::system::error_code error = net::error::not_found; + + callback( + std::move(context->handler_), ex, begin, error); + + return false; + } + + return true; + } + + template + void happy_eyeballs_detection(Iterator begin, Iterator end) { + bool has_a = false, has_aaaa = false; + + for (; begin != end && !(has_a && has_aaaa); begin++) { + const auto &addr = begin->endpoint().address(); + + if (!has_aaaa) + has_aaaa = addr.is_v6(); + + if (!has_a) + has_a = addr.is_v4(); + } + + if (has_aaaa && has_a) + use_happy_eyeball = true; + } + + template + void + do_connect(Iterator iter, Stream &stream, + boost::local_shared_ptr> &context, + Executor ex, boost::local_shared_ptr sock, + ConnectCondition connect_condition) { + if (!check_condition({}, *sock, *iter, connect_condition)) { + if (reject == context->num_) { + boost::system::error_code error = net::error::not_found; + + callback( + std::forward(context->handler_), ex, iter, error); + } + + return; + } + + sock->async_connect( + *iter, [&stream, context, ex, iter, + sock](const boost::system::error_code &error) mutable { + if (!error) { + if (context->flag_.exchange(true)) + return; + + stream = std::move(*sock); + } + + context->num_--; + bool is_last = context->num_ == 0; + + if (error) { + if (context->flag_ || !is_last) + return; + } + + auto &sockets = context->socket_; + for (auto &s : sockets) { + if (!s) + continue; + + boost::system::error_code ignore_ec; + s->cancel(ignore_ec); + } + + callback( + std::forward(context->handler_), ex, iter, error); + }); + } + + template + void do_async_connect(Handler handler, Stream &stream, Iterator begin, + Iterator end, ConnectCondition connect_condition) { + auto context = boost::make_local_shared>( + std::move(handler)); + + // Process handler cancellation slot + cancellation_slot(context); + + // Get executor from handler or stream. + auto executor = + net::get_associated_executor(context->handler_, stream.get_executor()); + + // Check connect iterator valid + if (!check_connect_iterator(context, executor, begin, end)) + return; + + // happy eyeballs detection + happy_eyeballs_detection(begin, end); + + using connector_type = std::tuple, bool>; + std::vector connectors; + + for (; begin != end; begin++) { + auto sock = boost::make_local_shared(stream.get_executor()); + + context->socket_.emplace_back(sock); + + auto conn_func = [this, iter = begin, &stream, context, executor, sock, + connect_condition]() mutable { + do_connect( + iter, stream, context, executor, sock, connect_condition); + }; + + auto v4 = begin->endpoint().address().is_v4(); + + connectors.emplace_back(connector_type{conn_func, v4}); + } + + for (auto &[conn_func, v4] : connectors) { + if (use_happy_eyeball && v4) { + using namespace std::chrono_literals; + using net::steady_timer; + + // ipv4 delay 200ms. + auto timer = + boost::make_local_shared(stream.get_executor()); + + const auto delay = 200ms; + + timer->expires_from_now(delay); + timer->async_wait([timer, conn_func = std::move(conn_func), + context]([[maybe_unused]] auto error) { + if (context->flag_) + return; + conn_func(); + }); + } else { + conn_func(); + } + } + } + + template + void operator()(Handler &&handler, Stream *s, Iterator begin, Iterator end, + ConnectCondition connect_condition) { + do_async_connect(std::move(handler), *s, begin, end, connect_condition); + } + + template + void operator()(Handler &&handler, Stream *s, + const EndpointSequence &endpoints, + ConnectCondition connect_condition) { + auto begin = endpoints.begin(); + auto end = endpoints.end(); + using Iterator = decltype(begin); + + do_async_connect( + std::move(handler), *s, begin, end, connect_condition); + } +}; +} // namespace detail + +template +inline auto async_connect( + net::basic_stream_socket &s, Iterator begin, + ConnectHandler handler = net::default_completion_token_t(), + typename net::enable_if::value>::type + * = 0) + -> decltype(net::async_initiate( + detail::initiate_do_connect{}, handler, &s, begin, Iterator(), + detail::default_connect_condition{})) { + return net::async_initiate( + detail::initiate_do_connect{}, handler, &s, begin, Iterator(), + detail::default_connect_condition{}); +} + +template > +auto async_connect( + net::basic_stream_socket &s, Iterator begin, + Iterator end, + ConnectHandler &&handler = net::default_completion_token_t()) + -> decltype(net::async_initiate( + detail::initiate_do_connect{}, handler, &s, begin, end, + detail::default_connect_condition{})) { + return net::async_initiate( + detail::initiate_do_connect{}, handler, &s, begin, end, + detail::default_connect_condition{}); +} + +template > +auto async_connect( + net::basic_stream_socket &s, + const EndpointSequence &endpoints, + ConnectHandler &&handler = net::default_completion_token_t(), + typename net::enable_if< + net::is_endpoint_sequence::value>::type * = 0) { + return net::async_initiate::endpoint_type)>( + detail::initiate_do_connect{}, handler, &s, endpoints, + detail::default_connect_condition{}); +} + +template > +auto async_connect( + net::basic_stream_socket &s, Iterator begin, + ConnectCondition connect_condition, + ConnectHandler &&handler = net::default_completion_token_t(), + typename net::enable_if::value>::type + * = 0) + -> decltype(net::async_initiate( + detail::initiate_do_connect{}, handler, &s, begin, Iterator(), + connect_condition)) { + return net::async_initiate( + detail::initiate_do_connect{}, handler, &s, begin, Iterator(), + connect_condition); +} + +template > +auto async_connect( + net::basic_stream_socket &s, Iterator begin, + Iterator end, ConnectCondition connect_condition, + ConnectHandler &&handler = net::default_completion_token_t()) + -> decltype(net::async_initiate( + detail::initiate_do_connect{}, handler, &s, begin, end, + connect_condition)) { + return net::async_initiate( + detail::initiate_do_connect{}, handler, &s, begin, end, + connect_condition); +} + +template > +auto async_connect( + net::basic_stream_socket &s, + const EndpointSequence &endpoints, ConnectCondition connect_condition, + ConnectHandler &&handler = net::default_completion_token_t(), + typename net::enable_if< + net::is_endpoint_sequence::value>::type * = 0) + -> decltype(net::async_initiate< + ConnectHandler, void(boost::system::error_code, + typename net::basic_stream_socket< + Protocol, Executor>::endpoint_type)>( + detail::initiate_do_connect{}, handler, &s, endpoints, + connect_condition)) { + return net::async_initiate::endpoint_type)>( + detail::initiate_do_connect{}, handler, &s, endpoints, connect_condition); } +} // namespace asio_util #endif // INCLUDE__2023_10_18__ASYNC_CONNECT_HPP