From 76a5f15178cfb30a4a3a7312b4db487a32e17c3b Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Tue, 26 Sep 2023 12:15:58 +1300 Subject: [PATCH 1/2] feat: add Unit.set_ports() for declarative port opening (#1005) * Make the protocol argument of open_port and close_port optional. In the majority of cases, the protocol is tcp, so make it a little easier and more concise for those cases. * Rename OpenPort to Port. This avoids the impression that the port is already/always opened, rather than just being a port and the instruction is then given via the open/close methods. * Add a Unit.set_ports method to declare which ports should be open. Tests still to be added. This does do a transform to Port() objects that are then disassembled again almost immediately, for the case where ints are provided. We could normalise to a pair of (str, int) instead to avoid that. Not normalising at all ended up with a fairly messy couple of loops instead of simple set operations. * Verify that the set_ports() method works. This can be flakey, because we're working with sets but then asserting that things happen in a specific order. That should get resolved. * Fix flakiness. The opened-ports call should always be first, but after that it depends on the order of items returned from set operations, and we don't care which order they happen in, so compare the sorted items. * Normalise to (str, int) to avoid unnecessary transforms. * Enforce structure on the open_port and close_port methods. Add typing overloads to ensure that type checking will raise issues. Good: * open_port('icmp') * open_port('tcp', 1) * open_port('udp', 2) * open_port(port=3) Bad: * open_port() * open_port('icmp', 1) * open_port('tcp') * open_port('udp') * open_port(4, 'tcp') Also check that nothing bad has been done at runtime, raising an error (ModelError to be consistent with ops.testing) if one of the invalid specifications is used. * Remove the default protocol of tcp, as per discussion. Also some minor adjustments per code review comments. * Remove overload docstrings and satisfy the linter with #noqa instead. * Remove the type hitn overloads. In other repos, there's code like this: Unfortunately, pyright cannot tell that the protocol argument is always 'tcp' in this case, only that the protocol is one of tcp/udp/icmp, and that's too broad to fit the righter overloaded specs, so it complains. This could be handled with changes in the upstream places, but I don't see any simple way to handle it within ops itself. This was a minor improvement and not really needed, plus it's in the two methods that we're wanting to move people away from, so drop this change. * Minor docstring fix. * Update docstring to align with the new open/close port ones. * Remove empty file, presumably a bad git add. * Remove duplicated validation. * Split tests into separate test methods. * Tweak docstring: "&" -> "and" --------- Co-authored-by: Ben Hoyt Co-authored-by: Ben Hoyt --- ops/__init__.py | 2 ++ ops/model.py | 70 +++++++++++++++++++++++++++++++++++++------- ops/testing.py | 8 ++--- test/test_model.py | 67 +++++++++++++++++++++++++++++++++++++++--- test/test_testing.py | 7 ++--- 5 files changed, 132 insertions(+), 22 deletions(-) diff --git a/ops/__init__.py b/ops/__init__.py index 1d6e383dc..bb900bf1f 100644 --- a/ops/__init__.py +++ b/ops/__init__.py @@ -138,6 +138,7 @@ 'Network', 'NetworkInterface', 'OpenedPort', + 'Port', 'Pod', 'Relation', 'RelationData', @@ -272,6 +273,7 @@ NetworkInterface, OpenedPort, Pod, + Port, Relation, RelationData, RelationDataAccessError, diff --git a/ops/model.py b/ops/model.py index 5155a96db..1d88eaa34 100644 --- a/ops/model.py +++ b/ops/model.py @@ -592,7 +592,7 @@ def add_secret(self, content: Dict[str, str], *, return Secret(self._backend, id=id, label=label, content=content) def open_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'], - port: Optional[int] = None): + port: Optional[int] = None) -> None: """Open a port with the given protocol for this unit. Some behaviour, such as whether the port is opened externally without @@ -601,17 +601,25 @@ def open_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'], `Juju documentation `__ for more detail. + Use :meth:`set_ports` for a more declarative approach where all of + the ports that should be open are provided in a single call. + Args: protocol: String representing the protocol; must be one of 'tcp', 'udp', or 'icmp' (lowercase is recommended, but uppercase is also supported). port: The port to open. Required for TCP and UDP; not allowed for ICMP. + + Raises: + ModelError: If ``port`` is provided when ``protocol`` is 'icmp' + or ``port`` is not provided when ``protocol`` is 'tcp' or + 'udp'. """ self._backend.open_port(protocol.lower(), port) def close_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'], - port: Optional[int] = None): + port: Optional[int] = None) -> None: """Close a port with the given protocol for this unit. Some behaviour, such as whether the port is closed externally without @@ -620,23 +628,62 @@ def close_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'], `Juju documentation `__ for more detail. + Use :meth:`set_ports` for a more declarative approach where all + of the ports that should be open are provided in a single call. + For example, ``set_ports()`` will close all open ports. + Args: protocol: String representing the protocol; must be one of 'tcp', 'udp', or 'icmp' (lowercase is recommended, but uppercase is also supported). port: The port to open. Required for TCP and UDP; not allowed for ICMP. + + Raises: + ModelError: If ``port`` is provided when ``protocol`` is 'icmp' + or ``port`` is not provided when ``protocol`` is 'tcp' or + 'udp'. """ self._backend.close_port(protocol.lower(), port) - def opened_ports(self) -> Set['OpenedPort']: + def opened_ports(self) -> Set['Port']: """Return a list of opened ports for this unit.""" return self._backend.opened_ports() + def set_ports(self, *ports: Union[int, 'Port']) -> None: + """Set the open ports for this unit, closing any others that are open. + + Some behaviour, such as whether the port is opened or closed externally without + using Juju's ``expose`` and ``unexpose`` commands, differs between Kubernetes + and machine charms. See the + `Juju documentation `__ + for more detail. + + Use :meth:`open_port` and :meth:`close_port` to manage ports + individually. + + Args: + ports: The ports to open. Provide an int to open a TCP port, or + a :class:`Port` to open a port for another protocol. + """ + # Normalise to get easier comparisons. + existing = { + (port.protocol, port.port) + for port in self._backend.opened_ports() + } + desired = { + ('tcp', port) if isinstance(port, int) else (port.protocol, port.port) + for port in ports + } + for protocol, port in existing - desired: + self._backend.close_port(protocol, port) + for protocol, port in desired - existing: + self._backend.open_port(protocol, port) + @dataclasses.dataclass(frozen=True) -class OpenedPort: - """Represents a port opened by :meth:`Unit.open_port`.""" +class Port: + """Represents a port opened by :meth:`Unit.open_port` or :meth:`Unit.set_ports`.""" protocol: typing.Literal['tcp', 'udp', 'icmp'] """The IP protocol.""" @@ -645,6 +692,9 @@ class OpenedPort: """The port number. Will be ``None`` if protocol is ``'icmp'``.""" +OpenedPort = Port # Alias for backwards compatibility. + + class LazyMapping(Mapping[str, str], ABC): """Represents a dict that isn't populated until it is accessed. @@ -3241,13 +3291,13 @@ def close_port(self, protocol: str, port: Optional[int] = None): arg = f'{port}/{protocol}' if port is not None else protocol self._run('close-port', arg) - def opened_ports(self) -> Set[OpenedPort]: + def opened_ports(self) -> Set[Port]: # We could use "opened-ports --format=json", but it's not really # structured; it's just an array of strings which are the lines of the # text output, like ["icmp","8081/udp"]. So it's probably just as # likely to change as the text output, and doesn't seem any better. output = typing.cast(str, self._run('opened-ports', return_output=True)) - ports: Set[OpenedPort] = set() + ports: Set[Port] = set() for line in output.splitlines(): line = line.strip() if not line: @@ -3258,9 +3308,9 @@ def opened_ports(self) -> Set[OpenedPort]: return ports @classmethod - def _parse_opened_port(cls, port_str: str) -> Optional[OpenedPort]: + def _parse_opened_port(cls, port_str: str) -> Optional[Port]: if port_str == 'icmp': - return OpenedPort('icmp', None) + return Port('icmp', None) port_range, slash, protocol = port_str.partition('/') if not slash or protocol not in ['tcp', 'udp']: logger.warning('Unexpected opened-ports protocol: %s', port_str) @@ -3269,7 +3319,7 @@ def _parse_opened_port(cls, port_str: str) -> Optional[OpenedPort]: if hyphen: logger.warning('Ignoring opened-ports port range: %s', port_str) protocol_lit = typing.cast(typing.Literal['tcp', 'udp'], protocol) - return OpenedPort(protocol_lit, int(port)) + return Port(protocol_lit, int(port)) class _ModelBackendValidator: diff --git a/ops/testing.py b/ops/testing.py index 8b6ebb8db..d81acf9c8 100755 --- a/ops/testing.py +++ b/ops/testing.py @@ -1951,7 +1951,7 @@ def __init__(self, unit_name: str, meta: charm.CharmMeta, config: 'RawConfig'): self._planned_units: Optional[int] = None self._hook_is_running = '' self._secrets: List[_Secret] = [] - self._opened_ports: Set[model.OpenedPort] = set() + self._opened_ports: Set[model.Port] = set() self._networks: Dict[Tuple[Optional[str], Optional[int]], _NetworkDict] = {} def _validate_relation_access(self, relation_name: str, relations: List[model.Relation]): @@ -2489,14 +2489,14 @@ def secret_remove(self, id: str, *, revision: Optional[int] = None) -> None: def open_port(self, protocol: str, port: Optional[int] = None): self._check_protocol_and_port(protocol, port) protocol_lit = cast(Literal['tcp', 'udp', 'icmp'], protocol) - self._opened_ports.add(model.OpenedPort(protocol_lit, port)) + self._opened_ports.add(model.Port(protocol_lit, port)) def close_port(self, protocol: str, port: Optional[int] = None): self._check_protocol_and_port(protocol, port) protocol_lit = cast(Literal['tcp', 'udp', 'icmp'], protocol) - self._opened_ports.discard(model.OpenedPort(protocol_lit, port)) + self._opened_ports.discard(model.Port(protocol_lit, port)) - def opened_ports(self) -> Set[model.OpenedPort]: + def opened_ports(self) -> Set[model.Port]: return set(self._opened_ports) def _check_protocol_and_port(self, protocol: str, port: Optional[int]): diff --git a/test/test_model.py b/test/test_model.py index 8b6434702..189181e48 100755 --- a/test/test_model.py +++ b/test/test_model.py @@ -3368,10 +3368,10 @@ def test_opened_ports(self): self.assertIsInstance(ports_set, set) ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) self.assertEqual(len(ports), 2) - self.assertIsInstance(ports[0], ops.OpenedPort) + self.assertIsInstance(ports[0], ops.Port) self.assertEqual(ports[0].protocol, 'icmp') self.assertIsNone(ports[0].port) - self.assertIsInstance(ports[1], ops.OpenedPort) + self.assertIsInstance(ports[1], ops.Port) self.assertEqual(ports[1].protocol, 'tcp') self.assertEqual(ports[1].port, 8080) @@ -3391,10 +3391,10 @@ def test_opened_ports_warnings(self): self.assertIsInstance(ports_set, set) ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) self.assertEqual(len(ports), 2) - self.assertIsInstance(ports[0], ops.OpenedPort) + self.assertIsInstance(ports[0], ops.Port) self.assertEqual(ports[0].protocol, 'tcp') self.assertEqual(ports[0].port, 8080) - self.assertIsInstance(ports[1], ops.OpenedPort) + self.assertIsInstance(ports[1], ops.Port) self.assertEqual(ports[1].protocol, 'udp') self.assertEqual(ports[1].port, 1000) @@ -3402,6 +3402,65 @@ def test_opened_ports_warnings(self): ['opened-ports', ''], ]) + def test_set_ports_all_open(self): + fake_script(self, 'open-port', 'exit 0') + fake_script(self, 'close-port', 'exit 0') + fake_script(self, 'opened-ports', 'exit 0') + self.unit.set_ports(8000, 8025) + calls = fake_script_calls(self, clear=True) + self.assertEqual(calls.pop(0), ['opened-ports', '']) + calls.sort() # We make no guarantee on the order the ports are opened. + self.assertEqual(calls, [ + ['open-port', '8000/tcp'], + ['open-port', '8025/tcp'], + ]) + + def test_set_ports_mixed(self): + # Two open ports, leave one alone and open another one. + fake_script(self, 'open-port', 'exit 0') + fake_script(self, 'close-port', 'exit 0') + fake_script(self, 'opened-ports', 'echo 8025/tcp; echo 8028/tcp') + self.unit.set_ports(ops.Port('udp', 8022), 8028) + self.assertEqual(fake_script_calls(self, clear=True), [ + ['opened-ports', ''], + ['close-port', '8025/tcp'], + ['open-port', '8022/udp'], + ]) + + def test_set_ports_replace(self): + fake_script(self, 'open-port', 'exit 0') + fake_script(self, 'close-port', 'exit 0') + fake_script(self, 'opened-ports', 'echo 8025/tcp; echo 8028/tcp') + self.unit.set_ports(8001, 8002) + calls = fake_script_calls(self, clear=True) + self.assertEqual(calls.pop(0), ['opened-ports', '']) + calls.sort() + self.assertEqual(calls, [ + ['close-port', '8025/tcp'], + ['close-port', '8028/tcp'], + ['open-port', '8001/tcp'], + ['open-port', '8002/tcp'], + ]) + + def test_set_ports_close_all(self): + fake_script(self, 'open-port', 'exit 0') + fake_script(self, 'close-port', 'exit 0') + fake_script(self, 'opened-ports', 'echo 8022/udp') + self.unit.set_ports() + self.assertEqual(fake_script_calls(self, clear=True), [ + ['opened-ports', ''], + ['close-port', '8022/udp'], + ]) + + def test_set_ports_noop(self): + fake_script(self, 'open-port', 'exit 0') + fake_script(self, 'close-port', 'exit 0') + fake_script(self, 'opened-ports', 'echo 8000/tcp') + self.unit.set_ports(ops.Port('tcp', 8000)) + self.assertEqual(fake_script_calls(self, clear=True), [ + ['opened-ports', ''], + ]) + if __name__ == "__main__": unittest.main() diff --git a/test/test_testing.py b/test/test_testing.py index df24292e0..495f704b2 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -4829,13 +4829,12 @@ def test_ports(self): self.assertIsInstance(ports_set, set) ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) self.assertEqual(len(ports), 3) - self.assertIsInstance(ports[0], ops.OpenedPort) + self.assertIsInstance(ports[0], ops.Port) self.assertEqual(ports[0].protocol, 'icmp') self.assertIsNone(ports[0].port) - self.assertIsInstance(ports[1], ops.OpenedPort) self.assertEqual(ports[1].protocol, 'tcp') self.assertEqual(ports[1].port, 8080) - self.assertIsInstance(ports[2], ops.OpenedPort) + self.assertIsInstance(ports[1], ops.Port) self.assertEqual(ports[2].protocol, 'udp') self.assertEqual(ports[2].port, 4000) @@ -4847,7 +4846,7 @@ def test_ports(self): self.assertIsInstance(ports_set, set) ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) self.assertEqual(len(ports), 1) - self.assertIsInstance(ports[0], ops.OpenedPort) + self.assertIsInstance(ports[0], ops.Port) self.assertEqual(ports[0].protocol, 'icmp') self.assertIsNone(ports[0].port) From b2c4a3eb85020c9acdbf493b9e640f059e084b42 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Tue, 26 Sep 2023 14:03:43 +1300 Subject: [PATCH 2/2] test: add type hinting to test_helpers (#1014) * Add type hints to tests/test_helpers * For fake_script and fake_script_calls, there are a bunch of `type: ignore`s to work around not being able to state the type of `fake_script_path` - I think there will be a few of these through other test modules as well, as hints are added there. I feel it would be cleaner to have the fake script functionality in a class rather than dynamically added to TestCase instances, but [that is a much more substantial change](https://github.com/tonyandrewmeyer/operator/commit/46238eee2d0b39d383f2ea6b3d6cac722be5dddd) Partially addresses #1007 --- pyproject.toml | 1 + test/test_helpers.py | 44 +++++++++++++++++++++++++------------------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 197b45dfc..463870eaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ include = ["ops/*.py", "ops/_private/*.py", "test/test_infra.py", "test/test_jujuversion.py", "test/test_log.py", + "test/test_helpers.py", "test/test_lib.py", ] pythonVersion = "3.8" # check no python > 3.8 features are used diff --git a/test/test_helpers.py b/test/test_helpers.py index 34f951e57..ada79af09 100755 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -17,6 +17,7 @@ import shutil import subprocess import tempfile +import typing import unittest import ops @@ -24,7 +25,7 @@ from ops.storage import SQLiteStorage -def fake_script(test_case, name, content): +def fake_script(test_case: unittest.TestCase, name: str, content: str): if not hasattr(test_case, 'fake_script_path'): fake_script_path = tempfile.mkdtemp('-fake_script') old_path = os.environ["PATH"] @@ -35,40 +36,42 @@ def cleanup(): os.environ['PATH'] = old_path test_case.addCleanup(cleanup) - test_case.fake_script_path = pathlib.Path(fake_script_path) + test_case.fake_script_path = pathlib.Path(fake_script_path) # type: ignore - template_args = { + template_args: typing.Dict[str, str] = { 'name': name, - 'path': test_case.fake_script_path.as_posix(), + 'path': test_case.fake_script_path.as_posix(), # type: ignore 'content': content, } - path = test_case.fake_script_path / name - with path.open('wt') as f: + path: pathlib.Path = test_case.fake_script_path / name # type: ignore + with path.open('wt') as f: # type: ignore # Before executing the provided script, dump the provided arguments in calls.txt. # ASCII 1E is RS 'record separator', and 1C is FS 'file separator', which seem appropriate. - f.write('''#!/bin/sh + f.write( # type: ignore + '''#!/bin/sh {{ printf {name}; printf "\\036%s" "$@"; printf "\\034"; }} >> {path}/calls.txt {content}'''.format_map(template_args)) - os.chmod(str(path), 0o755) + os.chmod(str(path), 0o755) # type: ignore # TODO: this hardcodes the path to bash.exe, which works for now but might # need to be set via environ or something like that. - path.with_suffix(".bat").write_text( + path.with_suffix(".bat").write_text( # type: ignore f'@"C:\\Program Files\\git\\bin\\bash.exe" {path} %*\n') -def fake_script_calls(test_case, clear=False): - calls_file = test_case.fake_script_path / 'calls.txt' - if not calls_file.exists(): +def fake_script_calls(test_case: unittest.TestCase, + clear: bool = False) -> typing.List[typing.List[str]]: + calls_file: pathlib.Path = test_case.fake_script_path / 'calls.txt' # type: ignore + if not calls_file.exists(): # type: ignore return [] # newline and encoding forced to linuxy defaults because on # windows they're written from git-bash - with calls_file.open('r+t', newline='\n', encoding='utf8') as f: - calls = [line.split('\x1e') for line in f.read().split('\x1c')[:-1]] + with calls_file.open('r+t', newline='\n', encoding='utf8') as f: # type: ignore + calls = [line.split('\x1e') for line in f.read().split('\x1c')[:-1]] # type: ignore if clear: - f.truncate(0) - return calls + f.truncate(0) # type: ignore + return calls # type: ignore class FakeScriptTest(unittest.TestCase): @@ -105,7 +108,10 @@ def test_fake_script_clear(self): class BaseTestCase(unittest.TestCase): - def create_framework(self, *, model=None, tmpdir=None): + def create_framework(self, + *, + model: typing.Optional[ops.Model] = None, + tmpdir: typing.Optional[pathlib.Path] = None): """Create a Framework object. By default operate in-memory; pass a temporary directory via the 'tmpdir' @@ -122,8 +128,8 @@ def create_framework(self, *, model=None, tmpdir=None): framework = ops.Framework( SQLiteStorage(data_fpath), charm_dir, - meta=None, - model=model) + meta=model._cache._meta if model else ops.CharmMeta(), + model=model) # type: ignore self.addCleanup(framework.close) return framework