Skip to content

Commit

Permalink
both: optimize for udp over wss
Browse files Browse the repository at this point in the history
  • Loading branch information
liudongmiao committed Apr 7, 2024
1 parent 311a57d commit 54d3e0e
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 118 deletions.
161 changes: 73 additions & 88 deletions common.c
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ static enum bufferevent_filter_result wss_input_filter(struct evbuffer *src, str
ssize_t header_size;
int result;
struct ws_header_info info;
struct udp_frame udp_frame;

(void) dst_limit;
(void) mode;
Expand Down Expand Up @@ -465,37 +466,21 @@ static enum bufferevent_filter_result wss_input_filter(struct evbuffer *src, str
reply_close(src, info.payload_size, info.mask_key, raw);
return BEV_ERROR;
}
#ifdef WSS_PROXY_SERVER
if (evbuffer_remove(src, udp_frame.buffer, info.payload_size) != info.payload_size) {
LOGW("cannot read more data");
send_close(raw, CLOSE_INTERNAL_ERROR);
return BEV_ERROR;
}
if (info.mask_key) {
char buffer[MAX_WSS_PAYLOAD_SIZE];
uint16_t payload_size = info.payload_size;
while (payload_size > 0) {
int size = evbuffer_remove(src, buffer, MIN(payload_size, MAX_WSS_PAYLOAD_SIZE));
if (size <= 0) {
LOGW("cannot read more data");
send_close(raw, CLOSE_INTERNAL_ERROR);
return BEV_ERROR;
}
unmask(buffer, (uint16_t) size, info.mask_key);
evbuffer_add(dst, buffer, (uint16_t) size);
payload_size -= (uint16_t) size;
}
return BEV_OK;
unmask(udp_frame.buffer, info.payload_size, info.mask_key);
}
#endif
{
uint16_t payload_size = info.payload_size;
while (payload_size > 0) {
int size = evbuffer_remove_buffer(src, dst, payload_size);
if (size <= 0) {
LOGW("cannot read more data");
send_close(raw, CLOSE_INTERNAL_ERROR);
return BEV_ERROR;
}
payload_size -= (uint16_t) size;
}
return BEV_OK;
if (((struct bufferevent *) raw)->be_ops) {
evbuffer_add(dst, udp_frame.buffer, info.payload_size);
} else {
udp_frame.length = ntohs(info.payload_size);
evbuffer_add(dst, &udp_frame, info.payload_size + 2);
}
return BEV_OK;
}

static void close_wss_data_cb(struct bufferevent *tev, void *wss) {
Expand Down Expand Up @@ -567,24 +552,37 @@ static void raw_forward_cb(struct bufferevent *raw, void *wss) {
struct evbuffer *src;
struct evbuffer *dst;
struct bufferevent *tev;
int payload_size;
size_t total_size;

tev = evhttp_connection_get_bufferevent(wss);
src = bufferevent_get_input(raw);
dst = bufferevent_get_output(tev);

payload_size = raw->be_ops ? WSS_PAYLOAD_SIZE : MAX_WSS_PAYLOAD_SIZE;
for (;;) {
total_size = evbuffer_get_length(src);
while (total_size > 0) {
// should we use continuation fame?
uint8_t *wss_header, wss_header_size;
struct wss_frame_data {
char header[MAX_WS_HEADER_SIZE];
char buffer[MAX_WSS_PAYLOAD_SIZE];
} wss_frame_data;
int size = evbuffer_remove(src, wss_frame_data.buffer, payload_size);
int size, payload_size;
if (raw->be_ops) {
payload_size = WSS_PAYLOAD_SIZE;
} else {
if (evbuffer_remove(src, &payload_size, UDP_FRAME_LENGTH_SIZE) != UDP_FRAME_LENGTH_SIZE) {
LOGE("cannot remove 2 from src for payload_size");
break;
}
total_size -= UDP_FRAME_LENGTH_SIZE;
payload_size = htons(payload_size);
}
size = evbuffer_remove(src, wss_frame_data.buffer, payload_size);
if (size <= 0) {
LOGE("cannot remove %d from src", payload_size);
break;
}
total_size -= size;
wss_header = build_ws_frame(OP_BINARY, &(wss_frame_data.buffer), (uint16_t) size, &wss_header_size);
evbuffer_add(dst, wss_header, (uint16_t) size + wss_header_size);
}
Expand Down Expand Up @@ -624,6 +622,9 @@ static void wss_forward_cb(struct bufferevent *wev, void *raw) {
struct evbuffer *dst;

src = bufferevent_get_input(wev);
if (!evbuffer_get_length(src)) {
return;
}
dst = bufferevent_get_output(raw);
evbuffer_add_buffer(dst, src);
}
Expand Down Expand Up @@ -680,6 +681,8 @@ void tunnel_wss(struct bufferevent *raw, struct evhttp_connection *wss) {
raw->readcb = raw_forward_cb;
raw->errorcb = raw_event_cb_wss;
raw->cbarg = wss;
evbuffer_add_cb(raw->input, udp_read_cb, raw);
raw->readcb(raw, raw->cbarg);
}
}

Expand All @@ -706,6 +709,9 @@ static void raw_forward_cb_ss(struct bufferevent *raw, void *wss) {
struct evbuffer *dst;

src = bufferevent_get_input(raw);
if (!evbuffer_get_length(src)) {
return;
}
dst = bufferevent_get_output(evhttp_connection_get_bufferevent(wss));
evbuffer_add_buffer(dst, src);
}
Expand All @@ -725,50 +731,53 @@ void tunnel_ss(struct bufferevent *raw, struct evhttp_connection *wss) {
raw->readcb = raw_forward_cb_ss;
raw->errorcb = raw_event_cb;
raw->cbarg = wss;
evbuffer_add_cb(raw->input, udp_read_cb, raw);
raw->readcb(raw, raw->cbarg);
}
}

void udp_read_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void *arg) {
struct bufferevent *raw = arg;
(void) buf;
if (info->n_added > 0) {
raw->readcb(raw, raw->cbarg);
}
}

void udp_send_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void *arg) {
char *payload;
uint8_t offset;
unsigned length;
uint16_t payload_length;
struct udp_frame udp_frame;
struct bufferevent *raw = arg;
struct bufferevent_udp *bev_udp = arg;
size_t size = evbuffer_get_length(buf);
(void) info;
if (info->n_added <= 0) {
return;
}
while (size > 0) {
if (raw->errorcb == raw_event_cb) {
offset = 2;
if (size < offset) {
break;
}
if (evbuffer_copyout(buf, &(udp_frame.raw.length), offset) < 0) {
LOGD("cannot read udp 2 for %d, will try later", get_port(bev_udp->sockaddr));
break;
}
payload_length = htons(udp_frame.raw.length);
if (size < payload_length + offset) {
LOGD("cannot read udp %u + 2 for %d, will try later", payload_length, get_port(bev_udp->sockaddr));
break;
}
} else {
offset = 0;
payload_length = size;
if (size < UDP_FRAME_LENGTH_SIZE) {
break;
}
if (evbuffer_copyout(buf, &udp_frame, UDP_FRAME_LENGTH_SIZE) != UDP_FRAME_LENGTH_SIZE) {
LOGE("cannot copy udp to get payload length for %d", get_port(bev_udp->sockaddr));
raw->errorcb(raw, BEV_EVENT_ERROR, get_wss(raw));
break;
}
payload_length = htons(udp_frame.length);
length = payload_length + UDP_FRAME_LENGTH_SIZE;
if (size < length) {
break;
}
length = payload_length + offset;
if (evbuffer_copyout(buf, &udp_frame, length) != (int) length) {
LOGE("cannot copy udp %d for %d", (int) length, get_port(bev_udp->sockaddr));
raw_event_cb(raw, BEV_EVENT_ERROR, get_wss(raw));
raw->errorcb(raw, BEV_EVENT_ERROR, get_wss(raw));
break;
}
payload = raw->errorcb == raw_event_cb ? udp_frame.raw.buffer : udp_frame.buffer;
if (sendto(bev_udp->sock, payload, payload_length, 0, bev_udp->sockaddr, bev_udp->socklen) < 0) {
if (sendto(bev_udp->sock, udp_frame.buffer, payload_length, 0, bev_udp->sockaddr, bev_udp->socklen) < 0) {
// is there any chance to sendto later?
int socket_error = evutil_socket_geterror(bev_udp->sock);
LOGE("cannot send udp to %d: %s", get_port(bev_udp->sockaddr), evutil_socket_error_to_string(socket_error));
raw_event_cb(raw, BEV_EVENT_ERROR, get_wss(raw));
raw->errorcb(raw, BEV_EVENT_ERROR, get_wss(raw));
break;
}
LOGD("udp sent %d to peer %d", payload_length, get_port(bev_udp->sockaddr));
Expand All @@ -777,46 +786,22 @@ void udp_send_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void
}
}

int udp_read_cb(evutil_socket_t sock, struct bufferevent *raw, struct sockaddr *sockaddr, ev_socklen_t *socklen) {
struct udp_frame udp_frame;
struct timeval one_minute = {60, 0};
char *payload = raw->errorcb == raw_event_cb ? udp_frame.raw.buffer : udp_frame.buffer;
ssize_t size = recvfrom(sock, payload, MAX_UDP_FRAME_SIZE, 0, sockaddr, socklen);
#ifdef WSS_PROXY_CLIENT
#define get_common_port() get_port(sockaddr)
#endif
#ifdef WSS_PROXY_SERVER
#define get_common_port() get_http_port(get_wss(raw))
#endif
if (!raw->readcb || !raw->errorcb) {
LOGE("no readcb or errorcb");
return -1;
}
ssize_t udp_read(evutil_socket_t sock, struct udp_frame *udp_frame, struct sockaddr *sockaddr, ev_socklen_t *socklen) {
ssize_t size = recvfrom(sock, udp_frame->buffer, MAX_UDP_FRAME_SIZE, 0, sockaddr, socklen);
if (size < 0) {
int socket_error = evutil_socket_geterror(sock);
if (!EVUTIL_ERR_RW_RETRIABLE(socket_error)) {
LOGE("cannot recvfrom udp: %s", evutil_socket_error_to_string(socket_error));
}
return -1;
} else if (size == 0) {
LOGE("udp receive 0 for peer %d", get_common_port());
raw->errorcb(raw, BEV_EVENT_EOF, raw->cbarg);
LOGE("udp receive 0 from port %d", get_port(sockaddr));
return 0;
} else {
size_t old_size = evbuffer_get_length(raw->input);
if (old_size) {
LOGE("udp receive %d for peer %d, previous: %d", (int) size, get_common_port(), (int) old_size);
}
event_add(&(raw->ev_read), &one_minute);
if (raw->errorcb == raw_event_cb) {
udp_frame.raw.length = ntohs(size);
evbuffer_add(raw->input, &udp_frame, size + 2);
} else {
evbuffer_add(raw->input, &udp_frame, size);
}
LOGD("udp read %d for peer %d", (int) size, get_common_port());
raw->readcb(raw, raw->cbarg);
udp_frame->length = ntohs(size);
LOGD("udp read %d from port %d", (int) size, get_port(sockaddr));
return (int) size;
}
return 0;
}

#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK
Expand Down
16 changes: 7 additions & 9 deletions common.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,10 @@ enum log_level {
};

#define MAX_UDP_FRAME_SIZE 65535
#define UDP_FRAME_LENGTH_SIZE 2
struct udp_frame {
union {
struct {
uint16_t length;
char buffer[MAX_UDP_FRAME_SIZE];
} raw;
char buffer[MAX_UDP_FRAME_SIZE];
};
uint16_t length;
char buffer[MAX_UDP_FRAME_SIZE];
};
#define MAX_WSS_PAYLOAD_SIZE MAX_UDP_FRAME_SIZE

Expand Down Expand Up @@ -168,9 +164,11 @@ void tunnel_ss(struct bufferevent *raw, struct evhttp_connection *wss);
#define SOCK_TYPE_UDP "udp"
#define IS_UDP(x) (x != NULL && !evutil_ascii_strcasecmp(x, SOCK_TYPE_UDP))

void udp_send_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void *arg);
ssize_t udp_read(evutil_socket_t sock, struct udp_frame *udp_frame, struct sockaddr *sockaddr, ev_socklen_t *socklen);

int udp_read_cb(evutil_socket_t sock, struct bufferevent *raw, struct sockaddr *sockaddr, ev_socklen_t *socklen);
void udp_read_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void *arg);

void udp_send_cb(struct evbuffer *buf, const struct evbuffer_cb_info *info, void *arg);

#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK
void ssl_keylog_callback(const SSL *ssl, const char *line);
Expand Down
36 changes: 30 additions & 6 deletions testcase/check-udp.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,43 @@
import socket
import socks
import time

if __name__ == '__main__':
for x in range(4):
sock = socks.socksocket(socket.AF_INET, socket.SOCK_DGRAM)
sock.set_proxy(socks.SOCKS5, "localhost", 1081)
for y in range(65):
data = ("deadbeef_%d_%d" % (x, y))
data += 'c' * (1024 * y - len(data))
reqs = []
for y in range(10):
data = ("deadbeef_%d_%d_1" % (x, y))
req = data.encode('utf8')
sock.sendto(req, ("127.0.0.1", 1235))
reqs.append(req)
for y in range(10):
data = ("deadbeef_%d_%d_2" % (x, y))
req = data.encode('utf8')
sock.sendto(req, ("127.0.0.1", 1235))
time.sleep(.1)
reqs.append(req)
ress = []
for req in reqs:
(res, addr) = sock.recvfrom(65535)
ress.append(res)
if len(reqs) != len(ress):
raise ValueError("length wrong")
if set(reqs) != set(ress):
raise ValueError("value wrong")
print(b', '.join(reqs).decode('utf8'))
print(b', '.join(ress).decode('utf8'))
for y in range(1, 10):
req = ('c' * 1024 * y).encode('utf8')
try:
sock.sendto(req, ("127.0.0.1", 1235))
sock.sendto(req, ('127.0.0.1', 1235))
except OSError:
print('length %d ko' % len(req))
sock.close()
break
(res, addr) = sock.recvfrom(65535)
print("round %d_%d, length: %d" % (x, y, len(data)))
print('length %d ok' % len(req))
if req != res:
raise ValueError
raise ValueError()

18 changes: 4 additions & 14 deletions wss-proxy-client.c
Original file line number Diff line number Diff line change
Expand Up @@ -462,31 +462,21 @@ static struct bufferevent_udp *init_udp_client(struct bufferevent_udp *key, stru
}

static void udp_read_cb_client(evutil_socket_t sock, short event, void *ctx) {
char buffer[1];
struct udp_context *context = ctx;
struct bufferevent_udp key, *data;
struct udp_frame udp_frame;
(void) event;
key.sockaddr = (struct sockaddr *) &(key.sockaddr_storage);
for (;;) {
ssize_t size;
key.socklen = sizeof(struct sockaddr_storage);
if (recvfrom(sock, &buffer, 1, MSG_PEEK, key.sockaddr, &(key.socklen)) < 0) {
int socket_error = evutil_socket_geterror(sock);
if (!EVUTIL_ERR_RW_RETRIABLE(socket_error)) {
LOGE("cannot recvfrom udp when check addr: %s", evutil_socket_error_to_string(socket_error));
}
if ((size = udp_read(sock, &udp_frame, key.sockaddr, &(key.socklen))) < 0) {
break;
}
if ((data = init_udp_client(&key, context, sock, get_port(key.sockaddr))) == NULL) {
break;
}
if (!data->be.readcb) {
// wait for tunnel ok
break;
}
key.socklen = sizeof(struct sockaddr_storage);
if (udp_read_cb(sock, (struct bufferevent *) data, key.sockaddr, &(key.socklen)) < 0) {
break;
}
evbuffer_add(data->be.input, &udp_frame, size + 2);
}
}

Expand Down
5 changes: 4 additions & 1 deletion wss-proxy-server.c
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,16 @@ static void udp_read_cb_server(evutil_socket_t sock, short event, void *ctx) {
LOGD("udp timeout for peer %d", get_http_port(wss));
raw_event_cb(raw, BEV_EVENT_EOF, wss);
} else if (event & EV_READ) {
struct udp_frame udp_frame;
ev_socklen_t socklen;
struct sockaddr_storage sockaddr;
for (;;) {
ssize_t size;
socklen = sizeof(struct sockaddr_storage);
if (udp_read_cb(sock, raw, (struct sockaddr *) &sockaddr, &socklen) < 0) {
if ((size = udp_read(sock, &udp_frame, (struct sockaddr *) &sockaddr, &socklen)) < 0) {
break;
}
evbuffer_add(raw->input, &udp_frame, size + 2);
}
}
}
Expand Down

0 comments on commit 54d3e0e

Please sign in to comment.