From d8b0a8422658b2acb50f98409c8d32fc95411e51 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Wed, 27 Nov 2024 17:02:18 +1300 Subject: [PATCH] Refactor to use ops._main._Manager. --- testing/src/scenario/_consistency_checker.py | 18 +- testing/src/scenario/_ops_main_mock.py | 386 ++++++++----------- testing/src/scenario/_runtime.py | 169 +------- testing/src/scenario/context.py | 31 +- testing/tests/test_context_on.py | 9 +- testing/tests/test_e2e/test_stored_state.py | 10 +- testing/tests/test_emitted_events_util.py | 10 +- testing/tests/test_runtime.py | 24 +- 8 files changed, 222 insertions(+), 435 deletions(-) diff --git a/testing/src/scenario/_consistency_checker.py b/testing/src/scenario/_consistency_checker.py index 335c9c2ee..a9f253f6d 100644 --- a/testing/src/scenario/_consistency_checker.py +++ b/testing/src/scenario/_consistency_checker.py @@ -39,7 +39,7 @@ ) from .errors import InconsistentScenarioError -from .runtime import logger as scenario_logger +from ._runtime import logger as scenario_logger from .state import ( CharmType, PeerRelation, @@ -170,16 +170,14 @@ def check_event_consistency( errors: List[str] = [] warnings: List[str] = [] + # custom event: can't make assumptions about its name and its semantics + # todo: should we then just skip the other checks? if not event._is_builtin_event(charm_spec): - # This is a custom event - we can't make assumptions about its name and - # semantics. It doesn't really make sense to do checks that are designed - # for relations, workloads, and so on - most likely those will end up - # with false positives. Realistically, we can't know about what the - # requirements for the custom event are (in terms of the state), so we - # skip everything here. Perhaps in the future, custom events could - # optionally include some sort of state metadata that made testing - # consistency possible? - return Results(errors, warnings) + warnings.append( + "this is a custom event; if its name makes it look like a builtin one " + "(e.g. a relation event, or a workload event), you might get some false-negative " + "consistency checks.", + ) if event._is_relation_event: _check_relation_event(charm_spec, event, state, errors, warnings) diff --git a/testing/src/scenario/_ops_main_mock.py b/testing/src/scenario/_ops_main_mock.py index fbfea2909..87f2ec1ab 100644 --- a/testing/src/scenario/_ops_main_mock.py +++ b/testing/src/scenario/_ops_main_mock.py @@ -2,185 +2,101 @@ # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. -import inspect -import os -import pathlib +import dataclasses +import marshal +import re import sys -from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, cast +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Sequence, Set -import ops.charm -import ops.framework +import ops import ops.jujucontext -import ops.model import ops.storage -from ops import CharmBase -# use logger from ops._main so that juju_log will be triggered -from ops._main import CHARM_STATE_FILE, _Dispatcher, _get_event_args +from ops.framework import _event_regex +from ops._main import _Dispatcher, _Manager from ops._main import logger as ops_logger -from ops.charm import CharmMeta -from ops.log import setup_root_logging from .errors import BadOwnerPath, NoObserverError +from .logger import logger as scenario_logger +from .mocking import _MockModelBackend +from .state import DeferredEvent, StoredState if TYPE_CHECKING: # pragma: no cover from .context import Context from .state import CharmType, State, _CharmSpec, _Event -# pyright: reportPrivateUsage=false - - -def _get_owner(root: Any, path: Sequence[str]) -> ops.ObjectEvents: - """Walk path on root to an ObjectEvents instance.""" - obj = root - for step in path: - try: - obj = getattr(obj, step) - except AttributeError: - raise BadOwnerPath( - f"event_owner_path {path!r} invalid: {step!r} leads to nowhere.", - ) - if not isinstance(obj, ops.ObjectEvents): - raise BadOwnerPath( - f"event_owner_path {path!r} invalid: does not lead to " - f"an ObjectEvents instance.", - ) - return obj - - -def _emit_charm_event( - charm: "CharmBase", - event_name: str, - juju_context: ops.jujucontext._JujuContext, - event: Optional["_Event"] = None, -): - """Emits a charm event based on a Juju event name. - - Args: - charm: A charm instance to emit an event from. - event_name: A Juju event name to emit on a charm. - event: Event to emit. - juju_context: Juju context to use for the event. - """ - owner = _get_owner(charm, event.owner_path) if event else charm.on - - try: - event_to_emit = getattr(owner, event_name) - except AttributeError: - ops_logger.debug("Event %s not defined for %s.", event_name, charm) - raise NoObserverError( - f"Cannot fire {event_name!r} on {owner}: " - f"invalid event (not on charm.on).", - ) - - args, kwargs = _get_event_args(charm, event_to_emit, juju_context) - ops_logger.debug("Emitting Juju event %s.", event_name) - event_to_emit.emit(*args, **kwargs) - - -def setup_framework( - charm_dir: pathlib.Path, - state: "State", - event: "_Event", - context: "Context", - charm_spec: "_CharmSpec[CharmType]", - juju_context: Optional[ops.jujucontext._JujuContext] = None, -): - from .mocking import _MockModelBackend - - if juju_context is None: - juju_context = ops.jujucontext._JujuContext.from_dict(os.environ) - model_backend = _MockModelBackend( - state=state, - event=event, - context=context, - charm_spec=charm_spec, - juju_context=juju_context, - ) - setup_root_logging(model_backend, debug=juju_context.debug) - # ops sets sys.excepthook to go to Juju's debug-log, but that's not useful - # in a testing context, so reset it. - sys.excepthook = sys.__excepthook__ - ops_logger.debug( - "Operator Framework %s up and running.", - ops.__version__, - ) - - metadata = (charm_dir / "metadata.yaml").read_text() - actions_meta = charm_dir / "actions.yaml" - if actions_meta.exists(): - actions_metadata = actions_meta.read_text() - else: - actions_metadata = None - - meta = CharmMeta.from_yaml(metadata, actions_metadata) - - # ops >= 2.10 - if inspect.signature(ops.model.Model).parameters.get("broken_relation_id"): - # If we are in a RelationBroken event, we want to know which relation is - # broken within the model, not only in the event's `.relation` attribute. - broken_relation_id = ( - event.relation.id # type: ignore - if event.name.endswith("_relation_broken") - else None - ) - - model = ops.model.Model( - meta, - model_backend, - broken_relation_id=broken_relation_id, - ) - else: - ops_logger.warning( - "It looks like this charm is using an older `ops` version. " - "You may experience weirdness. Please update ops.", - ) - model = ops.model.Model(meta, model_backend) - - charm_state_path = charm_dir / CHARM_STATE_FILE +EVENT_REGEX = re.compile(_event_regex) +STORED_STATE_REGEX = re.compile( + r"((?P.*)\/)?(?P<_data_type_name>\D+)\[(?P.*)\]", +) - # TODO: add use_juju_for_storage support - store = ops.storage.SQLiteStorage(charm_state_path) - framework = ops.Framework(store, charm_dir, meta, model) - framework.set_breakpointhook() - return framework +logger = scenario_logger.getChild("ops_main_mock") - -def setup_charm( - charm_class: Type[ops.CharmBase], framework: ops.Framework, dispatcher: _Dispatcher -): - sig = inspect.signature(charm_class) - sig.bind(framework) # signature check - - charm = charm_class(framework) - dispatcher.ensure_event_links(charm) - return charm - - -def setup( - state: "State", - event: "_Event", - context: "Context", - charm_spec: "_CharmSpec[CharmType]", - juju_context: Optional[ops.jujucontext._JujuContext] = None, -): - """Setup dispatcher, framework and charm objects.""" - charm_class = charm_spec.charm_type - if juju_context is None: - juju_context = ops.jujucontext._JujuContext.from_dict(os.environ) - charm_dir = juju_context.charm_dir - - dispatcher = _Dispatcher(charm_dir, juju_context) - dispatcher.run_any_legacy_hook() - - framework = setup_framework( - charm_dir, state, event, context, charm_spec, juju_context - ) - charm = setup_charm(charm_class, framework, dispatcher) - return dispatcher, framework, charm +# pyright: reportPrivateUsage=false -class Ops: +class UnitStateDB: + """Wraps the unit-state database with convenience methods for adjusting the state.""" + + def __init__(self, underlying_store: ops.storage.SQLiteStorage): + self._db = underlying_store + + def get_stored_states(self) -> FrozenSet["StoredState"]: + """Load any StoredState data structures from the db.""" + db = self._db + stored_states: Set[StoredState] = set() + for handle_path in db.list_snapshots(): + if not EVENT_REGEX.match(handle_path) and ( + match := STORED_STATE_REGEX.match(handle_path) + ): + stored_state_snapshot = db.load_snapshot(handle_path) + kwargs = match.groupdict() + sst = StoredState(content=stored_state_snapshot, **kwargs) + stored_states.add(sst) + + return frozenset(stored_states) + + def get_deferred_events(self) -> List["DeferredEvent"]: + """Load any DeferredEvent data structures from the db.""" + db = self._db + deferred: List[DeferredEvent] = [] + for handle_path in db.list_snapshots(): + if EVENT_REGEX.match(handle_path): + notices = db.notices(handle_path) + for handle, owner, observer in notices: + try: + snapshot_data = db.load_snapshot(handle) + except ops.storage.NoSnapshotError: + snapshot_data: Dict[str, Any] = {} + + event = DeferredEvent( + handle_path=handle, + owner=owner, + observer=observer, + snapshot_data=snapshot_data, + ) + deferred.append(event) + + return deferred + + def apply_state(self, state: "State"): + """Add DeferredEvent and StoredState from this State instance to the storage.""" + db = self._db + for event in state.deferred: + db.save_notice(event.handle_path, event.owner, event.observer) + try: + marshal.dumps(event.snapshot_data) + except ValueError as e: + raise ValueError( + f"unable to save the data for {event}, it must contain only simple types.", + ) from e + db.save_snapshot(event.handle_path, event.snapshot_data) + + for stored_state in state.stored_states: + db.save_snapshot(stored_state._handle_path, stored_state.content) + + +class Ops(_Manager): """Class to manage stepping through ops setup, event emission and framework commit.""" def __init__( @@ -189,81 +105,93 @@ def __init__( event: "_Event", context: "Context", charm_spec: "_CharmSpec[CharmType]", - juju_context: Optional[ops.jujucontext._JujuContext] = None, + juju_context: ops.jujucontext._JujuContext, ): self.state = state self.event = event self.context = context self.charm_spec = charm_spec - if juju_context is None: - juju_context = ops.jujucontext._JujuContext.from_dict(os.environ) - self.juju_context = juju_context - - # set by setup() - self.dispatcher: Optional[_Dispatcher] = None - self.framework: Optional[ops.Framework] = None - self.charm: Optional[ops.CharmBase] = None - - self._has_setup = False - self._has_emitted = False - self._has_committed = False - - def setup(self): - """Setup framework, charm and dispatcher.""" - self._has_setup = True - self.dispatcher, self.framework, self.charm = setup( - self.state, - self.event, - self.context, - self.charm_spec, - self.juju_context, + self.store = None + + model_backend = _MockModelBackend( + state=state, + event=event, + context=context, + charm_spec=charm_spec, + juju_context=juju_context, ) - def emit(self): - """Emit the event on the charm.""" - if not self._has_setup: - raise RuntimeError("should .setup() before you .emit()") - self._has_emitted = True + super().__init__( + self.charm_spec.charm_type, model_backend, juju_context=juju_context + ) - dispatcher = cast(_Dispatcher, self.dispatcher) - charm = cast(CharmBase, self.charm) - framework = cast(ops.Framework, self.framework) + def _load_charm_meta(self): + metadata = (self._charm_root / "metadata.yaml").read_text() + actions_meta = self._charm_root / "actions.yaml" + if actions_meta.exists(): + actions_metadata = actions_meta.read_text() + else: + actions_metadata = None + + return ops.CharmMeta.from_yaml(metadata, actions_metadata) + + def _setup_root_logging(self): + # Ops sets sys.excepthook to go to Juju's debug-log, but that's not + # useful in a testing context, so we reset it here. + super()._setup_root_logging() + sys.excepthook = sys.__excepthook__ + + def _make_storage(self, _: _Dispatcher): + # TODO: add use_juju_for_storage support + # TODO: Pass a charm_state_path that is ':memory:' when appropriate. + charm_state_path = self._charm_root / self._charm_state_path + storage = ops.storage.SQLiteStorage(charm_state_path) + logger.info("Copying input state to storage.") + self.store = UnitStateDB(storage) + self.store.apply_state(self.state) + return storage + + def _get_event_to_emit(self, event_name: str): + owner = ( + self._get_owner(self.charm, self.event.owner_path) + if self.event + else self.charm.on + ) try: - if not dispatcher.is_restricted_context(): - framework.reemit() - - _emit_charm_event( - charm, dispatcher.event_name, self.juju_context, self.event + event_to_emit = getattr(owner, event_name) + except AttributeError: + ops_logger.debug("Event %s not defined for %s.", event_name, self.charm) + raise NoObserverError( + f"Cannot fire {event_name!r} on {owner}: " + f"invalid event (not on charm.on).", ) - - except Exception: - framework.close() - raise - - def commit(self): - """Commit the framework and teardown.""" - if not self._has_emitted: - raise RuntimeError("should .emit() before you .commit()") - - framework = cast(ops.Framework, self.framework) - charm = cast(CharmBase, self.charm) - - # emit collect-status events - ops.charm._evaluate_status(charm) - - self._has_committed = True - - try: - framework.commit() - finally: - framework.close() - - def finalize(self): - """Step through all non-manually-called procedures and run them.""" - if not self._has_setup: - self.setup() - if not self._has_emitted: - self.emit() - if not self._has_committed: - self.commit() + return event_to_emit + + @staticmethod + def _get_owner(root: Any, path: Sequence[str]) -> ops.ObjectEvents: + """Walk path on root to an ObjectEvents instance.""" + obj = root + for step in path: + try: + obj = getattr(obj, step) + except AttributeError: + raise BadOwnerPath( + f"event_owner_path {path!r} invalid: {step!r} leads to nowhere.", + ) + if not isinstance(obj, ops.ObjectEvents): + raise BadOwnerPath( + f"event_owner_path {path!r} invalid: does not lead to " + f"an ObjectEvents instance.", + ) + return obj + + def _close(self): + """Now that we're done processing this event, read the charm state and expose it.""" + logger.info("Copying storage to output state.") + assert self.store is not None + deferred = self.store.get_deferred_events() + stored_state = self.store.get_stored_states() + self.state = dataclasses.replace( + self.state, deferred=deferred, stored_states=stored_state + ) diff --git a/testing/src/scenario/_runtime.py b/testing/src/scenario/_runtime.py index 3ad2fd0a2..889a2d96a 100644 --- a/testing/src/scenario/_runtime.py +++ b/testing/src/scenario/_runtime.py @@ -2,24 +2,17 @@ # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. -"""Test framework runtime.""" - import copy import dataclasses -import marshal -import re import tempfile import typing from contextlib import contextmanager from pathlib import Path from typing import ( TYPE_CHECKING, - Any, Dict, - FrozenSet, List, Optional, - Set, Type, TypeVar, Union, @@ -37,17 +30,13 @@ PreCommitEvent, ) from ops.jujucontext import _JujuContext -from ops.storage import NoSnapshotError, SQLiteStorage -from ops.framework import _event_regex from ops._private.harness import ActionFailed from .errors import NoObserverError, UncaughtCharmError from .logger import logger as scenario_logger from .state import ( - DeferredEvent, PeerRelation, Relation, - StoredState, SubordinateRelation, ) @@ -56,114 +45,10 @@ from .state import CharmType, State, _CharmSpec, _Event logger = scenario_logger.getChild("runtime") -STORED_STATE_REGEX = re.compile( - r"((?P.*)\/)?(?P<_data_type_name>\D+)\[(?P.*)\]", -) -EVENT_REGEX = re.compile(_event_regex) RUNTIME_MODULE = Path(__file__).parent -class UnitStateDB: - """Represents the unit-state.db.""" - - def __init__(self, db_path: Union[Path, str]): - self._db_path = db_path - self._state_file = Path(self._db_path) - - def _open_db(self) -> SQLiteStorage: - """Open the db.""" - return SQLiteStorage(self._state_file) - - def get_stored_states(self) -> FrozenSet["StoredState"]: - """Load any StoredState data structures from the db.""" - - db = self._open_db() - - stored_states: Set[StoredState] = set() - for handle_path in db.list_snapshots(): - if not EVENT_REGEX.match(handle_path) and ( - match := STORED_STATE_REGEX.match(handle_path) - ): - stored_state_snapshot = db.load_snapshot(handle_path) - kwargs = match.groupdict() - sst = StoredState(content=stored_state_snapshot, **kwargs) - stored_states.add(sst) - - db.close() - return frozenset(stored_states) - - def get_deferred_events(self) -> List["DeferredEvent"]: - """Load any DeferredEvent data structures from the db.""" - - db = self._open_db() - - deferred: List[DeferredEvent] = [] - for handle_path in db.list_snapshots(): - if EVENT_REGEX.match(handle_path): - notices = db.notices(handle_path) - for handle, owner, observer in notices: - try: - snapshot_data = db.load_snapshot(handle) - except NoSnapshotError: - snapshot_data: Dict[str, Any] = {} - - event = DeferredEvent( - handle_path=handle, - owner=owner, - observer=observer, - snapshot_data=snapshot_data, - ) - deferred.append(event) - - db.close() - return deferred - - def apply_state(self, state: "State"): - """Add DeferredEvent and StoredState from this State instance to the storage.""" - db = self._open_db() - for event in state.deferred: - db.save_notice(event.handle_path, event.owner, event.observer) - try: - marshal.dumps(event.snapshot_data) - except ValueError as e: - raise ValueError( - f"unable to save the data for {event}, it must contain only simple types.", - ) from e - db.save_snapshot(event.handle_path, event.snapshot_data) - - for stored_state in state.stored_states: - db.save_snapshot(stored_state._handle_path, stored_state.content) - - db.close() - - -class _OpsMainContext: # type: ignore - """Context manager representing ops.main execution context. - - When entered, ops.main sets up everything up until the charm. - When .emit() is called, ops.main proceeds with emitting the event. - When exited, if .emit has not been called manually, it is called automatically. - """ - - def __init__(self): - self._has_emitted = False - - def __enter__(self): - pass - - def emit(self): - """Emit the event. - - Within the test framework, this only requires recording that it was emitted. - """ - self._has_emitted = True - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # noqa: U100 - if not self._has_emitted: - self.emit() - - class Runtime: """Charm runtime wrapper. @@ -199,11 +84,13 @@ def _get_event_env(self, state: "State", event: "_Event", charm_root: Path): "JUJU_MODEL_NAME": state.model.name, "JUJU_MODEL_UUID": state.model.uuid, "JUJU_CHARM_DIR": str(charm_root.absolute()), + # todo consider setting pwd, (python)path } if event._is_action_event and (action := event.action): env.update( { + # TODO: we should check we're doing the right thing here. "JUJU_ACTION_NAME": action.name.replace("_", "-"), "JUJU_ACTION_UUID": action.id, }, @@ -249,7 +136,7 @@ def _get_event_env(self, state: "State", event: "_Event", charm_root: Path): else: logger.warning( "remote unit ID unset; no remote unit data present. " - "Is this a realistic scenario?", + "Is this a realistic scenario?", # TODO: is it? ) if remote_unit_id is not None: @@ -301,17 +188,14 @@ def _get_event_env(self, state: "State", event: "_Event", charm_root: Path): @staticmethod def _wrap(charm_type: Type["CharmType"]) -> Type["CharmType"]: # dark sorcery to work around framework using class attrs to hold on to event sources - # this should only be needed if we call play multiple times on the same runtime. + # todo this should only be needed if we call play multiple times on the same runtime. + # can we avoid it? class WrappedEvents(charm_type.on.__class__): - """The charm's event sources, but wrapped.""" - pass WrappedEvents.__name__ = charm_type.on.__class__.__name__ class WrappedCharm(charm_type): - """The test charm's type, but with events wrapped.""" - on = WrappedEvents() WrappedCharm.__name__ = charm_type.__name__ @@ -388,28 +272,12 @@ def _virtual_charm_root(self): # charm_virtual_root is a tempdir typing.cast(tempfile.TemporaryDirectory, charm_virtual_root).cleanup() # type: ignore - @staticmethod - def _get_state_db(temporary_charm_root: Path): - charm_state_path = temporary_charm_root / ".unit-state.db" - return UnitStateDB(charm_state_path) - - def _initialize_storage(self, state: "State", temporary_charm_root: Path): - """Before we start processing this event, store the relevant parts of State.""" - store = self._get_state_db(temporary_charm_root) - store.apply_state(state) - - def _close_storage(self, state: "State", temporary_charm_root: Path): - """Now that we're done processing this event, read the charm state and expose it.""" - store = self._get_state_db(temporary_charm_root) - deferred = store.get_deferred_events() - stored_state = store.get_stored_states() - return dataclasses.replace(state, deferred=deferred, stored_states=stored_state) - @contextmanager def _exec_ctx(self, ctx: "Context"): """python 3.8 compatibility shim""" with self._virtual_charm_root() as temporary_charm_root: - with _capture_events( + # TODO: allow customising capture_events + with capture_events( include_deferred=ctx.capture_deferred_events, include_framework=ctx.capture_framework_events, ) as captured: @@ -430,6 +298,9 @@ def exec( This will set the environment up and call ops.main(). After that it's up to ops. """ + # todo consider forking out a real subprocess and do the mocking by + # mocking hook tool executables + from ._consistency_checker import check_consistency # avoid cycles check_consistency(state, event, self._charm_spec, self._juju_version) @@ -442,9 +313,6 @@ def exec( logger.info(" - generating virtual charm root") with self._exec_ctx(context) as (temporary_charm_root, captured): - logger.info(" - initializing storage") - self._initialize_storage(state, temporary_charm_root) - logger.info(" - preparing env") env = self._get_event_env( state=state, @@ -453,8 +321,8 @@ def exec( ) juju_context = _JujuContext.from_dict(env) - logger.info(" - Entering ops.main (mocked).") - from .ops_main_mock import Ops # noqa: F811 + logger.info(" - entering ops.main (mocked)") + from ._ops_main_mock import Ops # noqa: F811 try: ops = Ops( @@ -467,13 +335,9 @@ def exec( ), juju_context=juju_context, ) - ops.setup() yield ops - # if the caller did not manually emit or commit: do that. - ops.finalize() - except (NoObserverError, ActionFailed): raise # propagate along except Exception as e: @@ -482,21 +346,18 @@ def exec( ) from e finally: - logger.info(" - Exited ops.main.") - - logger.info(" - closing storage") - output_state = self._close_storage(output_state, temporary_charm_root) + logger.info(" - exited ops.main") context.emitted_events.extend(captured) logger.info("event dispatched. done.") - context._set_output_state(output_state) + context._set_output_state(ops.state) _T = TypeVar("_T", bound=EventBase) @contextmanager -def _capture_events( +def capture_events( *types: Type[EventBase], include_framework: bool = False, include_deferred: bool = True, diff --git a/testing/src/scenario/context.py b/testing/src/scenario/context.py index 6b3ea2170..4063d0e13 100644 --- a/testing/src/scenario/context.py +++ b/testing/src/scenario/context.py @@ -2,12 +2,6 @@ # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. -"""Test Context - -The test `Context` object provides the context of the wider Juju system that the -specific `State` exists in, and the events that can be executed on that `State`. -""" - from __future__ import annotations import functools @@ -19,7 +13,6 @@ Any, Callable, Mapping, - cast, ) import ops @@ -31,7 +24,6 @@ MetadataNotFoundError, ) from .logger import logger as scenario_logger -from .runtime import Runtime from .state import ( CharmType, CheckInfo, @@ -43,12 +35,14 @@ _CharmSpec, _Event, ) +from ._runtime import Runtime if TYPE_CHECKING: # pragma: no cover from ops._private.harness import ExecArgs - from .ops_main_mock import Ops + from ._ops_main_mock import Ops from .state import ( AnyJson, + CharmType, JujuLogLine, RelationBase, State, @@ -83,7 +77,6 @@ def __init__( self._state_in = state_in self._emitted: bool = False - self._wrapped_ctx = None self.ops: Ops | None = None @@ -97,7 +90,7 @@ def charm(self) -> ops.CharmBase: raise RuntimeError( "you should __enter__ this context manager before accessing this", ) - return cast(ops.CharmBase, self.ops.charm) + return self.ops.charm @property def _runner(self): @@ -105,7 +98,8 @@ def _runner(self): def __enter__(self): self._wrapped_ctx = wrapped_ctx = self._runner(self._arg, self._state_in) - self.ops = wrapped_ctx.__enter__() + ops = wrapped_ctx.__enter__() + self.ops = ops return self def run(self) -> State: @@ -115,10 +109,14 @@ def run(self) -> State: """ if self._emitted: raise AlreadyEmittedError("Can only run once.") + if not self.ops: + raise RuntimeError( + "you should __enter__ this context manager before running it", + ) self._emitted = True + self.ops.run() # wrap up Runtime.exec() so that we can gather the output state - assert self._wrapped_ctx is not None self._wrapped_ctx.__exit__(None, None, None) assert self._ctx._output_state is not None @@ -127,7 +125,8 @@ def run(self) -> State: def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # noqa: U100 if not self._emitted: logger.debug( - "user didn't emit the event within the context manager scope. Doing so implicitly upon exit...", + "user didn't emit the event within the context manager scope. " + "Doing so implicitly upon exit...", ) self.run() @@ -662,8 +661,8 @@ def run(self, event: _Event, state: State) -> State: if self.action_results is not None: self.action_results.clear() self._action_failure_message = None - with self._run(event=event, state=state) as manager: - manager.emit() + with self._run(event=event, state=state) as ops: + ops.run() # We know that the output state will have been set by this point, # so let the type checkers know that too. assert self._output_state is not None diff --git a/testing/tests/test_context_on.py b/testing/tests/test_context_on.py index 32759fd49..402de45ce 100644 --- a/testing/tests/test_context_on.py +++ b/testing/tests/test_context_on.py @@ -1,4 +1,5 @@ import copy +import typing import ops import pytest @@ -35,13 +36,13 @@ class ContextCharm(ops.CharmBase): - def __init__(self, framework): + def __init__(self, framework: ops.Framework): super().__init__(framework) - self.observed = [] + self.observed: typing.List[ops.EventBase] = [] for event in self.on.events().values(): framework.observe(event, self._on_event) - def _on_event(self, event): + def _on_event(self, event: ops.EventBase): self.observed.append(event) @@ -60,7 +61,7 @@ def _on_event(self, event): ("leader_elected", ops.LeaderElectedEvent), ], ) -def test_simple_events(event_name, event_kind): +def test_simple_events(event_name: str, event_kind: typing.Type[ops.EventBase]): ctx = scenario.Context(ContextCharm, meta=META, actions=ACTIONS) # These look like: # ctx.run(ctx.on.install(), state) diff --git a/testing/tests/test_e2e/test_stored_state.py b/testing/tests/test_e2e/test_stored_state.py index b4cb7c7a9..1f26e0aaa 100644 --- a/testing/tests/test_e2e/test_stored_state.py +++ b/testing/tests/test_e2e/test_stored_state.py @@ -1,6 +1,6 @@ import pytest -from ops.charm import CharmBase -from ops.framework import Framework + +import ops from ops.framework import StoredState as ops_storedstate from scenario.state import State, StoredState @@ -9,21 +9,21 @@ @pytest.fixture(scope="function") def mycharm(): - class MyCharm(CharmBase): + class MyCharm(ops.CharmBase): META = {"name": "mycharm"} _read = {} _stored = ops_storedstate() _stored2 = ops_storedstate() - def __init__(self, framework: Framework): + def __init__(self, framework: ops.Framework): super().__init__(framework) self._stored.set_default(foo="bar", baz={12: 142}) self._stored2.set_default(foo="bar", baz={12: 142}) for evt in self.on.events().values(): self.framework.observe(evt, self._on_event) - def _on_event(self, event): + def _on_event(self, _: ops.EventBase): self._read["foo"] = self._stored.foo self._read["baz"] = self._stored.baz diff --git a/testing/tests/test_emitted_events_util.py b/testing/tests/test_emitted_events_util.py index f22a69586..0714562f5 100644 --- a/testing/tests/test_emitted_events_util.py +++ b/testing/tests/test_emitted_events_util.py @@ -2,8 +2,8 @@ from ops.framework import CommitEvent, EventBase, EventSource, PreCommitEvent from scenario import State -from scenario.runtime import _capture_events from scenario.state import _Event +from scenario._runtime import capture_events from .helpers import trigger @@ -32,7 +32,7 @@ def _on_foo(self, e): def test_capture_custom_evt_nonspecific_capture_include_fw_evts(): - with _capture_events(include_framework=True) as emitted: + with capture_events(include_framework=True) as emitted: trigger(State(), "start", MyCharm, meta=MyCharm.META) assert len(emitted) == 5 @@ -44,7 +44,7 @@ def test_capture_custom_evt_nonspecific_capture_include_fw_evts(): def test_capture_juju_evt(): - with _capture_events() as emitted: + with capture_events() as emitted: trigger(State(), "start", MyCharm, meta=MyCharm.META) assert len(emitted) == 2 @@ -54,7 +54,7 @@ def test_capture_juju_evt(): def test_capture_deferred_evt(): # todo: this test should pass with ops < 2.1 as well - with _capture_events() as emitted: + with capture_events() as emitted: trigger( State(deferred=[_Event("foo").deferred(handler=MyCharm._on_foo)]), "start", @@ -70,7 +70,7 @@ def test_capture_deferred_evt(): def test_capture_no_deferred_evt(): # todo: this test should pass with ops < 2.1 as well - with _capture_events(include_deferred=False) as emitted: + with capture_events(include_deferred=False) as emitted: trigger( State(deferred=[_Event("foo").deferred(handler=MyCharm._on_foo)]), "start", diff --git a/testing/tests/test_runtime.py b/testing/tests/test_runtime.py index b303fadf8..79e465636 100644 --- a/testing/tests/test_runtime.py +++ b/testing/tests/test_runtime.py @@ -2,28 +2,28 @@ from tempfile import TemporaryDirectory import pytest -from ops.charm import CharmBase, CharmEvents -from ops.framework import EventBase + +import ops from scenario import Context -from scenario.runtime import Runtime, UncaughtCharmError from scenario.state import Relation, State, _CharmSpec, _Event +from scenario._runtime import Runtime, UncaughtCharmError def charm_type(): - class _CharmEvents(CharmEvents): + class _CharmEvents(ops.CharmEvents): pass - class MyCharm(CharmBase): - on = _CharmEvents() + class MyCharm(ops.CharmBase): + on = _CharmEvents() # type: ignore _event = None - def __init__(self, framework): + def __init__(self, framework: ops.Framework): super().__init__(framework) for evt in self.on.events().values(): self.framework.observe(evt, self._catchall) - def _catchall(self, e): + def _catchall(self, e: ops.EventBase): if self._event: return MyCharm._event = e @@ -40,7 +40,7 @@ def test_event_emission(): my_charm_type = charm_type() - class MyEvt(EventBase): + class MyEvt(ops.EventBase): pass my_charm_type.on.define_event("bar", MyEvt) @@ -56,8 +56,8 @@ class MyEvt(EventBase): state=State(), event=_Event("bar"), context=Context(my_charm_type, meta=meta), - ): - pass + ) as manager: + manager.run() assert my_charm_type._event assert isinstance(my_charm_type._event, MyEvt) @@ -109,7 +109,7 @@ def test_env_clean_on_charm_error(): event=_Event("box_relation_changed", relation=rel), context=Context(my_charm_type, meta=meta), ) as manager: - assert manager.juju_context.remote_app_name == remote_name + assert manager._juju_context.remote_app_name == remote_name assert "JUJU_REMOTE_APP" not in os.environ _ = 1 / 0 # raise some error # Ensure that some other error didn't occur (like AssertionError!).