Skip to content

Commit

Permalink
/* PR_START p--misc-improvements 03 */ Add snapshot expectation descr…
Browse files Browse the repository at this point in the history
…iption.
  • Loading branch information
plypaul committed Nov 7, 2024
1 parent 63c86ce commit 33132e0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

logger = logging.getLogger(__name__)

# In the snapshot file, include a header that describe what should be observed in the snapshot.
SNAPSHOT_EXPECTATION_DESCRIPTION = "expectation_description"


@dataclass(frozen=True)
class SnapshotConfiguration:
Expand All @@ -50,6 +53,7 @@ def assert_snapshot_text_equal(
incomparable_strings_replacement_function: Optional[Callable[[str], str]] = None,
additional_sub_directories_for_snapshots: Tuple[str, ...] = (),
additional_header_fields: Optional[Mapping[str, str]] = None,
expectation_description: Optional[str] = None,
) -> None:
"""Similar to assert_plan_snapshot_text_equal(), but with more controls on how the snapshot paths are generated."""
file_path = (
Expand Down Expand Up @@ -81,6 +85,9 @@ def assert_snapshot_text_equal(
if additional_header_fields is not None:
for header_field_name, header_field_value in additional_header_fields.items():
header_lines.append(f"{header_field_name}: {header_field_value}")
if expectation_description is not None:
header_lines.append(f"{SNAPSHOT_EXPECTATION_DESCRIPTION}:")
header_lines.append(indent(expectation_description))
header_lines.append("\n")

snapshot_text = "\n".join(header_lines) + snapshot_text
Expand Down Expand Up @@ -250,6 +257,7 @@ def assert_plan_snapshot_text_equal(
exclude_line_regex: Optional[str] = None,
incomparable_strings_replacement_function: Optional[Callable[[str], str]] = None,
additional_sub_directories_for_snapshots: Tuple[str, ...] = (),
expectation_description: Optional[str] = None,
) -> None:
"""Checks if the given plan text is equal to the one that's saved for comparison.
Expand All @@ -272,6 +280,7 @@ def assert_plan_snapshot_text_equal(
exclude_line_regex=exclude_line_regex,
incomparable_strings_replacement_function=incomparable_strings_replacement_function,
additional_sub_directories_for_snapshots=additional_sub_directories_for_snapshots,
expectation_description=expectation_description,
)


Expand All @@ -280,6 +289,7 @@ def assert_linkable_element_set_snapshot_equal( # noqa: D103
mf_test_configuration: SnapshotConfiguration,
set_id: str,
linkable_element_set: LinkableElementSet,
expectation_description: Optional[str] = None,
) -> None:
headers = ("Model Join-Path", "Entity Links", "Name", "Time Granularity", "Date Part", "Properties")
rows = []
Expand Down Expand Up @@ -350,17 +360,23 @@ def assert_linkable_element_set_snapshot_equal( # noqa: D103
mf_test_configuration=mf_test_configuration,
snapshot_id=set_id,
snapshot_str=tabulate.tabulate(headers=headers, tabular_data=sorted(rows)),
expectation_description=expectation_description,
)


def assert_spec_set_snapshot_equal( # noqa: D103
request: FixtureRequest, mf_test_configuration: SnapshotConfiguration, set_id: str, spec_set: InstanceSpecSet
request: FixtureRequest,
mf_test_configuration: SnapshotConfiguration,
set_id: str,
spec_set: InstanceSpecSet,
expectation_description: Optional[str] = None,
) -> None:
assert_object_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
obj_id=set_id,
obj=sorted(spec.qualified_name for spec in spec_set.all_specs),
expectation_description=expectation_description,
)


Expand All @@ -369,6 +385,7 @@ def assert_linkable_spec_set_snapshot_equal( # noqa: D103
mf_test_configuration: SnapshotConfiguration,
set_id: str,
spec_set: LinkableSpecSet,
expectation_description: Optional[str] = None,
) -> None:
naming_scheme = ObjectBuilderNamingScheme()
assert_snapshot_text_equal(
Expand All @@ -379,6 +396,7 @@ def assert_linkable_spec_set_snapshot_equal( # noqa: D103
snapshot_text=mf_pformat(sorted(naming_scheme.input_str(spec) for spec in spec_set.as_tuple)),
snapshot_file_extension=".txt",
additional_sub_directories_for_snapshots=(),
expectation_description=expectation_description,
)


Expand All @@ -387,6 +405,7 @@ def assert_object_snapshot_equal( # type: ignore[misc]
mf_test_configuration: SnapshotConfiguration,
obj: Any,
obj_id: str = "result",
expectation_description: Optional[str] = None,
) -> None:
"""For tests to compare large objects, this can be used to snapshot a text representation of the object."""
assert_snapshot_text_equal(
Expand All @@ -396,6 +415,7 @@ def assert_object_snapshot_equal( # type: ignore[misc]
snapshot_id=obj_id,
snapshot_text=mf_pformat(obj),
snapshot_file_extension=".txt",
expectation_description=expectation_description,
)


Expand All @@ -404,6 +424,7 @@ def assert_str_snapshot_equal( # noqa: D103
mf_test_configuration: SnapshotConfiguration,
snapshot_id: str,
snapshot_str: str,
expectation_description: Optional[str] = None,
) -> None:
"""Write / compare a string snapshot."""
assert_snapshot_text_equal(
Expand All @@ -413,4 +434,5 @@ def assert_str_snapshot_equal( # noqa: D103
snapshot_id=snapshot_id,
snapshot_text=snapshot_str,
snapshot_file_extension=".txt",
expectation_description=expectation_description,
)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from _pytest.fixtures import FixtureRequest
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.test_helpers.snapshot_helpers import assert_str_snapshot_equal


def test_expectation_description(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
) -> None:
"""Tests having a description of the expectation in a snapshot."""
assert_str_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="result",
snapshot_str="1 + 1 = 2",
expectation_description="The snapshot should show the 2 as the result.",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
test_name: test_expectation_description
test_filename: test_snapshot.py
docstring:
Tests having a description of the expectation in a snapshot.
expectation_description:
The snapshot should show the 2 as the result.

1 + 1 = 2
10 changes: 10 additions & 0 deletions tests_metricflow/snapshot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def assert_execution_plan_text_equal( # noqa: D103
mf_test_configuration: MetricFlowTestConfiguration,
sql_client: SqlClient,
execution_plan: ExecutionPlan,
expectation_description: Optional[str] = None,
) -> None:
assert_plan_snapshot_text_equal(
request=request,
Expand All @@ -44,6 +45,7 @@ def assert_execution_plan_text_equal( # noqa: D103
source_schema=mf_test_configuration.mf_source_schema,
),
additional_sub_directories_for_snapshots=(sql_client.sql_engine_type.value,),
expectation_description=expectation_description,
)


Expand All @@ -52,6 +54,7 @@ def assert_dataflow_plan_text_equal( # noqa: D103
mf_test_configuration: MetricFlowTestConfiguration,
dataflow_plan: DataflowPlan,
sql_client: SqlClient,
expectation_description: Optional[str] = None,
) -> None:
assert_plan_snapshot_text_equal(
request=request,
Expand All @@ -60,6 +63,7 @@ def assert_dataflow_plan_text_equal( # noqa: D103
plan_snapshot_text=dataflow_plan.structure_text(),
incomparable_strings_replacement_function=replace_dataset_id_hash,
additional_sub_directories_for_snapshots=(sql_client.sql_engine_type.value,),
expectation_description=expectation_description,
)


Expand All @@ -69,6 +73,7 @@ def assert_object_snapshot_equal( # type: ignore[misc]
obj_id: str,
obj: Any,
sql_client: Optional[SqlClient] = None,
expectation_description: Optional[str] = None,
) -> None:
"""For tests to compare large objects, this can be used to snapshot a text representation of the object."""
if sql_client is not None:
Expand All @@ -82,6 +87,7 @@ def assert_object_snapshot_equal( # type: ignore[misc]
snapshot_text=mf_pformat(obj),
snapshot_file_extension=".txt",
additional_sub_directories_for_snapshots=(sql_client.sql_engine_type.value,) if sql_client else (),
expectation_description=expectation_description,
)


Expand All @@ -91,6 +97,7 @@ def assert_sql_snapshot_equal(
snapshot_id: str,
sql: str,
sql_engine: Optional[SqlEngine] = None,
expectation_description: Optional[str] = None,
) -> None:
"""For tests that generate SQL, use this to write / check snapshots."""
if sql_engine is not None:
Expand All @@ -109,6 +116,7 @@ def assert_sql_snapshot_equal(
exclude_line_regex=_EXCLUDE_TABLE_ALIAS_REGEX,
additional_sub_directories_for_snapshots=(sql_engine.value,) if sql_engine is not None else (),
additional_header_fields={SQL_ENGINE_HEADER_NAME: sql_engine.value} if sql_engine is not None else None,
expectation_description=expectation_description,
)


Expand All @@ -118,6 +126,7 @@ def assert_str_snapshot_equal( # type: ignore[misc]
snapshot_id: str,
snapshot_str: str,
sql_engine: Optional[SqlEngine] = None,
expectation_description: Optional[str] = None,
) -> None:
"""Write / compare a string snapshot."""
if sql_engine is not None:
Expand All @@ -131,4 +140,5 @@ def assert_str_snapshot_equal( # type: ignore[misc]
snapshot_text=snapshot_str,
snapshot_file_extension=".txt",
additional_sub_directories_for_snapshots=(sql_engine.value,) if sql_engine is not None else (),
expectation_description=expectation_description,
)

0 comments on commit 33132e0

Please sign in to comment.