Skip to content

Commit

Permalink
test: add type hints for test_storage (#1023)
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyandrewmeyer authored Oct 2, 2023
1 parent 31cd829 commit ecc8052
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include = ["ops/*.py", "ops/_private/*.py",
"test/test_lib.py",
"test/test_model.py",
"test/test_testing.py",
"test/test_storage.py",
"test/test_charm.py",
]
pythonVersion = "3.8" # check no python > 3.8 features are used
Expand Down
35 changes: 24 additions & 11 deletions test/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,25 @@
import pathlib
import sys
import tempfile
import typing
import unittest
from test.test_helpers import BaseTestCase, fake_script, fake_script_calls
from textwrap import dedent

import yaml

import ops
import ops.storage


class StoragePermutations(abc.ABC):

assertEqual = unittest.TestCase.assertEqual # noqa
assertRaises = unittest.TestCase.assertRaises # noqa

def create_framework(self) -> ops.Framework:
"""Create a Framework that we can use to test the backend storage."""
return ops.Framework(self.create_storage(), None, None, None)
return ops.Framework(self.create_storage(), None, None, None) # type: ignore

@abc.abstractmethod
def create_storage(self) -> ops.storage.SQLiteStorage:
Expand All @@ -41,16 +47,17 @@ def create_storage(self) -> ops.storage.SQLiteStorage:
def test_save_and_load_snapshot(self):
f = self.create_framework()

class Sample(ops.Object):
class Sample(ops.StoredStateData):

def __init__(self, parent, key, content):
def __init__(self, parent: ops.Object, key: str,
content: typing.Dict[str, typing.Any]):
super().__init__(parent, key)
self.content = content

def snapshot(self):
return {'content': self.content}

def restore(self, snapshot):
def restore(self, snapshot: typing.Dict[str, typing.Any]):
self.__dict__.update(snapshot)

f.register_type(Sample, None, Sample.handle_kind)
Expand All @@ -69,37 +76,43 @@ def restore(self, snapshot):
del s
gc.collect()
res = f.load_snapshot(handle)
self.assertEqual(data, res.content)
self.assertEqual(data, res.content) # type: ignore

def test_emit_event(self):
f = self.create_framework()

class Evt(ops.EventBase):
def __init__(self, handle, content):
def __init__(self, handle: ops.Handle, content: typing.Any):
super().__init__(handle)
self.content = content

def snapshot(self):
return self.content

def restore(self, content):
def restore(self, content: typing.Any):
self.content = content

class Events(ops.ObjectEvents):
event = ops.EventSource(Evt)

class Sample(ops.Object):

on = Events()
on = Events() # type: ignore

def __init__(self, parent, key):
def __init__(self, parent: ops.Object, key: str):
super().__init__(parent, key)
self.observed_content = None
self.framework.observe(self.on.event, self._on_event)

def _on_event(self, event: Evt):
self.observed_content = event.content

def snapshot(self) -> typing.Dict[str, typing.Any]:
raise NotImplementedError()

def restore(self, snapshot: typing.Dict[str, typing.Any]) -> None:
raise NotImplementedError()

s = Sample(f, 'key')
f.register_type(Sample, None, Sample.handle_kind)
s.on.event.emit('foo')
Expand Down Expand Up @@ -206,7 +219,7 @@ def create_storage(self):
return ops.storage.SQLiteStorage(':memory:')


def setup_juju_backend(test_case, state_file):
def setup_juju_backend(test_case: unittest.TestCase, state_file: pathlib.Path):
"""Create fake scripts for pretending to be state-set and state-get."""
template_args = {
'executable': str(pathlib.Path(sys.executable).as_posix()),
Expand Down Expand Up @@ -305,7 +318,7 @@ def test_handles_tuples(self):
parsed = yaml.load(raw, Loader=ops.storage._SimpleLoader)
self.assertEqual(parsed, (1, 'tuple'))

def assertRefused(self, obj): # noqa: N802
def assertRefused(self, obj: typing.Any): # noqa: N802
# We shouldn't allow them to be written
with self.assertRaises(yaml.representer.RepresenterError):
yaml.dump(obj, Dumper=ops.storage._SimpleDumper)
Expand Down

0 comments on commit ecc8052

Please sign in to comment.