diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 35c301dd..7013514b 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -23,7 +23,13 @@ from .extensions.amber.serializer import Repr if TYPE_CHECKING: - from .extensions.base import AbstractSyrupyExtension + from .extensions.base import ( + AbstractSyrupyExtension, + SnapshotCollectionStorage, + SnapshotComparator, + SnapshotReporter, + SnapshotSerializer, + ) from .location import PyTestLocation from .session import SnapshotSession from .types import ( @@ -91,6 +97,19 @@ class SnapshotAssertion: init=False, default_factory=list, ) + _serializer: Optional["SnapshotSerializer"] = field(init=False, default=None) + _storage: Optional["SnapshotCollectionStorage"] = field( + init=False, + default=None, + ) + _comparator: Optional["SnapshotComparator"] = field( + init=False, + default=None, + ) + _reporter: Optional["SnapshotReporter"] = field( + init=False, + default=None, + ) def __post_init__(self) -> None: self.session.register_request(self) @@ -106,6 +125,22 @@ def extension(self) -> "AbstractSyrupyExtension": self._extension = self.__init_extension(self.extension_class) return self._extension + @property + def serializer(self) -> "SnapshotSerializer": + return self._serializer or self.extension + + @property + def storage(self) -> "SnapshotCollectionStorage": + return self._storage or self.extension + + @property + def comparator(self) -> "SnapshotComparator": + return self._comparator or self.extension + + @property + def reporter(self) -> "SnapshotReporter": + return self._reporter or self.extension + @property def num_executions(self) -> int: return int(self._executions) @@ -178,7 +213,7 @@ def assert_match(self, data: "SerializableData") -> None: assert self == data def _serialize(self, data: "SerializableData") -> "SerializedData": - return self.extension.serialize( + return self.serializer.serialize( data, exclude=self._exclude, matcher=self.__matcher ) @@ -214,7 +249,7 @@ def get_assert_diff(self) -> List[str]: ) ) if not assertion_result.success: - diff.extend(self.extension.diff_lines(serialized_data, snapshot_data or "")) + diff.extend(self.reporter.diff_lines(serialized_data, snapshot_data or "")) return diff def __with_prop(self, prop_name: str, prop_value: Any) -> None: @@ -229,6 +264,10 @@ def __call__( extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, matcher: Optional["PropertyMatcher"] = None, name: Optional["SnapshotIndex"] = None, + serializer: Optional["SnapshotSerializer"] = None, + storage: Optional["SnapshotCollectionStorage"] = None, + reporter: Optional["SnapshotReporter"] = None, + comparator: Optional["SnapshotComparator"] = None, ) -> "SnapshotAssertion": """ Modifies assertion instance options @@ -243,6 +282,14 @@ def __call__( self.__with_prop("_custom_index", name) if diff is not None: self.__with_prop("_snapshot_diff", diff) + if serializer is not None: + self.__with_prop("_serializer", serializer) + if storage is not None: + self.__with_prop("_storage", storage) + if reporter is not None: + self.__with_prop("_reporter", reporter) + if comparator is not None: + self.__with_prop("_comparator", comparator) return self def __repr__(self) -> str: @@ -252,10 +299,10 @@ def __eq__(self, other: "SerializableData") -> bool: return self._assert(other) def _assert(self, data: "SerializableData") -> bool: - snapshot_location = self.extension.get_location( + snapshot_location = self.storage.get_location( test_location=self.test_location, index=self.index ) - snapshot_name = self.extension.get_snapshot_name( + snapshot_name = self.storage.get_snapshot_name( test_location=self.test_location, index=self.index ) snapshot_data: Optional["SerializedData"] = None @@ -271,14 +318,14 @@ def _assert(self, data: "SerializableData") -> bool: snapshot_data_diff, _ = self._recall_data(index=snapshot_diff) if snapshot_data_diff is None: raise SnapshotDoesNotExist() - serialized_data = self.extension.diff_snapshots( + serialized_data = self.reporter.diff_snapshots( serialized_data=serialized_data, snapshot_data=snapshot_data_diff, ) matches = ( not tainted and snapshot_data is not None - and self.extension.matches( + and self.comparator.matches( serialized_data=serialized_data, snapshot_data=snapshot_data ) ) @@ -286,7 +333,7 @@ def _assert(self, data: "SerializableData") -> bool: if not matches: if self.update_snapshots: self.session.queue_snapshot_write( - extension=self.extension, + storage=self.storage, test_location=self.test_location, data=serialized_data, index=self.index, @@ -327,7 +374,7 @@ def _recall_data( ) -> Tuple[Optional["SerializableData"], bool]: try: return ( - self.extension.read_snapshot( + self.storage.read_snapshot( test_location=self.test_location, index=index, session_id=str(id(self.session)), diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 6b612145..11cb65a4 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -14,7 +14,6 @@ Optional, Set, Tuple, - Type, ) import pytest @@ -34,7 +33,7 @@ if TYPE_CHECKING: from .assertion import SnapshotAssertion - from .extensions.base import AbstractSyrupyExtension + from .extensions.base import SnapshotCollectionStorage @dataclass @@ -47,28 +46,28 @@ class SnapshotSession: # All the selected test items. Will be set to False until the test item is run. _selected_items: Dict[str, bool] = field(default_factory=dict) _assertions: List["SnapshotAssertion"] = field(default_factory=list) - _extensions: Dict[str, "AbstractSyrupyExtension"] = field(default_factory=dict) + _extensions: Dict[str, "SnapshotCollectionStorage"] = field(default_factory=dict) _locations_discovered: DefaultDict[str, Set[Any]] = field( default_factory=lambda: defaultdict(set) ) _queued_snapshot_writes: Dict[ - Tuple[Type["AbstractSyrupyExtension"], str], + Tuple["SnapshotCollectionStorage", str], List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]], ] = field(default_factory=dict) def queue_snapshot_write( self, - extension: "AbstractSyrupyExtension", + storage: "SnapshotCollectionStorage", test_location: "PyTestLocation", data: "SerializedData", index: "SnapshotIndex", ) -> None: - snapshot_location = extension.get_location( + snapshot_location = storage.get_location( test_location=test_location, index=index ) - key = (extension.__class__, snapshot_location) + key = (storage, snapshot_location) queue = self._queued_snapshot_writes.get(key, []) queue.append((data, test_location, index)) self._queued_snapshot_writes[key] = queue @@ -147,12 +146,12 @@ def register_request(self, assertion: "SnapshotAssertion") -> None: self._assertions.append(assertion) test_location = assertion.test_location.filepath - extension_class = assertion.extension.__class__ - if extension_class not in self._locations_discovered[test_location]: - self._locations_discovered[test_location].add(extension_class) + storage_class = assertion.storage.__class__ + if storage_class not in self._locations_discovered[test_location]: + self._locations_discovered[test_location].add(storage_class) discovered_extensions = { - discovered.location: assertion.extension - for discovered in assertion.extension.discover_snapshots( + discovered.location: assertion.storage + for discovered in assertion.storage.discover_snapshots( test_location=assertion.test_location ) if discovered.has_snapshots diff --git a/tests/syrupy/extensions/__snapshots__/test_single_file_amber/test_single_file_amber.1.raw b/tests/syrupy/extensions/__snapshots__/test_single_file_amber/test_single_file_amber.1.raw new file mode 100644 index 00000000..07d85950 --- /dev/null +++ b/tests/syrupy/extensions/__snapshots__/test_single_file_amber/test_single_file_amber.1.raw @@ -0,0 +1,3 @@ +dict({ + 'fruit': 'orange', +}) \ No newline at end of file diff --git a/tests/syrupy/extensions/__snapshots__/test_single_file_amber/test_single_file_amber.raw b/tests/syrupy/extensions/__snapshots__/test_single_file_amber/test_single_file_amber.raw new file mode 100644 index 00000000..4b648631 --- /dev/null +++ b/tests/syrupy/extensions/__snapshots__/test_single_file_amber/test_single_file_amber.raw @@ -0,0 +1,3 @@ +dict({ + 'fruit': 'apple', +}) \ No newline at end of file diff --git a/tests/syrupy/extensions/test_single_file_amber.py b/tests/syrupy/extensions/test_single_file_amber.py new file mode 100644 index 00000000..ebf22ede --- /dev/null +++ b/tests/syrupy/extensions/test_single_file_amber.py @@ -0,0 +1,16 @@ +from syrupy.extensions.amber import AmberSnapshotExtension +from syrupy.extensions.single_file import ( + SingleFileSnapshotExtension, + WriteMode, +) + + +class SingleTextFileExtension(SingleFileSnapshotExtension): + _write_mode = WriteMode.TEXT + + +def test_single_file_amber(snapshot): + storage = SingleTextFileExtension() + serializer = AmberSnapshotExtension() + assert {"fruit": "apple"} == snapshot(storage=storage, serializer=serializer) + assert {"fruit": "orange"} == snapshot(storage=storage, serializer=serializer)