Skip to content

Commit

Permalink
WIP - reviewing snapshots
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 22, 2024
1 parent cb4cc61 commit dcd97f6
Show file tree
Hide file tree
Showing 150 changed files with 6,799 additions and 4,945 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,14 @@ def with_entity_prefix(self, entity_prefix: EntityReference) -> TimeDimensionSpe
)

@staticmethod
def with_base_grains(time_dimension_specs: Sequence[TimeDimensionSpec]) -> List[TimeDimensionSpec]:
def with_base_grains(time_dimension_specs: Sequence[TimeDimensionSpec]) -> Tuple[TimeDimensionSpec, ...]:
"""Return the list of time dimension specs, replacing any custom grains with base grains.
Dedupes new specs, but preserves the initial order.
"""
base_grain_specs: List[TimeDimensionSpec] = []
base_grain_specs: Tuple[TimeDimensionSpec, ...] = ()
for spec in time_dimension_specs:
base_grain_spec = spec.with_base_grain()
if base_grain_spec not in base_grain_specs:
base_grain_specs.append(base_grain_spec)
base_grain_specs += (base_grain_spec,)
return base_grain_specs
138 changes: 110 additions & 28 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import logging
import time
from typing import Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, Union
Expand Down Expand Up @@ -87,9 +88,12 @@
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
Expand Down Expand Up @@ -642,13 +646,16 @@ def _build_derived_metric_output_node(

# For ratio / derived metrics with time offset, apply offset join here. Constraints will be applied after the offset
# to avoid filtering out values that will be changed.
if metric_spec.has_time_offset:
queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_metric(
metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup
)
queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_metric(
metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup
)
if metric_spec.has_time_offset and queried_agg_time_dimension_specs:
time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs)
output_node = JoinToTimeSpineNode.create(
parent_node=output_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0],
offset_window=metric_spec.offset_window,
offset_to_grain=metric_spec.offset_to_grain,
join_type=SqlJoinType.INNER,
Expand Down Expand Up @@ -677,7 +684,7 @@ def _build_derived_metric_output_node(

def _get_base_agg_time_dimensions(
self, queried_linkable_specs: LinkableSpecSet, measure_reference: MeasureReference
) -> Sequence[TimeDimensionSpec]:
) -> Tuple[TimeDimensionSpec, ...]:
"""Get queried agg_time_dimensios with their base grains, deduped.
Custom grains are joined right before measure aggregation and after all other pre-aggregation joins,
Expand Down Expand Up @@ -1037,8 +1044,7 @@ def _find_source_node_recipe_non_cached(
)
# If metric_time is requested without metrics, choose appropriate time spine node to select those values from.
if linkable_specs_to_satisfy.metric_time_specs:
time_spine_source = self._choose_time_spine_source(linkable_specs_to_satisfy.metric_time_specs)
time_spine_node = self._source_node_set.time_spine_metric_time_nodes[time_spine_source.base_granularity]
time_spine_node = self._choose_time_spine_metric_time_node(linkable_specs_to_satisfy.metric_time_specs)
candidate_nodes_for_right_side_of_join += [time_spine_node]
candidate_nodes_for_left_side_of_join += [time_spine_node]
default_join_type = SqlJoinType.FULL_OUTER
Expand Down Expand Up @@ -1591,17 +1597,17 @@ def _build_aggregated_measure_from_measure_source_node(
f"Unable to join all items in request. Measure: {measure_spec.element_name}; Specs to join: {required_linkable_specs}"
)

base_agg_time_dimension_specs = self._get_base_agg_time_dimensions(
base_queried_agg_time_dimension_specs = self._get_base_agg_time_dimensions(
queried_linkable_specs=queried_linkable_specs, measure_reference=measure_spec.reference
)

# If a cumulative metric is queried with metric_time / agg_time_dimension, join over time range.
# Otherwise, the measure will be aggregated over all time.
unaggregated_measure_node: DataflowPlanNode = measure_recipe.source_node
if cumulative and base_agg_time_dimension_specs:
if cumulative and base_queried_agg_time_dimension_specs:
unaggregated_measure_node = JoinOverTimeRangeNode.create(
parent_node=unaggregated_measure_node,
queried_agg_time_dimension_specs=tuple(base_agg_time_dimension_specs),
queried_agg_time_dimension_specs=tuple(base_queried_agg_time_dimension_specs),
window=cumulative_window,
grain_to_date=cumulative_grain_to_date,
# Note: we use the original constraint here because the JoinOverTimeRangeNode will eventually get
Expand All @@ -1614,16 +1620,32 @@ def _build_aggregated_measure_from_measure_source_node(
)

# If querying an offset metric, join to time spine before aggregation.
if before_aggregation_time_spine_join_description is not None:
if before_aggregation_time_spine_join_description and base_queried_agg_time_dimension_specs:
# TODO: move all of this to a helper function
assert before_aggregation_time_spine_join_description.join_type is SqlJoinType.INNER, (
f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a "
f"new use case."
)
# This also uses the original time range constraint due to the application of the time window intervals
# in join rendering

# Determine the spec to join on. Defaults to metric_time if it is included in the request, else agg_time.
# Will use the smallest available grain for the meeasure.
required_time_spine_specs = base_queried_agg_time_dimension_specs
join_spec_grain = ExpandedTimeGranularity.from_time_granularity(measure_properties.agg_time_dimension_grain)
join_on_time_dimension_spec = DataSet.metric_time_dimension_spec(time_granularity=join_spec_grain)
if not LinkableSpecSet(time_dimension_specs=required_time_spine_specs).contains_metric_time:
sample_agg_time_dimension_spec = required_time_spine_specs[0]
join_on_time_dimension_spec = sample_agg_time_dimension_spec.with_grain_and_date_part(
time_granularity=join_spec_grain, date_part=None
)
required_time_spine_specs += (join_on_time_dimension_spec,)
time_spine_node = self._build_time_spine_node(required_time_spine_specs)

# Join time spine to source node.
unaggregated_measure_node = JoinToTimeSpineNode.create(
parent_node=unaggregated_measure_node,
requested_agg_time_dimension_specs=base_agg_time_dimension_specs,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=base_queried_agg_time_dimension_specs,
join_on_time_dimension_spec=join_on_time_dimension_spec,
offset_window=before_aggregation_time_spine_join_description.offset_window,
offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain,
join_type=before_aggregation_time_spine_join_description.join_type,
Expand Down Expand Up @@ -1662,11 +1684,17 @@ def _build_aggregated_measure_from_measure_source_node(
after_aggregation_time_spine_join_description = (
metric_input_measure_spec.after_aggregation_time_spine_join_description
)
if after_aggregation_time_spine_join_description is not None:
queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_measure(
measure_reference=measure_spec.reference, semantic_model_lookup=self._semantic_model_lookup
)
if after_aggregation_time_spine_join_description and queried_agg_time_dimension_specs:
# TODO: move all of this to a helper function
assert after_aggregation_time_spine_join_description.join_type is SqlJoinType.LEFT_OUTER, (
f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if "
f"there's a new use case."
)
time_spine_required_specs = copy.deepcopy(queried_agg_time_dimension_specs)

# Find filters that contain only metric_time or agg_time_dimension. They will be applied to the time spine table.
agg_time_only_filters: List[WhereFilterSpec] = []
non_agg_time_filters: List[WhereFilterSpec] = []
Expand All @@ -1676,27 +1704,23 @@ def _build_aggregated_measure_from_measure_source_node(
)
if set(included_agg_time_specs) == set(filter_spec.linkable_spec_set.as_tuple):
agg_time_only_filters.append(filter_spec)
if filter_spec.linkable_spec_set.time_dimension_specs_with_custom_grain:
raise ValueError(
"Using custom granularity in filters for `join_to_timespine` metrics is not yet fully supported. "
"This feature is coming soon!"
)
for agg_time_spec in included_agg_time_specs:
if agg_time_spec not in time_spine_required_specs:
time_spine_required_specs.append(agg_time_spec)
else:
non_agg_time_filters.append(filter_spec)

# TODO: split this node into TimeSpineSourceNode and JoinToTimeSpineNode - then can use standard nodes here
# like JoinToCustomGranularityNode, WhereConstraintNode, etc.
queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_measure(
measure_reference=measure_spec.reference, semantic_model_lookup=self._semantic_model_lookup
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=queried_agg_time_dimension_specs,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
where_filter_specs=agg_time_only_filters,
)
output_node: DataflowPlanNode = JoinToTimeSpineNode.create(
parent_node=aggregate_measures_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0],
join_type=after_aggregation_time_spine_join_description.join_type,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
offset_window=after_aggregation_time_spine_join_description.offset_window,
offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain,
time_spine_filters=agg_time_only_filters,
)

# Since new rows might have been added due to time spine join, re-apply constraints here. Only re-apply filters
Expand Down Expand Up @@ -1812,3 +1836,61 @@ def _choose_time_spine_source(self, required_time_spine_specs: Sequence[TimeDime
required_time_spine_specs=required_time_spine_specs,
time_spine_sources=self._source_node_builder.time_spine_sources,
)

def _choose_time_spine_metric_time_node(
self, required_time_spine_specs: Sequence[TimeDimensionSpec]
) -> MetricTimeDimensionTransformNode:
"""Return the MetricTimeDimensionTransform time spine node needed to satisfy the specs."""
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
return self._source_node_set.time_spine_metric_time_nodes[time_spine_source.base_granularity]

def _choose_time_spine_read_node(self, time_spine_source: TimeSpineSource) -> ReadSqlSourceNode:
"""Return the MetricTimeDimensionTransform time spine node needed to satisfy the specs."""
return self._source_node_set.time_spine_read_nodes[time_spine_source.base_granularity.value]

def _build_time_spine_node(
self,
queried_time_spine_specs: Sequence[TimeDimensionSpec],
where_filter_specs: Sequence[WhereFilterSpec] = (),
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> DataflowPlanNode:
"""Return the time spine node needed to satisfy the specs."""
required_time_spine_spec_set = self.__get_required_linkable_specs(
queried_linkable_specs=LinkableSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)),
filter_specs=where_filter_specs,
)
required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs

time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
time_spine_node = TransformTimeDimensionsNode.create(
parent_node=self._choose_time_spine_read_node(time_spine_source),
requested_time_dimension_specs=required_time_spine_specs,
)

# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}

return self._build_pre_aggregation_plan(
source_node=time_spine_node,
filter_to_specs=InstanceSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)),
custom_granularity_specs=required_time_spine_spec_set.time_dimension_specs_with_custom_grain,
time_range_constraint=time_range_constraint,
where_filter_specs=where_filter_specs,
distinct=should_dedupe,
)

def _sort_by_base_granularity(self, time_dimension_specs: Sequence[TimeDimensionSpec]) -> List[TimeDimensionSpec]:
"""Sort the time dimensions by their base granularity.
Specs with date part will come after specs without it. Standard grains will come before custom.
"""
return sorted(
time_dimension_specs,
key=lambda spec: (
spec.date_part is not None,
spec.time_granularity.is_custom_granularity,
spec.time_granularity.base_granularity.to_int(),
),
)
1 change: 1 addition & 0 deletions metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SourceNodeSet:
# Semantic models are 1:1 mapped to a ReadSqlSourceNode.
source_nodes_for_group_by_item_queries: Tuple[DataflowPlanNode, ...]

# TODO: maybe this didn't need to have string keys, check later
# Provides time spines that can be used to satisfy time spine joins, organized by granularity name.
time_spine_read_nodes: Mapping[str, ReadSqlSourceNode]

Expand Down
Loading

0 comments on commit dcd97f6

Please sign in to comment.