Skip to content

Commit

Permalink
Add kwargs to AbstractSyrupyExtension classes
Browse files Browse the repository at this point in the history
  • Loading branch information
atharva-2001 committed Sep 25, 2023
1 parent 7730070 commit f5b3c4c
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs
) -> "SerializedData":
"""
Serializes a python object / data structure into a string
Expand Down Expand Up @@ -108,7 +109,7 @@ def is_snapshot_location(self, *, location: str) -> bool:
return location.endswith(self._file_extension)

def discover_snapshots(
self, *, test_location: "PyTestLocation"
self, *, test_location: "PyTestLocation", **kwargs
) -> "SnapshotCollections":
"""
Returns all snapshot collections in test site
Expand Down Expand Up @@ -216,7 +217,7 @@ def delete_snapshots(

@abstractmethod
def _read_snapshot_collection(
self, *, snapshot_location: str
self, *, snapshot_location: str, **kwargs
) -> "SnapshotCollection":
"""
Read the snapshot location and construct a snapshot collection object
Expand All @@ -235,15 +236,15 @@ def _read_snapshot_data_from_location(
@classmethod
@abstractmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection", **kwargs
) -> None:
"""
Adds the snapshot data to the snapshots in collection location
"""
raise NotImplementedError

@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
def dirname(cls, *, test_location: "PyTestLocation", **kwargs) -> str:
test_dir = Path(test_location.filepath).parent
return str(test_dir.joinpath(SNAPSHOT_DIRNAME))

Expand All @@ -259,15 +260,15 @@ class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self, serialized_data: "SerializedData", snapshot_data: "SerializedData", **kwargs
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self, serialized_data: "SerializedData", snapshot_data: "SerializedData", **kwargs
) -> Iterator[str]:
for line in self.__diff_lines(str(snapshot_data), str(serialized_data)):
yield reset(line)
Expand Down Expand Up @@ -407,6 +408,7 @@ def matches(
*,
serialized_data: "SerializableData",
snapshot_data: "SerializableData",
**kwargs
) -> bool:
"""
Compares serialized data and snapshot data and returns
Expand Down

0 comments on commit f5b3c4c

Please sign in to comment.