Skip to content

Commit

Permalink
Add new dataflow plan nodes for custom offset windows
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 18, 2024
1 parent 4765589 commit 65346f9
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 25 deletions.
2 changes: 2 additions & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce"
DATAFLOW_NODE_WINDOW_REAGGREGATION_ID_PREFIX = "wr"
DATAFLOW_NODE_ALIAS_SPECS_ID_PREFIX = "as"
DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX = "cgb"
DATAFLOW_NODE_OFFSET_BY_CUSTOMG_GRANULARITY_ID_PREFIX = "obcg"

SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX = "cr"
SQL_EXPR_COMPARISON_ID_PREFIX = "cmp"
Expand Down
132 changes: 107 additions & 25 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.specs.where_filter.where_filter_spec_set import WhereFilterSpecSet
from metricflow_semantics.specs.where_filter.where_filter_transform import WhereSpecFactory
from metricflow_semantics.sql.sql_exprs import SqlWindowFunction
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.dateutil_adjuster import DateutilTimePeriodAdjuster
Expand Down Expand Up @@ -84,6 +85,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -92,6 +94,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -658,13 +661,22 @@ def _build_derived_metric_output_node(
)
if metric_spec.has_time_offset and queried_agg_time_dimension_specs:
# TODO: move this to a helper method
time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs)
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=queried_agg_time_dimension_specs,
offset_window=metric_spec.offset_window,
)
output_node = JoinToTimeSpineNode.create(
metric_source_node=output_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0],
offset_window=metric_spec.offset_window,
offset_window=(
metric_spec.offset_window
if metric_spec.offset_window
and metric_spec.offset_window.granularity
not in self._semantic_model_lookup.custom_granularity_names
else None
),
offset_to_grain=metric_spec.offset_to_grain,
join_type=SqlJoinType.INNER,
)
Expand Down Expand Up @@ -1648,14 +1660,25 @@ def _build_aggregated_measure_from_measure_source_node(
join_on_time_dimension_spec = self._determine_time_spine_join_spec(
measure_properties=measure_properties, required_time_spine_specs=base_queried_agg_time_dimension_specs
)
required_time_spine_specs = (join_on_time_dimension_spec,) + base_queried_agg_time_dimension_specs
time_spine_node = self._build_time_spine_node(required_time_spine_specs)
required_time_spine_specs = base_queried_agg_time_dimension_specs
if join_on_time_dimension_spec not in required_time_spine_specs:
required_time_spine_specs = (join_on_time_dimension_spec,) + required_time_spine_specs
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=required_time_spine_specs,
offset_window=before_aggregation_time_spine_join_description.offset_window,
)
unaggregated_measure_node = JoinToTimeSpineNode.create(
metric_source_node=unaggregated_measure_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=base_queried_agg_time_dimension_specs,
join_on_time_dimension_spec=join_on_time_dimension_spec,
offset_window=before_aggregation_time_spine_join_description.offset_window,
offset_window=(
before_aggregation_time_spine_join_description.offset_window
if before_aggregation_time_spine_join_description.offset_window
and before_aggregation_time_spine_join_description.offset_window.granularity
not in self._semantic_model_lookup.custom_granularity_names
else None
),
offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain,
join_type=before_aggregation_time_spine_join_description.join_type,
)
Expand Down Expand Up @@ -1862,6 +1885,7 @@ def _build_time_spine_node(
queried_time_spine_specs: Sequence[TimeDimensionSpec],
where_filter_specs: Sequence[WhereFilterSpec] = (),
time_range_constraint: Optional[TimeRangeConstraint] = None,
offset_window: Optional[MetricTimeWindow] = None,
) -> DataflowPlanNode:
"""Return the time spine node needed to satisfy the specs."""
required_time_spine_spec_set = self.__get_required_linkable_specs(
Expand All @@ -1870,28 +1894,86 @@ def _build_time_spine_node(
)
required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs

# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
read_node = self._choose_time_spine_read_node(time_spine_source)
time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node)

# Change the column aliases to match the specs that were requested in the query.
time_spine_node = AliasSpecsNode.create(
parent_node=read_node,
change_specs=tuple(
SpecToAlias(
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(required_spec).spec,
output_spec=required_spec,
should_dedupe = False
if offset_window and offset_window.granularity in self._semantic_model_lookup._custom_granularities:
# Are sets the right choice here?
all_queried_grains: Set[ExpandedTimeGranularity] = set()
queried_custom_specs: Tuple[TimeDimensionSpec, ...] = ()
queried_standard_specs: Tuple[TimeDimensionSpec, ...] = ()
for spec in queried_time_spine_specs:
all_queried_grains.add(spec.time_granularity)
if spec.time_granularity.is_custom_granularity:
queried_custom_specs += (spec,)
else:
queried_standard_specs += (spec,)

custom_grain = self._semantic_model_lookup._custom_granularities[offset_window.granularity]
time_spine_source = self._choose_time_spine_source((DataSet.metric_time_dimension_spec(custom_grain),))
time_spine_read_node = self._choose_time_spine_read_node(time_spine_source)
# TODO: make sure this is checking the correct granularity type once DSI is updated
if {spec.time_granularity for spec in queried_time_spine_specs} == {custom_grain}:
# If querying with only the same grain as is used in the offset_window, can use a simpler plan.
# offset_node = OffsetCustomGranularityNode.create(
# parent_node=time_spine_read_node, offset_window=offset_window
# )
# time_spine_node: DataflowPlanNode = JoinToTimeSpineNode.create(
# metric_source_node=offset_node,
# # TODO: need to make sure we apply both agg time and metric time
# requested_agg_time_dimension_specs=queried_time_spine_specs,
# time_spine_node=time_spine_read_node,
# join_type=SqlJoinType.INNER,
# join_on_time_dimension_spec=custom_grain_metric_time_spec,
# )
pass
else:
bounds_node = CustomGranularityBoundsNode.create(
parent_node=time_spine_read_node,
custom_granularity_name=custom_grain.name,
)
for required_spec in required_time_spine_specs
),
)
bounds_data_set = self._node_data_set_resolver.get_output_data_set(bounds_node)
bounds_specs = tuple(
bounds_data_set.instance_from_window_function(window_func).spec
for window_func in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE)
)
custom_grain_spec = bounds_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=custom_grain.name, date_part=None
).spec
filter_elements_node = FilterElementsNode.create(
parent_node=bounds_node,
include_specs=InstanceSpecSet(time_dimension_specs=(custom_grain_spec,) + bounds_specs),
distinct=True,
)
time_spine_node: DataflowPlanNode = OffsetByCustomGranularityNode.create(
custom_granularity_bounds_node=bounds_node,
filter_elements_node=filter_elements_node,
offset_window=offset_window,
required_time_spine_specs=required_time_spine_specs,
)
else:
# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
read_node = self._choose_time_spine_read_node(time_spine_source)
time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node)

# Change the column aliases to match the specs that were requested in the query.
time_spine_node = AliasSpecsNode.create(
parent_node=read_node,
change_specs=tuple(
SpecToAlias(
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part
).spec,
output_spec=required_spec,
)
for required_spec in required_time_spine_specs
),
)

# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}
# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}

return self._build_pre_aggregation_plan(
source_node=time_spine_node,
Expand Down
22 changes: 22 additions & 0 deletions metricflow/dataflow/dataflow_plan_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -23,6 +24,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -126,6 +128,16 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> VisitorOutputT:
raise NotImplementedError


class DataflowPlanNodeVisitorWithDefaultHandler(DataflowPlanNodeVisitor[VisitorOutputT], Generic[VisitorOutputT]):
"""Similar to `DataflowPlanNodeVisitor`, but with an abstract default handler that gets called for each node.
Expand Down Expand Up @@ -222,3 +234,13 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
@override
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> VisitorOutputT:
return self._default_handler(node)
64 changes: 64 additions & 0 deletions metricflow/dataflow/nodes/custom_granularity_bounds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Sequence

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
class CustomGranularityBoundsNode(DataflowPlanNode, ABC):
"""Calculate the start and end of a custom granularity period and each row number within that period."""

custom_granularity_name: str

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, custom_granularity_name: str
) -> CustomGranularityBoundsNode:
return CustomGranularityBoundsNode(parent_nodes=(parent_node,), custom_granularity_name=custom_granularity_name)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_custom_granularity_bounds_node(self)

@property
def description(self) -> str: # noqa: D102
return """Calculate Custom Granularity Bounds"""

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("custom_granularity_name", self.custom_granularity_name),
)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.custom_granularity_name == self.custom_granularity_name
)

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> CustomGranularityBoundsNode:
assert len(new_parent_nodes) == 1
return CustomGranularityBoundsNode.create(
parent_node=new_parent_nodes[0], custom_granularity_name=self.custom_granularity_name
)
Loading

0 comments on commit 65346f9

Please sign in to comment.