Skip to content

Commit

Permalink
Replace uses of __del__ with weakref.finalize (#317)
Browse files Browse the repository at this point in the history
Modern Python should use `weakref.finalize` instead of `__del__`. This change removes all legacy uses of `__del__` in favor of `weakref.finalize`.

Additionally rename the `id` attribute from `ActiveClients` as it's a reserved keyword in Python, thus it's best if we use another name such as `ident`.

Closes #209

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)
  - Kyle Edwards (https://github.com/KyleFromNVIDIA)

URL: #317
  • Loading branch information
pentschev authored Nov 14, 2024
1 parent 73e2102 commit f1d98f2
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 54 deletions.
4 changes: 2 additions & 2 deletions ci/test_common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ run_py_tests_async() {
ENABLE_PYTHON_FUTURE=$3
SKIP=$4

CMD_LINE="UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 20m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --runslow"
CMD_LINE="UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 30m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --runslow"

if [ $SKIP -ne 0 ]; then
echo -e "\e[1;33mSkipping unstable test: ${CMD_LINE}\e[0m"
else
log_command "${CMD_LINE}"
UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 20m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --durations=50
UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 30m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --runslow
fi
}

Expand Down
32 changes: 28 additions & 4 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,25 @@ def _close_comm(ref):
comm._closed = True


def _finalizer(endpoint: ucxx.Endpoint, resource_id: int) -> None:
"""UCXX comms object finalizer.
Attempt to close the UCXX endpoint if it's still alive, and deregister Dask
resource.
Parameters
----------
endpoint: ucx_api.UCXEndpoint
The endpoint to close.
resource_id: int
The unique ID of the resource returned by `_register_dask_resource` upon
registration.
"""
if endpoint is not None:
endpoint.abort()
_deregister_dask_resource(resource_id)


class UCXX(Comm):
"""Comm object using UCXX.
Expand Down Expand Up @@ -375,14 +394,19 @@ def __init__( # type: ignore[no-untyped-def]
else:
self._has_close_callback = False

self._resource_id = _register_dask_resource()
resource_id = _register_dask_resource()
self._resource_id = resource_id

logger.debug("UCX.__init__ %s", self)

weakref.finalize(self, _deregister_dask_resource, self._resource_id)
weakref.finalize(self, _finalizer, ep, resource_id)

def __del__(self) -> None:
self.abort()
def abort(self):
self._closed = True
if self._ep is not None:
self._ep.abort()
self._ep = None
_deregister_dask_resource(self._resource_id)

@property
def local_address(self) -> str:
Expand Down
4 changes: 1 addition & 3 deletions python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def __init__(self, worker, event_loop, polling_mode=False):
super().__init__(worker, event_loop)
worker.set_progress_thread_start_callback(_create_context)
worker.start_progress_thread(polling_mode=polling_mode, epoll_timeout=1)

def __del__(self):
self.worker.stop_progress_thread()
weakref.finalize(self, worker.stop_progress_thread)


class PollingMode(ProgressTask):
Expand Down
23 changes: 20 additions & 3 deletions python/ucxx/ucxx/_lib_async/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
import logging
import warnings
import weakref

import ucxx._lib.libucxx as ucx_api
from ucxx._lib.arr import Array
Expand All @@ -17,6 +18,23 @@
logger = logging.getLogger("ucx")


def _finalizer(endpoint: ucx_api.UCXEndpoint) -> None:
"""Endpoint finalizer.
Attempt to close the endpoint if it's still alive.
Parameters
----------
endpoint: ucx_api.UCXEndpoint
The endpoint to close.
"""
if endpoint is not None:
logger.debug(f"Endpoint _finalize(): {endpoint.handle:#x}")
# Wait for a maximum of `period` ns
endpoint.close_blocking(period=10**10, max_attempts=1)
endpoint.remove_close_callback()


class Endpoint:
"""An endpoint represents a connection to a peer
Expand All @@ -41,8 +59,7 @@ def __init__(self, endpoint, ctx, tags=None):
self._close_after_n_recv = None
self._tags = tags

def __del__(self):
self.abort()
weakref.finalize(self, _finalizer, endpoint)

@property
def alive(self):
Expand Down Expand Up @@ -107,7 +124,7 @@ def abort(self, period=10**10, max_attempts=1):
if worker is running a progress thread and `period > 0`.
"""
if self._ep is not None:
logger.debug("Endpoint.abort(): 0x%x" % self.uid)
logger.debug(f"Endpoint.abort(): {self.uid:#x}")
# Wait for a maximum of `period` ns
self._ep.close_blocking(period=period, max_attempts=max_attempts)
self._ep.remove_close_callback()
Expand Down
101 changes: 59 additions & 42 deletions python/ucxx/ucxx/_lib_async/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import threading
import weakref

import ucxx._lib.libucxx as ucx_api
from ucxx.exceptions import UCXMessageTruncatedError
Expand All @@ -30,37 +31,62 @@ def __init__(self):
self._locks = dict()
self._active_clients = dict()

def add_listener(self, id: int) -> None:
if id in self._active_clients:
raise ValueError("Listener {id} is already registered in ActiveClients.")
def add_listener(self, ident: int) -> None:
if ident in self._active_clients:
raise ValueError("Listener {ident} is already registered in ActiveClients.")

self._locks[id] = threading.Lock()
self._active_clients[id] = 0
self._locks[ident] = threading.Lock()
self._active_clients[ident] = 0

def remove_listener(self, id: int) -> None:
with self._locks[id]:
active_clients = self.get_active(id)
def remove_listener(self, ident: int) -> None:
with self._locks[ident]:
active_clients = self.get_active(ident)
if active_clients > 0:
raise RuntimeError(
"Listener {id} is being removed from ActiveClients, but "
"Listener {ident} is being removed from ActiveClients, but "
f"{active_clients} active client(s) is(are) still accounted for."
)

del self._locks[id]
del self._active_clients[id]
del self._locks[ident]
del self._active_clients[ident]

def inc(self, id: int) -> None:
with self._locks[id]:
self._active_clients[id] += 1
def inc(self, ident: int) -> None:
with self._locks[ident]:
self._active_clients[ident] += 1

def dec(self, id: int) -> None:
with self._locks[id]:
if self._active_clients[id] == 0:
raise ValueError(f"There are no active clients for listener {id}")
self._active_clients[id] -= 1
def dec(self, ident: int) -> None:
with self._locks[ident]:
if self._active_clients[ident] == 0:
raise ValueError(f"There are no active clients for listener {ident}")
self._active_clients[ident] -= 1

def get_active(self, id: int) -> int:
return self._active_clients[id]
def get_active(self, ident: int) -> int:
return self._active_clients[ident]


def _finalizer(ident: int, active_clients: ActiveClients) -> None:
"""Listener finalizer.
Finalize the listener and remove it from the `ActiveClients`. If there are
active clients, a warning is logged.
Parameters
----------
ident: int
The unique identifier of the `Listener`.
active_clients: ActiveClients
Instance of `ActiveClients` owned by the parent `ApplicationContext`
from which to remove the `Listener`.
"""
try:
active_clients.remove_listener(ident)
except RuntimeError:
active_clients = active_clients.get_active(ident)
logger.warning(
f"Listener object is being destroyed, but {active_clients} client "
"handler(s) is(are) still alive. This usually indicates the Listener "
"was prematurely destroyed."
)


class Listener:
Expand All @@ -70,26 +96,17 @@ class Listener:
Please use `create_listener()` to create an Listener.
"""

def __init__(self, listener, id, active_clients):
def __init__(self, listener, ident, active_clients):
if not isinstance(listener, ucx_api.UCXListener):
raise ValueError("listener must be an instance of UCXListener")

self._listener = listener

active_clients.add_listener(id)
self._id = id
active_clients.add_listener(ident)
self._ident = ident
self._active_clients = active_clients

def __del__(self):
try:
self._active_clients.remove_listener(self._id)
except RuntimeError:
active_clients = self._active_clients.get_active(self._id)
logger.warning(
f"Listener object is being destroyed, but {active_clients} client "
"handler(s) is(are) still alive. This usually indicates the Listener "
"was prematurely destroyed."
)
weakref.finalize(self, _finalizer, ident, active_clients)

@property
def closed(self):
Expand All @@ -108,7 +125,7 @@ def port(self):

@property
def active_clients(self):
return self._active_clients.get_active(self._id)
return self._active_clients.get_active(self._ident)

def close(self):
"""Closing the listener"""
Expand All @@ -121,19 +138,19 @@ async def _listener_handler_coroutine(
func,
endpoint_error_handling,
exchange_peer_info_timeout,
id,
ident,
active_clients,
):
# def _listener_handler_coroutine(
# conn_request, ctx, func, endpoint_error_handling, id, active_clients
# conn_request, ctx, func, endpoint_error_handling, ident, active_clients
# ):
# We create the Endpoint in five steps:
# 1) Create endpoint from conn_request
# 2) Generate unique IDs to use as tags
# 3) Exchange endpoint info such as tags
# 4) Setup control receive callback
# 5) Execute the listener's callback function
active_clients.inc(id)
active_clients.inc(ident)
endpoint = conn_request

seed = os.urandom(16)
Expand Down Expand Up @@ -186,9 +203,9 @@ async def _listener_handler_coroutine(
else:
func(ep)

active_clients.dec(id)
active_clients.dec(ident)

# Ensure `ep` is destroyed and `__del__` is called
# Ensure no references to `ep` remain to permit garbage collection.
del ep


Expand All @@ -199,7 +216,7 @@ def _listener_handler(
ctx,
endpoint_error_handling,
exchange_peer_info_timeout,
id,
ident,
active_clients,
):
asyncio.run_coroutine_threadsafe(
Expand All @@ -209,7 +226,7 @@ def _listener_handler(
callback_func,
endpoint_error_handling,
exchange_peer_info_timeout,
id,
ident,
active_clients,
),
event_loop,
Expand Down

0 comments on commit f1d98f2

Please sign in to comment.