From 910f21ac94f706dde8a000547c45f820764fc165 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Thu, 5 Sep 2024 12:47:52 -0700 Subject: [PATCH] Add SQL rendering logic for custom granularities --- .../time/time_spine_source.py | 21 ++- .../nodes/join_to_custom_granularity.py | 6 + metricflow/plan_conversion/dataflow_to_sql.py | 131 ++++++++++++++++-- 3 files changed, 141 insertions(+), 17 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/time/time_spine_source.py b/metricflow-semantics/metricflow_semantics/time/time_spine_source.py index 8f7f6744fb..1985369f77 100644 --- a/metricflow-semantics/metricflow_semantics/time/time_spine_source.py +++ b/metricflow-semantics/metricflow_semantics/time/time_spine_source.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Dict, Optional, Sequence +from dbt_semantic_interfaces.implementations.time_spine import PydanticTimeSpineCustomGranularityColumn from dbt_semantic_interfaces.protocols import SemanticManifest from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity @@ -29,7 +30,7 @@ class TimeSpineSource: # The time granularity of the base column. base_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY db_name: Optional[str] = None - custom_granularities: Sequence[str] = () + custom_granularities: Sequence[PydanticTimeSpineCustomGranularityColumn] = () @property def spine_table(self) -> SqlTable: @@ -48,7 +49,14 @@ def build_standard_time_spine_sources( db_name=time_spine.node_relation.database, base_column=time_spine.primary_column.name, base_granularity=time_spine.primary_column.time_granularity, - custom_granularities=[column.name for column in time_spine.custom_granularities], + custom_granularities=tuple( + [ + PydanticTimeSpineCustomGranularityColumn( + name=custom_granularity.name, column_name=custom_granularity.column_name + ) + for custom_granularity in time_spine.custom_granularities + ] + ), ) for time_spine in semantic_manifest.project_configuration.time_spines } @@ -74,3 +82,12 @@ def build_standard_time_spine_sources( ) return time_spine_sources + + @staticmethod + def build_custom_time_spine_sources(time_spine_sources: Sequence[TimeSpineSource]) -> Dict[str, TimeSpineSource]: + """Creates a set of time spine sources with custom granularities based on what's in the manifest.""" + return { + custom_granularity.name: time_spine_source + for time_spine_source in time_spine_sources + for custom_granularity in time_spine_source.custom_granularities + } diff --git a/metricflow/dataflow/nodes/join_to_custom_granularity.py b/metricflow/dataflow/nodes/join_to_custom_granularity.py index 14e5008c0e..83d9a902b0 100644 --- a/metricflow/dataflow/nodes/join_to_custom_granularity.py +++ b/metricflow/dataflow/nodes/join_to_custom_granularity.py @@ -18,6 +18,12 @@ class JoinToCustomGranularityNode(DataflowPlanNode, ABC): time_dimension_spec: TimeDimensionSpec + def __post_init__(self) -> None: # noqa: D105 + assert ( + self.time_dimension_spec.time_granularity.is_custom_granularity + ), "Time granularity for time dimension spec in JoinToCustomGranularityNode must be qualified as custom granularity." + f" Instead, found {self.time_dimension_spec.time_granularity.name}. This indicates internal misconfiguration." + @staticmethod def create( # noqa: D102 parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 7b168c4f0c..6e58685232 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -199,6 +199,9 @@ def __init__( self._time_spine_sources = TimeSpineSource.build_standard_time_spine_sources( semantic_manifest_lookup.semantic_manifest ) + self._custom_granularity_time_spine_sources = TimeSpineSource.build_custom_time_spine_sources( + tuple(self._time_spine_sources.values()) + ) @property def column_association_resolver(self) -> ColumnAssociationResolver: # noqa: D102 @@ -237,7 +240,7 @@ def _next_unique_table_alias(self) -> str: """Return the next unique table alias to use in generating queries.""" return SequentialIdGenerator.create_next_id(StaticIdPrefix.SUB_QUERY).str_value - def _choose_time_spine_source( + def _choose_time_spine_source_for_standard_granularity( self, agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...] ) -> TimeSpineSource: """Determine which time spine source to use when building time spine dataset. @@ -254,9 +257,10 @@ def _choose_time_spine_source( assert ( agg_time_dimension_instances ), "Building time spine dataset requires agg_time_dimension_instances, but none were found." - smallest_agg_time_grain = min( - dim.spec.time_granularity.base_granularity for dim in agg_time_dimension_instances - ) + smallest_agg_time_grain = sorted( + agg_time_dimension_instances, + key=lambda x: x.spec.time_granularity.base_granularity.to_int(), + )[0].spec.time_granularity.base_granularity compatible_time_spine_grains = [ grain for grain in self._time_spine_sources.keys() if grain.to_int() <= smallest_agg_time_grain.to_int() ] @@ -268,6 +272,48 @@ def _choose_time_spine_source( ) return self._time_spine_sources[max(compatible_time_spine_grains)] + def _get_time_spine_for_custom_granularity(self, custom_granularity: str) -> TimeSpineSource: + time_spine_source = self._custom_granularity_time_spine_sources.get(custom_granularity) + assert time_spine_source, ( + f"Custom granularity {custom_granularity} does not not exist in time spine sources. " + f"Available custom granularities: {list(self._custom_granularity_time_spine_sources.keys())}" + ) + return time_spine_source + + def _get_custom_granularity_column_name(self, custom_granularity: str) -> str: + time_spine_source = self._get_time_spine_for_custom_granularity(custom_granularity) + for custom_granularity in time_spine_source.custom_granularities: + if custom_granularity.name == custom_granularity: + return custom_granularity.column_name if custom_granularity.column_name else custom_granularity.name + + raise RuntimeError( + f"Custom granularity {custom_granularity} not found. This indicates internal misconfiguration." + ) + + def _make_custom_granularity_dataset(self, time_dimension_instance: TimeDimensionInstance) -> SqlDataSet: + time_spine_instance_set = InstanceSet(time_dimension_instances=(time_dimension_instance,)) + time_spine_table_alias = self._next_unique_table_alias() + assert ( + time_dimension_instance.spec.time_granularity.is_custom_granularity + ), "_make_custom_granularity_dataset() should only be called for custom granularities." + + custom_granularity_name = time_dimension_instance.spec.time_granularity.name + time_spine_source = self._get_time_spine_for_custom_granularity(custom_granularity_name) + custom_granularity_column_name = self._get_custom_granularity_column_name(custom_granularity_name) + column_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=time_spine_table_alias, column_name=custom_granularity_column_name + ) + column_alias = self.column_association_resolver.resolve_spec(time_dimension_instance.spec).column_name + return SqlDataSet( + instance_set=time_spine_instance_set, + sql_select_node=SqlSelectStatementNode.create( + description=TIME_SPINE_DATA_SET_DESCRIPTION, + select_columns=(SqlSelectColumn(expr=column_expr, column_alias=column_alias),), + from_source=SqlTableFromClauseNode.create(sql_table=time_spine_source.spine_table), + from_source_alias=time_spine_table_alias, + ), + ) + def _make_time_spine_data_set( self, agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...], @@ -280,7 +326,7 @@ def _make_time_spine_data_set( time_spine_instance_set = InstanceSet(time_dimension_instances=agg_time_dimension_instances) time_spine_table_alias = self._next_unique_table_alias() - time_spine_source = self._choose_time_spine_source(agg_time_dimension_instances) + time_spine_source = self._choose_time_spine_source_for_standard_granularity(agg_time_dimension_instances) column_expr = SqlColumnReferenceExpression.from_table_and_column_names( table_alias=time_spine_table_alias, column_name=time_spine_source.base_column ) @@ -290,9 +336,7 @@ def _make_time_spine_data_set( column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_instance.spec).column_name # If the requested granularity is the same as the granularity of the spine, do a direct select. # TODO: also handle date part. - # TODO: [custom granularity] add support for custom granularities to make_time_spine_data_set agg_time_grain = agg_time_dimension_instance.spec.time_granularity - assert not agg_time_grain.is_custom_granularity, "Custom time granularities are not yet supported!" if agg_time_grain.base_granularity == time_spine_source.base_granularity: select_columns += (SqlSelectColumn(expr=column_expr, column_alias=column_alias),) # If any columns have a different granularity, apply a DATE_TRUNC() and aggregate via group_by. @@ -334,6 +378,7 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet: instance_set=node.data_set.instance_set, ) + # TODO: write tests for custom granularities that hit this node def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet: """Generate time range join SQL.""" table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict() @@ -1257,6 +1302,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe ), ) + # TODO: write tests for custom granularities that hit this node def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102 parent_data_set = node.parent_node.accept(self) parent_alias = self._next_unique_table_alias() @@ -1280,10 +1326,6 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet agg_time_dimension_instances.append(instance) # Choose the instance with the smallest standard granularity available. - # TODO: [custom granularity] Update to account for custom granularity instances - assert all( - [not instance.spec.time_granularity.is_custom_granularity for instance in agg_time_dimension_instances] - ), "Custom granularities are not yet supported!" agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int()) assert len(agg_time_dimension_instances) > 0, ( "Couldn't find requested agg_time_dimension in parent data set. The dataflow plan may have been " @@ -1369,9 +1411,6 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet # Add requested granularities (if different from time_spine) and date_parts to time spine column. for time_dimension_instance in time_dimensions_to_select_from_time_spine: time_dimension_spec = time_dimension_instance.spec - - # TODO: this will break when we start supporting smaller grain than DAY unless the time spine table is - # updated to use the smallest available grain. if ( time_dimension_spec.time_granularity.base_granularity.to_int() < original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int() @@ -1437,7 +1476,69 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet ) def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> SqlDataSet: # noqa: D102 - raise NotImplementedError # TODO in later commit + parent_data_set = node.parent_node.accept(self) + parent_alias = self._next_unique_table_alias() + parent_time_dimension_instance: Optional[TimeDimensionInstance] = None + for instance in parent_data_set.instance_set.time_dimension_instances: + if instance.spec == node.time_dimension_spec: + parent_time_dimension_instance = instance + break + assert parent_time_dimension_instance, ( + "JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. " + "This indicates internal misconfiguration." + ) + + # Build time spine dataset. + time_spine_dataset = self._make_custom_granularity_dataset(parent_time_dimension_instance) + assert ( + time_spine_dataset.instance_set.time_dimension_instances + ), "No time dimensions found in time spine dataset. This indicates internal misconfiguration." + time_spine_instance = time_spine_dataset.instance_set.time_dimension_instances[0] + + # Build join expression. + time_spine_source = self._get_time_spine_for_custom_granularity(node.time_dimension_spec.time_granularity.name) + left_expr_for_join: SqlExpressionNode = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_alias, column_name=parent_time_dimension_instance.associated_column.column_name + ) + left_expr_for_join = ( + left_expr_for_join + if parent_time_dimension_instance.spec.time_granularity == time_spine_source.base_granularity + else SqlDateTruncExpression.create( + time_granularity=time_spine_source.base_granularity, arg=left_expr_for_join + ) + ) + time_spine_alias = self._next_unique_table_alias() + join_description = SqlJoinDescription( + right_source=time_spine_dataset.checked_sql_select_node, + right_source_alias=time_spine_alias, + on_condition=SqlComparisonExpression.create( + left_expr=left_expr_for_join, + comparison=SqlComparison.EQUALS, + right_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=time_spine_alias, column_name=time_spine_source.base_column + ), + ), + join_type=SqlJoinType.LEFT_OUTER, + ) + + # Build output dataset, replacing the custom time dimension from the parent dataset with the one from the time spine. + parent_instance_set = parent_data_set.instance_set.transform( + FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=(parent_time_dimension_instance.spec,))) + ) + time_spine_instance_set = InstanceSet(time_dimension_instances=(time_spine_instance,)) + return SqlDataSet( + instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]), + sql_select_node=SqlSelectStatementNode.create( + description=node.description, + select_columns=create_select_columns_for_instance_sets( + self._column_association_resolver, + OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_instance_set}), + ), + from_source=parent_data_set.checked_sql_select_node, + from_source_alias=parent_alias, + join_descs=(join_description,), + ), + ) def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D102 parent_data_set = node.parent_node.accept(self)