Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add type hints for test_storage #1023

Merged
merged 3 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include = ["ops/*.py", "ops/_private/*.py",
"test/test_lib.py",
"test/test_model.py",
"test/test_testing.py",
"test/test_storage.py",
"test/test_charm.py",
]
pythonVersion = "3.8" # check no python > 3.8 features are used
Expand Down
35 changes: 24 additions & 11 deletions test/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,25 @@
import pathlib
import sys
import tempfile
import typing
import unittest
from test.test_helpers import BaseTestCase, fake_script, fake_script_calls
from textwrap import dedent

import yaml

import ops
import ops.storage


class StoragePermutations(abc.ABC):

assertEqual = unittest.TestCase.assertEqual # noqa
assertRaises = unittest.TestCase.assertRaises # noqa

def create_framework(self) -> ops.Framework:
"""Create a Framework that we can use to test the backend storage."""
return ops.Framework(self.create_storage(), None, None, None)
return ops.Framework(self.create_storage(), None, None, None) # type: ignore

@abc.abstractmethod
def create_storage(self) -> ops.storage.SQLiteStorage:
Expand All @@ -41,16 +47,17 @@ def create_storage(self) -> ops.storage.SQLiteStorage:
def test_save_and_load_snapshot(self):
f = self.create_framework()

class Sample(ops.Object):
class Sample(ops.StoredStateData):

def __init__(self, parent, key, content):
def __init__(self, parent: ops.Object, key: str,
content: typing.Dict[str, typing.Any]):
super().__init__(parent, key)
self.content = content

def snapshot(self):
return {'content': self.content}

def restore(self, snapshot):
def restore(self, snapshot: typing.Dict[str, typing.Any]):
self.__dict__.update(snapshot)

f.register_type(Sample, None, Sample.handle_kind)
Expand All @@ -69,37 +76,43 @@ def restore(self, snapshot):
del s
gc.collect()
res = f.load_snapshot(handle)
self.assertEqual(data, res.content)
self.assertEqual(data, res.content) # type: ignore

def test_emit_event(self):
f = self.create_framework()

class Evt(ops.EventBase):
def __init__(self, handle, content):
def __init__(self, handle: ops.Handle, content: typing.Any):
super().__init__(handle)
self.content = content

def snapshot(self):
return self.content

def restore(self, content):
def restore(self, content: typing.Any):
self.content = content

class Events(ops.ObjectEvents):
event = ops.EventSource(Evt)

class Sample(ops.Object):

on = Events()
on = Events() # type: ignore

def __init__(self, parent, key):
def __init__(self, parent: ops.Object, key: str):
super().__init__(parent, key)
self.observed_content = None
self.framework.observe(self.on.event, self._on_event)

def _on_event(self, event: Evt):
self.observed_content = event.content

def snapshot(self) -> typing.Dict[str, typing.Any]:
raise NotImplementedError()

def restore(self, snapshot: typing.Dict[str, typing.Any]) -> None:
raise NotImplementedError()

s = Sample(f, 'key')
f.register_type(Sample, None, Sample.handle_kind)
s.on.event.emit('foo')
Expand Down Expand Up @@ -206,7 +219,7 @@ def create_storage(self):
return ops.storage.SQLiteStorage(':memory:')


def setup_juju_backend(test_case, state_file):
def setup_juju_backend(test_case: unittest.TestCase, state_file: pathlib.Path):
"""Create fake scripts for pretending to be state-set and state-get."""
template_args = {
'executable': str(pathlib.Path(sys.executable).as_posix()),
Expand Down Expand Up @@ -305,7 +318,7 @@ def test_handles_tuples(self):
parsed = yaml.load(raw, Loader=ops.storage._SimpleLoader)
self.assertEqual(parsed, (1, 'tuple'))

def assertRefused(self, obj): # noqa: N802
def assertRefused(self, obj: typing.Any): # noqa: N802
# We shouldn't allow them to be written
with self.assertRaises(yaml.representer.RepresenterError):
yaml.dump(obj, Dumper=ops.storage._SimpleDumper)
Expand Down