Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support passing arbitrary arguments/context to custom extensions (Issue #700) #814

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __call__(
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional["SnapshotIndex"] = None,
**kwargs: Any,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried about name collisions by just forwarding arbitrary arguments, vs. something like:

assert "hi" == snapshot(extra={...})

I know it doesn't read as nicely, but you can always add syntactic sugar via a fixture:

@pytest.fixture
def snapshot(named_arg: str):
    return snapshot.with_defaults(extra={ "named_arg": named_arg })

def test_case(snapshot):
    assert "hi" == snapshot(named_arg="hi")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the argument to both __call__ and with_defaults method. I added it to the __call__ argument because I wanted it to work with cases when I want to change the snapshot behaviour but don't want it to persist.

) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand Down
6 changes: 4 additions & 2 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def delete_snapshots(
else:
Path(snapshot_location).unlink()

def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
def _read_snapshot_collection(
self, snapshot_location: str, **kwargs: Any
) -> "SnapshotCollection":
return self.serializer_class.read_file(snapshot_location)

@classmethod
Expand All @@ -72,7 +74,7 @@ def _read_snapshot_data_from_location(

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any
) -> None:
cls.serializer_class.write_file(snapshot_collection, merge=True)

Expand Down
21 changes: 15 additions & 6 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
Expand Down Expand Up @@ -67,6 +68,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
"""
Serializes a python object / data structure into a string
Expand Down Expand Up @@ -108,7 +110,7 @@ def is_snapshot_location(self, *, location: str) -> bool:
return location.endswith(self._file_extension)

def discover_snapshots(
self, *, test_location: "PyTestLocation"
self, *, test_location: "PyTestLocation", **kwargs: Any
) -> "SnapshotCollections":
"""
Returns all snapshot collections in test site
Expand Down Expand Up @@ -216,7 +218,7 @@ def delete_snapshots(

@abstractmethod
def _read_snapshot_collection(
self, *, snapshot_location: str
self, *, snapshot_location: str, **kwargs: Any
) -> "SnapshotCollection":
"""
Read the snapshot location and construct a snapshot collection object
Expand All @@ -235,15 +237,15 @@ def _read_snapshot_data_from_location(
@classmethod
@abstractmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any
) -> None:
"""
Adds the snapshot data to the snapshots in collection location
"""
raise NotImplementedError

@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Any) -> str:
test_dir = Path(test_location.filepath).parent
return str(test_dir.joinpath(SNAPSHOT_DIRNAME))

Expand All @@ -259,15 +261,21 @@ class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
**kwargs: Any,
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
**kwargs: Any,
) -> Iterator[str]:
for line in self.__diff_lines(str(snapshot_data), str(serialized_data)):
yield reset(line)
Expand Down Expand Up @@ -407,6 +415,7 @@ def matches(
*,
serialized_data: "SerializableData",
snapshot_data: "SerializableData",
**kwargs: Any,
) -> bool:
"""
Compares serialized data and snapshot data and returns
Expand Down
1 change: 1 addition & 0 deletions src/syrupy/extensions/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
data = self._filter(
data=data,
Expand Down
14 changes: 11 additions & 3 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Optional,
Set,
Type,
Expand Down Expand Up @@ -49,6 +50,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
return self.get_supported_dataclass()(data)

Expand All @@ -74,12 +76,15 @@ def _get_file_basename(
return cls.get_snapshot_name(test_location=test_location, index=index)

@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Any) -> str:
original_dirname = AbstractSyrupyExtension.dirname(test_location=test_location)
return str(Path(original_dirname).joinpath(test_location.basename))

def _read_snapshot_collection(
self, *, snapshot_location: str
self,
*,
snapshot_location: str,
**kwargs: Any,
) -> "SnapshotCollection":
file_ext_len = len(self._file_extension) + 1 if self._file_extension else 0
filename_wo_ext = snapshot_location[:-file_ext_len]
Expand Down Expand Up @@ -116,7 +121,10 @@ def get_write_encoding(cls) -> Optional[str]:

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls,
*,
snapshot_collection: "SnapshotCollection",
**kwargs: Any,
) -> None:
filepath, data = (
snapshot_collection.location,
Expand Down