Skip to content

Commit

Permalink
Add new dataflow plan nodes for custom offset windows
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 18, 2024
1 parent 74c5cf7 commit f267e57
Show file tree
Hide file tree
Showing 13 changed files with 576 additions and 11 deletions.
2 changes: 2 additions & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce"
DATAFLOW_NODE_WINDOW_REAGGREGATION_ID_PREFIX = "wr"
DATAFLOW_NODE_ALIAS_SPECS_ID_PREFIX = "as"
DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX = "cgb"
DATAFLOW_NODE_OFFSET_BY_CUSTOMG_GRANULARITY_ID_PREFIX = "obcg"

SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX = "cr"
SQL_EXPR_COMPARISON_ID_PREFIX = "cmp"
Expand Down
8 changes: 4 additions & 4 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,9 +1648,7 @@ def _build_aggregated_measure_from_measure_source_node(
join_on_time_dimension_spec = self._determine_time_spine_join_spec(
measure_properties=measure_properties, required_time_spine_specs=base_queried_agg_time_dimension_specs
)
required_time_spine_specs = base_queried_agg_time_dimension_specs
if join_on_time_dimension_spec not in required_time_spine_specs:
required_time_spine_specs = (join_on_time_dimension_spec,) + required_time_spine_specs
required_time_spine_specs = (join_on_time_dimension_spec,) + base_queried_agg_time_dimension_specs
time_spine_node = self._build_time_spine_node(required_time_spine_specs)
unaggregated_measure_node = JoinToTimeSpineNode.create(
metric_source_node=unaggregated_measure_node,
Expand Down Expand Up @@ -1883,7 +1881,9 @@ def _build_time_spine_node(
parent_node=read_node,
change_specs=tuple(
SpecToAlias(
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(required_spec).spec,
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part
).spec,
output_spec=required_spec,
)
for required_spec in required_time_spine_specs
Expand Down
22 changes: 22 additions & 0 deletions metricflow/dataflow/dataflow_plan_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -23,6 +24,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -126,6 +128,16 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> VisitorOutputT:
raise NotImplementedError


class DataflowPlanNodeVisitorWithDefaultHandler(DataflowPlanNodeVisitor[VisitorOutputT], Generic[VisitorOutputT]):
"""Similar to `DataflowPlanNodeVisitor`, but with an abstract default handler that gets called for each node.
Expand Down Expand Up @@ -222,3 +234,13 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
@override
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> VisitorOutputT:
return self._default_handler(node)
64 changes: 64 additions & 0 deletions metricflow/dataflow/nodes/custom_granularity_bounds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Sequence

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
class CustomGranularityBoundsNode(DataflowPlanNode, ABC):
"""Calculate the start and end of a custom granularity period and each row number within that period."""

custom_granularity_name: str

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, custom_granularity_name: str
) -> CustomGranularityBoundsNode:
return CustomGranularityBoundsNode(parent_nodes=(parent_node,), custom_granularity_name=custom_granularity_name)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_custom_granularity_bounds_node(self)

@property
def description(self) -> str: # noqa: D102
return """Calculate Custom Granularity Bounds"""

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("custom_granularity_name", self.custom_granularity_name),
)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.custom_granularity_name == self.custom_granularity_name
)

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> CustomGranularityBoundsNode:
assert len(new_parent_nodes) == 1
return CustomGranularityBoundsNode.create(
parent_node=new_parent_nodes[0], custom_granularity_name=self.custom_granularity_name
)
95 changes: 95 additions & 0 deletions metricflow/dataflow/nodes/offset_by_custom_granularity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Optional, Sequence

from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode


@dataclass(frozen=True, eq=False)
class OffsetByCustomGranularityNode(DataflowPlanNode, ABC):
"""For a given custom grain, offset its base grain by the requested number of custom grain periods.
Only accepts CustomGranularityBoundsNode as parent node.
"""

offset_window: MetricTimeWindow
required_time_spine_specs: Sequence[TimeDimensionSpec]
custom_granularity_bounds_node: CustomGranularityBoundsNode
filter_elements_node: FilterElementsNode

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()

@staticmethod
def create( # noqa: D102
custom_granularity_bounds_node: CustomGranularityBoundsNode,
filter_elements_node: FilterElementsNode,
offset_window: MetricTimeWindow,
required_time_spine_specs: Sequence[TimeDimensionSpec],
) -> OffsetByCustomGranularityNode:
return OffsetByCustomGranularityNode(
parent_nodes=(custom_granularity_bounds_node, filter_elements_node),
custom_granularity_bounds_node=custom_granularity_bounds_node,
filter_elements_node=filter_elements_node,
offset_window=offset_window,
required_time_spine_specs=required_time_spine_specs,
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_OFFSET_BY_CUSTOMG_GRANULARITY_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_offset_by_custom_granularity_node(self)

@property
def description(self) -> str: # noqa: D102
return """Offset Base Granularity By Custom Granularity Period(s)"""

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("offset_window", self.offset_window),
DisplayedProperty("required_time_spine_specs", self.required_time_spine_specs),
)

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.offset_window == self.offset_window
and other_node.required_time_spine_specs == self.required_time_spine_specs
)

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> OffsetByCustomGranularityNode:
custom_granularity_bounds_node: Optional[CustomGranularityBoundsNode] = None
filter_elements_node: Optional[FilterElementsNode] = None
for parent_node in new_parent_nodes:
if isinstance(parent_node, CustomGranularityBoundsNode):
custom_granularity_bounds_node = parent_node
elif isinstance(parent_node, FilterElementsNode):
filter_elements_node = parent_node
assert custom_granularity_bounds_node and filter_elements_node, (
"Can't rewrite OffsetByCustomGranularityNode because the node requires a CustomGranularityBoundsNode and a "
f"FilterElementsNode as parents. Instead, got: {new_parent_nodes}"
)

return OffsetByCustomGranularityNode(
parent_nodes=tuple(new_parent_nodes),
custom_granularity_bounds_node=custom_granularity_bounds_node,
filter_elements_node=filter_elements_node,
offset_window=self.offset_window,
required_time_spine_specs=self.required_time_spine_specs,
)
12 changes: 12 additions & 0 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -31,6 +32,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -472,6 +474,16 @@ def visit_join_to_custom_granularity_node( # noqa: D102
def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102
raise NotImplementedError

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult:
"""Handles pushdown state propagation for the standard join node type.
Expand Down
14 changes: 14 additions & 0 deletions metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -25,6 +26,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -472,3 +474,15 @@ def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerRe
def visit_alias_specs_node(self, node: AliasSpecsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)
14 changes: 14 additions & 0 deletions metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -27,6 +28,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -363,3 +365,15 @@ def visit_min_max_node(self, node: MinMaxNode) -> OptimizeBranchResult: # noqa:
def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
Loading

0 comments on commit f267e57

Please sign in to comment.