diff --git a/kazoo/handlers/gevent.py b/kazoo/handlers/gevent.py index 78d234d9..96ee765d 100644 --- a/kazoo/handlers/gevent.py +++ b/kazoo/handlers/gevent.py @@ -60,6 +60,10 @@ def __init__(self): self._state_change = Semaphore() self._workers = [] + @property + def running(self): + return self._running + class timeout_exception(gevent.Timeout): def __init__(self, msg): gevent.Timeout.__init__(self, exception=msg) diff --git a/kazoo/handlers/threading.py b/kazoo/handlers/threading.py index afd05c56..1ab33491 100644 --- a/kazoo/handlers/threading.py +++ b/kazoo/handlers/threading.py @@ -113,6 +113,10 @@ def __init__(self): self._state_change = threading.Lock() self._workers = [] + @property + def running(self): + return self._running + def _create_thread_worker(self, queue): def _thread_worker(): # pragma: nocover while True: diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index 25173906..bd1b92ef 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -46,20 +46,14 @@ def set(self, value=None): with self._condition: self.value = value self._exception = None - for callback in self._callbacks: - self._handler.completion_queue.put( - functools.partial(callback, self) - ) + self._do_callbacks() self._condition.notify_all() def set_exception(self, exception): """Store the exception. Wake up the waiters.""" with self._condition: self._exception = exception - for callback in self._callbacks: - self._handler.completion_queue.put( - functools.partial(callback, self) - ) + self._do_callbacks() self._condition.notify_all() def get(self, block=True, timeout=None): @@ -102,16 +96,13 @@ def rawlink(self, callback): """Register a callback to call when a value or an exception is set""" with self._condition: - # Are we already set? Dispatch it now - if self.ready(): - self._handler.completion_queue.put( - functools.partial(callback, self) - ) - return - if callback not in self._callbacks: self._callbacks.append(callback) + # Are we already set? Dispatch it now + if self.ready(): + self._do_callbacks() + def unlink(self, callback): """Remove the callback set by :meth:`rawlink`""" with self._condition: @@ -122,6 +113,18 @@ def unlink(self, callback): if callback in self._callbacks: self._callbacks.remove(callback) + def _do_callbacks(self): + """Execute the callbacks that were registered by :meth:`rawlink`. + If the handler is in running state this method only schedules + the calls to be performed by the handler. If it's stopped, + the callbacks are called right away.""" + + for callback in self._callbacks: + if self._handler.running: + self._handler.completion_queue.put( + functools.partial(callback, self)) + else: + functools.partial(callback, self)() def _set_fd_cloexec(fd): flags = fcntl.fcntl(fd, fcntl.F_GETFD) diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py index e22261de..e988fdb1 100644 --- a/kazoo/tests/test_client.py +++ b/kazoo/tests/test_client.py @@ -1154,7 +1154,7 @@ def test_context(self): eq_(self.client.get('/smith')[0], b'32') -class TestCallbacks(unittest.TestCase): +class TestSessionCallbacks(unittest.TestCase): def test_session_callback_states(self): from kazoo.protocol.states import KazooState, KeeperState from kazoo.client import KazooClient @@ -1185,6 +1185,28 @@ def test_session_callback_states(self): eq_(client.state, KazooState.SUSPENDED) +class TestCallbacks(KazooTestCase): + def test_async_result_callbacks_are_always_called(self): + # create a callback object + callback_mock = mock.Mock() + + # simulate waiting for a response + async_result = self.client.handler.async_result() + async_result.rawlink(callback_mock) + + # begin the procedure to stop the client + self.client.stop() + + # the response has just been received; + # this should be on another thread, + # simultaneously with the stop procedure + async_result.set_exception( + Exception("Anything that throws an exception")) + + # with the fix the callback should be called + self.assertGreater(callback_mock.call_count, 0) + + class TestNonChrootClient(KazooTestCase): def test_create(self):