diff --git a/cms/io/rpc.py b/cms/io/rpc.py index 9ec358b05a..bf6b09f97c 100644 --- a/cms/io/rpc.py +++ b/cms/io/rpc.py @@ -5,6 +5,7 @@ # Copyright © 2010-2018 Stefano Maggiolo # Copyright © 2010-2012 Matteo Boscariol # Copyright © 2013-2017 Luca Wehrstedt +# Copyright © 2019 Edoardo Morassutto # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -478,13 +479,28 @@ def _connect(self): """ try: - sock = gevent.socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(self.remote_address) - except OSError as error: - logger.debug("Couldn't connect to %s: %s.", - self._repr_remote(), error) - else: - self.initialize(sock, self.remote_service_coord) + # Try to resolve the address, this can lead to many possible + # addresses, we'll try all of them. + addresses = gevent.socket.getaddrinfo( + self.remote_address.ip, + self.remote_address.port, + type=socket.SOCK_STREAM) + except socket.gaierror: + logger.warning("Cannot resolve %s.", self.remote_address) + raise + + for family, type, proto, _canonname, sockaddr in addresses: + try: + host, port, *rest = sockaddr + logger.debug("Trying to connect to %s at port %d.", host, port) + sock = gevent.socket.socket(family, type, proto) + sock.connect(sockaddr) + except OSError as error: + logger.debug("Couldn't connect to %s at %s port %d: %s.", + self._repr_remote(), host, port, error) + else: + self.initialize(sock, self.remote_service_coord) + break def _run(self): """Maintain the connection up, if required. diff --git a/cms/io/service.py b/cms/io/service.py index 03ce17deb5..3ba62aa81a 100644 --- a/cms/io/service.py +++ b/cms/io/service.py @@ -5,6 +5,7 @@ # Copyright © 2010-2016 Stefano Maggiolo # Copyright © 2010-2012 Matteo Boscariol # Copyright © 2013 Luca Wehrstedt +# Copyright © 2019 Edoardo Morassutto # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -156,13 +157,7 @@ def _connection_handler(self, sock, address): connection. """ - try: - ipaddr, port = address - ipaddr = gevent.socket.gethostbyname(ipaddr) - address = Address(ipaddr, port) - except OSError: - logger.warning("Unexpected error.", exc_info=True) - return + address = Address(address[0], address[1]) remote_service = RemoteServiceServer(self, address) remote_service.handle(sock) diff --git a/cmstestsuite/unit_tests/io/rpc_test.py b/cmstestsuite/unit_tests/io/rpc_test.py index eab99069e2..ab22ec8a2b 100755 --- a/cmstestsuite/unit_tests/io/rpc_test.py +++ b/cmstestsuite/unit_tests/io/rpc_test.py @@ -3,6 +3,7 @@ # Contest Management System - http://cms-dev.github.io/ # Copyright © 2014-2017 Luca Wehrstedt # Copyright © 2015 Stefano Maggiolo +# Copyright © 2019 Edoardo Morassutto # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -21,6 +22,7 @@ """ +import socket import unittest from unittest.mock import Mock, patch @@ -231,7 +233,14 @@ def test_method_return_list(self): self.assertEqual(result.value, ["Hello", 42, "World"]) @patch("cms.io.rpc.gevent.socket.socket") - def test_background_connect(self, socket_mock): + @patch("cms.io.rpc.gevent.socket.getaddrinfo") + def test_background_connect(self, getaddrinfo_mock, socket_mock): + # Calling getaddrinfo breaks the mocking of the socket, so it is mocked + # as well. It returns the addresses associated with the service, in + # this case just one. + getaddrinfo_mock.return_value = [ + (gevent.socket.AF_INET, gevent.socket.SOCK_STREAM, 6, "", + (self.host, self.port))] # Patch the connect method of sockets so that it blocks until # we set the done_event (we will do so at the end of the test). connect_mock = socket_mock.return_value.connect @@ -255,7 +264,9 @@ def test_background_connect(self, socket_mock): # event triggered. done_event.set() gevent.sleep() - connect_mock.assert_called_once_with(Address(self.host, self.port)) + getaddrinfo_mock.assert_called_once_with(self.host, self.port, + type=socket.SOCK_STREAM) + connect_mock.assert_called_once_with((self.host, self.port)) def test_autoreconnect1(self): client = self.get_client(ServiceCoord("Foo", 0), auto_retry=0.002)