diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index d7d84d15..ee21c8d6 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -614,7 +614,9 @@ def _connect_attempt(self, host, hostip, port, retry): try: self._xid = 0 - read_timeout, connect_timeout = self._connect(host, hostip, port) + read_timeout, connect_timeout = self._connect( + host, hostip, port, timeout=retry.cur_delay + ) read_timeout = read_timeout / 1000.0 connect_timeout = connect_timeout / 1000.0 retry.reset() @@ -685,7 +687,7 @@ def _connect_attempt(self, host, hostip, port, retry): if self._socket is not None: self._socket.close() - def _connect(self, host, hostip, port): + def _connect(self, host, hostip, port, timeout): client = self.client self.logger.info( "Connecting to %s(%s):%s, use_ssl: %r", @@ -705,7 +707,7 @@ def _connect(self, host, hostip, port): with self._socket_error_handling(): self._socket = self.handler.create_connection( address=(hostip, port), - timeout=client._session_timeout / 1000.0, + timeout=timeout, use_ssl=self.client.use_ssl, keyfile=self.client.keyfile, certfile=self.client.certfile, diff --git a/kazoo/retry.py b/kazoo/retry.py index 48b94e15..124b1f26 100644 --- a/kazoo/retry.py +++ b/kazoo/retry.py @@ -109,6 +109,10 @@ def copy(self): obj.retry_exceptions = self.retry_exceptions return obj + @property + def cur_delay(self): + return self._cur_delay + def __call__(self, func, *args, **kwargs): """Call a function with arguments until it completes without throwing a Kazoo exception diff --git a/kazoo/tests/test_connection.py b/kazoo/tests/test_connection.py new file mode 100644 index 00000000..724ae48a --- /dev/null +++ b/kazoo/tests/test_connection.py @@ -0,0 +1,106 @@ +from unittest import mock + +import pytest + +from kazoo import retry +from kazoo.handlers import threading +from kazoo.protocol import connection +from kazoo.protocol import states + + +@mock.patch("kazoo.protocol.connection.ConnectionHandler._expand_client_hosts") +def test_retry_logic(mock_expand): + mock_client = mock.Mock() + mock_client._state = states.KeeperState.CLOSED + mock_client._session_id = None + mock_client._session_passwd = b"\x00" * 16 + mock_client._stopped.is_set.return_value = False + mock_client.handler.timeout_exception = threading.KazooTimeoutError + mock_client.handler.create_connection.side_effect = ( + threading.KazooTimeoutError() + ) + test_retry = retry.KazooRetry( + max_tries=6, + delay=1.0, + backoff=2, + max_delay=30.0, + max_jitter=0.0, + sleep_func=lambda _x: None, + ) + test_cnx = connection.ConnectionHandler( + client=mock_client, + retry_sleeper=test_retry, + ) + mock_expand.return_value = [ + ("a", "1.1.1.1", 2181), + ("b", "2.2.2.2", 2181), + ("c", "3.3.3.3", 2181), + ] + + with pytest.raises(retry.RetryFailedError): + test_retry(test_cnx._connect_loop, test_retry) + + assert mock_client.handler.create_connection.call_args_list[:3] == [ + mock.call( + address=("1.1.1.1", 2181), + timeout=1.0, + use_ssl=mock.ANY, + keyfile=mock.ANY, + certfile=mock.ANY, + ca=mock.ANY, + keyfile_password=mock.ANY, + verify_certs=mock.ANY, + ), + mock.call( + address=("2.2.2.2", 2181), + timeout=1.0, + use_ssl=mock.ANY, + keyfile=mock.ANY, + certfile=mock.ANY, + ca=mock.ANY, + keyfile_password=mock.ANY, + verify_certs=mock.ANY, + ), + mock.call( + address=("3.3.3.3", 2181), + timeout=1.0, + use_ssl=mock.ANY, + keyfile=mock.ANY, + certfile=mock.ANY, + ca=mock.ANY, + keyfile_password=mock.ANY, + verify_certs=mock.ANY, + ), + ], "All hosts are first tried with the lowest timeout value" + assert mock_client.handler.create_connection.call_args_list[-3:] == [ + mock.call( + address=("1.1.1.1", 2181), + timeout=30.0, + use_ssl=mock.ANY, + keyfile=mock.ANY, + certfile=mock.ANY, + ca=mock.ANY, + keyfile_password=mock.ANY, + verify_certs=mock.ANY, + ), + mock.call( + address=("2.2.2.2", 2181), + timeout=30.0, + use_ssl=mock.ANY, + keyfile=mock.ANY, + certfile=mock.ANY, + ca=mock.ANY, + keyfile_password=mock.ANY, + verify_certs=mock.ANY, + ), + mock.call( + address=("3.3.3.3", 2181), + timeout=30.0, + use_ssl=mock.ANY, + keyfile=mock.ANY, + certfile=mock.ANY, + ca=mock.ANY, + keyfile_password=mock.ANY, + verify_certs=mock.ANY, + ), + ], "All hosts are last tried with the highest timeout value"