diff --git a/tests_metricflow/integration/test_cases/itest_granularity.yaml b/tests_metricflow/integration/test_cases/itest_granularity.yaml index a0f495e86a..7ad299077f 100644 --- a/tests_metricflow/integration/test_cases/itest_granularity.yaml +++ b/tests_metricflow/integration/test_cases/itest_granularity.yaml @@ -456,14 +456,14 @@ integration_test: model: SIMPLE_MODEL metrics: ["bookings"] where_filter: | - {{ render_time_dimension_template('listing__ds', 'martian_day') }} >= '2019-12-20' + {{ render_time_constraint(render_time_dimension_template('listing__ds', 'martian_day'), start_time="2019-12-20") }} check_query: | SELECT SUM(1) AS bookings FROM {{ source_schema }}.fct_bookings b LEFT OUTER JOIN {{ source_schema }}.dim_listings_latest l ON b.listing_id = l.listing_id LEFT OUTER JOIN {{ source_schema }}.mf_time_spine ts ON {{ render_date_trunc("l.created_at", TimeGranularity.DAY) }} = ts.ds - WHERE ts.martian_day >= '2019-12-20' + WHERE {{ render_time_constraint("ts.martian_day", start_time="2019-12-20") }} --- integration_test: name: simple_metric_with_custom_granularity_in_filter_and_group_by @@ -472,7 +472,7 @@ integration_test: metrics: ["bookings"] group_bys: ["listing__ds__martian_day"] where_filter: | - {{ render_time_dimension_template('listing__ds', 'martian_day') }} >= '2019-12-20' + {{ render_time_constraint(render_time_dimension_template('listing__ds', 'martian_day'), start_time="2019-12-20") }} check_query: | SELECT ts.martian_day AS listing__ds__martian_day @@ -481,8 +481,8 @@ integration_test: LEFT OUTER JOIN {{ source_schema }}.dim_listings_latest l ON b.listing_id = l.listing_id LEFT OUTER JOIN {{ source_schema }}.mf_time_spine ts ON {{ render_date_trunc("l.created_at", TimeGranularity.DAY) }} = ts.ds - WHERE ts.martian_day >= '2019-12-20' - GROUP BY listing__ds__martian_day + WHERE {{ render_time_constraint("ts.martian_day", start_time="2019-12-20") }} + GROUP BY ts.martian_day --- integration_test: name: test_no_metrics_with_custom_granularity_filter @@ -490,7 +490,7 @@ integration_test: model: SIMPLE_MODEL group_bys: ["metric_time__day"] where_filter: | - {{ render_time_dimension_template('listing__ds', 'martian_day') }} >= '2019-12-20' + {{ render_time_constraint(render_time_dimension_template('listing__ds', 'martian_day'), start_time="2019-12-20") }} check_query: | SELECT {{ render_date_trunc("ts.ds", TimeGranularity.DAY) }} AS metric_time__day @@ -498,8 +498,8 @@ integration_test: CROSS JOIN {{ source_schema }}.mf_time_spine ts LEFT OUTER JOIN {{ source_schema }}.mf_time_spine ts2 ON {{ render_date_trunc("l.created_at", TimeGranularity.DAY) }} = ts2.ds - WHERE ts2.martian_day >= '2019-12-20' - GROUP BY metric_time__day + WHERE {{ render_time_constraint("ts2.martian_day", start_time="2019-12-20") }} + GROUP BY {{ render_date_trunc("ts.ds", TimeGranularity.DAY) }} --- integration_test: name: test_no_metrics_with_custom_granularity_in_filter_and_group_by @@ -507,12 +507,12 @@ integration_test: model: SIMPLE_MODEL group_bys: ["listing__ds__martian_day"] where_filter: | - {{ render_time_dimension_template('listing__ds', 'martian_day') }} >= '2019-12-20' + {{ render_time_constraint(render_time_dimension_template('listing__ds', 'martian_day'), start_time="2019-12-20") }} check_query: | SELECT ts.martian_day AS listing__ds__martian_day FROM {{ source_schema }}.dim_listings_latest l LEFT OUTER JOIN {{ source_schema }}.mf_time_spine ts ON {{ render_date_trunc("l.created_at", TimeGranularity.DAY) }} = ts.ds - WHERE ts.martian_day >= '2019-12-20' - GROUP BY listing__ds__martian_day + WHERE {{ render_time_constraint("ts.martian_day", start_time="2019-12-20") }} + GROUP BY ts.martian_day diff --git a/tests_metricflow/integration/test_configured_cases.py b/tests_metricflow/integration/test_configured_cases.py index 9c700d7d6f..3c387a3b3f 100644 --- a/tests_metricflow/integration/test_configured_cases.py +++ b/tests_metricflow/integration/test_configured_cases.py @@ -16,6 +16,7 @@ from metricflow_semantics.protocols.query_parameter import DimensionOrEntityQueryParameter from metricflow_semantics.specs.query_param_implementations import DimensionOrEntityParameter, TimeDimensionParameter from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration +from metricflow_semantics.time.time_constants import ISO8601_PYTHON_FORMAT, ISO8601_PYTHON_TS_FORMAT from metricflow_semantics.time.time_spine_source import TimeSpineSource from metricflow.engine.metricflow_engine import MetricFlowQueryRequest @@ -53,20 +54,26 @@ def __init__(self, sql_client: SqlClient) -> None: # noqa: D107 def render_time_constraint( self, expr: str, - start_time: str, - stop_time: str, + start_time: Optional[str] = None, + stop_time: Optional[str] = None, ) -> str: """Render an expression like "ds >='2020-01-01' AND ds < '2020-01-02'" for start_time = stop_time = '2020-01-01'.""" - start_expr = self.cast_to_ts(f"{start_time}") - time_format = "%Y-%m-%d" if len(start_time) == 10 else "%Y-%m-%d %H:%M:%S" - stop_time_dt = datetime.datetime.strptime(stop_time, time_format) - if len(start_time) == 10: - stop_time = (stop_time_dt + datetime.timedelta(days=1)).strftime(time_format) - else: - stop_time = (stop_time_dt + datetime.timedelta(seconds=1)).strftime(time_format) - - stop_expr = self.cast_to_ts(f"{stop_time}") - return f"{self.cast_expr_to_ts(expr)} >= {start_expr} AND {self.cast_expr_to_ts(expr)} < {stop_expr}" + time_param = start_time or stop_time # needed for type checking + assert time_param, "At least one of start_time or stop_time must be provided." + time_format = ISO8601_PYTHON_FORMAT if len(time_param) == 10 else ISO8601_PYTHON_TS_FORMAT + + if start_time: + start_expr = f"{self.cast_expr_to_ts(expr)} >= {self.cast_to_ts(f'{start_time}')}" + + if stop_time: + stop_time_dt = datetime.datetime.strptime(stop_time, time_format) + if time_format == ISO8601_PYTHON_FORMAT: + stop_time = (stop_time_dt + datetime.timedelta(days=1)).strftime(time_format) + else: + stop_time = (stop_time_dt + datetime.timedelta(seconds=1)).strftime(time_format) + stop_expr = f"{self.cast_expr_to_ts(expr)} < {self.cast_to_ts(f'{stop_time}')}" + + return f"{start_expr if start_time else ''}{' AND ' if start_time and stop_time else ''}{stop_expr if stop_time else ''}" def cast_expr_to_ts(self, expr: str) -> str: """Returns the expression as a new expression cast to the timestamp type, if applicable for the DB."""