Skip to content

Commit

Permalink
Rename and refactor ert.shared.port_handler
Browse files Browse the repository at this point in the history
This commit changes `port_handler.py` -> `net_utils.py` and cleans up
the functions in the file. `find_available_port` is now called
`find_available_socket`, and only returns the socket. The caller can use
the socket object to get the attached hostname and port, so this is no
longer explicitly returned.
  • Loading branch information
jonathan-eq committed Sep 3, 2024
1 parent 266007d commit de9d63a
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 135 deletions.
12 changes: 5 additions & 7 deletions src/ert/ensemble_evaluator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID

from ert.shared import find_available_socket
from ert.shared import get_machine_name as ert_shared_get_machine_name
from ert.shared import port_handler

from .evaluator_connection_info import EvaluatorConnectionInfo

Expand Down Expand Up @@ -128,16 +128,16 @@ def __init__(
generate_cert: bool = True,
custom_host: typing.Optional[str] = None,
) -> None:
self.host, self.port, self._socket_handle = port_handler.find_available_port(
self._socket_handle = find_available_socket(
custom_range=custom_port_range, custom_host=custom_host
)
host, port = self._socket_handle.getsockname()
self.protocol = "wss" if generate_cert else "ws"
self.url = f"{self.protocol}://{self.host}:{self.port}"
self.url = f"{self.protocol}://{host}:{port}"
self.client_uri = f"{self.url}/client"
self.dispatch_uri = f"{self.url}/dispatch"

if generate_cert:
cert, key, pw = _generate_certificate(ip_address=self.host)
cert, key, pw = _generate_certificate(host)
else:
cert, key, pw = None, None, None
self.cert = cert
Expand All @@ -151,8 +151,6 @@ def get_socket(self) -> socket.socket:

def get_connection_info(self) -> EvaluatorConnectionInfo:
return EvaluatorConnectionInfo(
self.host,
self.port,
self.url,
self.cert,
self.token,
Expand Down
2 changes: 0 additions & 2 deletions src/ert/ensemble_evaluator/evaluator_connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
class EvaluatorConnectionInfo:
"""Read only server-info"""

host: str
port: int
url: str
cert: Optional[Union[str, bytes]] = None
token: Optional[str] = None
Expand Down
5 changes: 2 additions & 3 deletions src/ert/services/_storage_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ert.logging import STORAGE_LOG_CONFIG
from ert.plugins import ErtPluginContext
from ert.shared import __file__ as ert_shared_path
from ert.shared import port_handler
from ert.shared import find_available_socket
from ert.shared.storage.command import add_parser_options


Expand Down Expand Up @@ -95,8 +95,7 @@ def run_server(args: Optional[argparse.Namespace] = None, debug: bool = False) -
config_args.update(reload=True, reload_dirs=[os.path.dirname(ert_shared_path)])
os.environ["ERT_STORAGE_DEBUG"] = "1"

_, _, sock = port_handler.find_available_port(custom_host=args.host)

sock = find_available_socket(custom_host=args.host)
connection_info = _create_connection_info(sock, authtoken)

# Appropriated from uvicorn.main:run
Expand Down
4 changes: 2 additions & 2 deletions src/ert/shared/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ def ert_share_path() -> str:
return str(Path(spec_origin).parent.parent / "resources")


from .port_handler import get_machine_name
from .net_utils import find_available_socket, get_machine_name

__all__ = ["__version__", "ert_share_path", "get_machine_name"]
__all__ = ["__version__", "ert_share_path", "find_available_socket", "get_machine_name"]
26 changes: 9 additions & 17 deletions src/ert/shared/port_handler.py → src/ert/shared/net_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import random
import socket
from typing import Optional, Tuple
from typing import Optional

from dns import exception, resolver, reversename

Expand Down Expand Up @@ -46,11 +46,11 @@ def get_machine_name() -> str:
return "localhost"


def find_available_port(
def find_available_socket(
custom_host: Optional[str] = None,
custom_range: Optional[range] = None,
will_close_then_reopen_socket: bool = False,
) -> Tuple[str, int, socket.socket]:
) -> socket.socket:
"""
The default and recommended approach here is to return a bound socket to the
caller, requiring the caller to keep the socket-object alive as long as the
Expand All @@ -62,7 +62,7 @@ def find_available_port(
but port is not ready to be re-bound yet, and 2) some other process managed to
bind the port before the original caller gets around to re-bind.
Thus, we expect clients calling find_available_port() to keep the returned
Thus, we expect clients calling find_available_socket() to keep the returned
socket-object alive and open as long as the port is needed. If a socket-object
is passed to other modules like for example a websocket-server, use dup() to
obtain a new Python socket-object bound to the same underlying socket (and hence
Expand All @@ -84,14 +84,10 @@ def find_available_port(
random.shuffle(ports)
for port in ports:
try:
return (
current_host,
port,
_bind_socket(
host=current_host,
port=port,
will_close_then_reopen_socket=will_close_then_reopen_socket,
),
return _bind_socket(
host=current_host,
port=port,
will_close_then_reopen_socket=will_close_then_reopen_socket,
)
except PortAlreadyInUseException:
continue
Expand All @@ -108,7 +104,7 @@ def _bind_socket(

# Setting flags like SO_REUSEADDR and/or SO_REUSEPORT may have
# undesirable side-effects but we allow it if caller insists. Refer to
# comment on find_available_port()
# comment on find_available_socket()
#
# See e.g. https://stackoverflow.com/a/14388707 for an extensive
# explanation of these flags, in particular the part about TIME_WAIT
Expand All @@ -135,10 +131,6 @@ def _bind_socket(
raise OSError(f"Unknown `OSError` while binding port {port}") from err_info


def get_family_for_localhost() -> socket.AddressFamily:
return get_family(_get_ip_address())


def get_family(host: str) -> socket.AddressFamily:
try:
socket.inet_pton(socket.AF_INET6, host)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from urllib.parse import urlparse

from ert.ensemble_evaluator.config import EvaluatorServerConfig


Expand All @@ -13,8 +15,9 @@ def test_load_config(unused_tcp_port):
expected_client_uri = f"{expected_url}/client"
expected_dispatch_uri = f"{expected_url}/dispatch"

assert serv_config.host == expected_host
assert serv_config.port == expected_port
url = urlparse(serv_config.url)
assert url.hostname == expected_host
assert url.port == expected_port
assert serv_config.url == expected_url
assert serv_config.client_uri == expected_client_uri
assert serv_config.dispatch_uri == expected_dispatch_uri
Expand Down
24 changes: 9 additions & 15 deletions tests/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import logging
from http import HTTPStatus
from urllib.parse import urlparse

import pytest
import websockets
from websockets import server
from websockets.exceptions import ConnectionClosedOK

from _ert.events import EEUserCancel, EEUserDone, event_from_json
Expand All @@ -18,8 +19,9 @@ async def process_request(path, request_headers):
if path == "/healthcheck":
return HTTPStatus.OK, {}, b""

async with websockets.server.serve(
handler, ee_config.host, ee_config.port, process_request=process_request
url = urlparse(ee_config.url)
async with server.serve(
handler, url.hostname, url.port, process_request=process_request
):
await set_when_done.wait()

Expand All @@ -36,9 +38,7 @@ async def test_no_connection_established(make_ee_config):


async def test_immediate_stop(unused_tcp_port):
ee_con_info = EvaluatorConnectionInfo(
"127.0.0.1", unused_tcp_port, f"ws://127.0.0.1:{unused_tcp_port}"
)
ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}")

set_when_done = asyncio.Event()

Expand All @@ -59,9 +59,7 @@ async def mock_ws_event_handler(websocket):


async def test_unexpected_close(unused_tcp_port):
ee_con_info = EvaluatorConnectionInfo(
"127.0.0.1", unused_tcp_port, f"ws://127.0.0.1:{unused_tcp_port}"
)
ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}")

set_when_done = asyncio.Event()
socket_closed = asyncio.Event()
Expand Down Expand Up @@ -89,9 +87,7 @@ async def test_that_monitor_track_can_exit_without_terminated_event_from_evaluat
unused_tcp_port, caplog
):
caplog.set_level(logging.ERROR)
ee_con_info = EvaluatorConnectionInfo(
"127.0.0.1", unused_tcp_port, f"ws://127.0.0.1:{unused_tcp_port}"
)
ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}")

set_when_done = asyncio.Event()

Expand Down Expand Up @@ -125,9 +121,7 @@ async def test_that_monitor_can_emit_heartbeats(unused_tcp_port):
exit anytime. A heartbeat is a None event.
If the heartbeat is never sent, this test function will hang and then timeout."""
ee_con_info = EvaluatorConnectionInfo(
"127.0.0.1", unused_tcp_port, f"ws://127.0.0.1:{unused_tcp_port}"
)
ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}")

set_when_done = asyncio.Event()
websocket_server_task = asyncio.create_task(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/services/test_storage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from ert.services import StorageService
from ert.services._storage_main import _create_connection_info
from ert.shared import port_handler
from ert.shared import find_available_socket


def test_create_connection_string():
authtoken = "very_secret_token"
_, _, sock = port_handler.find_available_port()
sock = find_available_socket()

_create_connection_info(sock, authtoken)

Expand Down
Loading

0 comments on commit de9d63a

Please sign in to comment.