Skip to content

Commit

Permalink
/* PR_START p--cte 12 */ Separate out visitor from SQL converter.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 10, 2024
1 parent e65749f commit 22bb9de
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 50 deletions.
20 changes: 13 additions & 7 deletions metricflow/dataflow/builder/node_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
DataflowPlanNode,
)
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.plan_conversion.dataflow_to_sql import (
DataflowNodeToSqlSubqueryVisitor,
)

if TYPE_CHECKING:
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup


class DataflowPlanNodeOutputDataSetResolver(DataflowToSqlQueryPlanConverter):
class DataflowPlanNodeOutputDataSetResolver:
"""Given a node in a dataflow plan, figure out what is the data set output by that node.
Recall that in the dataflow plan, the nodes represent computation, and the inputs and outputs of the nodes are
Expand Down Expand Up @@ -68,9 +70,11 @@ def __init__( # noqa: D107
_node_to_output_data_set: Optional[Dict[DataflowPlanNode, SqlDataSet]] = None,
) -> None:
self._node_to_output_data_set: Dict[DataflowPlanNode, SqlDataSet] = _node_to_output_data_set or {}
super().__init__(
column_association_resolver=column_association_resolver,
semantic_manifest_lookup=semantic_manifest_lookup,
self._column_association_resolver = column_association_resolver
self._semantic_manifest_lookup = semantic_manifest_lookup
self._to_data_set_visitor = _NodeDataSetVisitor(
column_association_resolver=self._column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
)

def get_output_data_set(self, node: DataflowPlanNode) -> SqlDataSet:
Expand All @@ -79,7 +83,7 @@ def get_output_data_set(self, node: DataflowPlanNode) -> SqlDataSet:
# TODO: The cache needs to be pruned, but has not yet been an issue.
"""
if node not in self._node_to_output_data_set:
self._node_to_output_data_set[node] = node.accept(self)
self._node_to_output_data_set[node] = node.accept(self._to_data_set_visitor)

return self._node_to_output_data_set[node]

Expand All @@ -92,11 +96,13 @@ def cache_output_data_sets(self, nodes: Sequence[DataflowPlanNode]) -> None:
def copy(self) -> DataflowPlanNodeOutputDataSetResolver:
"""Return a copy of this with the same nodes cached."""
return DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=self.column_association_resolver,
column_association_resolver=self._column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
_node_to_output_data_set=dict(self._node_to_output_data_set),
)


class _NodeDataSetVisitor(DataflowNodeToSqlSubqueryVisitor):
@override
def _next_unique_table_alias(self) -> str:
"""Return the next unique table alias to use in generating queries.
Expand Down
120 changes: 77 additions & 43 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,41 +143,8 @@
logger = logging.getLogger(__name__)


def _make_time_range_comparison_expr(
table_alias: str, column_alias: str, time_range_constraint: TimeRangeConstraint
) -> SqlExpressionNode:
"""Build an expression like "ds BETWEEN CAST('2020-01-01' AS TIMESTAMP) AND CAST('2020-01-02' AS TIMESTAMP).
If the constraint uses day or larger grain, only render to the date level. Otherwise, render to the timestamp level.
"""

def strip_time_from_dt(ts: dt.datetime) -> dt.datetime:
date_obj = ts.date()
return dt.datetime(date_obj.year, date_obj.month, date_obj.day)

constraint_uses_day_or_larger_grain = True
for constraint_input in (time_range_constraint.start_time, time_range_constraint.end_time):
if strip_time_from_dt(constraint_input) != constraint_input:
constraint_uses_day_or_larger_grain = False
break

time_format_to_render = ISO8601_PYTHON_FORMAT if constraint_uses_day_or_larger_grain else ISO8601_PYTHON_TS_FORMAT

return SqlBetweenExpression.create(
column_arg=SqlColumnReferenceExpression.create(
SqlColumnReference(table_alias=table_alias, column_name=column_alias)
),
start_expr=SqlStringLiteralExpression.create(
literal_value=time_range_constraint.start_time.strftime(time_format_to_render),
),
end_expr=SqlStringLiteralExpression.create(
literal_value=time_range_constraint.end_time.strftime(time_format_to_render),
),
)


class DataflowToSqlQueryPlanConverter(DataflowPlanNodeVisitor[SqlDataSet]):
"""Generates an SQL query plan from a node in the a metric dataflow plan."""
class DataflowToSqlQueryPlanConverter:
"""Generates an SQL query plan from a node in the metric dataflow plan."""

def __init__(
self,
Expand Down Expand Up @@ -214,7 +181,12 @@ def convert_to_sql_query_plan(
sql_query_plan_id: Optional[DagId] = None,
) -> ConvertToSqlPlanResult:
"""Create an SQL query plan that represents the computation up to the given dataflow plan node."""
data_set = dataflow_plan_node.accept(self)
to_sql_visitor = DataflowNodeToSqlSubqueryVisitor(
column_association_resolver=self.column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
)
data_set = dataflow_plan_node.accept(to_sql_visitor)

sql_node: SqlQueryPlanNode = data_set.sql_node
# TODO: Make this a more generally accessible attribute instead of checking against the
# BigQuery-ness of the engine
Expand All @@ -237,6 +209,66 @@ def convert_to_sql_query_plan(
sql_plan=SqlQueryPlan(render_node=sql_node, plan_id=sql_query_plan_id),
)


def _make_time_range_comparison_expr(
table_alias: str, column_alias: str, time_range_constraint: TimeRangeConstraint
) -> SqlExpressionNode:
"""Build an expression like "ds BETWEEN CAST('2020-01-01' AS TIMESTAMP) AND CAST('2020-01-02' AS TIMESTAMP).
If the constraint uses day or larger grain, only render to the date level. Otherwise, render to the timestamp level.
"""

def strip_time_from_dt(ts: dt.datetime) -> dt.datetime:
date_obj = ts.date()
return dt.datetime(date_obj.year, date_obj.month, date_obj.day)

constraint_uses_day_or_larger_grain = True
for constraint_input in (time_range_constraint.start_time, time_range_constraint.end_time):
if strip_time_from_dt(constraint_input) != constraint_input:
constraint_uses_day_or_larger_grain = False
break

time_format_to_render = ISO8601_PYTHON_FORMAT if constraint_uses_day_or_larger_grain else ISO8601_PYTHON_TS_FORMAT

return SqlBetweenExpression.create(
column_arg=SqlColumnReferenceExpression.create(
SqlColumnReference(table_alias=table_alias, column_name=column_alias)
),
start_expr=SqlStringLiteralExpression.create(
literal_value=time_range_constraint.start_time.strftime(time_format_to_render),
),
end_expr=SqlStringLiteralExpression.create(
literal_value=time_range_constraint.end_time.strftime(time_format_to_render),
),
)


class DataflowNodeToSqlSubqueryVisitor(DataflowPlanNodeVisitor[SqlDataSet]):
"""Generates a SQL query plan by converting parent nodes to a sub-query and the given node to a query."""

def __init__(
self,
column_association_resolver: ColumnAssociationResolver,
semantic_manifest_lookup: SemanticManifestLookup,
) -> None:
"""Constructor.
Args:
column_association_resolver: controls how columns for instances are generated and used between nested
queries.
semantic_manifest_lookup: Self-explanatory.
"""
self._column_association_resolver = column_association_resolver
self._semantic_manifest_lookup = semantic_manifest_lookup
self._metric_lookup = semantic_manifest_lookup.metric_lookup
self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup
self._time_spine_sources = TimeSpineSource.build_standard_time_spine_sources(
semantic_manifest_lookup.semantic_manifest
)
self._custom_granularity_time_spine_sources = TimeSpineSource.build_custom_time_spine_sources(
tuple(self._time_spine_sources.values())
)

def _next_unique_table_alias(self) -> str:
"""Return the next unique table alias to use in generating queries."""
return SequentialIdGenerator.create_next_id(StaticIdPrefix.SUB_QUERY).str_value
Expand Down Expand Up @@ -279,7 +311,7 @@ def _make_time_spine_data_set(
select_columns: Tuple[SqlSelectColumn, ...] = ()
apply_group_by = True
for agg_time_dimension_spec in required_time_spine_specs:
column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_spec).column_name
column_alias = self._column_association_resolver.resolve_spec(agg_time_dimension_spec).column_name
# If the requested granularity is the same as the granularity of the spine, do a direct select.
agg_time_grain = agg_time_dimension_spec.time_granularity
if (
Expand Down Expand Up @@ -1224,8 +1256,10 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe
column_equality_descriptions: List[ColumnEqualityDescription] = []

# Build Time Dimension SqlSelectColumn
time_dimension_column_name = self.column_association_resolver.resolve_spec(node.time_dimension_spec).column_name
join_time_dimension_column_name = self.column_association_resolver.resolve_spec(
time_dimension_column_name = self._column_association_resolver.resolve_spec(
node.time_dimension_spec
).column_name
join_time_dimension_column_name = self._column_association_resolver.resolve_spec(
node.time_dimension_spec.with_aggregation_state(AggregationState.COMPLETE),
).column_name
time_dimension_select_column = SqlSelectColumn(
Expand All @@ -1250,7 +1284,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe
# Build optional window grouping SqlSelectColumn
entity_select_columns: List[SqlSelectColumn] = []
for entity_spec in node.entity_specs:
entity_column_name = self.column_association_resolver.resolve_spec(entity_spec).column_name
entity_column_name = self._column_association_resolver.resolve_spec(entity_spec).column_name
entity_select_columns.append(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
Expand All @@ -1269,10 +1303,10 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe
)
)

# Propogate additional group by during query time of the non-additive time dimension
# Propagate additional group by during query time of the non-additive time dimension
queried_time_dimension_select_column: Optional[SqlSelectColumn] = None
if node.queried_time_dimension_spec:
query_time_dimension_column_name = self.column_association_resolver.resolve_spec(
query_time_dimension_column_name = self._column_association_resolver.resolve_spec(
node.queried_time_dimension_spec
).column_name
queried_time_dimension_select_column = SqlSelectColumn(
Expand Down Expand Up @@ -1360,7 +1394,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description(
node=node,
time_spine_alias=time_spine_alias,
agg_time_dimension_column_name=self.column_association_resolver.resolve_spec(
agg_time_dimension_column_name=self._column_association_resolver.resolve_spec(
agg_time_dimension_instance_for_join.spec
).column_name,
parent_sql_select_node=parent_data_set.checked_sql_select_node,
Expand Down

0 comments on commit 22bb9de

Please sign in to comment.