Skip to content

Commit

Permalink
fix racing on linux, swith to case-insensitive header comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
DaseinPhaos committed Aug 14, 2024
1 parent b0c87ec commit dae6880
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
41 changes: 19 additions & 22 deletions linux.jai
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ thread_pool : [THREAD_POOL_COUNT] Thread;
epoll_fd_pool : [THREAD_POOL_COUNT] s32;
epoll_events_pool : [THREAD_POOL_COUNT] [MAX_EVENTS] epoll_event;

request_pool : Table(Socket, Request);
request_pool : [THREAD_POOL_COUNT]Table(Socket, Request);


listen_socket: Socket;
Expand All @@ -38,19 +38,14 @@ http_listen :: (port: u16 = 80, ip: string = "0.0.0.0") {
listen_socket = create_listen_socket(port, ip);
log("HTTP Server listening on %:%", ip, port);

// request_pool init
// without this our code actually crashes when there's high concurrency!
// Array bounds check failed
// i think while this table is expanding in 1 thread it's being accessed out of bounds in another thread?
init(*request_pool, BACKLOG+3); // why +3?

// spin up the thread pool
for * thread, i: thread_pool {
epoll_fd_pool[i] = epoll_create1(0);
if epoll_fd_pool[i] < 0 die("epoll_create1(0)");
add_to_epoll_fd_list(epoll_fd_pool[i], listen_socket, null, EPOLLIN | EPOLLET /*| EPOLLONESHOT*/);

thread.data = xx i;
init(*request_pool[i], BACKLOG+3); // why +3?
thread_init(thread, thread_pool_proc);
thread_start(thread);
}
Expand Down Expand Up @@ -127,8 +122,8 @@ thread_pool_proc :: (thread: *Thread) -> s64 {

if event.data.fd == listen_socket_ssl then ssl_make_socket_ssl(client_socket);

new_request_from_pool(client_socket);
handle_client_socket_accepted_or_read_ready(client_socket, epoll_fd);
new_request_from_pool(client_socket, thread_id);
handle_client_socket_accepted_or_read_ready(client_socket, epoll_fd, thread_id);
}

continue;
Expand All @@ -137,7 +132,7 @@ thread_pool_proc :: (thread: *Thread) -> s64 {
// not the listen socket, must be a client_socket ready to read or write ... right?
if event.events & EPOLLIN {
client_socket := event.data.fd;
handle_client_socket_accepted_or_read_ready(client_socket, epoll_fd);
handle_client_socket_accepted_or_read_ready(client_socket, epoll_fd, thread_id);
// add_to_epoll_fd_list(epoll_fd, client_socket, null, EPOLLIN | EPOLLET | EPOLLONESHOT, mod=true);
continue;
}
Expand All @@ -163,8 +158,8 @@ make_mysocket :: (socket: Socket) -> MySocket {
return mysocket;
}

new_request_from_pool :: (client_socket: Socket) -> *Request {
request: *Request = find_or_add(*request_pool, client_socket);
new_request_from_pool :: (client_socket: Socket, thread_id: s32) -> *Request {
request: *Request = find_or_add(*request_pool[thread_id], client_socket);
request.content_length = 0;
request.body.data = null;
request.response_sent = false;
Expand All @@ -186,7 +181,7 @@ new_request_from_pool :: (client_socket: Socket) -> *Request {
return request;
}

handle_client_socket_accepted_or_read_ready :: (client_socket: Socket, epoll_fd: s32) {
handle_client_socket_accepted_or_read_ready :: (client_socket: Socket, epoll_fd: s32, thread_id: s32) {
while 1 {

// if this socket is a websocket, not a request
Expand All @@ -197,7 +192,7 @@ handle_client_socket_accepted_or_read_ready :: (client_socket: Socket, epoll_fd:
}


request := read_http_request(client_socket);
request := read_http_request(client_socket, thread_id);
if request == null return; // we're waiting to read, need to come back later
if request.err {
// ssl has a massive memory leak when run with high concurrency... this doesn't fix it...
Expand All @@ -210,15 +205,15 @@ handle_client_socket_accepted_or_read_ready :: (client_socket: Socket, epoll_fd:
request.epoll_fd = epoll_fd;
handle_request(request);

new_request_from_pool(client_socket); // reset this requests data when we're done with it
new_request_from_pool(client_socket, thread_id); // reset this requests data when we're done with it
}
}


// todo: i think i have to rewrite this to handle read() not giving us a full line
// this is slowwwwww. maybe i should use faster string compares
read_http_request :: (socket: Socket) -> *Request {
request: *Request = table_find_pointer(*request_pool, socket);
read_http_request :: (socket: Socket, thread_id: s32) -> *Request {
// this is slowwwwww. maybe i should use faster string compares
request: *Request = table_find_pointer(*request_pool[thread_id], socket);
if request == null then die("no request for socket", socket);
// assert(request != null);

Expand Down Expand Up @@ -265,10 +260,12 @@ read_http_request :: (socket: Socket) -> *Request {
kv := split(line, ": ");
#if parseheaders table_add(*request.headers, kv[0], kv[1]);

if kv[0] == {
case "Content-Length"; request.content_length = xx to_integer(kv[1]);
case "Sec-WebSocket-Key"; request.websocket_key = kv[1];
case "Accept-Encoding"; request.accept_encoding = kv[1];
if compare_nocase(kv[0], "Content-Length") == 0 {
request.content_length = xx to_integer(kv[1]);
} else if compare_nocase(kv[0], "Sec-WebSocket-Key") == 0 {
request.websocket_key = kv[1];
} else if compare_nocase(kv[0], "Accept-Encoding") == 0 {
request.accept_encoding = kv[1];
}

if request.raw[request.buffercursor+2] == #char "\r" && request.raw[request.buffercursor+2 + 1] == #char "\n" {
Expand Down
9 changes: 6 additions & 3 deletions windows.jai
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,12 @@ recv_http_request :: (socket: MySocket) -> Request {
kv := split(line, ": ");
#if parseheaders table_add(*request.headers, kv[0], kv[1]);

if kv[0] == {
case "Content-Length"; request.content_length = xx to_integer(kv[1]);
case "Sec-WebSocket-Key"; request.websocket_key = kv[1];
if compare_nocase(kv[0], "Content-Length") == 0 {
request.content_length = xx to_integer(kv[1]);
} else if compare_nocase(kv[0], "Sec-WebSocket-Key") == 0 {
request.websocket_key = kv[1];
} else if compare_nocase(kv[0], "Accept-Encoding") == 0 {
request.accept_encoding = kv[1];
}

if request.raw[buffercursor+2] == #char "\r" && request.raw[buffercursor+2 + 1] == #char "\n" {
Expand Down

0 comments on commit dae6880

Please sign in to comment.