diff --git a/tests/test_socket.py b/tests/test_socket.py index f0015b8..3657069 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -146,3 +146,24 @@ def test_client_socket_set_timeout(): conn.close() client_socket.close() server_socket.close() + + +def test_client_socket_not_open_after_failed_open(): + client_socket = TSocket(host="localhost", port=12345) + assert not client_socket.is_open() + + with pytest.raises(TTransportException): + client_socket.open() + + assert not client_socket.is_open() + + +def test_client_socket_not_open_after_close(): + client_socket = TSocket(host="localhost", port=12345) + assert not client_socket.is_open() + + with pytest.raises(TTransportException): + client_socket.open() + client_socket.close() + + assert not client_socket.is_open() diff --git a/thriftpy/transport/socket.py b/thriftpy/transport/socket.py index 51a5283..a40a2c7 100644 --- a/thriftpy/transport/socket.py +++ b/thriftpy/transport/socket.py @@ -99,6 +99,7 @@ def open(self): self.sock.settimeout(self.socket_timeout) except (socket.error, OSError): + self.sock = None raise TTransportException( type=TTransportException.NOT_OPEN, message="Could not connect to %s" % str(addr)) @@ -140,7 +141,7 @@ def close(self): self.sock.close() self.sock = None except (socket.error, OSError): - pass + self.sock = None class TServerSocket(object):