Skip to content

Commit

Permalink
Make time spine a standard node in the dataflow plan
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 23, 2024
1 parent ca52355 commit 7ec9c75
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 100 deletions.
107 changes: 94 additions & 13 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import logging
import time
from typing import Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, Union
Expand Down Expand Up @@ -92,6 +93,7 @@
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
Expand Down Expand Up @@ -648,14 +650,19 @@ def _build_derived_metric_output_node(
metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup
)
if metric_spec.has_time_offset and queried_agg_time_dimension_specs:
# TODO: move this to a helper method
time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs)
output_node = JoinToTimeSpineNode.create(
parent_node=output_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0],
offset_window=metric_spec.offset_window,
offset_to_grain=metric_spec.offset_to_grain,
join_type=SqlJoinType.INNER,
)

# TODO: fix bug here where filter specs are being included in when aggregating.
if len(metric_spec.filter_spec_set.all_filter_specs) > 0 or predicate_pushdown_state.time_range_constraint:
# FilterElementsNode will only be needed if there are where filter specs that were selected in the group by.
specs_in_filters = set(
Expand Down Expand Up @@ -1616,15 +1623,22 @@ def _build_aggregated_measure_from_measure_source_node(

# If querying an offset metric, join to time spine before aggregation.
if before_aggregation_time_spine_join_description and base_queried_agg_time_dimension_specs:
# TODO: move all of this to a helper function
assert before_aggregation_time_spine_join_description.join_type is SqlJoinType.INNER, (
f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a "
f"new use case."
)
# This also uses the original time range constraint due to the application of the time window intervals
# in join rendering

join_on_time_dimension_spec = self._determine_time_spine_join_spec(
measure_properties=measure_properties, required_time_spine_specs=base_queried_agg_time_dimension_specs
)
required_time_spine_specs = (join_on_time_dimension_spec,) + base_queried_agg_time_dimension_specs
time_spine_node = self._build_time_spine_node(required_time_spine_specs)
unaggregated_measure_node = JoinToTimeSpineNode.create(
parent_node=unaggregated_measure_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=base_queried_agg_time_dimension_specs,
join_on_time_dimension_spec=join_on_time_dimension_spec,
offset_window=before_aggregation_time_spine_join_description.offset_window,
offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain,
join_type=before_aggregation_time_spine_join_description.join_type,
Expand Down Expand Up @@ -1667,10 +1681,13 @@ def _build_aggregated_measure_from_measure_source_node(
measure_reference=measure_spec.reference, semantic_model_lookup=self._semantic_model_lookup
)
if after_aggregation_time_spine_join_description and queried_agg_time_dimension_specs:
# TODO: move all of this to a helper function
assert after_aggregation_time_spine_join_description.join_type is SqlJoinType.LEFT_OUTER, (
f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if "
f"there's a new use case."
)
time_spine_required_specs = copy.deepcopy(queried_agg_time_dimension_specs)

# Find filters that contain only metric_time or agg_time_dimension. They will be applied to the time spine table.
agg_time_only_filters: List[WhereFilterSpec] = []
non_agg_time_filters: List[WhereFilterSpec] = []
Expand All @@ -1680,24 +1697,23 @@ def _build_aggregated_measure_from_measure_source_node(
)
if set(included_agg_time_specs) == set(filter_spec.linkable_spec_set.as_tuple):
agg_time_only_filters.append(filter_spec)
if filter_spec.linkable_spec_set.time_dimension_specs_with_custom_grain:
raise ValueError(
"Using custom granularity in filters for `join_to_timespine` metrics is not yet fully supported. "
"This feature is coming soon!"
)
for agg_time_spec in included_agg_time_specs:
if agg_time_spec not in time_spine_required_specs:
time_spine_required_specs.append(agg_time_spec)
else:
non_agg_time_filters.append(filter_spec)

# TODO: split this node into TimeSpineSourceNode and JoinToTimeSpineNode - then can use standard nodes here
# like JoinToCustomGranularityNode, WhereConstraintNode, etc.
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=queried_agg_time_dimension_specs,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
where_filter_specs=agg_time_only_filters,
)
output_node: DataflowPlanNode = JoinToTimeSpineNode.create(
parent_node=aggregate_measures_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0],
join_type=after_aggregation_time_spine_join_description.join_type,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
offset_window=after_aggregation_time_spine_join_description.offset_window,
offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain,
time_spine_filters=agg_time_only_filters,
)

# Since new rows might have been added due to time spine join, re-apply constraints here. Only re-apply filters
Expand Down Expand Up @@ -1824,3 +1840,68 @@ def _choose_time_spine_metric_time_node(
def _choose_time_spine_read_node(self, time_spine_source: TimeSpineSource) -> ReadSqlSourceNode:
"""Return the MetricTimeDimensionTransform time spine node needed to satisfy the specs."""
return self._source_node_set.time_spine_read_nodes[time_spine_source.base_granularity]

def _build_time_spine_node(
self,
queried_time_spine_specs: Sequence[TimeDimensionSpec],
where_filter_specs: Sequence[WhereFilterSpec] = (),
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> DataflowPlanNode:
"""Return the time spine node needed to satisfy the specs."""
required_time_spine_spec_set = self.__get_required_linkable_specs(
queried_linkable_specs=LinkableSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)),
filter_specs=where_filter_specs,
)
required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs

# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
time_spine_node = TransformTimeDimensionsNode.create(
parent_node=self._choose_time_spine_read_node(time_spine_source),
requested_time_dimension_specs=required_time_spine_specs,
)

# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}

return self._build_pre_aggregation_plan(
source_node=time_spine_node,
filter_to_specs=InstanceSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)),
time_range_constraint=time_range_constraint,
where_filter_specs=where_filter_specs,
distinct=should_dedupe,
)

def _sort_by_base_granularity(self, time_dimension_specs: Sequence[TimeDimensionSpec]) -> List[TimeDimensionSpec]:
"""Sort the time dimensions by their base granularity.
Specs with date part will come after specs without it. Standard grains will come before custom.
"""
return sorted(
time_dimension_specs,
key=lambda spec: (
spec.date_part is not None,
spec.time_granularity.is_custom_granularity,
spec.time_granularity.base_granularity.to_int(),
),
)

def _determine_time_spine_join_spec(
self, measure_properties: MeasureSpecProperties, required_time_spine_specs: Tuple[TimeDimensionSpec, ...]
) -> TimeDimensionSpec:
"""Determine the spec to join on for a time spine join.
Defaults to metric_time if it is included in the request, else the agg_time_dimension.
Will use the smallest available grain for the meeasure.
"""
join_spec_grain = ExpandedTimeGranularity.from_time_granularity(measure_properties.agg_time_dimension_grain)
join_on_time_dimension_spec = DataSet.metric_time_dimension_spec(time_granularity=join_spec_grain)
if not LinkableSpecSet(time_dimension_specs=required_time_spine_specs).contains_metric_time:
sample_agg_time_dimension_spec = required_time_spine_specs[0]
join_on_time_dimension_spec = sample_agg_time_dimension_spec.with_grain_and_date_part(
time_granularity=join_spec_grain, date_part=None
)
return join_on_time_dimension_spec
32 changes: 11 additions & 21 deletions metricflow/dataflow/nodes/join_to_time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from dbt_semantic_interfaces.type_enums import TimeGranularity
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.visitor import VisitorOutputT

Expand All @@ -25,17 +23,17 @@ class JoinToTimeSpineNode(DataflowPlanNode, ABC):
Attributes:
requested_agg_time_dimension_specs: Time dimensions requested in the query.
join_type: Join type to use when joining to time spine.
time_range_constraint: Time range to constrain the time spine to.
join_on_time_dimension_spec: The time dimension to use in the join ON condition.
offset_window: Time window to offset the parent dataset by when joining to time spine.
offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine.
"""

time_spine_node: DataflowPlanNode
requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec]
join_on_time_dimension_spec: TimeDimensionSpec
join_type: SqlJoinType
time_range_constraint: Optional[TimeRangeConstraint]
offset_window: Optional[MetricTimeWindow]
offset_to_grain: Optional[TimeGranularity]
time_spine_filters: Optional[Sequence[WhereFilterSpec]] = None

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
Expand All @@ -51,21 +49,21 @@ def __post_init__(self) -> None: # noqa: D105
@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode,
time_spine_node: DataflowPlanNode,
requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec],
join_on_time_dimension_spec: TimeDimensionSpec,
join_type: SqlJoinType,
time_range_constraint: Optional[TimeRangeConstraint] = None,
offset_window: Optional[MetricTimeWindow] = None,
offset_to_grain: Optional[TimeGranularity] = None,
time_spine_filters: Optional[Sequence[WhereFilterSpec]] = None,
) -> JoinToTimeSpineNode:
return JoinToTimeSpineNode(
parent_nodes=(parent_node,),
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=tuple(requested_agg_time_dimension_specs),
join_on_time_dimension_spec=join_on_time_dimension_spec,
join_type=join_type,
time_range_constraint=time_range_constraint,
offset_window=offset_window,
offset_to_grain=offset_to_grain,
time_spine_filters=time_spine_filters,
)

@classmethod
Expand All @@ -83,20 +81,13 @@ def description(self) -> str: # noqa: D102
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
props = tuple(super().displayed_properties) + (
DisplayedProperty("requested_agg_time_dimension_specs", self.requested_agg_time_dimension_specs),
DisplayedProperty("join_on_time_dimension_spec", self.join_on_time_dimension_spec),
DisplayedProperty("join_type", self.join_type),
)
if self.offset_window:
props += (DisplayedProperty("offset_window", self.offset_window),)
if self.offset_to_grain:
props += (DisplayedProperty("offset_to_grain", self.offset_to_grain),)
if self.time_range_constraint:
props += (DisplayedProperty("time_range_constraint", self.time_range_constraint),)
if self.time_spine_filters:
props += (
DisplayedProperty(
"time_spine_filters", [time_spine_filter.where_sql for time_spine_filter in self.time_spine_filters]
),
)
return props

@property
Expand All @@ -106,22 +97,21 @@ def parent_node(self) -> DataflowPlanNode: # noqa: D102
def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.time_range_constraint == self.time_range_constraint
and other_node.offset_window == self.offset_window
and other_node.offset_to_grain == self.offset_to_grain
and other_node.requested_agg_time_dimension_specs == self.requested_agg_time_dimension_specs
and other_node.join_on_time_dimension_spec == self.join_on_time_dimension_spec
and other_node.join_type == self.join_type
and other_node.time_spine_filters == self.time_spine_filters
)

def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> JoinToTimeSpineNode: # noqa: D102
assert len(new_parent_nodes) == 1
return JoinToTimeSpineNode.create(
parent_node=new_parent_nodes[0],
time_spine_node=self.time_spine_node,
requested_agg_time_dimension_specs=self.requested_agg_time_dimension_specs,
time_range_constraint=self.time_range_constraint,
offset_window=self.offset_window,
offset_to_grain=self.offset_to_grain,
join_type=self.join_type,
time_spine_filters=self.time_spine_filters,
join_on_time_dimension_spec=self.join_on_time_dimension_spec,
)
Loading

0 comments on commit 7ec9c75

Please sign in to comment.