diff --git a/ops/__init__.py b/ops/__init__.py index 1d6e383dc..bb900bf1f 100644 --- a/ops/__init__.py +++ b/ops/__init__.py @@ -138,6 +138,7 @@ 'Network', 'NetworkInterface', 'OpenedPort', + 'Port', 'Pod', 'Relation', 'RelationData', @@ -272,6 +273,7 @@ NetworkInterface, OpenedPort, Pod, + Port, Relation, RelationData, RelationDataAccessError, diff --git a/ops/model.py b/ops/model.py index 5155a96db..1d88eaa34 100644 --- a/ops/model.py +++ b/ops/model.py @@ -592,7 +592,7 @@ def add_secret(self, content: Dict[str, str], *, return Secret(self._backend, id=id, label=label, content=content) def open_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'], - port: Optional[int] = None): + port: Optional[int] = None) -> None: """Open a port with the given protocol for this unit. Some behaviour, such as whether the port is opened externally without @@ -601,17 +601,25 @@ def open_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'], `Juju documentation `__ for more detail. + Use :meth:`set_ports` for a more declarative approach where all of + the ports that should be open are provided in a single call. + Args: protocol: String representing the protocol; must be one of 'tcp', 'udp', or 'icmp' (lowercase is recommended, but uppercase is also supported). port: The port to open. Required for TCP and UDP; not allowed for ICMP. + + Raises: + ModelError: If ``port`` is provided when ``protocol`` is 'icmp' + or ``port`` is not provided when ``protocol`` is 'tcp' or + 'udp'. """ self._backend.open_port(protocol.lower(), port) def close_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'], - port: Optional[int] = None): + port: Optional[int] = None) -> None: """Close a port with the given protocol for this unit. Some behaviour, such as whether the port is closed externally without @@ -620,23 +628,62 @@ def close_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'], `Juju documentation `__ for more detail. + Use :meth:`set_ports` for a more declarative approach where all + of the ports that should be open are provided in a single call. + For example, ``set_ports()`` will close all open ports. + Args: protocol: String representing the protocol; must be one of 'tcp', 'udp', or 'icmp' (lowercase is recommended, but uppercase is also supported). port: The port to open. Required for TCP and UDP; not allowed for ICMP. + + Raises: + ModelError: If ``port`` is provided when ``protocol`` is 'icmp' + or ``port`` is not provided when ``protocol`` is 'tcp' or + 'udp'. """ self._backend.close_port(protocol.lower(), port) - def opened_ports(self) -> Set['OpenedPort']: + def opened_ports(self) -> Set['Port']: """Return a list of opened ports for this unit.""" return self._backend.opened_ports() + def set_ports(self, *ports: Union[int, 'Port']) -> None: + """Set the open ports for this unit, closing any others that are open. + + Some behaviour, such as whether the port is opened or closed externally without + using Juju's ``expose`` and ``unexpose`` commands, differs between Kubernetes + and machine charms. See the + `Juju documentation `__ + for more detail. + + Use :meth:`open_port` and :meth:`close_port` to manage ports + individually. + + Args: + ports: The ports to open. Provide an int to open a TCP port, or + a :class:`Port` to open a port for another protocol. + """ + # Normalise to get easier comparisons. + existing = { + (port.protocol, port.port) + for port in self._backend.opened_ports() + } + desired = { + ('tcp', port) if isinstance(port, int) else (port.protocol, port.port) + for port in ports + } + for protocol, port in existing - desired: + self._backend.close_port(protocol, port) + for protocol, port in desired - existing: + self._backend.open_port(protocol, port) + @dataclasses.dataclass(frozen=True) -class OpenedPort: - """Represents a port opened by :meth:`Unit.open_port`.""" +class Port: + """Represents a port opened by :meth:`Unit.open_port` or :meth:`Unit.set_ports`.""" protocol: typing.Literal['tcp', 'udp', 'icmp'] """The IP protocol.""" @@ -645,6 +692,9 @@ class OpenedPort: """The port number. Will be ``None`` if protocol is ``'icmp'``.""" +OpenedPort = Port # Alias for backwards compatibility. + + class LazyMapping(Mapping[str, str], ABC): """Represents a dict that isn't populated until it is accessed. @@ -3241,13 +3291,13 @@ def close_port(self, protocol: str, port: Optional[int] = None): arg = f'{port}/{protocol}' if port is not None else protocol self._run('close-port', arg) - def opened_ports(self) -> Set[OpenedPort]: + def opened_ports(self) -> Set[Port]: # We could use "opened-ports --format=json", but it's not really # structured; it's just an array of strings which are the lines of the # text output, like ["icmp","8081/udp"]. So it's probably just as # likely to change as the text output, and doesn't seem any better. output = typing.cast(str, self._run('opened-ports', return_output=True)) - ports: Set[OpenedPort] = set() + ports: Set[Port] = set() for line in output.splitlines(): line = line.strip() if not line: @@ -3258,9 +3308,9 @@ def opened_ports(self) -> Set[OpenedPort]: return ports @classmethod - def _parse_opened_port(cls, port_str: str) -> Optional[OpenedPort]: + def _parse_opened_port(cls, port_str: str) -> Optional[Port]: if port_str == 'icmp': - return OpenedPort('icmp', None) + return Port('icmp', None) port_range, slash, protocol = port_str.partition('/') if not slash or protocol not in ['tcp', 'udp']: logger.warning('Unexpected opened-ports protocol: %s', port_str) @@ -3269,7 +3319,7 @@ def _parse_opened_port(cls, port_str: str) -> Optional[OpenedPort]: if hyphen: logger.warning('Ignoring opened-ports port range: %s', port_str) protocol_lit = typing.cast(typing.Literal['tcp', 'udp'], protocol) - return OpenedPort(protocol_lit, int(port)) + return Port(protocol_lit, int(port)) class _ModelBackendValidator: diff --git a/ops/testing.py b/ops/testing.py index 8b6ebb8db..d81acf9c8 100755 --- a/ops/testing.py +++ b/ops/testing.py @@ -1951,7 +1951,7 @@ def __init__(self, unit_name: str, meta: charm.CharmMeta, config: 'RawConfig'): self._planned_units: Optional[int] = None self._hook_is_running = '' self._secrets: List[_Secret] = [] - self._opened_ports: Set[model.OpenedPort] = set() + self._opened_ports: Set[model.Port] = set() self._networks: Dict[Tuple[Optional[str], Optional[int]], _NetworkDict] = {} def _validate_relation_access(self, relation_name: str, relations: List[model.Relation]): @@ -2489,14 +2489,14 @@ def secret_remove(self, id: str, *, revision: Optional[int] = None) -> None: def open_port(self, protocol: str, port: Optional[int] = None): self._check_protocol_and_port(protocol, port) protocol_lit = cast(Literal['tcp', 'udp', 'icmp'], protocol) - self._opened_ports.add(model.OpenedPort(protocol_lit, port)) + self._opened_ports.add(model.Port(protocol_lit, port)) def close_port(self, protocol: str, port: Optional[int] = None): self._check_protocol_and_port(protocol, port) protocol_lit = cast(Literal['tcp', 'udp', 'icmp'], protocol) - self._opened_ports.discard(model.OpenedPort(protocol_lit, port)) + self._opened_ports.discard(model.Port(protocol_lit, port)) - def opened_ports(self) -> Set[model.OpenedPort]: + def opened_ports(self) -> Set[model.Port]: return set(self._opened_ports) def _check_protocol_and_port(self, protocol: str, port: Optional[int]): diff --git a/test/test_model.py b/test/test_model.py index 8b6434702..189181e48 100755 --- a/test/test_model.py +++ b/test/test_model.py @@ -3368,10 +3368,10 @@ def test_opened_ports(self): self.assertIsInstance(ports_set, set) ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) self.assertEqual(len(ports), 2) - self.assertIsInstance(ports[0], ops.OpenedPort) + self.assertIsInstance(ports[0], ops.Port) self.assertEqual(ports[0].protocol, 'icmp') self.assertIsNone(ports[0].port) - self.assertIsInstance(ports[1], ops.OpenedPort) + self.assertIsInstance(ports[1], ops.Port) self.assertEqual(ports[1].protocol, 'tcp') self.assertEqual(ports[1].port, 8080) @@ -3391,10 +3391,10 @@ def test_opened_ports_warnings(self): self.assertIsInstance(ports_set, set) ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) self.assertEqual(len(ports), 2) - self.assertIsInstance(ports[0], ops.OpenedPort) + self.assertIsInstance(ports[0], ops.Port) self.assertEqual(ports[0].protocol, 'tcp') self.assertEqual(ports[0].port, 8080) - self.assertIsInstance(ports[1], ops.OpenedPort) + self.assertIsInstance(ports[1], ops.Port) self.assertEqual(ports[1].protocol, 'udp') self.assertEqual(ports[1].port, 1000) @@ -3402,6 +3402,65 @@ def test_opened_ports_warnings(self): ['opened-ports', ''], ]) + def test_set_ports_all_open(self): + fake_script(self, 'open-port', 'exit 0') + fake_script(self, 'close-port', 'exit 0') + fake_script(self, 'opened-ports', 'exit 0') + self.unit.set_ports(8000, 8025) + calls = fake_script_calls(self, clear=True) + self.assertEqual(calls.pop(0), ['opened-ports', '']) + calls.sort() # We make no guarantee on the order the ports are opened. + self.assertEqual(calls, [ + ['open-port', '8000/tcp'], + ['open-port', '8025/tcp'], + ]) + + def test_set_ports_mixed(self): + # Two open ports, leave one alone and open another one. + fake_script(self, 'open-port', 'exit 0') + fake_script(self, 'close-port', 'exit 0') + fake_script(self, 'opened-ports', 'echo 8025/tcp; echo 8028/tcp') + self.unit.set_ports(ops.Port('udp', 8022), 8028) + self.assertEqual(fake_script_calls(self, clear=True), [ + ['opened-ports', ''], + ['close-port', '8025/tcp'], + ['open-port', '8022/udp'], + ]) + + def test_set_ports_replace(self): + fake_script(self, 'open-port', 'exit 0') + fake_script(self, 'close-port', 'exit 0') + fake_script(self, 'opened-ports', 'echo 8025/tcp; echo 8028/tcp') + self.unit.set_ports(8001, 8002) + calls = fake_script_calls(self, clear=True) + self.assertEqual(calls.pop(0), ['opened-ports', '']) + calls.sort() + self.assertEqual(calls, [ + ['close-port', '8025/tcp'], + ['close-port', '8028/tcp'], + ['open-port', '8001/tcp'], + ['open-port', '8002/tcp'], + ]) + + def test_set_ports_close_all(self): + fake_script(self, 'open-port', 'exit 0') + fake_script(self, 'close-port', 'exit 0') + fake_script(self, 'opened-ports', 'echo 8022/udp') + self.unit.set_ports() + self.assertEqual(fake_script_calls(self, clear=True), [ + ['opened-ports', ''], + ['close-port', '8022/udp'], + ]) + + def test_set_ports_noop(self): + fake_script(self, 'open-port', 'exit 0') + fake_script(self, 'close-port', 'exit 0') + fake_script(self, 'opened-ports', 'echo 8000/tcp') + self.unit.set_ports(ops.Port('tcp', 8000)) + self.assertEqual(fake_script_calls(self, clear=True), [ + ['opened-ports', ''], + ]) + if __name__ == "__main__": unittest.main() diff --git a/test/test_testing.py b/test/test_testing.py index df24292e0..495f704b2 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -4829,13 +4829,12 @@ def test_ports(self): self.assertIsInstance(ports_set, set) ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) self.assertEqual(len(ports), 3) - self.assertIsInstance(ports[0], ops.OpenedPort) + self.assertIsInstance(ports[0], ops.Port) self.assertEqual(ports[0].protocol, 'icmp') self.assertIsNone(ports[0].port) - self.assertIsInstance(ports[1], ops.OpenedPort) self.assertEqual(ports[1].protocol, 'tcp') self.assertEqual(ports[1].port, 8080) - self.assertIsInstance(ports[2], ops.OpenedPort) + self.assertIsInstance(ports[1], ops.Port) self.assertEqual(ports[2].protocol, 'udp') self.assertEqual(ports[2].port, 4000) @@ -4847,7 +4846,7 @@ def test_ports(self): self.assertIsInstance(ports_set, set) ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) self.assertEqual(len(ports), 1) - self.assertIsInstance(ports[0], ops.OpenedPort) + self.assertIsInstance(ports[0], ops.Port) self.assertEqual(ports[0].protocol, 'icmp') self.assertIsNone(ports[0].port)