Skip to content

Commit

Permalink
Merge branch 'main' into pyright-test_helpers-1007
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyandrewmeyer authored Sep 26, 2023
2 parents 8c6de94 + 76a5f15 commit dd27ae0
Show file tree
Hide file tree
Showing 11 changed files with 206 additions and 131 deletions.
2 changes: 2 additions & 0 deletions ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
'Network',
'NetworkInterface',
'OpenedPort',
'Port',
'Pod',
'Relation',
'RelationData',
Expand Down Expand Up @@ -272,6 +273,7 @@
NetworkInterface,
OpenedPort,
Pod,
Port,
Relation,
RelationData,
RelationDataAccessError,
Expand Down
5 changes: 3 additions & 2 deletions ops/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import re
import sys
import typing
import warnings
from ast import literal_eval
from importlib.machinery import ModuleSpec
Expand Down Expand Up @@ -117,7 +118,7 @@ def autoimport():
versions.sort(reverse=True)


def _find_all_specs(path):
def _find_all_specs(path: typing.Iterable[str]) -> typing.Iterator[ModuleSpec]:
for sys_dir in path:
if sys_dir == "":
sys_dir = "."
Expand Down Expand Up @@ -192,7 +193,7 @@ def __str__(self):
return f"got {_join_and(sorted(got))}, but missing {_join_and(sorted(exp - got))}"


def _parse_lib(spec):
def _parse_lib(spec: ModuleSpec) -> typing.Optional["_Lib"]:
if spec.origin is None:
# "can't happen"
logger.warning("No origin for %r (no idea why; please report)", spec.name)
Expand Down
70 changes: 60 additions & 10 deletions ops/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def add_secret(self, content: Dict[str, str], *,
return Secret(self._backend, id=id, label=label, content=content)

def open_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'],
port: Optional[int] = None):
port: Optional[int] = None) -> None:
"""Open a port with the given protocol for this unit.
Some behaviour, such as whether the port is opened externally without
Expand All @@ -601,17 +601,25 @@ def open_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'],
`Juju documentation <https://juju.is/docs/sdk/hook-tool#heading--open-port>`__
for more detail.
Use :meth:`set_ports` for a more declarative approach where all of
the ports that should be open are provided in a single call.
Args:
protocol: String representing the protocol; must be one of
'tcp', 'udp', or 'icmp' (lowercase is recommended, but
uppercase is also supported).
port: The port to open. Required for TCP and UDP; not allowed
for ICMP.
Raises:
ModelError: If ``port`` is provided when ``protocol`` is 'icmp'
or ``port`` is not provided when ``protocol`` is 'tcp' or
'udp'.
"""
self._backend.open_port(protocol.lower(), port)

def close_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'],
port: Optional[int] = None):
port: Optional[int] = None) -> None:
"""Close a port with the given protocol for this unit.
Some behaviour, such as whether the port is closed externally without
Expand All @@ -620,23 +628,62 @@ def close_port(self, protocol: typing.Literal['tcp', 'udp', 'icmp'],
`Juju documentation <https://juju.is/docs/sdk/hook-tool#heading--close-port>`__
for more detail.
Use :meth:`set_ports` for a more declarative approach where all
of the ports that should be open are provided in a single call.
For example, ``set_ports()`` will close all open ports.
Args:
protocol: String representing the protocol; must be one of
'tcp', 'udp', or 'icmp' (lowercase is recommended, but
uppercase is also supported).
port: The port to open. Required for TCP and UDP; not allowed
for ICMP.
Raises:
ModelError: If ``port`` is provided when ``protocol`` is 'icmp'
or ``port`` is not provided when ``protocol`` is 'tcp' or
'udp'.
"""
self._backend.close_port(protocol.lower(), port)

def opened_ports(self) -> Set['OpenedPort']:
def opened_ports(self) -> Set['Port']:
"""Return a list of opened ports for this unit."""
return self._backend.opened_ports()

def set_ports(self, *ports: Union[int, 'Port']) -> None:
"""Set the open ports for this unit, closing any others that are open.
Some behaviour, such as whether the port is opened or closed externally without
using Juju's ``expose`` and ``unexpose`` commands, differs between Kubernetes
and machine charms. See the
`Juju documentation <https://juju.is/docs/sdk/hook-tool#heading--networking>`__
for more detail.
Use :meth:`open_port` and :meth:`close_port` to manage ports
individually.
Args:
ports: The ports to open. Provide an int to open a TCP port, or
a :class:`Port` to open a port for another protocol.
"""
# Normalise to get easier comparisons.
existing = {
(port.protocol, port.port)
for port in self._backend.opened_ports()
}
desired = {
('tcp', port) if isinstance(port, int) else (port.protocol, port.port)
for port in ports
}
for protocol, port in existing - desired:
self._backend.close_port(protocol, port)
for protocol, port in desired - existing:
self._backend.open_port(protocol, port)


@dataclasses.dataclass(frozen=True)
class OpenedPort:
"""Represents a port opened by :meth:`Unit.open_port`."""
class Port:
"""Represents a port opened by :meth:`Unit.open_port` or :meth:`Unit.set_ports`."""

protocol: typing.Literal['tcp', 'udp', 'icmp']
"""The IP protocol."""
Expand All @@ -645,6 +692,9 @@ class OpenedPort:
"""The port number. Will be ``None`` if protocol is ``'icmp'``."""


OpenedPort = Port # Alias for backwards compatibility.


class LazyMapping(Mapping[str, str], ABC):
"""Represents a dict that isn't populated until it is accessed.
Expand Down Expand Up @@ -3241,13 +3291,13 @@ def close_port(self, protocol: str, port: Optional[int] = None):
arg = f'{port}/{protocol}' if port is not None else protocol
self._run('close-port', arg)

def opened_ports(self) -> Set[OpenedPort]:
def opened_ports(self) -> Set[Port]:
# We could use "opened-ports --format=json", but it's not really
# structured; it's just an array of strings which are the lines of the
# text output, like ["icmp","8081/udp"]. So it's probably just as
# likely to change as the text output, and doesn't seem any better.
output = typing.cast(str, self._run('opened-ports', return_output=True))
ports: Set[OpenedPort] = set()
ports: Set[Port] = set()
for line in output.splitlines():
line = line.strip()
if not line:
Expand All @@ -3258,9 +3308,9 @@ def opened_ports(self) -> Set[OpenedPort]:
return ports

@classmethod
def _parse_opened_port(cls, port_str: str) -> Optional[OpenedPort]:
def _parse_opened_port(cls, port_str: str) -> Optional[Port]:
if port_str == 'icmp':
return OpenedPort('icmp', None)
return Port('icmp', None)
port_range, slash, protocol = port_str.partition('/')
if not slash or protocol not in ['tcp', 'udp']:
logger.warning('Unexpected opened-ports protocol: %s', port_str)
Expand All @@ -3269,7 +3319,7 @@ def _parse_opened_port(cls, port_str: str) -> Optional[OpenedPort]:
if hyphen:
logger.warning('Ignoring opened-ports port range: %s', port_str)
protocol_lit = typing.cast(typing.Literal['tcp', 'udp'], protocol)
return OpenedPort(protocol_lit, int(port))
return Port(protocol_lit, int(port))


class _ModelBackendValidator:
Expand Down
8 changes: 4 additions & 4 deletions ops/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1951,7 +1951,7 @@ def __init__(self, unit_name: str, meta: charm.CharmMeta, config: 'RawConfig'):
self._planned_units: Optional[int] = None
self._hook_is_running = ''
self._secrets: List[_Secret] = []
self._opened_ports: Set[model.OpenedPort] = set()
self._opened_ports: Set[model.Port] = set()
self._networks: Dict[Tuple[Optional[str], Optional[int]], _NetworkDict] = {}

def _validate_relation_access(self, relation_name: str, relations: List[model.Relation]):
Expand Down Expand Up @@ -2489,14 +2489,14 @@ def secret_remove(self, id: str, *, revision: Optional[int] = None) -> None:
def open_port(self, protocol: str, port: Optional[int] = None):
self._check_protocol_and_port(protocol, port)
protocol_lit = cast(Literal['tcp', 'udp', 'icmp'], protocol)
self._opened_ports.add(model.OpenedPort(protocol_lit, port))
self._opened_ports.add(model.Port(protocol_lit, port))

def close_port(self, protocol: str, port: Optional[int] = None):
self._check_protocol_and_port(protocol, port)
protocol_lit = cast(Literal['tcp', 'udp', 'icmp'], protocol)
self._opened_ports.discard(model.OpenedPort(protocol_lit, port))
self._opened_ports.discard(model.Port(protocol_lit, port))

def opened_ports(self) -> Set[model.OpenedPort]:
def opened_ports(self) -> Set[model.Port]:
return set(self._opened_ports)

def _check_protocol_and_port(self, protocol: str, port: Optional[int]):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ include = ["ops/*.py", "ops/_private/*.py",
"test/test_jujuversion.py",
"test/test_log.py",
"test/test_helpers.py",
"test/test_lib.py",
]
pythonVersion = "3.8" # check no python > 3.8 features are used
pythonPlatform = "All"
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
isort==5.11.4
logassert==7
autopep8==1.6.0
flake8==4.0.1
flake8-docstrings==1.6.0
Expand Down
51 changes: 25 additions & 26 deletions test/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from test.test_helpers import BaseTestCase, fake_script
from unittest.mock import patch

import logassert

import ops
from ops.framework import _BREAKPOINT_WELCOME_MESSAGE, _event_regex
from ops.storage import NoSnapshotError, SQLiteStorage
Expand All @@ -41,13 +39,14 @@ def setUp(self):
patcher = patch('ops.storage.SQLiteStorage.DB_LOCK_TIMEOUT', datetime.timedelta(0))
patcher.start()
self.addCleanup(patcher.stop)
logassert.setup(self, 'ops')

def test_deprecated_init(self):
# For 0.7, this still works, but it is deprecated.
framework = ops.Framework(':memory:', None, None, None)
self.assertLoggedWarning(
"deprecated: Framework now takes a Storage not a path")
with self.assertLogs(level="WARNING") as cm:
framework = ops.Framework(':memory:', None, None, None)
self.assertIn(
"WARNING:ops.framework:deprecated: Framework now takes a Storage not a path",
cm.output)
self.assertIsInstance(framework._storage, SQLiteStorage)

def test_handle_path(self):
Expand Down Expand Up @@ -356,12 +355,7 @@ def _on_foo(self, event):
obs = MyObserver(framework, "1")

framework.observe(pub.foo, obs._on_foo)

self.assertNotLogged("Deferring")
pub.foo.emit(1)
self.assertLogged("Deferring <MyEvent via MyNotifier[1]/foo[1]>.")
self.assertNotLogged("Re-emitting")

framework.reemit()

# Two things being checked here:
Expand All @@ -375,7 +369,6 @@ def _on_foo(self, event):
# we'd get a foo=3).
#
self.assertEqual(obs.seen, ["on_foo:foo=2", "on_foo:foo=2"])
self.assertLoggedDebug("Re-emitting deferred event <MyEvent via MyNotifier[1]/foo[1]>.")

def test_weak_observer(self):
framework = self.create_framework()
Expand Down Expand Up @@ -1487,21 +1480,24 @@ def callback_method(self, event):
@patch('sys.stderr', new_callable=io.StringIO)
class BreakpointTests(BaseTestCase):

def setUp(self):
super().setUp()
logassert.setup(self, 'ops')

def test_ignored(self, fake_stderr):
# It doesn't do anything really unless proper environment is there.
with patch.dict(os.environ):
os.environ.pop('JUJU_DEBUG_AT', None)
framework = self.create_framework()

with patch('pdb.Pdb.set_trace') as mock:
framework.breakpoint()
# We want to verify that there are *no* logs at warning level.
# However, assertNoLogs is Python 3.10+.
try:
with self.assertLogs(level="WARNING"):
framework.breakpoint()
except AssertionError:
pass
else:
self.fail("No warning logs should be generated")
self.assertEqual(mock.call_count, 0)
self.assertEqual(fake_stderr.getvalue(), "")
self.assertNotLoggedWarning("Breakpoint", "skipped")

def test_pdb_properly_called(self, fake_stderr):
# The debugger needs to leave the user in the frame where the breakpoint is executed,
Expand Down Expand Up @@ -1658,17 +1654,20 @@ def test_named_indicated_specifically(self, fake_stderr):

def test_named_indicated_unnamed(self, fake_stderr):
# Some breakpoint was indicated, but the framework call was unnamed
self.check_trace_set('some-breakpoint', None, 0)
self.assertLoggedWarning(
"Breakpoint None skipped",
"not found in the requested breakpoints: {'some-breakpoint'}")
with self.assertLogs(level="WARNING") as cm:
self.check_trace_set('some-breakpoint', None, 0)
self.assertEqual(cm.output, [
"WARNING:ops.framework:Breakpoint None skipped "
"(not found in the requested breakpoints: {'some-breakpoint'})"
])

def test_named_indicated_somethingelse(self, fake_stderr):
# Some breakpoint was indicated, but the framework call was with a different name
self.check_trace_set('some-breakpoint', 'other-name', 0)
self.assertLoggedWarning(
"Breakpoint 'other-name' skipped",
"not found in the requested breakpoints: {'some-breakpoint'}")
with self.assertLogs(level="WARNING") as cm:
self.check_trace_set('some-breakpoint', 'other-name', 0)
self.assertEqual(cm.output, [
"WARNING:ops.framework:Breakpoint 'other-name' skipped "
"(not found in the requested breakpoints: {'some-breakpoint'})"])

def test_named_indicated_ingroup(self, fake_stderr):
# A multiple breakpoint was indicated, and the framework call used a name among those.
Expand Down
Loading

0 comments on commit dd27ae0

Please sign in to comment.