diff --git a/testing/src/scenario/context.py b/testing/src/scenario/context.py index 6b3ea2170..8087480f2 100644 --- a/testing/src/scenario/context.py +++ b/testing/src/scenario/context.py @@ -15,11 +15,11 @@ from contextlib import contextmanager from pathlib import Path from typing import ( + Generic, TYPE_CHECKING, Any, Callable, Mapping, - cast, ) import ops @@ -60,7 +60,7 @@ _DEFAULT_JUJU_VERSION = "3.5" -class Manager: +class Manager(Generic[CharmType]): """Context manager to offer test code some runtime charm object introspection. This class should not be instantiated directly: use a :class:`Context` @@ -74,7 +74,7 @@ class Manager: def __init__( self, - ctx: Context, + ctx: Context[CharmType], arg: _Event, state_in: State, ): @@ -85,19 +85,19 @@ def __init__( self._emitted: bool = False self._wrapped_ctx = None - self.ops: Ops | None = None + self.ops: Ops[CharmType] | None = None @property - def charm(self) -> ops.CharmBase: + def charm(self) -> CharmType: """The charm object instantiated by ops to handle the event. The charm is only available during the context manager scope. """ - if not self.ops: + if self.ops is None or self.ops.charm is None: raise RuntimeError( "you should __enter__ this context manager before accessing this", ) - return cast(ops.CharmBase, self.ops.charm) + return self.ops.charm @property def _runner(self): @@ -361,7 +361,7 @@ def action( return _Event(f"{name}_action", action=_Action(name, **kwargs)) -class Context: +class Context(Generic[CharmType]): """Represents a simulated charm's execution context. The main entry point to running a test. It contains: @@ -571,7 +571,7 @@ def _record_status(self, state: State, is_app: bool): else: self.unit_status_history.append(state.unit_status) - def __call__(self, event: _Event, state: State): + def __call__(self, event: _Event, state: State) -> Manager[CharmType]: """Context manager to introspect live charm object before and after the event is emitted. Usage:: diff --git a/testing/src/scenario/ops_main_mock.py b/testing/src/scenario/ops_main_mock.py index fbfea2909..5e4846eba 100644 --- a/testing/src/scenario/ops_main_mock.py +++ b/testing/src/scenario/ops_main_mock.py @@ -6,7 +6,7 @@ import os import pathlib import sys -from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, cast +from typing import TYPE_CHECKING, Any, Generic, Optional, Sequence, Type, cast import ops.charm import ops.framework @@ -22,10 +22,11 @@ from ops.log import setup_root_logging from .errors import BadOwnerPath, NoObserverError +from .state import CharmType if TYPE_CHECKING: # pragma: no cover from .context import Context - from .state import CharmType, State, _CharmSpec, _Event + from .state import State, _CharmSpec, _Event # pyright: reportPrivateUsage=false @@ -82,7 +83,7 @@ def setup_framework( charm_dir: pathlib.Path, state: "State", event: "_Event", - context: "Context", + context: "Context[CharmType]", charm_spec: "_CharmSpec[CharmType]", juju_context: Optional[ops.jujucontext._JujuContext] = None, ): @@ -160,7 +161,7 @@ def setup_charm( def setup( state: "State", event: "_Event", - context: "Context", + context: "Context[CharmType]", charm_spec: "_CharmSpec[CharmType]", juju_context: Optional[ops.jujucontext._JujuContext] = None, ): @@ -180,14 +181,14 @@ def setup( return dispatcher, framework, charm -class Ops: +class Ops(Generic[CharmType]): """Class to manage stepping through ops setup, event emission and framework commit.""" def __init__( self, state: "State", event: "_Event", - context: "Context", + context: "Context[CharmType]", charm_spec: "_CharmSpec[CharmType]", juju_context: Optional[ops.jujucontext._JujuContext] = None, ): @@ -202,7 +203,7 @@ def __init__( # set by setup() self.dispatcher: Optional[_Dispatcher] = None self.framework: Optional[ops.Framework] = None - self.charm: Optional[ops.CharmBase] = None + self.charm: Optional["CharmType"] = None self._has_setup = False self._has_emitted = False