Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time Spine Source Node #1543

Closed
wants to merge 11 commits into from
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,14 @@ def with_entity_prefix(self, entity_prefix: EntityReference) -> TimeDimensionSpe
)

@staticmethod
def with_base_grains(time_dimension_specs: Sequence[TimeDimensionSpec]) -> List[TimeDimensionSpec]:
def with_base_grains(time_dimension_specs: Sequence[TimeDimensionSpec]) -> Tuple[TimeDimensionSpec, ...]:
"""Return the list of time dimension specs, replacing any custom grains with base grains.

Dedupes new specs, but preserves the initial order.
"""
base_grain_specs: List[TimeDimensionSpec] = []
base_grain_specs: Tuple[TimeDimensionSpec, ...] = ()
for spec in time_dimension_specs:
base_grain_spec = spec.with_base_grain()
if base_grain_spec not in base_grain_specs:
base_grain_specs.append(base_grain_spec)
base_grain_specs += (base_grain_spec,)
return base_grain_specs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

logger = logging.getLogger(__name__)

TIME_SPINE_DATA_SET_DESCRIPTION = "Time Spine"


@dataclass(frozen=True)
class TimeSpineSource:
Expand Down Expand Up @@ -101,10 +99,10 @@ def build_custom_granularities(time_spine_sources: Sequence[TimeSpineSource]) ->
}

@staticmethod
def choose_time_spine_sources(
def choose_time_spine_source(
required_time_spine_specs: Sequence[TimeDimensionSpec],
time_spine_sources: Dict[TimeGranularity, TimeSpineSource],
) -> Sequence[TimeSpineSource]:
) -> TimeSpineSource:
"""Determine which time spine sources to use to satisfy the given specs.

Custom grains can only use the time spine where they are defined. For standard grains, this will choose the time
Expand Down Expand Up @@ -147,4 +145,15 @@ def choose_time_spine_sources(
if not required_time_spines.intersection(set(compatible_time_spines_for_standard_grains.values())):
required_time_spines.add(time_spine_sources[max(compatible_time_spines_for_standard_grains)])

return tuple(required_time_spines)
if len(required_time_spines) != 1:
raise RuntimeError(
"Multiple time spines are required to satisfy the specs, but only one is supported per query currently. "
f"Multiple will be supported in the future. Time spines required: {required_time_spines}."
)

return required_time_spines.pop()

@property
def data_set_description(self) -> str:
"""Description to be displayed when this time spine is used in a data set."""
return f"Read From Time Spine '{self.table_name}'"
180 changes: 136 additions & 44 deletions metricflow/dataflow/builder/dataflow_plan_builder.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
semantic_model_lookup: SemanticModelLookup,
nodes_available_for_joins: Sequence[DataflowPlanNode],
node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver,
time_spine_nodes: Sequence[MetricTimeDimensionTransformNode],
time_spine_metric_time_nodes: Sequence[MetricTimeDimensionTransformNode],
) -> None:
"""Initializer.

Expand All @@ -186,7 +186,7 @@ def __init__(
self._node_data_set_resolver = node_data_set_resolver
self._partition_resolver = PartitionJoinResolver(self._semantic_model_lookup)
self._join_evaluator = SemanticModelJoinEvaluator(self._semantic_model_lookup)
self._time_spine_nodes = time_spine_nodes
self._time_spine_metric_time_nodes = time_spine_metric_time_nodes

def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
self,
Expand All @@ -205,7 +205,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
linkable_specs_in_right_node = data_set_in_right_node.instance_set.spec_set.linkable_specs

# If right node is time spine source node, use cross join.
if right_node in self._time_spine_nodes:
if right_node in self._time_spine_metric_time_nodes:
satisfiable_metric_time_specs = [
spec for spec in linkable_specs_in_right_node if spec in needed_linkable_specs
]
Expand Down
28 changes: 18 additions & 10 deletions metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,23 @@ class SourceNodeSet:
# Semantic models are 1:1 mapped to a ReadSqlSourceNode.
source_nodes_for_group_by_item_queries: Tuple[DataflowPlanNode, ...]

# Provides the time spines.
time_spine_nodes: Mapping[TimeGranularity, MetricTimeDimensionTransformNode]
# Provides time spines that can be used to satisfy time spine joins.
time_spine_read_nodes: Mapping[TimeGranularity, ReadSqlSourceNode]

# Provides time spines that can be used to satisfy metric_time without metrics.
time_spine_metric_time_nodes: Mapping[TimeGranularity, MetricTimeDimensionTransformNode]

@property
def all_nodes(self) -> Sequence[DataflowPlanNode]: # noqa: D102
return (
self.source_nodes_for_metric_queries
+ self.source_nodes_for_group_by_item_queries
+ self.time_spine_nodes_tuple
+ self.time_spine_metric_time_nodes_tuple
)

@property
def time_spine_nodes_tuple(self) -> Tuple[MetricTimeDimensionTransformNode, ...]: # noqa: D102
return tuple(self.time_spine_nodes.values())
def time_spine_metric_time_nodes_tuple(self) -> Tuple[MetricTimeDimensionTransformNode, ...]: # noqa: D102
return tuple(self.time_spine_metric_time_nodes.values())


class SourceNodeBuilder:
Expand All @@ -65,11 +68,15 @@ def __init__( # noqa: D107
self.time_spine_sources = TimeSpineSource.build_standard_time_spine_sources(
semantic_manifest_lookup.semantic_manifest
)
self._time_spine_source_nodes = {}
for granularity, time_spine_source in self.time_spine_sources.items():

self._time_spine_read_nodes = {}
self._time_spine_metric_time_nodes = {}
for base_granularity, time_spine_source in self.time_spine_sources.items():
data_set = data_set_converter.build_time_spine_source_data_set(time_spine_source)
self._time_spine_source_nodes[granularity] = MetricTimeDimensionTransformNode.create(
parent_node=ReadSqlSourceNode.create(data_set),
read_node = ReadSqlSourceNode.create(data_set)
self._time_spine_read_nodes[base_granularity] = read_node
self._time_spine_metric_time_nodes[base_granularity] = MetricTimeDimensionTransformNode.create(
parent_node=read_node,
aggregation_time_dimension_reference=TimeDimensionReference(time_spine_source.base_column),
)

Expand Down Expand Up @@ -103,7 +110,8 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> So
source_nodes_for_metric_queries.append(metric_time_transform_node)

return SourceNodeSet(
time_spine_nodes=self._time_spine_source_nodes,
time_spine_metric_time_nodes=self._time_spine_metric_time_nodes,
time_spine_read_nodes=self._time_spine_read_nodes,
source_nodes_for_group_by_item_queries=tuple(group_by_item_source_nodes),
source_nodes_for_metric_queries=tuple(source_nodes_for_metric_queries),
)
Expand Down
11 changes: 10 additions & 1 deletion metricflow/dataflow/dataflow_plan_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
Expand Down Expand Up @@ -121,6 +122,10 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> V
def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_transform_time_dimensions_node(self, node: TransformTimeDimensionsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError


class DataflowPlanNodeVisitorWithDefaultHandler(DataflowPlanNodeVisitor[VisitorOutputT], Generic[VisitorOutputT]):
"""Similar to `DataflowPlanNodeVisitor`, but with an abstract default handler that gets called for each node.
Expand Down Expand Up @@ -191,7 +196,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> VisitorOu
@override
def visit_metric_time_dimension_transform_node( # noqa: D102
self, node: MetricTimeDimensionTransformNode
) -> VisitorOutputT: # noqa: D102
) -> VisitorOutputT:
return self._default_handler(node)

@override
Expand All @@ -213,3 +218,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> V
@override
def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_transform_time_dimensions_node(self, node: TransformTimeDimensionsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)
32 changes: 11 additions & 21 deletions metricflow/dataflow/nodes/join_to_time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from dbt_semantic_interfaces.type_enums import TimeGranularity
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.visitor import VisitorOutputT

Expand All @@ -25,17 +23,17 @@ class JoinToTimeSpineNode(DataflowPlanNode, ABC):
Attributes:
requested_agg_time_dimension_specs: Time dimensions requested in the query.
join_type: Join type to use when joining to time spine.
time_range_constraint: Time range to constrain the time spine to.
join_on_time_dimension_spec: The time dimension to use in the join ON condition.
offset_window: Time window to offset the parent dataset by when joining to time spine.
offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine.
"""

time_spine_node: DataflowPlanNode
requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec]
join_on_time_dimension_spec: TimeDimensionSpec
join_type: SqlJoinType
time_range_constraint: Optional[TimeRangeConstraint]
offset_window: Optional[MetricTimeWindow]
offset_to_grain: Optional[TimeGranularity]
time_spine_filters: Optional[Sequence[WhereFilterSpec]] = None

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
Expand All @@ -51,21 +49,21 @@ def __post_init__(self) -> None: # noqa: D105
@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode,
time_spine_node: DataflowPlanNode,
requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec],
join_on_time_dimension_spec: TimeDimensionSpec,
join_type: SqlJoinType,
time_range_constraint: Optional[TimeRangeConstraint] = None,
offset_window: Optional[MetricTimeWindow] = None,
offset_to_grain: Optional[TimeGranularity] = None,
time_spine_filters: Optional[Sequence[WhereFilterSpec]] = None,
) -> JoinToTimeSpineNode:
return JoinToTimeSpineNode(
parent_nodes=(parent_node,),
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=tuple(requested_agg_time_dimension_specs),
join_on_time_dimension_spec=join_on_time_dimension_spec,
join_type=join_type,
time_range_constraint=time_range_constraint,
offset_window=offset_window,
offset_to_grain=offset_to_grain,
time_spine_filters=time_spine_filters,
)

@classmethod
Expand All @@ -83,20 +81,13 @@ def description(self) -> str: # noqa: D102
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
props = tuple(super().displayed_properties) + (
DisplayedProperty("requested_agg_time_dimension_specs", self.requested_agg_time_dimension_specs),
DisplayedProperty("join_on_time_dimension_spec", self.join_on_time_dimension_spec),
DisplayedProperty("join_type", self.join_type),
)
if self.offset_window:
props += (DisplayedProperty("offset_window", self.offset_window),)
if self.offset_to_grain:
props += (DisplayedProperty("offset_to_grain", self.offset_to_grain),)
if self.time_range_constraint:
props += (DisplayedProperty("time_range_constraint", self.time_range_constraint),)
if self.time_spine_filters:
props += (
DisplayedProperty(
"time_spine_filters", [time_spine_filter.where_sql for time_spine_filter in self.time_spine_filters]
),
)
return props

@property
Expand All @@ -106,22 +97,21 @@ def parent_node(self) -> DataflowPlanNode: # noqa: D102
def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.time_range_constraint == self.time_range_constraint
and other_node.offset_window == self.offset_window
and other_node.offset_to_grain == self.offset_to_grain
and other_node.requested_agg_time_dimension_specs == self.requested_agg_time_dimension_specs
and other_node.join_on_time_dimension_spec == self.join_on_time_dimension_spec
and other_node.join_type == self.join_type
and other_node.time_spine_filters == self.time_spine_filters
)

def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> JoinToTimeSpineNode: # noqa: D102
assert len(new_parent_nodes) == 1
return JoinToTimeSpineNode.create(
parent_node=new_parent_nodes[0],
time_spine_node=self.time_spine_node,
requested_agg_time_dimension_specs=self.requested_agg_time_dimension_specs,
time_range_constraint=self.time_range_constraint,
offset_window=self.offset_window,
offset_to_grain=self.offset_to_grain,
join_type=self.join_type,
time_spine_filters=self.time_spine_filters,
join_on_time_dimension_spec=self.join_on_time_dimension_spec,
)
74 changes: 74 additions & 0 deletions metricflow/dataflow/nodes/transform_time_dimensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Sequence

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
class TransformTimeDimensionsNode(DataflowPlanNode, ABC):
"""Change the columns in the parent node to match the requested time dimension specs.

Args:
requested_time_dimension_specs: The time dimension specs to match in the parent node and transform.
"""

requested_time_dimension_specs: Sequence[TimeDimensionSpec]

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert (
len(self.requested_time_dimension_specs) > 0
), "Must have at least one value in requested_time_dimension_specs for TransformTimeDimensionsNode."

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, requested_time_dimension_specs: Sequence[TimeDimensionSpec]
) -> TransformTimeDimensionsNode:
return TransformTimeDimensionsNode(
parent_nodes=(parent_node,), requested_time_dimension_specs=requested_time_dimension_specs
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_JOIN_TO_CUSTOM_GRANULARITY_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_transform_time_dimensions_node(self)

@property
def description(self) -> str: # noqa: D102
return """Transform Time Dimension Columns"""

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("requested_time_dimension_specs", self.requested_time_dimension_specs),
)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.requested_time_dimension_specs == self.requested_time_dimension_specs
)

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> TransformTimeDimensionsNode:
assert len(new_parent_nodes) == 1, "TransformTimeDimensionsNode accepts exactly one parent node."
return TransformTimeDimensionsNode.create(
parent_node=new_parent_nodes[0],
requested_time_dimension_specs=self.requested_time_dimension_specs,
)
6 changes: 6 additions & 0 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
Expand Down Expand Up @@ -468,6 +469,11 @@ def visit_join_to_custom_granularity_node( # noqa: D102
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_transform_time_dimensions_node( # noqa: D102
self, node: TransformTimeDimensionsNode
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult:
"""Handles pushdown state propagation for the standard join node type.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
Expand Down Expand Up @@ -460,3 +461,9 @@ def visit_join_to_custom_granularity_node( # noqa: D102
def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_transform_time_dimensions_node( # noqa: D102
self, node: TransformTimeDimensionsNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)
Loading
Loading