Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support expectations in snapshot files #1513

Merged
merged 1 commit into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

logger = logging.getLogger(__name__)

# In the snapshot file, include a header that describe what should be observed in the snapshot.
SNAPSHOT_EXPECTATION_DESCRIPTION = "expectation_description"


@dataclass(frozen=True)
class SnapshotConfiguration:
Expand All @@ -50,6 +53,7 @@ def assert_snapshot_text_equal(
incomparable_strings_replacement_function: Optional[Callable[[str], str]] = None,
additional_sub_directories_for_snapshots: Tuple[str, ...] = (),
additional_header_fields: Optional[Mapping[str, str]] = None,
expectation_description: Optional[str] = None,
) -> None:
"""Similar to assert_plan_snapshot_text_equal(), but with more controls on how the snapshot paths are generated."""
file_path = (
Expand Down Expand Up @@ -81,6 +85,9 @@ def assert_snapshot_text_equal(
if additional_header_fields is not None:
for header_field_name, header_field_value in additional_header_fields.items():
header_lines.append(f"{header_field_name}: {header_field_value}")
if expectation_description is not None:
header_lines.append(f"{SNAPSHOT_EXPECTATION_DESCRIPTION}:")
header_lines.append(indent(expectation_description))
header_lines.append("---")

snapshot_text = "\n".join(header_lines) + "\n" + snapshot_text
Expand Down Expand Up @@ -250,6 +257,7 @@ def assert_plan_snapshot_text_equal(
exclude_line_regex: Optional[str] = None,
incomparable_strings_replacement_function: Optional[Callable[[str], str]] = None,
additional_sub_directories_for_snapshots: Tuple[str, ...] = (),
expectation_description: Optional[str] = None,
) -> None:
"""Checks if the given plan text is equal to the one that's saved for comparison.

Expand All @@ -272,6 +280,7 @@ def assert_plan_snapshot_text_equal(
exclude_line_regex=exclude_line_regex,
incomparable_strings_replacement_function=incomparable_strings_replacement_function,
additional_sub_directories_for_snapshots=additional_sub_directories_for_snapshots,
expectation_description=expectation_description,
)


Expand All @@ -280,6 +289,7 @@ def assert_linkable_element_set_snapshot_equal( # noqa: D103
mf_test_configuration: SnapshotConfiguration,
set_id: str,
linkable_element_set: LinkableElementSet,
expectation_description: Optional[str] = None,
) -> None:
headers = ("Model Join-Path", "Entity Links", "Name", "Time Granularity", "Date Part", "Properties")
rows = []
Expand Down Expand Up @@ -350,17 +360,23 @@ def assert_linkable_element_set_snapshot_equal( # noqa: D103
mf_test_configuration=mf_test_configuration,
snapshot_id=set_id,
snapshot_str=tabulate.tabulate(headers=headers, tabular_data=sorted(rows)),
expectation_description=expectation_description,
)


def assert_spec_set_snapshot_equal( # noqa: D103
request: FixtureRequest, mf_test_configuration: SnapshotConfiguration, set_id: str, spec_set: InstanceSpecSet
request: FixtureRequest,
mf_test_configuration: SnapshotConfiguration,
set_id: str,
spec_set: InstanceSpecSet,
expectation_description: Optional[str] = None,
) -> None:
assert_object_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
obj_id=set_id,
obj=sorted(spec.qualified_name for spec in spec_set.all_specs),
expectation_description=expectation_description,
)


Expand All @@ -369,6 +385,7 @@ def assert_linkable_spec_set_snapshot_equal( # noqa: D103
mf_test_configuration: SnapshotConfiguration,
set_id: str,
spec_set: LinkableSpecSet,
expectation_description: Optional[str] = None,
) -> None:
naming_scheme = ObjectBuilderNamingScheme()
assert_snapshot_text_equal(
Expand All @@ -379,6 +396,7 @@ def assert_linkable_spec_set_snapshot_equal( # noqa: D103
snapshot_text=mf_pformat(sorted(naming_scheme.input_str(spec) for spec in spec_set.as_tuple)),
snapshot_file_extension=".txt",
additional_sub_directories_for_snapshots=(),
expectation_description=expectation_description,
)


Expand All @@ -387,6 +405,7 @@ def assert_object_snapshot_equal( # type: ignore[misc]
mf_test_configuration: SnapshotConfiguration,
obj: Any,
obj_id: str = "result",
expectation_description: Optional[str] = None,
) -> None:
"""For tests to compare large objects, this can be used to snapshot a text representation of the object."""
assert_snapshot_text_equal(
Expand All @@ -396,6 +415,7 @@ def assert_object_snapshot_equal( # type: ignore[misc]
snapshot_id=obj_id,
snapshot_text=mf_pformat(obj),
snapshot_file_extension=".txt",
expectation_description=expectation_description,
)


Expand All @@ -404,6 +424,7 @@ def assert_str_snapshot_equal( # noqa: D103
mf_test_configuration: SnapshotConfiguration,
snapshot_id: str,
snapshot_str: str,
expectation_description: Optional[str] = None,
) -> None:
"""Write / compare a string snapshot."""
assert_snapshot_text_equal(
Expand All @@ -413,4 +434,5 @@ def assert_str_snapshot_equal( # noqa: D103
snapshot_id=snapshot_id,
snapshot_text=snapshot_str,
snapshot_file_extension=".txt",
expectation_description=expectation_description,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
test_name: test_expectation_description
test_filename: test_snapshot.py
docstring:
Tests having a description of the expectation in a snapshot.
expectation_description:
The snapshot should show the 2 as the result.
---
1 + 1 = 2
19 changes: 19 additions & 0 deletions metricflow-semantics/tests_metricflow_semantics/test_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from _pytest.fixtures import FixtureRequest
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.test_helpers.snapshot_helpers import assert_str_snapshot_equal


def test_expectation_description(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
) -> None:
"""Tests having a description of the expectation in a snapshot."""
assert_str_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="result",
snapshot_str="1 + 1 = 2",
expectation_description="The snapshot should show the 2 as the result.",
)
10 changes: 10 additions & 0 deletions tests_metricflow/snapshot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def assert_execution_plan_text_equal( # noqa: D103
mf_test_configuration: MetricFlowTestConfiguration,
sql_client: SqlClient,
execution_plan: ExecutionPlan,
expectation_description: Optional[str] = None,
) -> None:
assert_plan_snapshot_text_equal(
request=request,
Expand All @@ -44,6 +45,7 @@ def assert_execution_plan_text_equal( # noqa: D103
source_schema=mf_test_configuration.mf_source_schema,
),
additional_sub_directories_for_snapshots=(sql_client.sql_engine_type.value,),
expectation_description=expectation_description,
)


Expand All @@ -52,6 +54,7 @@ def assert_dataflow_plan_text_equal( # noqa: D103
mf_test_configuration: MetricFlowTestConfiguration,
dataflow_plan: DataflowPlan,
sql_client: SqlClient,
expectation_description: Optional[str] = None,
) -> None:
assert_plan_snapshot_text_equal(
request=request,
Expand All @@ -60,6 +63,7 @@ def assert_dataflow_plan_text_equal( # noqa: D103
plan_snapshot_text=dataflow_plan.structure_text(),
incomparable_strings_replacement_function=replace_dataset_id_hash,
additional_sub_directories_for_snapshots=(sql_client.sql_engine_type.value,),
expectation_description=expectation_description,
)


Expand All @@ -69,6 +73,7 @@ def assert_object_snapshot_equal( # type: ignore[misc]
obj_id: str,
obj: Any,
sql_client: Optional[SqlClient] = None,
expectation_description: Optional[str] = None,
) -> None:
"""For tests to compare large objects, this can be used to snapshot a text representation of the object."""
if sql_client is not None:
Expand All @@ -82,6 +87,7 @@ def assert_object_snapshot_equal( # type: ignore[misc]
snapshot_text=mf_pformat(obj),
snapshot_file_extension=".txt",
additional_sub_directories_for_snapshots=(sql_client.sql_engine_type.value,) if sql_client else (),
expectation_description=expectation_description,
)


Expand All @@ -91,6 +97,7 @@ def assert_sql_snapshot_equal(
snapshot_id: str,
sql: str,
sql_engine: Optional[SqlEngine] = None,
expectation_description: Optional[str] = None,
) -> None:
"""For tests that generate SQL, use this to write / check snapshots."""
if sql_engine is not None:
Expand All @@ -109,6 +116,7 @@ def assert_sql_snapshot_equal(
exclude_line_regex=_EXCLUDE_TABLE_ALIAS_REGEX,
additional_sub_directories_for_snapshots=(sql_engine.value,) if sql_engine is not None else (),
additional_header_fields={SQL_ENGINE_HEADER_NAME: sql_engine.value} if sql_engine is not None else None,
expectation_description=expectation_description,
)


Expand All @@ -118,6 +126,7 @@ def assert_str_snapshot_equal( # type: ignore[misc]
snapshot_id: str,
snapshot_str: str,
sql_engine: Optional[SqlEngine] = None,
expectation_description: Optional[str] = None,
) -> None:
"""Write / compare a string snapshot."""
if sql_engine is not None:
Expand All @@ -131,4 +140,5 @@ def assert_str_snapshot_equal( # type: ignore[misc]
snapshot_text=snapshot_str,
snapshot_file_extension=".txt",
additional_sub_directories_for_snapshots=(sql_engine.value,) if sql_engine is not None else (),
expectation_description=expectation_description,
)
Loading