Skip to content

Commit

Permalink
DataflowPlan for custom granularities
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 18, 2024
1 parent a01653f commit 0c12518
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
39 changes: 37 additions & 2 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
Expand Down Expand Up @@ -795,6 +796,15 @@ def _build_plan_for_distinct_values(
if dataflow_recipe.join_targets:
output_node = JoinOnEntitiesNode.create(left_node=output_node, join_targets=dataflow_recipe.join_targets)

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
include_base_grain = time_dimension_spec.with_base_grain in required_linkable_specs.time_dimension_specs
output_node = JoinToCustomGranularityNode.create(
parent_node=output_node,
time_dimension_spec=time_dimension_spec,
include_base_grain=include_base_grain,
)

if len(query_level_filter_specs) > 0:
output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=query_level_filter_specs)
if query_spec.time_range_constraint:
Expand Down Expand Up @@ -885,11 +895,25 @@ def _select_source_nodes_with_linkable_specs(
"""Find source nodes with requested linkable specs and no measures."""
# Use a dictionary to dedupe for consistent ordering.
selected_nodes: Dict[DataflowPlanNode, None] = {}
requested_linkable_specs_set = set(linkable_specs.as_tuple)

# Find the source node that will satisfy the base granularity. Custom granularities will be joined in later.
linkable_specs_set_with_base_granularities: Set[LinkableInstanceSpec] = set()
# TODO: Add support for no-metrics queries for custom grains without a join (i.e., select directly from time spine).
for linkable_spec in linkable_specs.as_tuple:
if isinstance(linkable_spec, TimeDimensionSpec) and linkable_spec.time_granularity.is_custom_granularity:
linkable_spec_with_base_grain = linkable_spec.with_grain(
ExpandedTimeGranularity.from_time_granularity(linkable_spec.time_granularity.base_granularity)
)
linkable_specs_set_with_base_granularities.add(linkable_spec_with_base_grain)
else:
linkable_specs_set_with_base_granularities.add(linkable_spec)

for source_node in source_nodes:
output_spec_set = self._node_data_set_resolver.get_output_data_set(source_node).instance_set.spec_set
all_linkable_specs_in_node = set(output_spec_set.linkable_specs)
requested_linkable_specs_in_node = requested_linkable_specs_set.intersection(all_linkable_specs_in_node)
requested_linkable_specs_in_node = linkable_specs_set_with_base_granularities.intersection(
all_linkable_specs_in_node
)
if requested_linkable_specs_in_node:
selected_nodes[source_node] = None

Expand Down Expand Up @@ -1020,10 +1044,12 @@ def _find_dataflow_recipe(
metric_time_dimension_reference=self._metric_time_dimension_reference,
time_spine_nodes=self._source_node_set.time_spine_nodes_tuple,
)

logger.info(
f"After removing unnecessary nodes, there are {len(candidate_nodes_for_right_side_of_join)} candidate "
f"nodes for the right side of the join"
)
# TODO: test multi-hop with custom grains
if DataflowPlanBuilder._contains_multihop_linkables(linkable_specs):
candidate_nodes_for_right_side_of_join = list(
node_processor.add_multi_hop_joins(
Expand Down Expand Up @@ -1544,6 +1570,15 @@ def _build_aggregated_measure_from_measure_source_node(
else:
unaggregated_measure_node = filtered_measure_source_node

for time_dimension_spec in queried_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
include_base_grain = time_dimension_spec.with_base_grain in required_linkable_specs.time_dimension_specs
unaggregated_measure_node = JoinToCustomGranularityNode.create(
parent_node=unaggregated_measure_node,
time_dimension_spec=time_dimension_spec,
include_base_grain=include_base_grain,
)

# If time constraint was previously adjusted for cumulative window or grain, apply original time constraint
# here. Can skip if metric is being aggregated over all time.
cumulative_metric_constrained_node: Optional[ConstrainTimeRangeNode] = None
Expand Down
9 changes: 7 additions & 2 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from metricflow_semantics.specs.entity_spec import LinklessEntitySpec
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.sql.sql_join_type import SqlJoinType

from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
Expand Down Expand Up @@ -406,6 +407,10 @@ def evaluate_node(
logger.debug(f"Candidate spec set is:\n{mf_pformat(candidate_spec_set)}")

data_set_linkable_specs = candidate_spec_set.linkable_specs
# Look for which nodes can satisfy the linkable specs at their base grains. Custom grains will be joined later.
required_linkable_specs_with_base_grains = [
spec.with_base_grain if isinstance(spec, TimeDimensionSpec) else spec for spec in required_linkable_specs
]

# These are linkable specs in the start node data set. Those are considered "local".
local_linkable_specs: List[LinkableInstanceSpec] = []
Expand All @@ -415,10 +420,10 @@ def evaluate_node(

# Group required_linkable_specs into local / un-joinable / or possibly joinable.
unjoinable_linkable_specs = []
for required_linkable_spec in required_linkable_specs:
for required_linkable_spec in required_linkable_specs_with_base_grains:
is_metric_time = required_linkable_spec.element_name == DataSet.metric_time_dimension_name()
is_local = required_linkable_spec in data_set_linkable_specs
is_unjoinable = not is_metric_time and (
is_unjoinable = (not is_metric_time) and (
len(required_linkable_spec.entity_links) == 0
or LinklessEntitySpec.from_reference(required_linkable_spec.entity_links[0])
not in data_set_linkable_specs
Expand Down

0 comments on commit 0c12518

Please sign in to comment.