diff --git a/redshift_connector/core.py b/redshift_connector/core.py index ba1a4c5..528f058 100644 --- a/redshift_connector/core.py +++ b/redshift_connector/core.py @@ -656,6 +656,11 @@ def get_calling_module() -> str: self._sock: typing.Optional[typing.BinaryIO] = self._usock.makefile(mode="rwb") if tcp_keepalive: self._usock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + + except socket.timeout as timeout_error: + self._usock.close() + raise OperationalError("connection time out", timeout_error) + except socket.error as e: self._usock.close() raise InterfaceError("communication error", e) diff --git a/test/integration/test_connection.py b/test/integration/test_connection.py index 2f1fd58..0c5195c 100644 --- a/test/integration/test_connection.py +++ b/test/integration/test_connection.py @@ -307,3 +307,9 @@ def test_execute_do_parsing_bind_params_when_exist(mocker, db_kwargs, sql, args) with redshift_connector.connect(**db_kwargs) as conn: conn.cursor().execute(sql, args) assert convert_paramstyle_spy.called + +def test_socket_timeout(db_kwargs): + db_kwargs["timeout"] = 0 + + with pytest.raises(redshift_connector.InterfaceError): + redshift_connector.connect(**db_kwargs) \ No newline at end of file diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index c5c90bf..61c9b92 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -2,6 +2,8 @@ from collections import deque from decimal import Decimal from unittest.mock import patch +import socket +from unittest import mock import pytest # type: ignore @@ -12,6 +14,7 @@ IntegrityError, InterfaceError, ProgrammingError, + OperationalError ) from redshift_connector.config import ( ClientProtocolVersion, @@ -328,3 +331,9 @@ def test_client_os_version_is_not_present(): with patch("platform.platform", side_effect=Exception("not for you")): assert mock_connection.client_os_version == "unknown" + +def test_socket_timeout_error(): + with mock.patch('socket.socket.connect') as mock_socket: + mock_socket.side_effect = (socket.timeout) + with pytest.raises(OperationalError): + Connection(user='mock_user', password='mock_password', host='localhost', port=8080, database='mocked')