From 54d3e0ee774c7c8bf7ead011e8096985d36bd856 Mon Sep 17 00:00:00 2001 From: Liu Dongmiao Date: Mon, 8 Apr 2024 00:36:54 +0800 Subject: [PATCH] both: optimize for udp over wss --- common.c | 161 +++++++++++++++++++----------------------- common.h | 16 ++--- testcase/check-udp.py | 36 ++++++++-- wss-proxy-client.c | 18 ++--- wss-proxy-server.c | 5 +- 5 files changed, 118 insertions(+), 118 deletions(-) diff --git a/common.c b/common.c index 014a3a3..1477c8d 100644 --- a/common.c +++ b/common.c @@ -369,6 +369,7 @@ static enum bufferevent_filter_result wss_input_filter(struct evbuffer *src, str ssize_t header_size; int result; struct ws_header_info info; + struct udp_frame udp_frame; (void) dst_limit; (void) mode; @@ -465,37 +466,21 @@ static enum bufferevent_filter_result wss_input_filter(struct evbuffer *src, str reply_close(src, info.payload_size, info.mask_key, raw); return BEV_ERROR; } -#ifdef WSS_PROXY_SERVER + if (evbuffer_remove(src, udp_frame.buffer, info.payload_size) != info.payload_size) { + LOGW("cannot read more data"); + send_close(raw, CLOSE_INTERNAL_ERROR); + return BEV_ERROR; + } if (info.mask_key) { - char buffer[MAX_WSS_PAYLOAD_SIZE]; - uint16_t payload_size = info.payload_size; - while (payload_size > 0) { - int size = evbuffer_remove(src, buffer, MIN(payload_size, MAX_WSS_PAYLOAD_SIZE)); - if (size <= 0) { - LOGW("cannot read more data"); - send_close(raw, CLOSE_INTERNAL_ERROR); - return BEV_ERROR; - } - unmask(buffer, (uint16_t) size, info.mask_key); - evbuffer_add(dst, buffer, (uint16_t) size); - payload_size -= (uint16_t) size; - } - return BEV_OK; + unmask(udp_frame.buffer, info.payload_size, info.mask_key); } -#endif - { - uint16_t payload_size = info.payload_size; - while (payload_size > 0) { - int size = evbuffer_remove_buffer(src, dst, payload_size); - if (size <= 0) { - LOGW("cannot read more data"); - send_close(raw, CLOSE_INTERNAL_ERROR); - return BEV_ERROR; - } - payload_size -= (uint16_t) size; - } - return BEV_OK; + if (((struct bufferevent *) raw)->be_ops) { + evbuffer_add(dst, udp_frame.buffer, info.payload_size); + } else { + udp_frame.length = ntohs(info.payload_size); + evbuffer_add(dst, &udp_frame, info.payload_size + 2); } + return BEV_OK; } static void close_wss_data_cb(struct bufferevent *tev, void *wss) { @@ -567,24 +552,37 @@ static void raw_forward_cb(struct bufferevent *raw, void *wss) { struct evbuffer *src; struct evbuffer *dst; struct bufferevent *tev; - int payload_size; + size_t total_size; tev = evhttp_connection_get_bufferevent(wss); src = bufferevent_get_input(raw); dst = bufferevent_get_output(tev); - payload_size = raw->be_ops ? WSS_PAYLOAD_SIZE : MAX_WSS_PAYLOAD_SIZE; - for (;;) { + total_size = evbuffer_get_length(src); + while (total_size > 0) { // should we use continuation fame? uint8_t *wss_header, wss_header_size; struct wss_frame_data { char header[MAX_WS_HEADER_SIZE]; char buffer[MAX_WSS_PAYLOAD_SIZE]; } wss_frame_data; - int size = evbuffer_remove(src, wss_frame_data.buffer, payload_size); + int size, payload_size; + if (raw->be_ops) { + payload_size = WSS_PAYLOAD_SIZE; + } else { + if (evbuffer_remove(src, &payload_size, UDP_FRAME_LENGTH_SIZE) != UDP_FRAME_LENGTH_SIZE) { + LOGE("cannot remove 2 from src for payload_size"); + break; + } + total_size -= UDP_FRAME_LENGTH_SIZE; + payload_size = htons(payload_size); + } + size = evbuffer_remove(src, wss_frame_data.buffer, payload_size); if (size <= 0) { + LOGE("cannot remove %d from src", payload_size); break; } + total_size -= size; wss_header = build_ws_frame(OP_BINARY, &(wss_frame_data.buffer), (uint16_t) size, &wss_header_size); evbuffer_add(dst, wss_header, (uint16_t) size + wss_header_size); } @@ -624,6 +622,9 @@ static void wss_forward_cb(struct bufferevent *wev, void *raw) { struct evbuffer *dst; src = bufferevent_get_input(wev); + if (!evbuffer_get_length(src)) { + return; + } dst = bufferevent_get_output(raw); evbuffer_add_buffer(dst, src); } @@ -680,6 +681,8 @@ void tunnel_wss(struct bufferevent *raw, struct evhttp_connection *wss) { raw->readcb = raw_forward_cb; raw->errorcb = raw_event_cb_wss; raw->cbarg = wss; + evbuffer_add_cb(raw->input, udp_read_cb, raw); + raw->readcb(raw, raw->cbarg); } } @@ -706,6 +709,9 @@ static void raw_forward_cb_ss(struct bufferevent *raw, void *wss) { struct evbuffer *dst; src = bufferevent_get_input(raw); + if (!evbuffer_get_length(src)) { + return; + } dst = bufferevent_get_output(evhttp_connection_get_bufferevent(wss)); evbuffer_add_buffer(dst, src); } @@ -725,50 +731,53 @@ void tunnel_ss(struct bufferevent *raw, struct evhttp_connection *wss) { raw->readcb = raw_forward_cb_ss; raw->errorcb = raw_event_cb; raw->cbarg = wss; + evbuffer_add_cb(raw->input, udp_read_cb, raw); + raw->readcb(raw, raw->cbarg); + } +} + +void udp_read_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void *arg) { + struct bufferevent *raw = arg; + (void) buf; + if (info->n_added > 0) { + raw->readcb(raw, raw->cbarg); } } void udp_send_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void *arg) { - char *payload; - uint8_t offset; unsigned length; uint16_t payload_length; struct udp_frame udp_frame; struct bufferevent *raw = arg; struct bufferevent_udp *bev_udp = arg; size_t size = evbuffer_get_length(buf); - (void) info; + if (info->n_added <= 0) { + return; + } while (size > 0) { - if (raw->errorcb == raw_event_cb) { - offset = 2; - if (size < offset) { - break; - } - if (evbuffer_copyout(buf, &(udp_frame.raw.length), offset) < 0) { - LOGD("cannot read udp 2 for %d, will try later", get_port(bev_udp->sockaddr)); - break; - } - payload_length = htons(udp_frame.raw.length); - if (size < payload_length + offset) { - LOGD("cannot read udp %u + 2 for %d, will try later", payload_length, get_port(bev_udp->sockaddr)); - break; - } - } else { - offset = 0; - payload_length = size; + if (size < UDP_FRAME_LENGTH_SIZE) { + break; + } + if (evbuffer_copyout(buf, &udp_frame, UDP_FRAME_LENGTH_SIZE) != UDP_FRAME_LENGTH_SIZE) { + LOGE("cannot copy udp to get payload length for %d", get_port(bev_udp->sockaddr)); + raw->errorcb(raw, BEV_EVENT_ERROR, get_wss(raw)); + break; + } + payload_length = htons(udp_frame.length); + length = payload_length + UDP_FRAME_LENGTH_SIZE; + if (size < length) { + break; } - length = payload_length + offset; if (evbuffer_copyout(buf, &udp_frame, length) != (int) length) { LOGE("cannot copy udp %d for %d", (int) length, get_port(bev_udp->sockaddr)); - raw_event_cb(raw, BEV_EVENT_ERROR, get_wss(raw)); + raw->errorcb(raw, BEV_EVENT_ERROR, get_wss(raw)); break; } - payload = raw->errorcb == raw_event_cb ? udp_frame.raw.buffer : udp_frame.buffer; - if (sendto(bev_udp->sock, payload, payload_length, 0, bev_udp->sockaddr, bev_udp->socklen) < 0) { + if (sendto(bev_udp->sock, udp_frame.buffer, payload_length, 0, bev_udp->sockaddr, bev_udp->socklen) < 0) { // is there any chance to sendto later? int socket_error = evutil_socket_geterror(bev_udp->sock); LOGE("cannot send udp to %d: %s", get_port(bev_udp->sockaddr), evutil_socket_error_to_string(socket_error)); - raw_event_cb(raw, BEV_EVENT_ERROR, get_wss(raw)); + raw->errorcb(raw, BEV_EVENT_ERROR, get_wss(raw)); break; } LOGD("udp sent %d to peer %d", payload_length, get_port(bev_udp->sockaddr)); @@ -777,21 +786,8 @@ void udp_send_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void } } -int udp_read_cb(evutil_socket_t sock, struct bufferevent *raw, struct sockaddr *sockaddr, ev_socklen_t *socklen) { - struct udp_frame udp_frame; - struct timeval one_minute = {60, 0}; - char *payload = raw->errorcb == raw_event_cb ? udp_frame.raw.buffer : udp_frame.buffer; - ssize_t size = recvfrom(sock, payload, MAX_UDP_FRAME_SIZE, 0, sockaddr, socklen); -#ifdef WSS_PROXY_CLIENT -#define get_common_port() get_port(sockaddr) -#endif -#ifdef WSS_PROXY_SERVER -#define get_common_port() get_http_port(get_wss(raw)) -#endif - if (!raw->readcb || !raw->errorcb) { - LOGE("no readcb or errorcb"); - return -1; - } +ssize_t udp_read(evutil_socket_t sock, struct udp_frame *udp_frame, struct sockaddr *sockaddr, ev_socklen_t *socklen) { + ssize_t size = recvfrom(sock, udp_frame->buffer, MAX_UDP_FRAME_SIZE, 0, sockaddr, socklen); if (size < 0) { int socket_error = evutil_socket_geterror(sock); if (!EVUTIL_ERR_RW_RETRIABLE(socket_error)) { @@ -799,24 +795,13 @@ int udp_read_cb(evutil_socket_t sock, struct bufferevent *raw, struct sockaddr * } return -1; } else if (size == 0) { - LOGE("udp receive 0 for peer %d", get_common_port()); - raw->errorcb(raw, BEV_EVENT_EOF, raw->cbarg); + LOGE("udp receive 0 from port %d", get_port(sockaddr)); + return 0; } else { - size_t old_size = evbuffer_get_length(raw->input); - if (old_size) { - LOGE("udp receive %d for peer %d, previous: %d", (int) size, get_common_port(), (int) old_size); - } - event_add(&(raw->ev_read), &one_minute); - if (raw->errorcb == raw_event_cb) { - udp_frame.raw.length = ntohs(size); - evbuffer_add(raw->input, &udp_frame, size + 2); - } else { - evbuffer_add(raw->input, &udp_frame, size); - } - LOGD("udp read %d for peer %d", (int) size, get_common_port()); - raw->readcb(raw, raw->cbarg); + udp_frame->length = ntohs(size); + LOGD("udp read %d from port %d", (int) size, get_port(sockaddr)); + return (int) size; } - return 0; } #ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK diff --git a/common.h b/common.h index 79a2461..14264bd 100644 --- a/common.h +++ b/common.h @@ -52,14 +52,10 @@ enum log_level { }; #define MAX_UDP_FRAME_SIZE 65535 +#define UDP_FRAME_LENGTH_SIZE 2 struct udp_frame { - union { - struct { - uint16_t length; - char buffer[MAX_UDP_FRAME_SIZE]; - } raw; - char buffer[MAX_UDP_FRAME_SIZE]; - }; + uint16_t length; + char buffer[MAX_UDP_FRAME_SIZE]; }; #define MAX_WSS_PAYLOAD_SIZE MAX_UDP_FRAME_SIZE @@ -168,9 +164,11 @@ void tunnel_ss(struct bufferevent *raw, struct evhttp_connection *wss); #define SOCK_TYPE_UDP "udp" #define IS_UDP(x) (x != NULL && !evutil_ascii_strcasecmp(x, SOCK_TYPE_UDP)) -void udp_send_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void *arg); +ssize_t udp_read(evutil_socket_t sock, struct udp_frame *udp_frame, struct sockaddr *sockaddr, ev_socklen_t *socklen); -int udp_read_cb(evutil_socket_t sock, struct bufferevent *raw, struct sockaddr *sockaddr, ev_socklen_t *socklen); +void udp_read_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void *arg); + +void udp_send_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void *arg); #ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK void ssl_keylog_callback(const SSL *ssl, const char *line); diff --git a/testcase/check-udp.py b/testcase/check-udp.py index 2955265..64a6460 100644 --- a/testcase/check-udp.py +++ b/testcase/check-udp.py @@ -1,19 +1,43 @@ import socket import socks +import time if __name__ == '__main__': for x in range(4): sock = socks.socksocket(socket.AF_INET, socket.SOCK_DGRAM) sock.set_proxy(socks.SOCKS5, "localhost", 1081) - for y in range(65): - data = ("deadbeef_%d_%d" % (x, y)) - data += 'c' * (1024 * y - len(data)) + reqs = [] + for y in range(10): + data = ("deadbeef_%d_%d_1" % (x, y)) req = data.encode('utf8') + sock.sendto(req, ("127.0.0.1", 1235)) + reqs.append(req) + for y in range(10): + data = ("deadbeef_%d_%d_2" % (x, y)) + req = data.encode('utf8') + sock.sendto(req, ("127.0.0.1", 1235)) + time.sleep(.1) + reqs.append(req) + ress = [] + for req in reqs: + (res, addr) = sock.recvfrom(65535) + ress.append(res) + if len(reqs) != len(ress): + raise ValueError("length wrong") + if set(reqs) != set(ress): + raise ValueError("value wrong") + print(b', '.join(reqs).decode('utf8')) + print(b', '.join(ress).decode('utf8')) + for y in range(1, 10): + req = ('c' * 1024 * y).encode('utf8') try: - sock.sendto(req, ("127.0.0.1", 1235)) + sock.sendto(req, ('127.0.0.1', 1235)) except OSError: + print('length %d ko' % len(req)) + sock.close() break (res, addr) = sock.recvfrom(65535) - print("round %d_%d, length: %d" % (x, y, len(data))) + print('length %d ok' % len(req)) if req != res: - raise ValueError + raise ValueError() + diff --git a/wss-proxy-client.c b/wss-proxy-client.c index 504d1a4..a5e551a 100644 --- a/wss-proxy-client.c +++ b/wss-proxy-client.c @@ -462,31 +462,21 @@ static struct bufferevent_udp *init_udp_client(struct bufferevent_udp *key, stru } static void udp_read_cb_client(evutil_socket_t sock, short event, void *ctx) { - char buffer[1]; struct udp_context *context = ctx; struct bufferevent_udp key, *data; + struct udp_frame udp_frame; (void) event; key.sockaddr = (struct sockaddr *) &(key.sockaddr_storage); for (;;) { + ssize_t size; key.socklen = sizeof(struct sockaddr_storage); - if (recvfrom(sock, &buffer, 1, MSG_PEEK, key.sockaddr, &(key.socklen)) < 0) { - int socket_error = evutil_socket_geterror(sock); - if (!EVUTIL_ERR_RW_RETRIABLE(socket_error)) { - LOGE("cannot recvfrom udp when check addr: %s", evutil_socket_error_to_string(socket_error)); - } + if ((size = udp_read(sock, &udp_frame, key.sockaddr, &(key.socklen))) < 0) { break; } if ((data = init_udp_client(&key, context, sock, get_port(key.sockaddr))) == NULL) { break; } - if (!data->be.readcb) { - // wait for tunnel ok - break; - } - key.socklen = sizeof(struct sockaddr_storage); - if (udp_read_cb(sock, (struct bufferevent *) data, key.sockaddr, &(key.socklen)) < 0) { - break; - } + evbuffer_add(data->be.input, &udp_frame, size + 2); } } diff --git a/wss-proxy-server.c b/wss-proxy-server.c index 39b86ad..877bedb 100644 --- a/wss-proxy-server.c +++ b/wss-proxy-server.c @@ -114,13 +114,16 @@ static void udp_read_cb_server(evutil_socket_t sock, short event, void *ctx) { LOGD("udp timeout for peer %d", get_http_port(wss)); raw_event_cb(raw, BEV_EVENT_EOF, wss); } else if (event & EV_READ) { + struct udp_frame udp_frame; ev_socklen_t socklen; struct sockaddr_storage sockaddr; for (;;) { + ssize_t size; socklen = sizeof(struct sockaddr_storage); - if (udp_read_cb(sock, raw, (struct sockaddr *) &sockaddr, &socklen) < 0) { + if ((size = udp_read(sock, &udp_frame, (struct sockaddr *) &sockaddr, &socklen)) < 0) { break; } + evbuffer_add(raw->input, &udp_frame, size + 2); } } }