Skip to content

Commit

Permalink
Use correct join node specs for GroupByMetrics (#1193)
Browse files Browse the repository at this point in the history
### Description
In the process of writing rendering tests, I found an issue with the
logic used to determine which specs to include in a joined node. This PR
primarily updates that logic to handle `GroupByMetricSpecs`, plus with
some related adjustments along the way.
  • Loading branch information
courtneyholcomb authored May 10, 2024
1 parent b768e7a commit d3571c3
Show file tree
Hide file tree
Showing 13 changed files with 65 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import itertools
import typing
from dataclasses import dataclass
from typing import Dict, List, Sequence, Set, Tuple
from typing import Dict, List, Sequence, Tuple

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from dbt_semantic_interfaces.references import LinkableElementReference, MeasureReference, MetricReference
from dbt_semantic_interfaces.references import MeasureReference, MetricReference
from typing_extensions import override

from metricflow_semantics.collection_helpers.merger import Mergeable
Expand Down Expand Up @@ -84,10 +84,6 @@ def as_tuple(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102
)
)

@property
def as_reference_set(self) -> Set[LinkableElementReference]: # noqa: D102
return {spec.reference for spec in self.as_tuple}

@override
def merge(self, other: LinkableSpecSet) -> LinkableSpecSet:
return LinkableSpecSet(
Expand Down
12 changes: 6 additions & 6 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
MetricTimeWindow,
MetricType,
)
from dbt_semantic_interfaces.references import LinkableElementReference, MetricReference, TimeDimensionReference
from dbt_semantic_interfaces.references import MetricReference, TimeDimensionReference
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from dbt_semantic_interfaces.validations.unique_valid_name import MetricFlowReservedKeywords
from metricflow_semantics.dag.id_prefix import StaticIdPrefix
Expand Down Expand Up @@ -386,7 +386,7 @@ def _build_conversion_metric_output_node(
metric_spec=metric_spec,
aggregated_measures_node=aggregated_measures_node,
for_group_by_source_node=for_group_by_source_node,
aggregated_to_elements=queried_linkable_specs.as_reference_set,
aggregated_to_elements=set(queried_linkable_specs.as_tuple),
)

def _build_base_metric_output_node(
Expand Down Expand Up @@ -445,7 +445,7 @@ def _build_base_metric_output_node(
metric_spec=metric_spec,
aggregated_measures_node=aggregated_measures_node,
for_group_by_source_node=for_group_by_source_node,
aggregated_to_elements=queried_linkable_specs.as_reference_set,
aggregated_to_elements=set(queried_linkable_specs.as_tuple),
)

def _build_derived_metric_output_node(
Expand Down Expand Up @@ -513,7 +513,7 @@ def _build_derived_metric_output_node(
parent_node=parent_node,
metric_specs=[metric_spec],
for_group_by_source_node=for_group_by_source_node,
is_aggregated_to_elements={spec.reference for spec in queried_linkable_specs.as_tuple},
aggregated_to_elements=set(queried_linkable_specs.as_tuple),
)

# For ratio / derived metrics with time offset, apply offset & where constraint after metric computation.
Expand Down Expand Up @@ -1010,15 +1010,15 @@ def build_computed_metrics_node(
self,
metric_spec: MetricSpec,
aggregated_measures_node: Union[AggregateMeasuresNode, BaseOutput],
aggregated_to_elements: Set[LinkableElementReference],
aggregated_to_elements: Set[LinkableInstanceSpec],
for_group_by_source_node: bool = False,
) -> ComputeMetricsNode:
"""Builds a ComputeMetricsNode from aggregated measures."""
return ComputeMetricsNode(
parent_node=aggregated_measures_node,
metric_specs=[metric_spec],
for_group_by_source_node=for_group_by_source_node,
is_aggregated_to_elements=aggregated_to_elements,
aggregated_to_elements=aggregated_to_elements,
)

def _build_input_measure_specs_for_conversion_metric(
Expand Down
32 changes: 20 additions & 12 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,16 @@ def __post_init__(self) -> None: # noqa: D105
if self.join_on_entity is None and self.join_type != SqlJoinType.CROSS_JOIN:
raise RuntimeError("`join_on_entity` is required unless using CROSS JOIN.")

# TODO: JoinDescription is very similar to JoinLinkableInstancesRecipe. Can we consolidate by just adding a
# `filtered_node_to_join` property on JoinLinkableInstancesRecipe?
@property
def join_description(self) -> JoinDescription:
"""The recipe as a join description to use in the dataflow plan node."""
# Figure out what elements to filter from the joined node.
"""The recipe as a join description to use in the dataflow plan node.
Here, we figure out which instance specs to keep from this node in order to join to it and render its
satisfiable linkable specs, e.g. if the node is used to satisfy "user_id__country", the node must have the
entity "user_id" and the "country" dimension so that it can be joined to the source node.
"""
include_specs: List[LinkableInstanceSpec] = []
assert all(
[
Expand All @@ -92,13 +98,13 @@ def join_description(self) -> JoinDescription:
]
)

include_specs.extend(
[
LinklessEntitySpec.from_reference(spec.entity_links[0])
for spec in self.satisfiable_linkable_specs
if len(spec.entity_links) > 0
]
)
# Get the specs needed to join onto this node.
if self.node_to_join.aggregated_to_elements:
include_specs.extend(self.node_to_join.aggregated_to_elements)
else:
for spec in self.satisfiable_linkable_specs:
if len(spec.entity_links) > 0:
include_specs.append(LinklessEntitySpec.from_reference(spec.entity_links[0]))

include_specs.extend([join.node_to_join_dimension_spec for join in self.join_on_partition_dimensions])
include_specs.extend([join.node_to_join_time_dimension_spec for join in self.join_on_partition_time_dimensions])
Expand All @@ -109,14 +115,13 @@ def join_description(self) -> JoinDescription:

# `satisfiable_linkable_specs` describes what can be satisfied after the join, so remove the entity
# link when filtering before the join.
# e.g. if the node is used to satisfy "user_id__country", then the node must have the entity
# "user_id" and the "country" dimension so that it can be joined to the source node.
include_specs.extend(
[
spec.without_first_entity_link if len(spec.entity_links) > 0 else spec
for spec in self.satisfiable_linkable_specs
]
)

filtered_node_to_join = FilterElementsNode(
parent_node=self.node_to_join, include_specs=group_specs_by_type(include_specs)
)
Expand Down Expand Up @@ -256,6 +261,9 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
assert len(entity_instance_in_left_node.defined_from) == 1
assert len(entity_instance_in_right_node.defined_from) == 1

entity_spec_matches_aggregated_specs = {
spec.reference for spec in right_node.aggregated_to_elements
} == {entity_spec_in_right_node.reference}
if not (
self._join_evaluator.is_valid_semantic_model_join(
left_semantic_model_reference=entity_instance_in_left_node.defined_from[
Expand All @@ -266,7 +274,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
].semantic_model_reference,
on_entity_reference=entity_spec_in_right_node.reference,
)
or right_node.is_aggregated_to_elements == {entity_spec_in_right_node.reference}
or entity_spec_matches_aggregated_specs
):
continue

Expand Down
4 changes: 2 additions & 2 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from metricflow_semantics.visitor import Visitable, VisitorOutputT

if typing.TYPE_CHECKING:
from dbt_semantic_interfaces.references import LinkableElementReference
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec

from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
Expand Down Expand Up @@ -178,7 +178,7 @@ class BaseOutput(DataflowPlanNode, ABC):
"""

@property
def is_aggregated_to_elements(self) -> Set[LinkableElementReference]:
def aggregated_to_elements(self) -> Set[LinkableInstanceSpec]:
"""Indicates that the node has been aggregated to these specs, guaranteeing uniqueness in each combination of them."""
return set()

Expand Down
13 changes: 6 additions & 7 deletions metricflow/dataflow/nodes/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from typing import Sequence, Set

from dbt_semantic_interfaces.references import LinkableElementReference
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.specs.spec_classes import MetricSpec
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec, MetricSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import (
Expand All @@ -23,7 +22,7 @@ def __init__(
self,
parent_node: BaseOutput,
metric_specs: Sequence[MetricSpec],
is_aggregated_to_elements: Set[LinkableElementReference],
aggregated_to_elements: Set[LinkableInstanceSpec],
for_group_by_source_node: bool = False,
) -> None:
"""Constructor.
Expand All @@ -36,7 +35,7 @@ def __init__(
self._parent_node = parent_node
self._metric_specs = tuple(metric_specs)
self._for_group_by_source_node = for_group_by_source_node
self._is_aggregated_to_elements = is_aggregated_to_elements
self._aggregated_to_elements = aggregated_to_elements
super().__init__(node_id=self.create_unique_id(), parent_nodes=(self._parent_node,))

@classmethod
Expand Down Expand Up @@ -92,9 +91,9 @@ def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> ComputeMet
parent_node=new_parent_nodes[0],
metric_specs=self.metric_specs,
for_group_by_source_node=self.for_group_by_source_node,
is_aggregated_to_elements=self._is_aggregated_to_elements,
aggregated_to_elements=self._aggregated_to_elements,
)

@property
def is_aggregated_to_elements(self) -> Set[LinkableElementReference]: # noqa: D102
return self._is_aggregated_to_elements
def aggregated_to_elements(self) -> Set[LinkableInstanceSpec]: # noqa: D102
return self._aggregated_to_elements
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> ComputeMetrics
)
return ComputeMetricsBranchCombinerResult()

if not self._current_left_node.is_aggregated_to_elements == current_right_node.is_aggregated_to_elements:
if not self._current_left_node.aggregated_to_elements == current_right_node.aggregated_to_elements:
self._log_combine_failure(
left_node=self._current_left_node,
right_node=current_right_node,
Expand All @@ -307,7 +307,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> ComputeMetrics
combined_node = ComputeMetricsNode(
parent_node=combined_parent_node,
metric_specs=unique_metric_specs,
is_aggregated_to_elements=current_right_node.is_aggregated_to_elements,
aggregated_to_elements=current_right_node.aggregated_to_elements,
)
self._log_combine_success(
left_node=self._current_left_node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranch
parent_node=optimized_parent_result.base_output_node,
metric_specs=node.metric_specs,
for_group_by_source_node=node.for_group_by_source_node,
is_aggregated_to_elements=node.is_aggregated_to_elements,
aggregated_to_elements=node.aggregated_to_elements,
)
)

Expand Down
4 changes: 2 additions & 2 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def column_associations_for_entity(

if len(matching_instances) != 1:
raise RuntimeError(
f"Expected exactly one matching instance for {entity_spec} in instance set, but found: {matching_instances}"
f"Expected exactly one matching instance for {entity_spec} in instance set, but found: {matching_instances}. "
f"All entity instances: {self.instance_set.entity_instances}"
)
matching_instance = matching_instances[0]
if not matching_instance.associated_columns:
print("entity links to compare:", entity_spec.entity_links, linkable_instance.spec.entity_links)
raise RuntimeError(
f"No associated columns for entity instance {matching_instance} in data set."
"This indicates internal misconfiguration."
Expand Down
22 changes: 14 additions & 8 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,19 +558,25 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
# Sanity check to make sure the specs are in the instance set

if self._include_specs:
include_specs_not_found = []
for include_spec in self._include_specs.all_specs:
if include_spec not in instance_set.spec_set.all_specs:
raise RuntimeError(
f"Include spec {include_spec} is not in the spec set {instance_set.spec_set} - "
f"check if this node was constructed correctly."
)
include_specs_not_found.append(include_spec)
if include_specs_not_found:
raise RuntimeError(
f"Include specs {include_specs_not_found} are not in the spec set {instance_set.spec_set} - "
f"check if this node was constructed correctly."
)
elif self._exclude_specs:
exclude_specs_not_found = []
for exclude_spec in self._exclude_specs.all_specs:
if exclude_spec not in instance_set.spec_set.all_specs:
raise RuntimeError(
f"Exclude spec {exclude_spec} is not in the spec set {instance_set.spec_set} - "
f"check if this node was constructed correctly."
)
exclude_specs_not_found.append(exclude_spec)
if exclude_specs_not_found:
raise RuntimeError(
f"Exclude specs {exclude_specs_not_found} are not in the spec set {instance_set.spec_set} - "
f"check if this node was constructed correctly."
)
else:
assert False, "Include specs or exclude specs should have been specified."

Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _get_candidates_nodes_for_multi_hop(
left_instance_set=data_set_of_first_node_that_could_be_joined.instance_set,
right_instance_set=data_set_of_second_node_that_can_be_joined.instance_set,
on_entity_reference=entity_reference_to_join_first_and_second_nodes,
right_node_is_aggregated_to_entity=second_node_that_could_be_joined.is_aggregated_to_elements
right_node_is_aggregated_to_entity=second_node_that_could_be_joined.aggregated_to_elements
== {entity_reference_to_join_first_and_second_nodes},
):
continue
Expand Down
14 changes: 7 additions & 7 deletions tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def test_compute_metrics_node(
compute_metrics_node = ComputeMetricsNode(
parent_node=aggregated_measure_node,
metric_specs=[metric_spec],
is_aggregated_to_elements={entity_spec.reference, dimension_spec.reference},
aggregated_to_elements={entity_spec, dimension_spec},
)

convert_and_check(
Expand Down Expand Up @@ -528,7 +528,7 @@ def test_compute_metrics_node_simple_expr(
compute_metrics_node = ComputeMetricsNode(
parent_node=aggregated_measures_node,
metric_specs=[metric_spec],
is_aggregated_to_elements={entity_spec.reference, dimension_spec.reference},
aggregated_to_elements={entity_spec, dimension_spec},
)

sink_node = WriteToResultDataframeNode(compute_metrics_node)
Expand Down Expand Up @@ -592,7 +592,7 @@ def test_join_to_time_spine_node_without_offset(
compute_metrics_node = ComputeMetricsNode(
parent_node=aggregated_measures_node,
metric_specs=[metric_spec],
is_aggregated_to_elements={entity_spec.reference},
aggregated_to_elements={entity_spec},
)
join_to_time_spine_node = JoinToTimeSpineNode(
parent_node=compute_metrics_node,
Expand Down Expand Up @@ -663,7 +663,7 @@ def test_join_to_time_spine_node_with_offset_window(
compute_metrics_node = ComputeMetricsNode(
parent_node=aggregated_measures_node,
metric_specs=[metric_spec],
is_aggregated_to_elements={entity_spec.reference, metric_time_spec.reference},
aggregated_to_elements={entity_spec, metric_time_spec},
)
join_to_time_spine_node = JoinToTimeSpineNode(
parent_node=compute_metrics_node,
Expand Down Expand Up @@ -736,7 +736,7 @@ def test_join_to_time_spine_node_with_offset_to_grain(
compute_metrics_node = ComputeMetricsNode(
parent_node=aggregated_measures_node,
metric_specs=[metric_spec],
is_aggregated_to_elements={entity_spec.reference, metric_time_spec.reference},
aggregated_to_elements={entity_spec, metric_time_spec},
)
join_to_time_spine_node = JoinToTimeSpineNode(
parent_node=compute_metrics_node,
Expand Down Expand Up @@ -838,7 +838,7 @@ def test_compute_metrics_node_ratio_from_single_semantic_model(
compute_metrics_node = ComputeMetricsNode(
parent_node=aggregated_measures_node,
metric_specs=[metric_spec],
is_aggregated_to_elements={entity_spec.reference, dimension_spec.reference},
aggregated_to_elements={entity_spec, dimension_spec},
)

convert_and_check(
Expand Down Expand Up @@ -894,7 +894,7 @@ def test_order_by_node(
compute_metrics_node = ComputeMetricsNode(
parent_node=aggregated_measure_node,
metric_specs=[metric_spec],
is_aggregated_to_elements={dimension_spec.reference, time_dimension_spec.reference},
aggregated_to_elements={dimension_spec, time_dimension_spec},
)

order_by_node = OrderByLimitNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['listing', 'listing__bookings']" -->
<!-- node_id = NodeId(id_str='pfe_2') -->
<!-- include_spec = LinklessEntitySpec(element_name='listing') -->
<!-- include_spec = EntitySpec(element_name='listing') -->
<!-- include_spec = -->
<!-- GroupByMetricSpec( -->
<!-- element_name='bookings', -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['listing', 'listing__bookings']" -->
<!-- node_id = NodeId(id_str='pfe_2') -->
<!-- include_spec = LinklessEntitySpec(element_name='listing') -->
<!-- include_spec = EntitySpec(element_name='listing') -->
<!-- include_spec = -->
<!-- GroupByMetricSpec( -->
<!-- element_name='bookings', -->
Expand Down

0 comments on commit d3571c3

Please sign in to comment.