From 7fc74c0009554a9851e2a56b67b44be8b04bf499 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 12 Jun 2024 15:24:18 -0700 Subject: [PATCH] Refactor _make_time_spine_data_set() for readability & simplicity Also supports an upcoming change to allow multiple agg_time_dimensions in this function. --- metricflow/plan_conversion/dataflow_to_sql.py | 120 ++++++------------ metricflow/sql/sql_exprs.py | 4 + 2 files changed, 46 insertions(+), 78 deletions(-) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 4413b9e8b5..bd6a4567ce 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -223,89 +223,59 @@ def _next_unique_table_alias(self) -> str: def _make_time_spine_data_set( self, agg_time_dimension_instance: TimeDimensionInstance, - agg_time_dimension_column_name: str, time_spine_source: TimeSpineSource, time_range_constraint: Optional[TimeRangeConstraint] = None, ) -> SqlDataSet: - """Make a time spine data set, which contains all date values like '2020-01-01', '2020-01-02'... + """Make a time spine data set, which contains all date/time values like '2020-01-01', '2020-01-02'... - This is useful in computing cumulative metrics. This will need to be updated to support granularities finer than a - day. + Returns a data set with a column for the agg_time_dimension requested. + Column alias will use 'metric_time' or the agg_time_dimension name depending on which the user requested. """ - time_spine_instance = TimeDimensionInstance( - defined_from=agg_time_dimension_instance.defined_from, - associated_columns=(ColumnAssociation(agg_time_dimension_column_name),), - spec=agg_time_dimension_instance.spec, - ) - - time_spine_instance_set = InstanceSet(time_dimension_instances=(time_spine_instance,)) + time_spine_instance_set = InstanceSet(time_dimension_instances=(agg_time_dimension_instance,)) time_spine_table_alias = self._next_unique_table_alias() - # If the requested granularity is the same as the granularity of the spine, do a direct select. + column_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=time_spine_table_alias, column_name=time_spine_source.time_column_name + ) + + select_columns: Tuple[SqlSelectColumn, ...] = () + apply_group_by = False + column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_instance.spec).column_name + # If the requested granularity matches that of the time spine, do a direct select. + # TODO: also handle date part. if agg_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity: - return SqlDataSet( - instance_set=time_spine_instance_set, - sql_select_node=SqlSelectStatementNode( - description=TIME_SPINE_DATA_SET_DESCRIPTION, - select_columns=( - SqlSelectColumn( - expr=SqlColumnReferenceExpression( - SqlColumnReference( - table_alias=time_spine_table_alias, - column_name=time_spine_source.time_column_name, - ), - ), - column_alias=agg_time_dimension_column_name, - ), - ), - from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table), - from_source_alias=time_spine_table_alias, - where=( - _make_time_range_comparison_expr( - table_alias=time_spine_table_alias, - column_alias=time_spine_source.time_column_name, - time_range_constraint=time_range_constraint, - ) - if time_range_constraint - else None - ), - ), - ) - # If the granularity is different, apply a DATE_TRUNC() and aggregate. + select_columns += (SqlSelectColumn(expr=column_expr, column_alias=column_alias),) + # Otherwise, apply a DATE_TRUNC() and aggregate via group_by. else: - select_columns = ( + select_columns += ( SqlSelectColumn( expr=SqlDateTruncExpression( - time_granularity=agg_time_dimension_instance.spec.time_granularity, - arg=SqlColumnReferenceExpression( - SqlColumnReference( - table_alias=time_spine_table_alias, - column_name=time_spine_source.time_column_name, - ), - ), + time_granularity=agg_time_dimension_instance.spec.time_granularity, arg=column_expr ), - column_alias=agg_time_dimension_column_name, + column_alias=column_alias, ), ) - return SqlDataSet( - instance_set=time_spine_instance_set, - sql_select_node=SqlSelectStatementNode( - description=TIME_SPINE_DATA_SET_DESCRIPTION, - select_columns=select_columns, - from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table), - from_source_alias=time_spine_table_alias, - group_bys=select_columns, - where=( - _make_time_range_comparison_expr( - table_alias=time_spine_table_alias, - column_alias=time_spine_source.time_column_name, - time_range_constraint=time_range_constraint, - ) - if time_range_constraint - else None - ), + apply_group_by = True + + return SqlDataSet( + instance_set=time_spine_instance_set, + sql_select_node=SqlSelectStatementNode( + description=TIME_SPINE_DATA_SET_DESCRIPTION, + select_columns=select_columns, + from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table), + from_source_alias=time_spine_table_alias, + group_bys=select_columns if apply_group_by else (), + where=( + _make_time_range_comparison_expr( + table_alias=time_spine_table_alias, + column_alias=time_spine_source.time_column_name, + time_range_constraint=time_range_constraint, + ) + if time_range_constraint + else None ), - ) + ), + ) def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet: """Generate the SQL to read from the source.""" @@ -332,15 +302,10 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat time_spine_data_set_alias = self._next_unique_table_alias() - agg_time_dimension_column_name = self.column_association_resolver.resolve_spec( - agg_time_dimension_instance.spec - ).column_name - # Assemble time_spine dataset with metric_time_dimension to join. # Granularity of time_spine column should match granularity of metric_time column from parent dataset. time_spine_data_set = self._make_time_spine_data_set( agg_time_dimension_instance=agg_time_dimension_instance, - agg_time_dimension_column_name=agg_time_dimension_column_name, time_spine_source=self._time_spine_source, time_range_constraint=node.time_range_constraint, ) @@ -1275,22 +1240,21 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet agg_time_dimension_instance_for_join = agg_time_dimension_instances[0] # Build time spine data set using the requested agg_time_dimension name. - agg_time_dimension_column_name = self.column_association_resolver.resolve_spec( - agg_time_dimension_instance_for_join.spec - ).column_name time_spine_alias = self._next_unique_table_alias() time_spine_dataset = self._make_time_spine_data_set( agg_time_dimension_instance=agg_time_dimension_instance_for_join, - agg_time_dimension_column_name=agg_time_dimension_column_name, time_spine_source=self._time_spine_source, time_range_constraint=node.time_range_constraint, ) # Build join expression. + join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description( node=node, time_spine_alias=time_spine_alias, - agg_time_dimension_column_name=agg_time_dimension_column_name, + agg_time_dimension_column_name=self.column_association_resolver.resolve_spec( + agg_time_dimension_instance_for_join.spec + ).column_name, parent_sql_select_node=parent_data_set.checked_sql_select_node, parent_alias=parent_alias, ) diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 0d1f598abe..3a1f424822 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -476,6 +476,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.col_ref == other.col_ref + @staticmethod + def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumnReferenceExpression: # noqa: D102 + return SqlColumnReferenceExpression(SqlColumnReference(table_alias=table_alias, column_name=column_name)) + class SqlColumnAliasReferenceExpression(SqlExpressionNode): """An expression that evaluates to the alias of a column, but is not qualified with a table alias.