-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests that compare the output of the time adjustment implemetnati…
…ons.
- Loading branch information
Showing
1 changed file
with
159 additions
and
0 deletions.
There are no files selected for viewing
159 changes: 159 additions & 0 deletions
159
metricflow-semantics/tests_metricflow_semantics/time/test_time_adjuster.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from __future__ import annotations | ||
|
||
import datetime | ||
import logging | ||
from typing import Dict, List, Sequence, Tuple | ||
|
||
import pytest | ||
import tabulate | ||
from _pytest.fixtures import FixtureRequest | ||
from dbt_semantic_interfaces.test_utils import as_datetime | ||
from dbt_semantic_interfaces.type_enums import TimeGranularity | ||
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint | ||
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration | ||
from metricflow_semantics.test_helpers.snapshot_helpers import assert_str_snapshot_equal | ||
from metricflow_semantics.time.dateutil_adjuster import DateutilTimePeriodAdjuster | ||
from metricflow_semantics.time.pandas_adjuster import PandasTimePeriodAdjuster | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def date_times_to_check() -> Sequence[datetime.datetime]: # noqa: D103 | ||
date_times = [] | ||
# Cover leap years, non-leap years | ||
# 1900 was tested to work, though that requires a change to `TimeRangeConstraint.ALL_TIME_BEGIN` | ||
for year in (2000, 2000, 2021): | ||
start_date_time = datetime.datetime(year=year, month=1, day=1) | ||
end_date_time = datetime.datetime(year=year, month=12, day=31) | ||
current_date_time = start_date_time | ||
while True: | ||
date_times.append(current_date_time) | ||
current_date_time += datetime.timedelta(days=1) | ||
if current_date_time == end_date_time: | ||
break | ||
return date_times | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def grain_to_count_in_year() -> Dict[TimeGranularity, int]: | ||
"""Returns the maximum number of times the given item occurs in a year.""" | ||
return { | ||
TimeGranularity.DAY: 366, | ||
TimeGranularity.WEEK: 53, | ||
TimeGranularity.MONTH: 31, | ||
TimeGranularity.QUARTER: 4, | ||
TimeGranularity.YEAR: 1, | ||
} | ||
|
||
|
||
def test_start_and_end_periods( # noqa: D103 | ||
request: FixtureRequest, | ||
mf_test_configuration: MetricFlowTestConfiguration, | ||
date_times_to_check: Sequence[datetime.datetime], | ||
) -> None: | ||
pandas_adjuster = PandasTimePeriodAdjuster() | ||
dateutil_adjuster = DateutilTimePeriodAdjuster() | ||
|
||
rows: List[Tuple[str, ...]] = [] | ||
for date_time in date_times_to_check: | ||
for time_granularity in TimeGranularity: | ||
# Pandas implementation of `adjust_to_start_of_period` doesn't support DAY. | ||
if time_granularity == TimeGranularity.DAY: | ||
pandas_start_of_period = None | ||
pandas_end_of_period = None | ||
else: | ||
pandas_start_of_period = pandas_adjuster.adjust_to_start_of_period(time_granularity, date_time) | ||
pandas_end_of_period = pandas_adjuster.adjust_to_end_of_period(time_granularity, date_time) | ||
dateutil_start_of_period = dateutil_adjuster.adjust_to_start_of_period(time_granularity, date_time) | ||
dateutil_end_of_period = dateutil_adjuster.adjust_to_end_of_period(time_granularity, date_time) | ||
assert ( | ||
pandas_start_of_period or dateutil_start_of_period | ||
) == dateutil_start_of_period, f"start-of-period mismatch: {date_time.isoformat()} {time_granularity}" | ||
assert ( | ||
pandas_end_of_period or dateutil_end_of_period | ||
) == dateutil_end_of_period, f"end-of-period mismatch: {date_time.isoformat()} {time_granularity}" | ||
rows.append( | ||
( | ||
date_time.isoformat(), | ||
time_granularity.name, | ||
dateutil_start_of_period.isoformat(), | ||
dateutil_end_of_period.isoformat(), | ||
) | ||
) | ||
assert_str_snapshot_equal( | ||
request=request, | ||
mf_test_configuration=mf_test_configuration, | ||
snapshot_id="results", | ||
snapshot_str=tabulate.tabulate(rows, headers=["Date", "Grain", "Period Start", "Period End"]), | ||
) | ||
|
||
|
||
def test_expand_time_constraint_to_fill_granularity( # noqa: D103 | ||
date_times_to_check: Sequence[datetime.datetime], grain_to_count_in_year: Dict[TimeGranularity, int] | ||
) -> None: | ||
pandas_adjuster = PandasTimePeriodAdjuster() | ||
dateutil_adjuster = DateutilTimePeriodAdjuster() | ||
|
||
test_cases = tuple( | ||
(start_time, end_time, time_granularity) | ||
for start_time in date_times_to_check | ||
for time_granularity in TimeGranularity | ||
for end_time in ( | ||
start_time + datetime.timedelta(days=day_offset) | ||
# Add 2 to cross time grain boundaries. | ||
for day_offset in range(grain_to_count_in_year[time_granularity] + 2) | ||
) | ||
) | ||
|
||
test_case_count = len(test_cases) | ||
logger.info(f"There are {test_case_count} test cases") | ||
|
||
finished_count = 0 | ||
|
||
for start_time, end_time, time_granularity in test_cases: | ||
time_constraint = TimeRangeConstraint(start_time=start_time, end_time=end_time) | ||
pandas_adjuster_result = pandas_adjuster.expand_time_constraint_to_fill_granularity( | ||
time_constraint, time_granularity | ||
) | ||
|
||
dateutil_adjuster_result = dateutil_adjuster.expand_time_constraint_to_fill_granularity( | ||
time_constraint, time_granularity | ||
) | ||
|
||
assert ( | ||
pandas_adjuster_result == dateutil_adjuster_result | ||
), f"Expansion mismatch: {pandas_adjuster_result=} {dateutil_adjuster_result=} {time_granularity=}" | ||
finished_count += 1 | ||
if finished_count % 100000 == 0 or finished_count == test_case_count: | ||
logger.info(f"Progress {finished_count / test_case_count * 100:.0f}%") | ||
|
||
|
||
def test_expand_time_constraint_for_cumulative_metric( # noqa: D103 | ||
grain_to_count_in_year: Dict[TimeGranularity, int] | ||
) -> None: | ||
pandas_adjuster = PandasTimePeriodAdjuster() | ||
dateutil_adjuster = DateutilTimePeriodAdjuster() | ||
|
||
test_cases = tuple( | ||
(as_datetime("2020-01-01"), time_granularity, count) | ||
for time_granularity in TimeGranularity | ||
for count in (range(grain_to_count_in_year[time_granularity] + 2)) | ||
) | ||
|
||
test_case_count = len(test_cases) | ||
logger.info(f"There are {test_case_count} test cases") | ||
|
||
for start_time, time_granularity, count in test_cases: | ||
time_constraint = TimeRangeConstraint(start_time=start_time, end_time=start_time) | ||
pandas_adjuster_result = pandas_adjuster.expand_time_constraint_for_cumulative_metric( | ||
time_constraint, time_granularity, count | ||
) | ||
|
||
dateutil_adjuster_result = dateutil_adjuster.expand_time_constraint_for_cumulative_metric( | ||
time_constraint, time_granularity, count | ||
) | ||
|
||
assert ( | ||
pandas_adjuster_result == dateutil_adjuster_result | ||
), f"Expansion mismatch: {start_time=}, {time_granularity=}, {count=}" |