diff --git a/metricflow/dataflow/builder/node_data_set.py b/metricflow/dataflow/builder/node_data_set.py index 61b594768..df34568c3 100644 --- a/metricflow/dataflow/builder/node_data_set.py +++ b/metricflow/dataflow/builder/node_data_set.py @@ -12,13 +12,15 @@ DataflowPlanNode, ) from metricflow.dataset.sql_dataset import SqlDataSet -from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter +from metricflow.plan_conversion.dataflow_to_sql import ( + DataflowNodeToSqlSubqueryVisitor, +) if TYPE_CHECKING: from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup -class DataflowPlanNodeOutputDataSetResolver(DataflowToSqlQueryPlanConverter): +class DataflowPlanNodeOutputDataSetResolver: """Given a node in a dataflow plan, figure out what is the data set output by that node. Recall that in the dataflow plan, the nodes represent computation, and the inputs and outputs of the nodes are @@ -68,9 +70,11 @@ def __init__( # noqa: D107 _node_to_output_data_set: Optional[Dict[DataflowPlanNode, SqlDataSet]] = None, ) -> None: self._node_to_output_data_set: Dict[DataflowPlanNode, SqlDataSet] = _node_to_output_data_set or {} - super().__init__( - column_association_resolver=column_association_resolver, - semantic_manifest_lookup=semantic_manifest_lookup, + self._column_association_resolver = column_association_resolver + self._semantic_manifest_lookup = semantic_manifest_lookup + self._to_data_set_visitor = _NodeDataSetVisitor( + column_association_resolver=self._column_association_resolver, + semantic_manifest_lookup=self._semantic_manifest_lookup, ) def get_output_data_set(self, node: DataflowPlanNode) -> SqlDataSet: @@ -79,7 +83,7 @@ def get_output_data_set(self, node: DataflowPlanNode) -> SqlDataSet: # TODO: The cache needs to be pruned, but has not yet been an issue. """ if node not in self._node_to_output_data_set: - self._node_to_output_data_set[node] = node.accept(self) + self._node_to_output_data_set[node] = node.accept(self._to_data_set_visitor) return self._node_to_output_data_set[node] @@ -92,11 +96,13 @@ def cache_output_data_sets(self, nodes: Sequence[DataflowPlanNode]) -> None: def copy(self) -> DataflowPlanNodeOutputDataSetResolver: """Return a copy of this with the same nodes cached.""" return DataflowPlanNodeOutputDataSetResolver( - column_association_resolver=self.column_association_resolver, + column_association_resolver=self._column_association_resolver, semantic_manifest_lookup=self._semantic_manifest_lookup, _node_to_output_data_set=dict(self._node_to_output_data_set), ) + +class _NodeDataSetVisitor(DataflowNodeToSqlSubqueryVisitor): @override def _next_unique_table_alias(self) -> str: """Return the next unique table alias to use in generating queries. diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 17302b46d..7c894dfda 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -143,41 +143,8 @@ logger = logging.getLogger(__name__) -def _make_time_range_comparison_expr( - table_alias: str, column_alias: str, time_range_constraint: TimeRangeConstraint -) -> SqlExpressionNode: - """Build an expression like "ds BETWEEN CAST('2020-01-01' AS TIMESTAMP) AND CAST('2020-01-02' AS TIMESTAMP). - - If the constraint uses day or larger grain, only render to the date level. Otherwise, render to the timestamp level. - """ - - def strip_time_from_dt(ts: dt.datetime) -> dt.datetime: - date_obj = ts.date() - return dt.datetime(date_obj.year, date_obj.month, date_obj.day) - - constraint_uses_day_or_larger_grain = True - for constraint_input in (time_range_constraint.start_time, time_range_constraint.end_time): - if strip_time_from_dt(constraint_input) != constraint_input: - constraint_uses_day_or_larger_grain = False - break - - time_format_to_render = ISO8601_PYTHON_FORMAT if constraint_uses_day_or_larger_grain else ISO8601_PYTHON_TS_FORMAT - - return SqlBetweenExpression.create( - column_arg=SqlColumnReferenceExpression.create( - SqlColumnReference(table_alias=table_alias, column_name=column_alias) - ), - start_expr=SqlStringLiteralExpression.create( - literal_value=time_range_constraint.start_time.strftime(time_format_to_render), - ), - end_expr=SqlStringLiteralExpression.create( - literal_value=time_range_constraint.end_time.strftime(time_format_to_render), - ), - ) - - -class DataflowToSqlQueryPlanConverter(DataflowPlanNodeVisitor[SqlDataSet]): - """Generates an SQL query plan from a node in the a metric dataflow plan.""" +class DataflowToSqlQueryPlanConverter: + """Generates an SQL query plan from a node in the metric dataflow plan.""" def __init__( self, @@ -214,7 +181,12 @@ def convert_to_sql_query_plan( sql_query_plan_id: Optional[DagId] = None, ) -> ConvertToSqlPlanResult: """Create an SQL query plan that represents the computation up to the given dataflow plan node.""" - data_set = dataflow_plan_node.accept(self) + to_sql_visitor = DataflowNodeToSqlSubqueryVisitor( + column_association_resolver=self.column_association_resolver, + semantic_manifest_lookup=self._semantic_manifest_lookup, + ) + data_set = dataflow_plan_node.accept(to_sql_visitor) + sql_node: SqlQueryPlanNode = data_set.sql_node # TODO: Make this a more generally accessible attribute instead of checking against the # BigQuery-ness of the engine @@ -237,6 +209,69 @@ def convert_to_sql_query_plan( sql_plan=SqlQueryPlan(render_node=sql_node, plan_id=sql_query_plan_id), ) + +def _make_time_range_comparison_expr( + table_alias: str, column_alias: str, time_range_constraint: TimeRangeConstraint +) -> SqlExpressionNode: + """Build an expression like "ds BETWEEN CAST('2020-01-01' AS TIMESTAMP) AND CAST('2020-01-02' AS TIMESTAMP). + + If the constraint uses day or larger grain, only render to the date level. Otherwise, render to the timestamp level. + """ + + def strip_time_from_dt(ts: dt.datetime) -> dt.datetime: + date_obj = ts.date() + return dt.datetime(date_obj.year, date_obj.month, date_obj.day) + + constraint_uses_day_or_larger_grain = True + for constraint_input in (time_range_constraint.start_time, time_range_constraint.end_time): + if strip_time_from_dt(constraint_input) != constraint_input: + constraint_uses_day_or_larger_grain = False + break + + time_format_to_render = ISO8601_PYTHON_FORMAT if constraint_uses_day_or_larger_grain else ISO8601_PYTHON_TS_FORMAT + + return SqlBetweenExpression.create( + column_arg=SqlColumnReferenceExpression.create( + SqlColumnReference(table_alias=table_alias, column_name=column_alias) + ), + start_expr=SqlStringLiteralExpression.create( + literal_value=time_range_constraint.start_time.strftime(time_format_to_render), + ), + end_expr=SqlStringLiteralExpression.create( + literal_value=time_range_constraint.end_time.strftime(time_format_to_render), + ), + ) + + +class DataflowNodeToSqlSubqueryVisitor(DataflowPlanNodeVisitor[SqlDataSet]): + """Generates a SQL query plan by converting parent nodes to a sub-query and the given node to a query. + + TODO: Split classes in this file to separate files. + """ + + def __init__( + self, + column_association_resolver: ColumnAssociationResolver, + semantic_manifest_lookup: SemanticManifestLookup, + ) -> None: + """Constructor. + + Args: + column_association_resolver: controls how columns for instances are generated and used between nested + queries. + semantic_manifest_lookup: Self-explanatory. + """ + self._column_association_resolver = column_association_resolver + self._semantic_manifest_lookup = semantic_manifest_lookup + self._metric_lookup = semantic_manifest_lookup.metric_lookup + self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup + self._time_spine_sources = TimeSpineSource.build_standard_time_spine_sources( + semantic_manifest_lookup.semantic_manifest + ) + self._custom_granularity_time_spine_sources = TimeSpineSource.build_custom_time_spine_sources( + tuple(self._time_spine_sources.values()) + ) + def _next_unique_table_alias(self) -> str: """Return the next unique table alias to use in generating queries.""" return SequentialIdGenerator.create_next_id(StaticIdPrefix.SUB_QUERY).str_value @@ -279,7 +314,7 @@ def _make_time_spine_data_set( select_columns: Tuple[SqlSelectColumn, ...] = () apply_group_by = True for agg_time_dimension_spec in required_time_spine_specs: - column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_spec).column_name + column_alias = self._column_association_resolver.resolve_spec(agg_time_dimension_spec).column_name # If the requested granularity is the same as the granularity of the spine, do a direct select. agg_time_grain = agg_time_dimension_spec.time_granularity if ( @@ -1224,8 +1259,10 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe column_equality_descriptions: List[ColumnEqualityDescription] = [] # Build Time Dimension SqlSelectColumn - time_dimension_column_name = self.column_association_resolver.resolve_spec(node.time_dimension_spec).column_name - join_time_dimension_column_name = self.column_association_resolver.resolve_spec( + time_dimension_column_name = self._column_association_resolver.resolve_spec( + node.time_dimension_spec + ).column_name + join_time_dimension_column_name = self._column_association_resolver.resolve_spec( node.time_dimension_spec.with_aggregation_state(AggregationState.COMPLETE), ).column_name time_dimension_select_column = SqlSelectColumn( @@ -1250,7 +1287,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe # Build optional window grouping SqlSelectColumn entity_select_columns: List[SqlSelectColumn] = [] for entity_spec in node.entity_specs: - entity_column_name = self.column_association_resolver.resolve_spec(entity_spec).column_name + entity_column_name = self._column_association_resolver.resolve_spec(entity_spec).column_name entity_select_columns.append( SqlSelectColumn( expr=SqlColumnReferenceExpression.create( @@ -1269,10 +1306,10 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe ) ) - # Propogate additional group by during query time of the non-additive time dimension + # Propagate additional group by during query time of the non-additive time dimension queried_time_dimension_select_column: Optional[SqlSelectColumn] = None if node.queried_time_dimension_spec: - query_time_dimension_column_name = self.column_association_resolver.resolve_spec( + query_time_dimension_column_name = self._column_association_resolver.resolve_spec( node.queried_time_dimension_spec ).column_name queried_time_dimension_select_column = SqlSelectColumn( @@ -1360,7 +1397,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description( node=node, time_spine_alias=time_spine_alias, - agg_time_dimension_column_name=self.column_association_resolver.resolve_spec( + agg_time_dimension_column_name=self._column_association_resolver.resolve_spec( agg_time_dimension_instance_for_join.spec ).column_name, parent_sql_select_node=parent_data_set.checked_sql_select_node,