diff --git a/pyproject.toml b/pyproject.toml index 33393ad4c..43d596f95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ include = ["ops/*.py", "ops/_private/*.py", "test/test_lib.py", "test/test_model.py", "test/test_testing.py", + "test/test_storage.py", "test/test_charm.py", ] pythonVersion = "3.8" # check no python > 3.8 features are used diff --git a/test/test_storage.py b/test/test_storage.py index fdc32323e..2d58b7532 100755 --- a/test/test_storage.py +++ b/test/test_storage.py @@ -19,19 +19,25 @@ import pathlib import sys import tempfile +import typing +import unittest from test.test_helpers import BaseTestCase, fake_script, fake_script_calls from textwrap import dedent import yaml import ops +import ops.storage class StoragePermutations(abc.ABC): + assertEqual = unittest.TestCase.assertEqual # noqa + assertRaises = unittest.TestCase.assertRaises # noqa + def create_framework(self) -> ops.Framework: """Create a Framework that we can use to test the backend storage.""" - return ops.Framework(self.create_storage(), None, None, None) + return ops.Framework(self.create_storage(), None, None, None) # type: ignore @abc.abstractmethod def create_storage(self) -> ops.storage.SQLiteStorage: @@ -41,16 +47,17 @@ def create_storage(self) -> ops.storage.SQLiteStorage: def test_save_and_load_snapshot(self): f = self.create_framework() - class Sample(ops.Object): + class Sample(ops.StoredStateData): - def __init__(self, parent, key, content): + def __init__(self, parent: ops.Object, key: str, + content: typing.Dict[str, typing.Any]): super().__init__(parent, key) self.content = content def snapshot(self): return {'content': self.content} - def restore(self, snapshot): + def restore(self, snapshot: typing.Dict[str, typing.Any]): self.__dict__.update(snapshot) f.register_type(Sample, None, Sample.handle_kind) @@ -69,20 +76,20 @@ def restore(self, snapshot): del s gc.collect() res = f.load_snapshot(handle) - self.assertEqual(data, res.content) + self.assertEqual(data, res.content) # type: ignore def test_emit_event(self): f = self.create_framework() class Evt(ops.EventBase): - def __init__(self, handle, content): + def __init__(self, handle: ops.Handle, content: typing.Any): super().__init__(handle) self.content = content def snapshot(self): return self.content - def restore(self, content): + def restore(self, content: typing.Any): self.content = content class Events(ops.ObjectEvents): @@ -90,9 +97,9 @@ class Events(ops.ObjectEvents): class Sample(ops.Object): - on = Events() + on = Events() # type: ignore - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: str): super().__init__(parent, key) self.observed_content = None self.framework.observe(self.on.event, self._on_event) @@ -100,6 +107,12 @@ def __init__(self, parent, key): def _on_event(self, event: Evt): self.observed_content = event.content + def snapshot(self) -> typing.Dict[str, typing.Any]: + raise NotImplementedError() + + def restore(self, snapshot: typing.Dict[str, typing.Any]) -> None: + raise NotImplementedError() + s = Sample(f, 'key') f.register_type(Sample, None, Sample.handle_kind) s.on.event.emit('foo') @@ -206,7 +219,7 @@ def create_storage(self): return ops.storage.SQLiteStorage(':memory:') -def setup_juju_backend(test_case, state_file): +def setup_juju_backend(test_case: unittest.TestCase, state_file: pathlib.Path): """Create fake scripts for pretending to be state-set and state-get.""" template_args = { 'executable': str(pathlib.Path(sys.executable).as_posix()), @@ -305,7 +318,7 @@ def test_handles_tuples(self): parsed = yaml.load(raw, Loader=ops.storage._SimpleLoader) self.assertEqual(parsed, (1, 'tuple')) - def assertRefused(self, obj): # noqa: N802 + def assertRefused(self, obj: typing.Any): # noqa: N802 # We shouldn't allow them to be written with self.assertRaises(yaml.representer.RepresenterError): yaml.dump(obj, Dumper=ops.storage._SimpleDumper)