Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple different snapshot functionality #754

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
from .extensions.amber.serializer import Repr

if TYPE_CHECKING:
from .extensions.base import AbstractSyrupyExtension
from .extensions.base import (
AbstractSyrupyExtension,
SnapshotCollectionStorage,
SnapshotComparator,
SnapshotReporter,
SnapshotSerializer,
)
from .location import PyTestLocation
from .session import SnapshotSession
from .types import (
Expand Down Expand Up @@ -91,6 +97,19 @@ class SnapshotAssertion:
init=False,
default_factory=list,
)
_serializer: Optional["SnapshotSerializer"] = field(init=False, default=None)
_storage: Optional["SnapshotCollectionStorage"] = field(
init=False,
default=None,
)
_comparator: Optional["SnapshotComparator"] = field(
init=False,
default=None,
)
_reporter: Optional["SnapshotReporter"] = field(
init=False,
default=None,
)

def __post_init__(self) -> None:
self.session.register_request(self)
Expand All @@ -106,6 +125,22 @@ def extension(self) -> "AbstractSyrupyExtension":
self._extension = self.__init_extension(self.extension_class)
return self._extension

@property
def serializer(self) -> "SnapshotSerializer":
return self._serializer or self.extension

@property
def storage(self) -> "SnapshotCollectionStorage":
return self._storage or self.extension

@property
def comparator(self) -> "SnapshotComparator":
return self._comparator or self.extension

@property
def reporter(self) -> "SnapshotReporter":
return self._reporter or self.extension

@property
def num_executions(self) -> int:
return int(self._executions)
Expand Down Expand Up @@ -178,7 +213,7 @@ def assert_match(self, data: "SerializableData") -> None:
assert self == data

def _serialize(self, data: "SerializableData") -> "SerializedData":
return self.extension.serialize(
return self.serializer.serialize(
data, exclude=self._exclude, matcher=self.__matcher
)

Expand Down Expand Up @@ -214,7 +249,7 @@ def get_assert_diff(self) -> List[str]:
)
)
if not assertion_result.success:
diff.extend(self.extension.diff_lines(serialized_data, snapshot_data or ""))
diff.extend(self.reporter.diff_lines(serialized_data, snapshot_data or ""))
return diff

def __with_prop(self, prop_name: str, prop_value: Any) -> None:
Expand All @@ -229,6 +264,10 @@ def __call__(
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional["SnapshotIndex"] = None,
serializer: Optional["SnapshotSerializer"] = None,
storage: Optional["SnapshotCollectionStorage"] = None,
reporter: Optional["SnapshotReporter"] = None,
comparator: Optional["SnapshotComparator"] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand All @@ -243,6 +282,14 @@ def __call__(
self.__with_prop("_custom_index", name)
if diff is not None:
self.__with_prop("_snapshot_diff", diff)
if serializer is not None:
self.__with_prop("_serializer", serializer)
if storage is not None:
self.__with_prop("_storage", storage)
if reporter is not None:
self.__with_prop("_reporter", reporter)
if comparator is not None:
self.__with_prop("_comparator", comparator)
return self

def __repr__(self) -> str:
Expand All @@ -252,10 +299,10 @@ def __eq__(self, other: "SerializableData") -> bool:
return self._assert(other)

def _assert(self, data: "SerializableData") -> bool:
snapshot_location = self.extension.get_location(
snapshot_location = self.storage.get_location(
test_location=self.test_location, index=self.index
)
snapshot_name = self.extension.get_snapshot_name(
snapshot_name = self.storage.get_snapshot_name(
test_location=self.test_location, index=self.index
)
snapshot_data: Optional["SerializedData"] = None
Expand All @@ -271,22 +318,22 @@ def _assert(self, data: "SerializableData") -> bool:
snapshot_data_diff, _ = self._recall_data(index=snapshot_diff)
if snapshot_data_diff is None:
raise SnapshotDoesNotExist()
serialized_data = self.extension.diff_snapshots(
serialized_data = self.reporter.diff_snapshots(
serialized_data=serialized_data,
snapshot_data=snapshot_data_diff,
)
matches = (
not tainted
and snapshot_data is not None
and self.extension.matches(
and self.comparator.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
)
)
assertion_success = matches
if not matches:
if self.update_snapshots:
self.session.queue_snapshot_write(
extension=self.extension,
storage=self.storage,
test_location=self.test_location,
data=serialized_data,
index=self.index,
Expand Down Expand Up @@ -327,7 +374,7 @@ def _recall_data(
) -> Tuple[Optional["SerializableData"], bool]:
try:
return (
self.extension.read_snapshot(
self.storage.read_snapshot(
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
Expand Down
23 changes: 11 additions & 12 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Optional,
Set,
Tuple,
Type,
)

import pytest
Expand All @@ -34,7 +33,7 @@

if TYPE_CHECKING:
from .assertion import SnapshotAssertion
from .extensions.base import AbstractSyrupyExtension
from .extensions.base import SnapshotCollectionStorage


@dataclass
Expand All @@ -47,28 +46,28 @@ class SnapshotSession:
# All the selected test items. Will be set to False until the test item is run.
_selected_items: Dict[str, bool] = field(default_factory=dict)
_assertions: List["SnapshotAssertion"] = field(default_factory=list)
_extensions: Dict[str, "AbstractSyrupyExtension"] = field(default_factory=dict)
_extensions: Dict[str, "SnapshotCollectionStorage"] = field(default_factory=dict)

_locations_discovered: DefaultDict[str, Set[Any]] = field(
default_factory=lambda: defaultdict(set)
)

_queued_snapshot_writes: Dict[
Tuple[Type["AbstractSyrupyExtension"], str],
Tuple["SnapshotCollectionStorage", str],
List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]],
] = field(default_factory=dict)

def queue_snapshot_write(
self,
extension: "AbstractSyrupyExtension",
storage: "SnapshotCollectionStorage",
test_location: "PyTestLocation",
data: "SerializedData",
index: "SnapshotIndex",
) -> None:
snapshot_location = extension.get_location(
snapshot_location = storage.get_location(
test_location=test_location, index=index
)
key = (extension.__class__, snapshot_location)
key = (storage, snapshot_location)
queue = self._queued_snapshot_writes.get(key, [])
queue.append((data, test_location, index))
self._queued_snapshot_writes[key] = queue
Expand Down Expand Up @@ -147,12 +146,12 @@ def register_request(self, assertion: "SnapshotAssertion") -> None:
self._assertions.append(assertion)

test_location = assertion.test_location.filepath
extension_class = assertion.extension.__class__
if extension_class not in self._locations_discovered[test_location]:
self._locations_discovered[test_location].add(extension_class)
storage_class = assertion.storage.__class__
if storage_class not in self._locations_discovered[test_location]:
self._locations_discovered[test_location].add(storage_class)
discovered_extensions = {
discovered.location: assertion.extension
for discovered in assertion.extension.discover_snapshots(
discovered.location: assertion.storage
for discovered in assertion.storage.discover_snapshots(
test_location=assertion.test_location
)
if discovered.has_snapshots
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dict({
'fruit': 'orange',
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dict({
'fruit': 'apple',
})
16 changes: 16 additions & 0 deletions tests/syrupy/extensions/test_single_file_amber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from syrupy.extensions.amber import AmberSnapshotExtension
from syrupy.extensions.single_file import (
SingleFileSnapshotExtension,
WriteMode,
)


class SingleTextFileExtension(SingleFileSnapshotExtension):
_write_mode = WriteMode.TEXT


def test_single_file_amber(snapshot):
storage = SingleTextFileExtension()
serializer = AmberSnapshotExtension()
assert {"fruit": "apple"} == snapshot(storage=storage, serializer=serializer)
assert {"fruit": "orange"} == snapshot(storage=storage, serializer=serializer)