diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index c0b5c146bb..2f0a3c802b 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -1,10 +1,9 @@ from __future__ import annotations -import collections import logging import time from dataclasses import dataclass -from typing import DefaultDict, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Dict, List, Optional, Sequence, Set, Tuple, Union from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.pretty_print import pformat_big_objects @@ -15,7 +14,6 @@ from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.dag.id_generation import DATAFLOW_PLAN_PREFIX, IdGeneratorRegistry -from metricflow.dataflow.builder.measure_additiveness import group_measure_specs_by_additiveness from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataflow.builder.node_evaluator import ( JoinLinkableInstancesRecipe, @@ -30,7 +28,6 @@ ConstrainTimeRangeNode, DataflowPlan, FilterElementsNode, - JoinAggregatedMeasuresByGroupByColumnsNode, JoinDescription, JoinOverTimeRangeNode, JoinToBaseOutputNode, @@ -253,18 +250,22 @@ def _build_metrics_output_node( metric_reference=metric_reference, column_association_resolver=self._column_association_resolver, ) + assert ( + len(metric_input_measure_specs) == 1 + ), "Simple and cumulative metrics must have one input measure." + metric_input_measure_spec = metric_input_measure_specs[0] logger.info( - f"For {metric_spec}, needed measures are:\n" - f"{pformat_big_objects(metric_input_measure_specs=metric_input_measure_specs)}" + f"For {metric_spec}, needed measure is:\n" + f"{pformat_big_objects(metric_input_measure_spec=metric_input_measure_spec)}" ) combined_where = where_constraint if metric_spec.constraint: combined_where = ( combined_where.combine(metric_spec.constraint) if combined_where else metric_spec.constraint ) - aggregated_measures_node = self.build_aggregated_measures( - metric_input_measure_specs=metric_input_measure_specs, + aggregated_measures_node = self.build_aggregated_measure( + metric_input_measure_spec=metric_input_measure_spec, metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, where_constraint=combined_where, @@ -632,9 +633,9 @@ def build_computed_metrics_node( metric_specs=[metric_spec], ) - def build_aggregated_measures( + def build_aggregated_measure( self, - metric_input_measure_specs: Sequence[MetricInputMeasureSpec], + metric_input_measure_spec: MetricInputMeasureSpec, metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, where_constraint: Optional[WhereFilterSpec] = None, @@ -649,81 +650,29 @@ def build_aggregated_measures( a composite set of aggregations originating from multiple semantic models, and joined into a single aggregated set of measures. """ - output_nodes: List[BaseOutput] = [] - semantic_models_and_constraints_to_measures: DefaultDict[ - tuple[str, Optional[WhereFilterSpec]], List[MetricInputMeasureSpec] - ] = collections.defaultdict(list) - for input_spec in metric_input_measure_specs: - semantic_model_names = [ - dsource.name - for dsource in self._semantic_model_lookup.get_semantic_models_for_measure( - measure_reference=input_spec.measure_spec.as_reference - ) - ] - assert ( - len(semantic_model_names) == 1 - ), f"Validation should enforce one semantic model per measure, but found {semantic_model_names} for {input_spec}!" - semantic_models_and_constraints_to_measures[(semantic_model_names[0], input_spec.constraint)].append( - input_spec - ) - - for (semantic_model, measure_constraint), measures in semantic_models_and_constraints_to_measures.items(): - logger.info( - f"Building aggregated measures for {semantic_model}. " - f" Input measures: {measures} with constraints: {measure_constraint}" - ) - if measure_constraint is None: - node_where_constraint = where_constraint - elif where_constraint is None: - node_where_constraint = measure_constraint - else: - node_where_constraint = where_constraint.combine(measure_constraint) - - input_specs_by_measure_spec = {spec.measure_spec: spec for spec in measures} - grouped_measures_by_additiveness = group_measure_specs_by_additiveness( - tuple(input_specs_by_measure_spec.keys()) - ) - measures_by_additiveness = grouped_measures_by_additiveness.measures_by_additiveness - - # Build output nodes for each distinct non-additive dimension spec, including the None case - for non_additive_spec, measure_specs in measures_by_additiveness.items(): - non_additive_message = "" - if non_additive_spec is not None: - non_additive_message = f" with non-additive dimension spec: {non_additive_spec}" - - logger.info(f"Building aggregated measures for {semantic_model}{non_additive_message}") - input_specs = tuple(input_specs_by_measure_spec[measure_spec] for measure_spec in measure_specs) - output_nodes.append( - self._build_aggregated_measures_from_measure_source_node( - metric_input_measure_specs=input_specs, - metric_spec=metric_spec, - queried_linkable_specs=queried_linkable_specs, - where_constraint=node_where_constraint, - time_range_constraint=time_range_constraint, - cumulative=cumulative, - cumulative_window=cumulative_window, - cumulative_grain_to_date=cumulative_grain_to_date, - ) - ) - - if len(output_nodes) == 1: - return output_nodes[0] + measure_constraint = metric_input_measure_spec.constraint + logger.info(f"Building aggregated measure: {metric_input_measure_spec} with constraint: {measure_constraint}") + if measure_constraint is None: + node_where_constraint = where_constraint + elif where_constraint is None: + node_where_constraint = measure_constraint else: - return FilterElementsNode( - parent_node=JoinAggregatedMeasuresByGroupByColumnsNode(parent_nodes=output_nodes), - include_specs=InstanceSpecSet.merge( - ( - queried_linkable_specs.as_spec_set, - InstanceSpecSet( - measure_specs=tuple(x.post_aggregation_spec for x in metric_input_measure_specs) - ), - ) - ), - ) + node_where_constraint = where_constraint.combine(measure_constraint) + + return self._build_aggregated_measure_from_measure_source_node( + metric_input_measure_spec=metric_input_measure_spec, + metric_spec=metric_spec, + queried_linkable_specs=queried_linkable_specs, + where_constraint=node_where_constraint, + time_range_constraint=time_range_constraint, + cumulative=cumulative, + cumulative_window=cumulative_window, + cumulative_grain_to_date=cumulative_grain_to_date, + ) - def _build_aggregated_measures_from_measure_source_node( + def _build_aggregated_measure_from_measure_source_node( self, - metric_input_measure_specs: Sequence[MetricInputMeasureSpec], + metric_input_measure_spec: MetricInputMeasureSpec, metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, where_constraint: Optional[WhereFilterSpec] = None, @@ -738,8 +687,8 @@ def _build_aggregated_measures_from_measure_source_node( if time_dimension_spec.element_name == self._metric_time_dimension_reference.element_name ] metric_time_dimension_requested = len(metric_time_dimension_specs) > 0 - measure_specs = tuple(x.measure_spec for x in metric_input_measure_specs) - measure_properties = self._build_measure_spec_properties(measure_specs) + measure_spec = metric_input_measure_spec.measure_spec + measure_properties = self._build_measure_spec_properties([measure_spec]) non_additive_dimension_spec = measure_properties.non_additive_dimension_spec cumulative_metric_adjusted_time_constraint: Optional[TimeRangeConstraint] = None @@ -774,7 +723,7 @@ def _build_aggregated_measures_from_measure_source_node( required_linkable_specs = LinkableSpecSet.merge((queried_linkable_specs, extraneous_linkable_specs)) logger.info( f"Looking for a recipe to get:\n" - f"{pformat_big_objects(measure_specs=measure_specs, required_linkable_set=required_linkable_specs)}" + f"{pformat_big_objects(measure_specs=[measure_spec], required_linkable_set=required_linkable_specs)}" ) find_recipe_start_time = time.time() @@ -793,7 +742,7 @@ def _build_aggregated_measures_from_measure_source_node( if not measure_recipe: # TODO: Improve for better user understandability. raise UnableToSatisfyQueryError( - f"Recipe not found for measure specs: {measure_specs} and linkable specs: {required_linkable_specs}" + f"Recipe not found for measure spec: {measure_spec} and linkable specs: {required_linkable_specs}" ) # If a cumulative metric is queried with metric_time, join over time range. @@ -825,7 +774,7 @@ def _build_aggregated_measures_from_measure_source_node( parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node, include_specs=InstanceSpecSet.merge( ( - InstanceSpecSet(measure_specs=measure_specs), + InstanceSpecSet(measure_specs=(measure_spec,)), InstanceSpecSet.create_from_linkable_specs(measure_recipe.required_local_linkable_specs), ) ), @@ -841,7 +790,7 @@ def _build_aggregated_measures_from_measure_source_node( specs_to_keep_after_join = InstanceSpecSet.merge( ( - InstanceSpecSet(measure_specs=measure_specs), + InstanceSpecSet(measure_specs=(measure_spec,)), required_linkable_specs.as_spec_set, ) ) @@ -902,22 +851,16 @@ def _build_aggregated_measures_from_measure_source_node( pre_aggregate_node = FilterElementsNode( parent_node=pre_aggregate_node, include_specs=InstanceSpecSet.merge( - (InstanceSpecSet(measure_specs=measure_specs), queried_linkable_specs.as_spec_set) + (InstanceSpecSet(measure_specs=(measure_spec,)), queried_linkable_specs.as_spec_set) ), ) aggregate_measures_node = AggregateMeasuresNode( parent_node=pre_aggregate_node, - metric_input_measure_specs=tuple(metric_input_measure_specs), + metric_input_measure_specs=(metric_input_measure_spec,), ) - join_aggregated_measure_to_time_spine = False - for metric_input_measure in metric_input_measure_specs: - if metric_input_measure.join_to_timespine: - join_aggregated_measure_to_time_spine = True - break - # Only join to time spine if metric time was requested in the query. - if join_aggregated_measure_to_time_spine and metric_time_dimension_requested: + if metric_input_measure_spec.join_to_timespine and metric_time_dimension_requested: return JoinToTimeSpineNode( parent_node=aggregate_measures_node, requested_metric_time_dimension_specs=metric_time_dimension_specs, diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index 1449b88723..f3b101d3ef 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -20,7 +20,6 @@ DATAFLOW_NODE_COMBINE_METRICS_ID_PREFIX, DATAFLOW_NODE_COMPUTE_METRICS_ID_PREFIX, DATAFLOW_NODE_CONSTRAIN_TIME_RANGE_ID_PREFIX, - DATAFLOW_NODE_JOIN_AGGREGATED_MEASURES_BY_GROUPBY_COLUMNS_PREFIX, DATAFLOW_NODE_JOIN_SELF_OVER_TIME_RANGE_ID_PREFIX, DATAFLOW_NODE_JOIN_TO_STANDARD_OUTPUT_ID_PREFIX, DATAFLOW_NODE_JOIN_TO_TIME_SPINE_ID_PREFIX, @@ -122,12 +121,6 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> VisitorOutputT: # noqa: def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> VisitorOutputT: # noqa: D pass - @abstractmethod - def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode - ) -> VisitorOutputT: - pass - @abstractmethod def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> VisitorOutputT: # noqa: D pass @@ -499,60 +492,6 @@ def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> AggregateM ) -class JoinAggregatedMeasuresByGroupByColumnsNode(AggregatedMeasuresOutput): - """A node that joins aggregated measures with group by elements. - - This is designed to link two separate semantic models with measures aggregated by the complete set of group by - elements shared across both measures. Due to the way the DataflowPlan currently processes joins, this means - each separate semantic model will be pre-aggregated, and this final join will be run across fully aggregated - sets of input data. As such, all this requires is the list of aggregated measure outputs, since they can be - transformed into a SqlDataSet containing the complete list of non-measure specs for joining. - """ - - def __init__( - self, - parent_nodes: Sequence[BaseOutput], - ): - """Constructor. - - Args: - parent_nodes: sequence of nodes that output aggregated measures - """ - if len(parent_nodes) < 2: - raise ValueError( - "This node is designed for joining 2 or more aggregated nodes together, but " - f"we got {len(parent_nodes)}" - ) - super().__init__(node_id=self.create_unique_id(), parent_nodes=list(parent_nodes)) - - @classmethod - def id_prefix(cls) -> str: # noqa: D - return DATAFLOW_NODE_JOIN_AGGREGATED_MEASURES_BY_GROUPBY_COLUMNS_PREFIX - - def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D - return visitor.visit_join_aggregated_measures_by_groupby_columns_node(self) - - @property - def description(self) -> str: # noqa: D - return """Join Aggregated Measures with Standard Outputs""" - - @property - def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D - return super().displayed_properties + [ - DisplayedProperty("Join aggregated measure nodes: ", f"{[node.node_id for node in self.parent_nodes]}") - ] - - def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D - return isinstance(other_node, self.__class__) - - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput] - ) -> JoinAggregatedMeasuresByGroupByColumnsNode: - return JoinAggregatedMeasuresByGroupByColumnsNode( - parent_nodes=new_parent_nodes, - ) - - class SemiAdditiveJoinNode(BaseOutput): """A node that performs a row filter by aggregating a given non-additive dimension. diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 398f3c559f..2155caeb3e 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -13,7 +13,6 @@ DataflowPlanNode, DataflowPlanNodeVisitor, FilterElementsNode, - JoinAggregatedMeasuresByGroupByColumnsNode, JoinOverTimeRangeNode, JoinToBaseOutputNode, JoinToTimeSpineNode, @@ -216,12 +215,6 @@ def visit_join_to_base_output_node( # noqa: D self._log_visit_node_type(node) return self._default_handler(node) - def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode - ) -> ComputeMetricsBranchCombinerResult: - self._log_visit_node_type(node) - return self._default_handler(node) - def visit_aggregate_measures_node( # noqa: D self, node: AggregateMeasuresNode ) -> ComputeMetricsBranchCombinerResult: diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index 26beacef35..1903ec9b0d 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -15,7 +15,6 @@ DataflowPlanNode, DataflowPlanNodeVisitor, FilterElementsNode, - JoinAggregatedMeasuresByGroupByColumnsNode, JoinOverTimeRangeNode, JoinToBaseOutputNode, JoinToTimeSpineNode, @@ -155,12 +154,6 @@ def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> Optimize self._log_visit_node_type(node) return self._default_base_output_handler(node) - def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode - ) -> OptimizeBranchResult: - self._log_visit_node_type(node) - return self._default_base_output_handler(node) - def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index e403f3648e..e7f238c686 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -20,7 +20,6 @@ ConstrainTimeRangeNode, DataflowPlanNodeVisitor, FilterElementsNode, - JoinAggregatedMeasuresByGroupByColumnsNode, JoinOverTimeRangeNode, JoinToBaseOutputNode, JoinToTimeSpineNode, @@ -457,83 +456,6 @@ def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> SqlDataS ), ) - def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode - ) -> SqlDataSet: - """Generates the query that realizes the behavior of the JoinAggregatedMeasuresByGroupByColumnsNode. - - This node is a straight inner join against all of the columns used for grouping in the input - aggregation steps. Every column should be used, and at this point all inputs are fully aggregated, - meaning we can make assumptions about things like NULL handling and there being one row per value - set in each semantic model. - - In addition, this is used in cases where we expect a final metric to be computed using these - measures as input. Therefore, we make the assumption that any mismatch should be discarded, as - the behavior of the metric will be undefined in that case. This is why the INNER JOIN type is - appropriate - if a dimension value set exists in one aggregated set but not the other, there is - no sensible metric value for that dimension set. - """ - assert len(node.parent_nodes) > 1, "This cannot happen, the node initializer would have failed" - - table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict() - - from_data_set: SqlDataSet = node.parent_nodes[0].accept(self) - from_data_set_alias = self._next_unique_table_alias() - table_alias_to_instance_set[from_data_set_alias] = from_data_set.instance_set - join_aliases = [column.column_name for column in from_data_set.groupable_column_associations] - use_cross_join = len(join_aliases) == 0 - - sql_join_descs: List[SqlJoinDescription] = [] - for aggregated_node in node.parent_nodes[1:]: - right_data_set: SqlDataSet = aggregated_node.accept(self) - right_data_set_alias = self._next_unique_table_alias() - right_column_names = {column.column_name for column in right_data_set.groupable_column_associations} - if right_column_names != set(join_aliases): - # TODO test multi-hop dimensions and address any issues. For now, let's raise an exception. - raise ValueError( - f"We only support joins where all dimensions have the same name, but we got {right_column_names} " - f"and {join_aliases}, which differ by {right_column_names.difference(set(join_aliases))}!" - ) - # sort column names to ensure consistent join ordering for ease of debugging and testing - ordered_right_column_names = sorted(right_column_names) - column_equality_descriptions = [ - ColumnEqualityDescription( - left_column_alias=colname, right_column_alias=colname, treat_nulls_as_equal=True - ) - for colname in ordered_right_column_names - ] - sql_join_descs.append( - SqlQueryPlanJoinBuilder.make_column_equality_sql_join_description( - right_source_node=right_data_set.sql_select_node, - right_source_alias=right_data_set_alias, - left_source_alias=from_data_set_alias, - column_equality_descriptions=column_equality_descriptions, - join_type=SqlJoinType.INNER if not use_cross_join else SqlJoinType.CROSS_JOIN, - ) - ) - # All groupby columns are shared by all inputs, so we only want the measure/metric columns - # from the semantic models on the right side of the join - table_alias_to_instance_set[right_data_set_alias] = InstanceSet( - measure_instances=right_data_set.instance_set.measure_instances, - metric_instances=right_data_set.instance_set.metric_instances, - ) - - return SqlDataSet( - instance_set=InstanceSet.merge(list(table_alias_to_instance_set.values())), - sql_select_node=SqlSelectStatementNode( - description=node.description, - select_columns=create_select_columns_for_instance_sets( - self._column_association_resolver, table_alias_to_instance_set - ), - from_source=from_data_set.sql_select_node, - from_source_alias=from_data_set_alias, - joins_descs=tuple(sql_join_descs), - group_bys=(), - where=None, - order_bys=(), - ), - ) - def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> SqlDataSet: """Generates the query that realizes the behavior of AggregateMeasuresNode. diff --git a/metricflow/test/dataflow/optimizer/source_scan/test_source_scan_optimizer.py b/metricflow/test/dataflow/optimizer/source_scan/test_source_scan_optimizer.py index ed38f67087..03ca1b9134 100644 --- a/metricflow/test/dataflow/optimizer/source_scan/test_source_scan_optimizer.py +++ b/metricflow/test/dataflow/optimizer/source_scan/test_source_scan_optimizer.py @@ -17,7 +17,6 @@ DataflowPlanNode, DataflowPlanNodeVisitor, FilterElementsNode, - JoinAggregatedMeasuresByGroupByColumnsNode, JoinOverTimeRangeNode, JoinToBaseOutputNode, JoinToTimeSpineNode, @@ -59,11 +58,6 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> int: # noqa: D def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> int: # noqa: D return self._sum_parents(node) - def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode - ) -> int: - return self._sum_parents(node) - def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> int: # noqa: D return self._sum_parents(node) diff --git a/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py b/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py index 5cdf027392..caebaf4b6d 100644 --- a/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py +++ b/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py @@ -977,11 +977,7 @@ def test_compute_metrics_node_ratio_from_multiple_semantic_models( dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: - """Tests the compute metrics node for ratio type metrics. - - This test exercises the functionality provided in JoinAggregatedMeasuresByGroupByColumnsNode for - merging multiple measures into a single input source for final metrics computation. - """ + """Tests the combine metrics node for ratio type metrics.""" dimension_spec = DimensionSpec( element_name="country_latest", entity_links=(EntityReference(element_name="listing"),),