diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 52621baabb..58f6e5ae5c 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -650,6 +650,7 @@ def _build_derived_metric_output_node( ), "Joining to time spine requires querying with metric_time or the appropriate agg_time_dimension." output_node = JoinToTimeSpineNode.create( parent_node=output_node, + time_spine_node=time_spine_node, replace_time_dimension_specs=queried_agg_time_dimension_specs, time_range_constraint=predicate_pushdown_state.time_range_constraint, offset_window=metric_spec.offset_window, @@ -1486,6 +1487,14 @@ def __get_required_and_extraneous_linkable_specs( return required_linkable_specs, extraneous_linkable_specs + # Delete this helper function if it doesn't get more complex + # Merrrr this isn't right + def _choose_time_spine_source(self, required_time_spine_specs: Sequence[TimeDimensionSpec]): + return TimeSpineSource.choose_time_spine_source( + required_time_spine_specs=required_time_spine_specs, + time_spine_sources=self._source_node_set.time_spine_metric_time_nodes, + ) + def _build_aggregated_measure_from_measure_source_node( self, metric_input_measure_spec: MetricInputMeasureSpec, @@ -1623,6 +1632,7 @@ def _build_aggregated_measure_from_measure_source_node( # in join rendering join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=time_range_node or measure_recipe.source_node, + time_spine_node=time_spine_node, replace_time_dimension_specs=queried_agg_time_dimension_specs, time_range_constraint=predicate_pushdown_state.time_range_constraint, offset_window=before_aggregation_time_spine_join_description.offset_window, @@ -1757,10 +1767,14 @@ def _build_aggregated_measure_from_measure_source_node( else: non_agg_time_filters.append(filter_spec) - # TODO: split this node into TimeSpineSourceNode and JoinToTimeSpineNode - then can use standard nodes here - # like JoinToCustomGranularityNode, WhereConstraintNode, etc. + # Did I actually need to store read SQL nodes? maybe. + time_spine_node = self._choose_time_spine_source(queried_agg_time_dimension_specs) + + # TODO: apply WhereConstraintNode & TimeConstraintNode here + output_node: DataflowPlanNode = JoinToTimeSpineNode.create( parent_node=aggregate_measures_node, + time_spine_node=time_spine_node, replace_time_dimension_specs=queried_agg_time_dimension_specs, join_type=after_aggregation_time_spine_join_description.join_type, time_range_constraint=predicate_pushdown_state.time_range_constraint, diff --git a/metricflow/dataflow/nodes/join_to_time_spine.py b/metricflow/dataflow/nodes/join_to_time_spine.py index 0760fbd2c8..00db07dee3 100644 --- a/metricflow/dataflow/nodes/join_to_time_spine.py +++ b/metricflow/dataflow/nodes/join_to_time_spine.py @@ -15,6 +15,7 @@ from metricflow_semantics.visitor import VisitorOutputT from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor +from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode @dataclass(frozen=True) @@ -22,6 +23,7 @@ class JoinToTimeSpineNode(DataflowPlanNode, ABC): """Join parent dataset to time spine dataset. Attributes: + time_spine_source_node: The source node that should be joined to the parent node. replace_time_dimension_specs: Time dimensions that should be replaced with columns from the time spine. join_type: Join type to use when joining to time spine. time_range_constraint: Time range to constrain the time spine to. @@ -30,7 +32,7 @@ class JoinToTimeSpineNode(DataflowPlanNode, ABC): """ # TODO: filter params; will apply where filters & time constraints separately using standard nodes - # TODO: add time_spine_source_node as a param + time_spine_source_node: ReadSqlSourceNode replace_time_dimension_specs: Sequence[TimeDimensionSpec] join_type: SqlJoinType time_range_constraint: Optional[TimeRangeConstraint] @@ -52,6 +54,7 @@ def __post_init__(self) -> None: # noqa: D105 @staticmethod def create( # noqa: D102 parent_node: DataflowPlanNode, + time_spine_source_node: ReadSqlSourceNode, replace_time_dimension_specs: Sequence[TimeDimensionSpec], join_type: SqlJoinType, time_range_constraint: Optional[TimeRangeConstraint] = None, @@ -61,6 +64,7 @@ def create( # noqa: D102 ) -> JoinToTimeSpineNode: return JoinToTimeSpineNode( parent_nodes=(parent_node,), + time_spine_source_node=time_spine_source_node, replace_time_dimension_specs=tuple(replace_time_dimension_specs), join_type=join_type, time_range_constraint=time_range_constraint, @@ -100,6 +104,8 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 ) return props + # TODO: should the time spine be considered a parent node? There must be downstream implications. + # e.g., if this node is used in an export, the time spine source should show up in the DAG, right? @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 return self.parent_nodes[0] @@ -107,6 +113,7 @@ def parent_node(self) -> DataflowPlanNode: # noqa: D102 def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return ( isinstance(other_node, self.__class__) + and self.time_spine_source_node == other_node.time_spine_source_node and other_node.time_range_constraint == self.time_range_constraint and other_node.offset_window == self.offset_window and other_node.offset_to_grain == self.offset_to_grain @@ -119,6 +126,7 @@ def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> Join assert len(new_parent_nodes) == 1 return JoinToTimeSpineNode.create( parent_node=new_parent_nodes[0], + time_spine_source_node=self.time_spine_source_node, replace_time_dimension_specs=self.replace_time_dimension_specs, time_range_constraint=self.time_range_constraint, offset_window=self.offset_window, diff --git a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py index 7208d7bdc2..427388e625 100644 --- a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py +++ b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py @@ -565,6 +565,7 @@ def test_compute_metrics_node_simple_expr( ) +# TODO: move these tests to normal rendering test section (unless already covered) @pytest.mark.sql_engine_snapshot def test_join_to_time_spine_node_without_offset( request: FixtureRequest, @@ -607,6 +608,7 @@ def test_join_to_time_spine_node_without_offset( ) join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=compute_metrics_node, + time_spine_node=time_spine_node, replace_time_dimension_specs=[MTD_SPEC_DAY], time_range_constraint=TimeRangeConstraint( start_time=as_datetime("2020-01-01"), end_time=as_datetime("2021-01-01") @@ -680,6 +682,7 @@ def test_join_to_time_spine_node_with_offset_window( ) join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=compute_metrics_node, + time_spine_node=time_spine_node, replace_time_dimension_specs=[MTD_SPEC_DAY], time_range_constraint=TimeRangeConstraint( start_time=as_datetime("2020-01-01"), end_time=as_datetime("2021-01-01") @@ -754,6 +757,7 @@ def test_join_to_time_spine_node_with_offset_to_grain( ) join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=compute_metrics_node, + time_spine_node=time_spine_node, replace_time_dimension_specs=[MTD_SPEC_DAY], time_range_constraint=TimeRangeConstraint( start_time=as_datetime("2020-01-01"), end_time=as_datetime("2021-01-01")