From 265b0f5ce8c6c8d78208b9ed364bb7f81e3c8846 Mon Sep 17 00:00:00 2001 From: Garrett D'Amore Date: Fri, 29 Nov 2024 13:10:13 -0500 Subject: [PATCH] socket: convert to using reference counts for shutdown This is a major change, but it should eliminate some of the problems we have seen with use-after-free bugs in shutdown. It should also be faster as we don't need to use locks as much. --- src/core/device.c | 4 +- src/core/pipe.c | 20 ++++- src/core/socket.c | 187 ++++++++++++++++++++++------------------------ src/nng.c | 2 +- src/sp/protocol.c | 1 + 5 files changed, 110 insertions(+), 104 deletions(-) diff --git a/src/core/device.c b/src/core/device.c index 7a6fb479d..5a2ef9781 100644 --- a/src/core/device.c +++ b/src/core/device.c @@ -51,6 +51,8 @@ device_fini(void *arg) for (int i = 0; i < d->num_paths; i++) { nni_aio_stop(&d->paths[i].aio); } + nni_sock_rele(d->paths[0].src); + nni_sock_rele(d->paths[0].dst); NNI_FREE_STRUCT(d); } @@ -97,8 +99,6 @@ device_cb(void *arg) nni_aio_finish_error(d->user, d->rv); d->user = NULL; } - nni_sock_rele(d->paths[0].src); - nni_sock_rele(d->paths[0].dst); nni_reap(&device_reap, d); } diff --git a/src/core/pipe.c b/src/core/pipe.c index 11f815095..f13f874be 100644 --- a/src/core/pipe.c +++ b/src/core/pipe.c @@ -9,8 +9,12 @@ // found online at https://opensource.org/licenses/MIT. // +#include + #include "core/nng_impl.h" -#include "nng/nng.h" + +#include "dialer.h" +#include "listener.h" #include "sockimpl.h" #include @@ -47,8 +51,10 @@ pipe_destroy(void *arg) void pipe_reap(void *arg) { - nni_pipe *p = arg; - nni_sock *s = p->p_sock; + nni_pipe *p = arg; + nni_sock *s = p->p_sock; + nni_dialer *d = p->p_dialer; + nni_listener *l = p->p_listener; nni_pipe_run_cb(p, NNG_PIPE_EV_REM_POST); @@ -71,6 +77,12 @@ pipe_reap(void *arg) } nni_pipe_rele(p); + if (l != NULL) { + nni_listener_rele(l); + } + if (d != NULL) { + nni_dialer_rele(d); + } nni_sock_rele(s); } @@ -291,6 +303,7 @@ nni_pipe_create_dialer(nni_pipe **pp, nni_dialer *d, void *tran_data) pipe_stat_init(p, &p->st_ep_id, &dialer_info); nni_stat_set_id(&p->st_ep_id, (int) nni_dialer_id(d)); #endif + nni_dialer_hold(d); *pp = p; return (0); } @@ -312,6 +325,7 @@ nni_pipe_create_listener(nni_pipe **pp, nni_listener *l, void *tran_data) .si_desc = "listener for pipe", .si_type = NNG_STAT_ID, }; + nni_listener_hold(l); pipe_stat_init(p, &p->st_ep_id, &listener_info); nni_stat_set_id(&p->st_ep_id, (int) nni_listener_id(l)); #endif diff --git a/src/core/socket.c b/src/core/socket.c index 228188349..c74a25bde 100644 --- a/src/core/socket.c +++ b/src/core/socket.c @@ -9,10 +9,12 @@ // #include "core/defs.h" +#include "core/idhash.h" #include "core/nng_impl.h" #include "core/options.h" #include "core/pipe.h" #include "core/platform.h" +#include "core/reap.h" #include "list.h" #include "nng/nng.h" #include "sockimpl.h" @@ -68,13 +70,10 @@ static const nni_ep_option ep_options[] = { }; struct nni_socket { - nni_list_node s_node; - nni_mtx s_mx; - nni_cv s_close_cv; + nni_mtx s_mx; uint32_t s_id; uint32_t s_flags; - unsigned s_ref; // protected by global lock void *s_data; // Protocol private size_t s_size; @@ -106,6 +105,10 @@ struct nni_socket { nni_mtx s_pipe_cbs_mtx; nni_sock_pipe_cb s_pipe_cbs[NNG_PIPE_EV_NUM]; + nni_reap_node s_reap; + nni_atomic_flag s_close_once; + nni_refcnt s_refcnt; + #ifdef NNG_ENABLE_STATS nni_stat_item st_root; // socket scope nni_stat_item st_id; // socket id @@ -121,8 +124,7 @@ struct nni_socket { #endif }; -static nni_list sock_list = NNI_LIST_INITIALIZER(sock_list, nni_sock, s_node); -static nni_mtx sock_lk = NNI_MTX_INITIALIZER; +static nni_mtx sock_lk = NNI_MTX_INITIALIZER; static nni_id_map sock_ids = NNI_ID_MAP_INITIALIZER(1, 0x7fffffff, 0); static nni_id_map ctx_ids = NNI_ID_MAP_INITIALIZER(1, 0x7fffffff, 0); @@ -343,7 +345,7 @@ nni_sock_find(nni_sock **sockp, uint32_t id) if (s->s_closed) { rv = NNG_ECLOSED; } else { - s->s_ref++; + nni_sock_hold(s); *sockp = s; } } else { @@ -357,20 +359,13 @@ nni_sock_find(nni_sock **sockp, uint32_t id) void nni_sock_hold(nni_sock *s) { - nni_mtx_lock(&sock_lk); - s->s_ref++; - nni_mtx_unlock(&sock_lk); + nni_refcnt_hold(&s->s_refcnt); } void nni_sock_rele(nni_sock *s) { - nni_mtx_lock(&sock_lk); - s->s_ref--; - if (s->s_closed && (s->s_ref < 2)) { - nni_cv_wake(&s->s_close_cv); - } - nni_mtx_unlock(&sock_lk); + nni_refcnt_rele(&s->s_refcnt); } #ifdef NNG_ENABLE_STATS @@ -473,8 +468,9 @@ sock_stats_init(nni_sock *s) #endif static void -sock_destroy(nni_sock *s) +sock_destroy(void *arg) { + nni_sock *s = arg; #ifdef NNG_ENABLE_STATS nni_stat_unregister(&s->st_root); #endif @@ -486,7 +482,6 @@ sock_destroy(nni_sock *s) nni_msgq_fini(s->s_urq); nni_msgq_fini(s->s_uwq); - nni_cv_fini(&s->s_close_cv); nni_mtx_fini(&s->s_mx); nni_mtx_fini(&s->s_pipe_cbs_mtx); nni_free(s, s->s_size); @@ -511,7 +506,6 @@ nni_sock_create(nni_sock **sp, const nni_proto *proto) s->s_reconnmax = 0; s->s_rcvmaxsz = 0; // unlimited by default s->s_id = 0; - s->s_ref = 0; s->s_self_id = proto->proto_self; s->s_peer_id = proto->proto_peer; s->s_flags = proto->proto_flags; @@ -527,14 +521,12 @@ nni_sock_create(nni_sock **sp, const nni_proto *proto) NNI_ASSERT(s->s_sock_ops.sock_open != NULL); NNI_ASSERT(s->s_sock_ops.sock_close != NULL); - NNI_LIST_NODE_INIT(&s->s_node); NNI_LIST_INIT(&s->s_ctxs, nni_ctx, c_node); NNI_LIST_INIT(&s->s_pipes, nni_pipe, p_sock_node); NNI_LIST_INIT(&s->s_listeners, nni_listener, l_node); NNI_LIST_INIT(&s->s_dialers, nni_dialer, d_node); nni_mtx_init(&s->s_mx); nni_mtx_init(&s->s_pipe_cbs_mtx); - nni_cv_init(&s->s_close_cv, &sock_lk); #ifdef NNG_ENABLE_STATS sock_stats_init(s); @@ -547,6 +539,8 @@ nni_sock_create(nni_sock **sp, const nni_proto *proto) } s->s_sock_ops.sock_init(s->s_data, s); + nni_refcnt_init(&s->s_refcnt, 2, s, sock_destroy); + // These we *attempt* to set so that we are likely to have initial // values loaded. They should not fail, but if they do we don't // worry about it. @@ -589,7 +583,6 @@ nni_sock_open(nni_sock **sockp, const nni_proto *proto) sock_destroy(s); return (rv); } else { - nni_list_append(&sock_list, s); s->s_sock_ops.sock_open(s->s_data); *sockp = s; } @@ -615,54 +608,67 @@ nni_sock_open(nni_sock **sockp, const nni_proto *proto) int nni_sock_shutdown(nni_sock *sock) { - nni_pipe *pipe; - nni_dialer *d; + nni_mtx_lock(&sock->s_mx); + + // At this point, we've done everything we politely can to + // give the protocol a chance to flush its write side. Now + // it is time to be a little more insistent. + + sock->s_sock_ops.sock_close(sock->s_data); + + nni_mtx_unlock(&sock->s_mx); + + return (0); +} + +static void +sock_reap(void *arg) +{ + nni_sock *s = arg; + + // Shutdown everything if not already done. This operation + // is idempotent. + nni_sock_shutdown(s); + + nni_sock_rele(s); +} + +static nni_reap_list sock_reap_list = { + .rl_offset = offsetof(nni_sock, s_reap), + .rl_func = sock_reap, +}; + +static void +sock_close_children(nni_sock *s) +{ nni_listener *l; + nni_dialer *d; + nni_pipe *p; nni_ctx *ctx; - nni_mtx_lock(&sock->s_mx); - if (sock->s_closing) { - nni_mtx_unlock(&sock->s_mx); - return (NNG_ECLOSED); - } - // Mark us closing, so no more EPs or changes can occur. - sock->s_closing = true; + // Close the upper queues immediately. + nni_msgq_close(s->s_urq); + nni_msgq_close(s->s_uwq); - NNI_LIST_FOREACH (&sock->s_listeners, l) { + nni_mtx_lock(&s->s_mx); + s->s_closing = true; + + NNI_LIST_FOREACH (&s->s_listeners, l) { nni_listener_close(l); } - NNI_LIST_FOREACH (&sock->s_dialers, d) { + NNI_LIST_FOREACH (&s->s_dialers, d) { nni_dialer_close(d); } - NNI_LIST_FOREACH (&sock->s_pipes, pipe) { - nni_pipe_close(pipe); + NNI_LIST_FOREACH (&s->s_pipes, p) { + nni_pipe_close(p); } - NNI_LIST_FOREACH (&sock->s_ctxs, ctx) { + NNI_LIST_FOREACH (&s->s_ctxs, ctx) { nni_ctx_close(ctx); } - nni_mtx_unlock(&sock->s_mx); - - // Close the upper queues immediately. - nni_msgq_close(sock->s_urq); - nni_msgq_close(sock->s_uwq); - - nni_mtx_lock(&sock->s_mx); - - // At this point, we've done everything we politely can to - // give the protocol a chance to flush its write side. Now - // it is time to be a little more insistent. - - sock->s_sock_ops.sock_close(sock->s_data); - - nni_mtx_unlock(&sock->s_mx); - - // At this point, there are no threads blocked inside of us - // that are referencing socket state. User code should call - // nng_close to release the last resources. - return (0); + nni_mtx_unlock(&s->s_mx); } // nni_sock_close shuts down the socket, then releases any resources @@ -672,61 +678,47 @@ nni_sock_shutdown(nni_sock *sock) void nni_sock_close(nni_sock *s) { - // Shutdown everything if not already done. This operation - // is idempotent. - nni_sock_shutdown(s); - - nni_mtx_lock(&sock_lk); - if (s->s_closed) { - // Some other thread called close. All we need to do - // is drop our reference count. - nni_mtx_unlock(&sock_lk); - nni_sock_rele(s); + if (nni_atomic_flag_test_and_set(&s->s_close_once)) { return; } + + nni_mtx_lock(&sock_lk); s->s_closed = true; nni_id_remove(&sock_ids, s->s_id); + nni_mtx_unlock(&sock_lk); - // We might have been removed from the list already, e.g. by - // nni_sock_closeall. This is idempotent. - nni_list_node_remove(&s->s_node); - - // Wait for all other references to drop. Note that we - // have a reference already (from our caller). - while ((s->s_ref > 1) || (!nni_list_empty(&s->s_ctxs))) { - nni_cv_wait(&s->s_close_cv); + // no more pipe callbacks after close + nni_mtx_lock(&s->s_pipe_cbs_mtx); + for (int ev = 0; ev < NNG_PIPE_EV_NUM; ev++) { + s->s_pipe_cbs[ev].cb_fn = NULL; + s->s_pipe_cbs[ev].cb_arg = NULL; } - nni_mtx_unlock(&sock_lk); + nni_mtx_unlock(&s->s_pipe_cbs_mtx); - // Because we already shut everything down before, we should not - // have any child objects. - nni_mtx_lock(&s->s_mx); - NNI_ASSERT(nni_list_empty(&s->s_dialers)); - NNI_ASSERT(nni_list_empty(&s->s_listeners)); - NNI_ASSERT(nni_list_empty(&s->s_pipes)); - nni_mtx_unlock(&s->s_mx); + sock_close_children(s); - sock_destroy(s); + nni_reap(&sock_reap_list, s); } void nni_sock_closeall(void) { nni_sock *s; + uint32_t next = 0; - for (;;) { - nni_mtx_lock(&sock_lk); - if ((s = nni_list_first(&sock_list)) == NULL) { - nni_mtx_unlock(&sock_lk); - return; + nni_mtx_lock(&sock_lk); + while (nni_id_visit(&sock_ids, NULL, (void **) &s, &next)) { + if (nni_atomic_flag_test_and_set(&s->s_close_once)) { + continue; } - // Bump the reference count. The close call below - // will drop it. - s->s_ref++; - nni_list_node_remove(&s->s_node); - nni_mtx_unlock(&sock_lk); - nni_sock_close(s); + s->s_closed = true; + nni_id_remove(&sock_ids, s->s_id); + + sock_close_children(s); + + nni_reap(&sock_reap_list, s); } + nni_mtx_unlock(&sock_lk); } void @@ -953,7 +945,8 @@ nni_ctx_find(nni_ctx **cp, uint32_t id) nni_ctx *ctx; nni_mtx_lock(&sock_lk); - if ((ctx = nni_id_get(&ctx_ids, id)) != NULL) { + if (((ctx = nni_id_get(&ctx_ids, id)) != NULL) && + (!ctx->c_sock->s_closed)) { // We refuse a reference if either the socket is // closed, or the context is closed. (If the socket // is closed, and we are only getting the reference so @@ -1040,7 +1033,7 @@ nni_ctx_open(nni_ctx **ctxp, nni_sock *sock) sock->s_ctx_ops.ctx_init(ctx->c_data, sock->s_data); nni_refcnt_init(&ctx->c_refcnt, 2, ctx, ctx_destroy); - sock->s_ref++; + nni_sock_hold(sock); nni_mtx_unlock(&sock_lk); nni_mtx_lock(&sock->s_mx); @@ -1403,8 +1396,6 @@ nni_pipe_remove(nni_pipe *p) #endif nni_list_node_remove(&p->p_sock_node); nni_list_node_remove(&p->p_ep_node); - p->p_listener = NULL; - p->p_dialer = NULL; if ((d != NULL) && (d->d_pipe == p)) { d->d_pipe = NULL; dialer_timer_start_locked(d); // Kick the timer to redial. diff --git a/src/nng.c b/src/nng.c index e845c091e..000890610 100644 --- a/src/nng.c +++ b/src/nng.c @@ -36,8 +36,8 @@ nng_close(nng_socket s) if ((rv = nni_sock_find(&sock, s.id)) != 0) { return (rv); } - // No release -- close releases it. nni_sock_close(sock); + nni_sock_rele(sock); return (0); } diff --git a/src/sp/protocol.c b/src/sp/protocol.c index d64c85a37..1312ea79c 100644 --- a/src/sp/protocol.c +++ b/src/sp/protocol.c @@ -22,6 +22,7 @@ nni_proto_open(nng_socket *sip, const nni_proto *proto) nng_socket s; s.id = nni_sock_id(sock); // Keep socket held open. *sip = s; + nni_sock_rele(sock); } return (rv); }