Skip to content

Commit

Permalink
Dataflow plan for custom granularities (#1409)
Browse files Browse the repository at this point in the history
Update the `DataflowPlanBuilder` to support custom granularities. Steps
included:
- For each custom granularity requested, add the appropriate
`JoinToCustomGranularityNode` to the `DataflowPlan`.
- When looking for nodes that can satisfy a given linkable spec, check
for the ability to satisfy the spec's base granularity, not the custom
granularity requested. That will be joined in later.
  • Loading branch information
courtneyholcomb authored Sep 24, 2024
1 parent 4ee5994 commit 51d781b
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 45 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240827-112415.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Handle custom granularities in DataflowPlan.
time: 2024-08-27T11:24:15.909853-07:00
custom:
Author: courtneyholcomb
Issue: "1382"
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def contains_metric_time(self) -> bool:
"""Returns true if this set contains a spec referring to metric time at any grain."""
return len(self.metric_time_specs) > 0

@property
def time_dimension_specs_with_custom_grain(self) -> Tuple[TimeDimensionSpec, ...]: # noqa: D102
return tuple([spec for spec in self.time_dimension_specs if spec.time_granularity.is_custom_granularity])

def included_agg_time_dimension_specs_for_metric(
self, metric_reference: MetricReference, metric_lookup: MetricLookup
) -> List[TimeDimensionSpec]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def with_grain(self, time_granularity: ExpandedTimeGranularity) -> TimeDimension
aggregation_state=self.aggregation_state,
)

@property
def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
Expand Down
45 changes: 42 additions & 3 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
Expand Down Expand Up @@ -809,6 +810,12 @@ def _build_plan_for_distinct_values(
if dataflow_recipe.join_targets:
output_node = JoinOnEntitiesNode.create(left_node=output_node, join_targets=dataflow_recipe.join_targets)

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
output_node = JoinToCustomGranularityNode.create(
parent_node=output_node, time_dimension_spec=time_dimension_spec
)

if len(query_level_filter_specs) > 0:
output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=query_level_filter_specs)
if query_spec.time_range_constraint:
Expand Down Expand Up @@ -899,11 +906,25 @@ def _select_source_nodes_with_linkable_specs(
"""Find source nodes with requested linkable specs and no measures."""
# Use a dictionary to dedupe for consistent ordering.
selected_nodes: Dict[DataflowPlanNode, None] = {}
requested_linkable_specs_set = set(linkable_specs.as_tuple)

# Find the source node that will satisfy the base granularity. Custom granularities will be joined in later.
linkable_specs_set_with_base_granularities: Set[LinkableInstanceSpec] = set()
# TODO: Add support for no-metrics queries for custom grains without a join (i.e., select directly from time spine).
for linkable_spec in linkable_specs.as_tuple:
if isinstance(linkable_spec, TimeDimensionSpec) and linkable_spec.time_granularity.is_custom_granularity:
linkable_spec_with_base_grain = linkable_spec.with_grain(
ExpandedTimeGranularity.from_time_granularity(linkable_spec.time_granularity.base_granularity)
)
linkable_specs_set_with_base_granularities.add(linkable_spec_with_base_grain)
else:
linkable_specs_set_with_base_granularities.add(linkable_spec)

for source_node in source_nodes:
output_spec_set = self._node_data_set_resolver.get_output_data_set(source_node).instance_set.spec_set
all_linkable_specs_in_node = set(output_spec_set.linkable_specs)
requested_linkable_specs_in_node = requested_linkable_specs_set.intersection(all_linkable_specs_in_node)
requested_linkable_specs_in_node = linkable_specs_set_with_base_granularities.intersection(
all_linkable_specs_in_node
)
if requested_linkable_specs_in_node:
selected_nodes[source_node] = None

Expand Down Expand Up @@ -1042,6 +1063,7 @@ def _find_dataflow_recipe(
f"nodes for the right side of the join"
)
)
# TODO: test multi-hop with custom grains
if DataflowPlanBuilder._contains_multihop_linkables(linkable_specs):
candidate_nodes_for_right_side_of_join = list(
node_processor.add_multi_hop_joins(
Expand Down Expand Up @@ -1422,6 +1444,12 @@ def __get_required_and_extraneous_linkable_specs(
extraneous_linkable_specs = LinkableSpecSet.merge_iterable(linkable_spec_sets_to_merge).dedupe()
required_linkable_specs = queried_linkable_specs.merge(extraneous_linkable_specs).dedupe()

# Custom grains require joining to their base grain, so add base grain to extraneous specs.
base_grain_set = LinkableSpecSet.create_from_specs(
[spec.with_base_grain() for spec in required_linkable_specs.time_dimension_specs_with_custom_grain]
)
extraneous_linkable_specs = extraneous_linkable_specs.merge(base_grain_set).dedupe()

return required_linkable_specs, extraneous_linkable_specs

def _build_aggregated_measure_from_measure_source_node(
Expand Down Expand Up @@ -1584,7 +1612,12 @@ def _build_aggregated_measure_from_measure_source_node(
)

specs_to_keep_after_join = InstanceSpecSet(measure_specs=(measure_spec,)).merge(
InstanceSpecSet.create_from_specs(required_linkable_specs.as_tuple),
InstanceSpecSet.create_from_specs(
[
spec.with_base_grain() if isinstance(spec, TimeDimensionSpec) else spec
for spec in required_linkable_specs.as_tuple
]
),
)

after_join_filtered_node = FilterElementsNode.create(
Expand All @@ -1594,6 +1627,12 @@ def _build_aggregated_measure_from_measure_source_node(
else:
unaggregated_measure_node = filtered_measure_source_node

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
unaggregated_measure_node = JoinToCustomGranularityNode.create(
parent_node=unaggregated_measure_node, time_dimension_spec=time_dimension_spec
)

# If time constraint was previously adjusted for cumulative window or grain, apply original time constraint
# here. Can skip if metric is being aggregated over all time.
cumulative_metric_constrained_node: Optional[ConstrainTimeRangeNode] = None
Expand Down
22 changes: 17 additions & 5 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from metricflow_semantics.specs.entity_spec import LinklessEntitySpec
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.sql.sql_join_type import SqlJoinType

from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
Expand Down Expand Up @@ -407,6 +408,10 @@ def evaluate_node(
logger.debug(LazyFormat(lambda: f"Candidate spec set is:\n{mf_pformat(candidate_spec_set)}"))

data_set_linkable_specs = candidate_spec_set.linkable_specs
# Look for which nodes can satisfy the linkable specs at their base grains. Custom grains will be joined later.
required_linkable_specs_with_base_grains = [
spec.with_base_grain() if isinstance(spec, TimeDimensionSpec) else spec for spec in required_linkable_specs
]

# These are linkable specs in the start node data set. Those are considered "local".
local_linkable_specs: List[LinkableInstanceSpec] = []
Expand All @@ -416,13 +421,20 @@ def evaluate_node(

# Group required_linkable_specs into local / un-joinable / or possibly joinable.
unjoinable_linkable_specs = []
for required_linkable_spec in required_linkable_specs:
for required_linkable_spec in required_linkable_specs_with_base_grains:
is_metric_time = required_linkable_spec.element_name == DataSet.metric_time_dimension_name()
is_local = required_linkable_spec in data_set_linkable_specs
is_unjoinable = not is_metric_time and (
len(required_linkable_spec.entity_links) == 0
or LinklessEntitySpec.from_reference(required_linkable_spec.entity_links[0])
not in data_set_linkable_specs
is_unjoinable = (
# metric_time is never unjoinable. In metric queries, the agg_time_dimension is local to the measure source node.
# In no-metric queries, can always CROSS JOIN to a time spine.
(not is_metric_time)
and (
# metric_time is the only element that can be joined without entity links.
len(required_linkable_spec.entity_links) == 0
# In order be joinable, the first entity link must be in the left node's dataset.
or LinklessEntitySpec.from_reference(required_linkable_spec.entity_links[0])
not in data_set_linkable_specs
)
)
if is_local:
local_linkable_specs.append(required_linkable_spec)
Expand Down
19 changes: 4 additions & 15 deletions metricflow/dataflow/nodes/join_to_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,22 @@ class JoinToCustomGranularityNode(DataflowPlanNode, ABC):
Args:
time_dimension_spec: The time dimension spec with a custom granularity that will be satisfied by this node.
include_base_grain: Bool that indicates if a spec with the custom granularity's base grain
should be included in the node's output dataset. This is needed when the same time dimension is requested
twice in one query, with both a custom grain and that custom grain's base grain.
"""

time_dimension_spec: TimeDimensionSpec
include_base_grain: bool

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert (
self.time_dimension_spec.time_granularity.is_custom_granularity
), "Time granularity for time dimension spec in JoinToCustomGranularityNode must be qualified as custom granularity."
f" Instead, found {self.time_dimension_spec.time_granularity.name}. This indicates internal misconfiguration."

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec, include_base_grain: bool
parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec
) -> JoinToCustomGranularityNode:
return JoinToCustomGranularityNode(
parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec, include_base_grain=include_base_grain
)
return JoinToCustomGranularityNode(parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
Expand All @@ -55,19 +50,14 @@ def description(self) -> str: # noqa: D102
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("time_dimension_spec", self.time_dimension_spec),
DisplayedProperty("include_base_grain", self.include_base_grain),
)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.time_dimension_spec == self.time_dimension_spec
and other_node.include_base_grain == self.include_base_grain
)
return isinstance(other_node, self.__class__) and other_node.time_dimension_spec == self.time_dimension_spec

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
Expand All @@ -76,5 +66,4 @@ def with_new_parents( # noqa: D102
return JoinToCustomGranularityNode.create(
parent_node=new_parent_nodes[0],
time_dimension_spec=self.time_dimension_spec,
include_base_grain=self.include_base_grain,
)
25 changes: 4 additions & 21 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,7 +1439,7 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
parent_alias = parent_data_set.checked_sql_select_node.from_source_alias
parent_time_dimension_instance: Optional[TimeDimensionInstance] = None
for instance in parent_data_set.instance_set.time_dimension_instances:
if instance.spec == node.time_dimension_spec.with_base_grain:
if instance.spec == node.time_dimension_spec.with_base_grain():
parent_time_dimension_instance = instance
break
assert parent_time_dimension_instance, (
Expand Down Expand Up @@ -1467,23 +1467,6 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
join_type=SqlJoinType.LEFT_OUTER,
)

# Remove base grain from parent dataset, unless that grain was also requested (in addition to the custom grain).
parent_instance_set = parent_data_set.instance_set
parent_select_columns = parent_data_set.checked_sql_select_node.select_columns
if not node.include_base_grain:
parent_instance_set = parent_instance_set.transform(
FilterElements(
exclude_specs=InstanceSpecSet(time_dimension_specs=(parent_time_dimension_instance.spec,))
)
)
parent_select_columns = tuple(
[
column
for column in parent_select_columns
if column.column_alias != parent_time_dimension_instance.associated_column.column_name
]
)

# Build output time spine instances and columns.
time_spine_instance = TimeDimensionInstance(
defined_from=parent_time_dimension_instance.defined_from,
Expand All @@ -1501,10 +1484,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
),
)
return SqlDataSet(
instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]),
instance_set=InstanceSet.merge([time_spine_instance_set, parent_data_set.instance_set]),
sql_select_node=SqlSelectStatementNode.create(
description=node.description + "\n" + parent_data_set.checked_sql_select_node.description,
select_columns=parent_select_columns + time_spine_select_columns,
description=parent_data_set.checked_sql_select_node.description + "\n" + node.description,
select_columns=parent_data_set.checked_sql_select_node.select_columns + time_spine_select_columns,
from_source=parent_data_set.checked_sql_select_node.from_source,
from_source_alias=parent_alias,
join_descs=parent_data_set.checked_sql_select_node.join_descs + (join_description,),
Expand Down

0 comments on commit 51d781b

Please sign in to comment.