From 8637c35c2d4fe0d51d5076741af33b3083d9f497 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Tue, 20 Aug 2024 17:05:58 -0700 Subject: [PATCH] Add custom granularities to TimeSpineSource and rename other fields --- .../time/time_spine_source.py | 25 +++++++++++-------- metricflow/dataflow/builder/source_node.py | 2 +- metricflow/dataset/convert_semantic_model.py | 4 +-- metricflow/plan_conversion/dataflow_to_sql.py | 6 ++--- .../fixtures/dataflow_fixtures.py | 6 ++--- tests_metricflow/fixtures/table_fixtures.py | 2 +- .../plan_conversion/test_time_spine.py | 4 +-- 7 files changed, 27 insertions(+), 22 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/time/time_spine_source.py b/metricflow-semantics/metricflow_semantics/time/time_spine_source.py index 82f1a97acc..27655a5c00 100644 --- a/metricflow-semantics/metricflow_semantics/time/time_spine_source.py +++ b/metricflow-semantics/metricflow_semantics/time/time_spine_source.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Dict, Optional +from typing import Dict, Optional, Sequence from dbt_semantic_interfaces.protocols import SemanticManifest from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity @@ -17,15 +17,19 @@ @dataclass(frozen=True) class TimeSpineSource: - """Defines a source table containing all timestamps to use for computing cumulative metrics.""" + """A calendar table. Should contain at least one column with dates/times that map to a standard granularity. + + Dates should be contiguous. May also contain custom granularity columns. + """ schema_name: str table_name: str = "mf_time_spine" - # Name of the column in the table that contains the dates. - time_column_name: str = "ds" - # The time granularity of the dates in the spine table. - time_column_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY + # Name of the column in the table that contains date/time values that map to a standard granularity. + base_column: str = "ds" + # The time granularity of the base column. + base_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY db_name: Optional[str] = None + custom_granularity_columns: Sequence[str] = () @property def spine_table(self) -> SqlTable: @@ -40,8 +44,9 @@ def create_from_manifest(semantic_manifest: SemanticManifest) -> Dict[TimeGranul schema_name=time_spine.node_relation.schema_name, table_name=time_spine.node_relation.alias, db_name=time_spine.node_relation.database, - time_column_name=time_spine.primary_column.name, - time_column_granularity=time_spine.primary_column.time_granularity, + base_column=time_spine.primary_column.name, + base_granularity=time_spine.primary_column.time_granularity, + custom_granularity_columns=[column.name for column in time_spine.custom_granularity_columns], ) for time_spine in semantic_manifest.project_configuration.time_spines } @@ -56,8 +61,8 @@ def create_from_manifest(semantic_manifest: SemanticManifest) -> Dict[TimeGranul schema_name=time_spine_table.schema_name, table_name=time_spine_table.table_name, db_name=time_spine_table.db_name, - time_column_name=legacy_time_spine.column_name, - time_column_granularity=legacy_time_spine.grain, + base_column=legacy_time_spine.column_name, + base_granularity=legacy_time_spine.grain, ) # Sanity check: this should have been validated during manifest parsing. diff --git a/metricflow/dataflow/builder/source_node.py b/metricflow/dataflow/builder/source_node.py index ba48632eb5..cf1fbcdd1c 100644 --- a/metricflow/dataflow/builder/source_node.py +++ b/metricflow/dataflow/builder/source_node.py @@ -66,7 +66,7 @@ def __init__( # noqa: D107 data_set = data_set_converter.build_time_spine_source_data_set(time_spine_source) self._time_spine_source_nodes[granularity] = MetricTimeDimensionTransformNode.create( parent_node=ReadSqlSourceNode.create(data_set), - aggregation_time_dimension_reference=TimeDimensionReference(time_spine_source.time_column_name), + aggregation_time_dimension_reference=TimeDimensionReference(time_spine_source.base_column), ) self._query_parser = MetricFlowQueryParser(semantic_manifest_lookup) diff --git a/metricflow/dataset/convert_semantic_model.py b/metricflow/dataset/convert_semantic_model.py index 3304aefce3..e1de4b1258 100644 --- a/metricflow/dataset/convert_semantic_model.py +++ b/metricflow/dataset/convert_semantic_model.py @@ -516,8 +516,8 @@ def create_sql_source_data_set(self, semantic_model: SemanticModel) -> SemanticM def build_time_spine_source_data_set(self, time_spine_source: TimeSpineSource) -> SqlDataSet: """Build data set for time spine.""" from_source_alias = SequentialIdGenerator.create_next_id(StaticIdPrefix.TIME_SPINE_SOURCE).str_value - defined_time_granularity = time_spine_source.time_column_granularity - time_column_name = time_spine_source.time_column_name + defined_time_granularity = time_spine_source.base_granularity + time_column_name = time_spine_source.base_column time_dimension_instances: List[TimeDimensionInstance] = [] select_columns: List[SqlSelectColumn] = [] diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 63592b54f7..93aed77ec1 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -275,7 +275,7 @@ def _make_time_spine_data_set( time_spine_source = self._choose_time_spine_source(agg_time_dimension_instances) column_expr = SqlColumnReferenceExpression.from_table_and_column_names( - table_alias=time_spine_table_alias, column_name=time_spine_source.time_column_name + table_alias=time_spine_table_alias, column_name=time_spine_source.base_column ) select_columns: Tuple[SqlSelectColumn, ...] = () apply_group_by = False @@ -283,7 +283,7 @@ def _make_time_spine_data_set( column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_instance.spec).column_name # If the requested granularity is the same as the granularity of the spine, do a direct select. # TODO: also handle date part. - if agg_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity: + if agg_time_dimension_instance.spec.time_granularity == time_spine_source.base_granularity: select_columns += (SqlSelectColumn(expr=column_expr, column_alias=column_alias),) # If any columns have a different granularity, apply a DATE_TRUNC() and aggregate via group_by. else: @@ -308,7 +308,7 @@ def _make_time_spine_data_set( where=( _make_time_range_comparison_expr( table_alias=time_spine_table_alias, - column_alias=time_spine_source.time_column_name, + column_alias=time_spine_source.base_column, time_range_constraint=time_range_constraint, ) if time_range_constraint diff --git a/tests_metricflow/fixtures/dataflow_fixtures.py b/tests_metricflow/fixtures/dataflow_fixtures.py index 571d86972d..e47a4af09a 100644 --- a/tests_metricflow/fixtures/dataflow_fixtures.py +++ b/tests_metricflow/fixtures/dataflow_fixtures.py @@ -96,7 +96,7 @@ def time_spine_sources( # noqa: D103 ) -> Mapping[TimeGranularity, TimeSpineSource]: legacy_time_spine_grain = TimeGranularity.DAY time_spine_base_table_name = "mf_time_spine" - print("expected schema name:", mf_test_configuration.mf_source_schema) + # Legacy time spine time_spine_sources = { legacy_time_spine_grain: TimeSpineSource( @@ -113,8 +113,8 @@ def time_spine_sources( # noqa: D103 time_spine_sources[granularity] = TimeSpineSource( schema_name=mf_test_configuration.mf_source_schema, table_name=f"{time_spine_base_table_name}_{granularity.value}", - time_column_name="ts", - time_column_granularity=granularity, + base_column="ts", + base_granularity=granularity, ) return time_spine_sources diff --git a/tests_metricflow/fixtures/table_fixtures.py b/tests_metricflow/fixtures/table_fixtures.py index 11f7f7f19d..630b3fbbe4 100644 --- a/tests_metricflow/fixtures/table_fixtures.py +++ b/tests_metricflow/fixtures/table_fixtures.py @@ -64,7 +64,7 @@ def check_time_spine_source( assert len(time_spine_snapshot.column_definitions) == 1 time_column = time_spine_snapshot.column_definitions[0] - assert time_column.name == time_spine_source.time_column_name + assert time_column.name == time_spine_source.base_column @pytest.fixture(scope="session") diff --git a/tests_metricflow/plan_conversion/test_time_spine.py b/tests_metricflow/plan_conversion/test_time_spine.py index 654058da6a..178506aa39 100644 --- a/tests_metricflow/plan_conversion/test_time_spine.py +++ b/tests_metricflow/plan_conversion/test_time_spine.py @@ -22,8 +22,8 @@ def test_date_spine_date_range( # noqa: D103 textwrap.dedent( f"""\ SELECT - MIN({time_spine_source.time_column_name}) - , MAX({time_spine_source.time_column_name}) + MIN({time_spine_source.base_column}) + , MAX({time_spine_source.base_column}) FROM {time_spine_source.spine_table.sql} """, )