From 1dff572110f60346e49d0d23998faa4db696d64c Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 21 Aug 2024 11:22:42 -0700 Subject: [PATCH] Improve branch combining logic for AggregateMeasuresNode --- .../source_scan/cm_branch_combiner.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 6a3856de11..1882898860 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -239,7 +239,6 @@ def visit_aggregate_measures_node( # noqa: D102 assert len(combined_parent_nodes) == 1 combined_parent_node = combined_parent_nodes[0] - assert combined_parent_node is not None combined_metric_input_measure_specs = tuple( dict.fromkeys( @@ -247,17 +246,20 @@ def visit_aggregate_measures_node( # noqa: D102 ).keys() ) + # Avoid combining branches if the AggregateMeasuresNode specifies a metric with an alias to avoid + # collisions e.g. two metrics use the same alias for two different measures. This is not always the case, + # so this could be improved later. + seen_aliases = set() for spec in combined_metric_input_measure_specs: - # Avoid combining branches if the AggregateMeasuresNode specifies a metric with an alias to avoid - # collisions e.g. two metrics use the same alias for two different measures. This is not always the case, - # so this could be improved later. if spec.alias is not None: - self._log_combine_failure( - left_node=self._current_left_node, - right_node=current_right_node, - combine_failure_reason=f"Metric input measure spec {spec} has an alias", - ) - return ComputeMetricsBranchCombinerResult() + if spec.alias in seen_aliases: + self._log_combine_failure( + left_node=self._current_left_node, + right_node=current_right_node, + combine_failure_reason=f"Found multiple metric input measure specs with alias '{spec.alias}'", + ) + return ComputeMetricsBranchCombinerResult() + seen_aliases.add(spec.alias) combined_node = AggregateMeasuresNode.create( parent_node=combined_parent_node,