Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataflowPlan for custom granularities #1382

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240827-112415.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add JoinToCustomGranularityNode to DataflowPlan.
time: 2024-08-27T11:24:15.909853-07:00
custom:
Author: courtneyholcomb
Issue: "1382"
2 changes: 1 addition & 1 deletion extra-hatch-configuration/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Jinja2>=3.1.3
dbt-semantic-interfaces==0.7.1.dev0
dbt-semantic-interfaces==0.7.1.dev2
more-itertools>=8.10.0, <10.2.0
pydantic>=1.10.0, <3.0
tabulate>=0.8.9
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
dbt-semantic-interfaces==0.7.1.dev0
dbt-semantic-interfaces==0.7.1.dev2
graphviz>=0.18.2, <0.21
python-dateutil>=2.9.0, <2.10.0
rapidfuzz>=3.0, <4.0
1 change: 1 addition & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
DATAFLOW_NODE_SET_MEASURE_AGGREGATION_TIME = "sma"
DATAFLOW_NODE_SEMI_ADDITIVE_JOIN_ID_PREFIX = "saj"
DATAFLOW_NODE_JOIN_TO_TIME_SPINE_ID_PREFIX = "jts"
DATAFLOW_NODE_JOIN_TO_CUSTOM_GRANULARITY_ID_PREFIX = "jcg"
DATAFLOW_NODE_MIN_MAX_ID_PREFIX = "mm"
DATAFLOW_NODE_ADD_UUID_COLUMN_PREFIX = "auid"
DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
# Sort semantic models by name for consistency in building derived objects.
self._semantic_models = sorted(self._semantic_manifest.semantic_models, key=lambda x: x.name)
self._join_evaluator = SemanticModelJoinEvaluator(semantic_model_lookup)
self._time_spine_sources = TimeSpineSource.create_from_manifest(self._semantic_manifest)
self._time_spine_sources = TimeSpineSource.build_standard_time_spine_sources(self._semantic_manifest)

assert max_entity_links >= 0
self._max_entity_links = max_entity_links
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,17 @@ def with_grain(self, time_granularity: TimeGranularity) -> TimeDimensionSpec: #
aggregation_state=self.aggregation_state,
)

def with_grain_and_date_part( # noqa: D102
self, time_granularity: TimeGranularity, date_part: Optional[DatePart]
) -> TimeDimensionSpec:
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=self.entity_links,
time_granularity=time_granularity,
date_part=date_part,
aggregation_state=self.aggregation_state,
)

def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
# Note: This file may be symlinked.
project_configuration:
time_spine_table_configurations:
time_spine_table_configurations: # TODO: update DSI json schema so this is not required
- location: $source_schema.mf_time_spine
column_name: ds
grain: day
Expand Down Expand Up @@ -42,3 +42,11 @@ project_configuration:
primary_column:
name: ts
time_granularity: hour
- node_relation:
alias: mf_time_spine
schema_name: $source_schema
primary_column:
name: ds
time_granularity: day
custom_granularities:
- name: martian_day
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import Dict, Optional, Sequence

from dbt_semantic_interfaces.implementations.time_spine import PydanticTimeSpineCustomGranularityColumn
from dbt_semantic_interfaces.protocols import SemanticManifest
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

Expand All @@ -29,24 +30,33 @@ class TimeSpineSource:
# The time granularity of the base column.
base_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY
db_name: Optional[str] = None
custom_granularity_columns: Sequence[str] = ()
custom_granularities: Sequence[PydanticTimeSpineCustomGranularityColumn] = ()

@property
def spine_table(self) -> SqlTable:
"""Table containing all dates."""
return SqlTable(schema_name=self.schema_name, table_name=self.table_name, db_name=self.db_name)

@staticmethod
def create_from_manifest(semantic_manifest: SemanticManifest) -> Dict[TimeGranularity, TimeSpineSource]:
"""Creates a time spine source based on what's in the manifest."""
def build_standard_time_spine_sources(
semantic_manifest: SemanticManifest,
) -> Dict[TimeGranularity, TimeSpineSource]:
"""Creates a set of time spine sources with standard granularities based on what's in the manifest."""
time_spine_sources = {
time_spine.primary_column.time_granularity: TimeSpineSource(
schema_name=time_spine.node_relation.schema_name,
table_name=time_spine.node_relation.alias,
db_name=time_spine.node_relation.database,
base_column=time_spine.primary_column.name,
base_granularity=time_spine.primary_column.time_granularity,
custom_granularity_columns=[column.name for column in time_spine.custom_granularity_columns],
custom_granularities=tuple(
[
PydanticTimeSpineCustomGranularityColumn(
name=custom_granularity.name, column_name=custom_granularity.column_name
)
for custom_granularity in time_spine.custom_granularities
]
),
)
for time_spine in semantic_manifest.project_configuration.time_spines
}
Expand All @@ -72,3 +82,12 @@ def create_from_manifest(semantic_manifest: SemanticManifest) -> Dict[TimeGranul
)

return time_spine_sources

@staticmethod
def build_custom_time_spine_sources(time_spine_sources: Sequence[TimeSpineSource]) -> Dict[str, TimeSpineSource]:
"""Creates a set of time spine sources with custom granularities based on what's in the manifest."""
return {
custom_granularity.name: time_spine_source
for time_spine_source in time_spine_sources
for custom_granularity in time_spine_source.custom_granularities
}
15 changes: 15 additions & 0 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
Expand Down Expand Up @@ -783,6 +784,13 @@ def _build_plan_for_distinct_values(
if dataflow_recipe.join_targets:
output_node = JoinOnEntitiesNode.create(left_node=output_node, join_targets=dataflow_recipe.join_targets)

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
# TODO: Update conditional when avail - if the time dimension spec has a custom granularity
if False:
output_node = JoinToCustomGranularityNode.create(
parent_node=output_node, time_dimension_spec=time_dimension_spec
)

if len(query_level_filter_specs) > 0:
output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=query_level_filter_specs)
if query_spec.time_range_constraint:
Expand Down Expand Up @@ -1517,6 +1525,13 @@ def _build_aggregated_measure_from_measure_source_node(
else:
unaggregated_measure_node = filtered_measure_source_node

for time_dimension_spec in queried_linkable_specs.time_dimension_specs:
# TODO: Update conditional when avail - if the time dimension spec has a custom granularity
if False:
unaggregated_measure_node = JoinToCustomGranularityNode.create(
parent_node=unaggregated_measure_node, time_dimension_spec=time_dimension_spec
)

# If time constraint was previously adjusted for cumulative window or grain, apply original time constraint
# here. Can skip if metric is being aggregated over all time.
cumulative_metric_constrained_node: Optional[ConstrainTimeRangeNode] = None
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__( # noqa: D107
self._semantic_manifest_lookup = semantic_manifest_lookup
data_set_converter = SemanticModelToDataSetConverter(column_association_resolver)
self._time_spine_source_nodes = {}
for granularity, time_spine_source in TimeSpineSource.create_from_manifest(
for granularity, time_spine_source in TimeSpineSource.build_standard_time_spine_sources(
semantic_manifest_lookup.semantic_manifest
).items():
data_set = data_set_converter.build_time_spine_source_data_set(time_spine_source)
Expand Down
5 changes: 5 additions & 0 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -187,6 +188,10 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode)
def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> VisitorOutputT: # noqa: D102
pass


class DataflowPlan(MetricFlowDag[DataflowPlanNode]):
"""Describes the flow of metric data as it goes from source nodes to sink nodes in the graph."""
Expand Down
57 changes: 57 additions & 0 deletions metricflow/dataflow/nodes/join_to_custom_granularity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Sequence

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
class JoinToCustomGranularityNode(DataflowPlanNode, ABC):
"""Join parent dataset to time spine dataset to convert time dimension to a custom granularity."""

time_dimension_spec: TimeDimensionSpec

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec
) -> JoinToCustomGranularityNode:
return JoinToCustomGranularityNode(parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_JOIN_TO_CUSTOM_GRANULARITY_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_join_to_custom_granularity_node(self)

@property
def description(self) -> str: # noqa: D102
return """Join to Custom Granularity Dataset"""

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("time_dimension_spec", self.time_dimension_spec),
)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return isinstance(other_node, self.__class__) and other_node.time_dimension_spec == self.time_dimension_spec

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> JoinToCustomGranularityNode:
assert len(new_parent_nodes) == 1, "JoinToCustomGranularity accepts exactly one parent node."
return JoinToCustomGranularityNode.create(
parent_node=new_parent_nodes[0], time_dimension_spec=self.time_dimension_spec
)
8 changes: 8 additions & 0 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -458,6 +459,13 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O
)
)

def visit_join_to_custom_granularity_node( # noqa: D102
self, node: JoinToCustomGranularityNode
) -> OptimizeBranchResult:
# TODO: is this right? also add docstring when you understand better
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult:
"""Handles pushdown state propagation for the standard join node type.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -450,6 +451,12 @@ def visit_join_conversion_events_node( # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_to_custom_granularity_node( # noqa: D102
self, node: JoinToCustomGranularityNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -315,6 +316,12 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_join_to_custom_granularity_node( # noqa: D102
self, node: JoinToCustomGranularityNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_min_max_node(self, node: MinMaxNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
4 changes: 3 additions & 1 deletion metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def __init__(
DunderColumnAssociationResolver(semantic_manifest_lookup)
)
self._time_source = time_source
self._time_spine_sources = TimeSpineSource.create_from_manifest(semantic_manifest_lookup.semantic_manifest)
self._time_spine_sources = TimeSpineSource.build_standard_time_spine_sources(
semantic_manifest_lookup.semantic_manifest
)
self._source_data_sets: List[SemanticModelDataSet] = []
converter = SemanticModelToDataSetConverter(column_association_resolver=self._column_association_resolver)
for semantic_model in sorted(
Expand Down
5 changes: 5 additions & 0 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -189,3 +190,7 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode)
@override
def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError

@override
def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError
Loading
Loading