Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow CombineMetricsNode to combine the outputs from aggregated measures #858

Merged
merged 6 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion metricflow/dag/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
DATAFLOW_NODE_WHERE_CONSTRAINT_ID_PREFIX = "wcc"
DATAFLOW_NODE_WRITE_TO_RESULT_DATAFRAME_ID_PREFIX = "wrd"
DATAFLOW_NODE_WRITE_TO_RESULT_TABLE_ID_PREFIX = "wrt"
DATAFLOW_NODE_COMBINE_METRICS_ID_PREFIX = "cbm"
DATAFLOW_NODE_COMBINE_AGGREGATED_OUTPUTS_ID_PREFIX = "cao"
DATAFLOW_NODE_CONSTRAIN_TIME_RANGE_ID_PREFIX = "ctr"
DATAFLOW_NODE_SET_MEASURE_AGGREGATION_TIME = "sma"
DATAFLOW_NODE_SEMI_ADDITIVE_JOIN_ID_PREFIX = "saj"
Expand Down
4 changes: 2 additions & 2 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from metricflow.dataflow.dataflow_plan import (
AggregateMeasuresNode,
BaseOutput,
CombineMetricsNode,
CombineAggregatedOutputsNode,
ComputeMetricsNode,
ConstrainTimeRangeNode,
DataflowPlan,
Expand Down Expand Up @@ -298,7 +298,7 @@ def _build_metrics_output_node(
if len(output_nodes) == 1:
return output_nodes[0]

return CombineMetricsNode(parent_nodes=output_nodes)
return CombineAggregatedOutputsNode(parent_nodes=output_nodes)

def build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> DataflowPlan:
"""Generate a plan that would get the distinct values of a linkable instance.
Expand Down
16 changes: 8 additions & 8 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from metricflow.dag.id_generation import (
DATAFLOW_NODE_AGGREGATE_MEASURES_ID_PREFIX,
DATAFLOW_NODE_COMBINE_METRICS_ID_PREFIX,
DATAFLOW_NODE_COMBINE_AGGREGATED_OUTPUTS_ID_PREFIX,
DATAFLOW_NODE_COMPUTE_METRICS_ID_PREFIX,
DATAFLOW_NODE_CONSTRAIN_TIME_RANGE_ID_PREFIX,
DATAFLOW_NODE_JOIN_SELF_OVER_TIME_RANGE_ID_PREFIX,
Expand Down Expand Up @@ -150,7 +150,7 @@ def visit_pass_elements_filter_node(self, node: FilterElementsNode) -> VisitorOu
pass

@abstractmethod
def visit_combine_metrics_node(self, node: CombineMetricsNode) -> VisitorOutputT: # noqa: D
def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNode) -> VisitorOutputT: # noqa: D
pass

@abstractmethod
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> WhereConst
)


class CombineMetricsNode(ComputedMetricsOutput):
class CombineAggregatedOutputsNode(ComputedMetricsOutput):
"""Combines metrics from different nodes into a single output."""

def __init__( # noqa: D
Expand All @@ -1171,21 +1171,21 @@ def __init__( # noqa: D

@classmethod
def id_prefix(cls) -> str: # noqa: D
return DATAFLOW_NODE_COMBINE_METRICS_ID_PREFIX
return DATAFLOW_NODE_COMBINE_AGGREGATED_OUTPUTS_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D
return visitor.visit_combine_metrics_node(self)
return visitor.visit_combine_aggregated_outputs_node(self)

@property
def description(self) -> str: # noqa: D
return "Combine Metrics"
return "Combine Aggregated Outputs"

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D
return isinstance(other_node, self.__class__)

def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> CombineMetricsNode: # noqa: D
def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> CombineAggregatedOutputsNode: # noqa: D
assert len(new_parent_nodes) == 1
return CombineMetricsNode(parent_nodes=new_parent_nodes)
return CombineAggregatedOutputsNode(parent_nodes=new_parent_nodes)


class ConstrainTimeRangeNode(AggregatedMeasuresOutput, BaseOutput):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from metricflow.dataflow.dataflow_plan import (
AggregateMeasuresNode,
BaseOutput,
CombineMetricsNode,
CombineAggregatedOutputsNode,
ComputeMetricsNode,
ConstrainTimeRangeNode,
DataflowPlanNode,
Expand Down Expand Up @@ -368,7 +368,9 @@ def visit_pass_elements_filter_node( # noqa: D
)
return ComputeMetricsBranchCombinerResult(combined_node)

def visit_combine_metrics_node(self, node: CombineMetricsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D
def visit_combine_aggregated_outputs_node( # noqa: D
self, node: CombineAggregatedOutputsNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._handle_unsupported_node(node)

Expand Down
30 changes: 17 additions & 13 deletions metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from metricflow.dataflow.dataflow_plan import (
AggregateMeasuresNode,
BaseOutput,
CombineMetricsNode,
CombineAggregatedOutputsNode,
ComputeMetricsNode,
ConstrainTimeRangeNode,
DataflowPlan,
Expand Down Expand Up @@ -68,12 +68,12 @@ class SourceScanOptimizer(
):
"""Reduces the number of scans (ReadSqlSourceNodes) in a dataflow plan.

This attempts to reduce the number of scans by combining the parent nodes of CombineMetricsNode via the
This attempts to reduce the number of scans by combining the parent nodes of CombineAggregatedOutputsNode via the
ComputeMetricsBranchCombiner.

A plan with a structure similar to
...
<CombineMetricsNode>
<CombineAggregatedOutputsNode>
<ComputeMetricsNode metrics="[metric0]">
<AggregateMeasuresNode>
...
Expand All @@ -89,11 +89,11 @@ class SourceScanOptimizer(
...
</AggregateMeasuresNode>
</ComputeMetricsNode>
</CombineMetricsNode>
</CombineAggregatedOutputsNode>
...
will be converted to
...
<CombineMetricsNode>
<CombineAggregatedOutputsNode>
<ComputeMetricsNode metrics="[metric0, metric1]">
<AggregateMeasuresNode>
...
Expand All @@ -104,11 +104,11 @@ class SourceScanOptimizer(
...
</AggregateMeasuresNode>
</ComputeMetricsNode>
</CombineMetricsNode>
</CombineAggregatedOutputsNode>
...
when possible.

In cases where all ComputeMetricsNodes can be combined into a single one, the CombineMetricsNode may be removed as
In cases where all ComputeMetricsNodes can be combined into a single one, the CombineAggregatedOutputsNode may be removed as
well.

This traverses the dataflow plan using DFS. When visiting a node (current_node), it first runs the optimization
Expand Down Expand Up @@ -229,9 +229,11 @@ def _combine_branches(
)
return results

def visit_combine_metrics_node(self, node: CombineMetricsNode) -> OptimizeBranchResult: # noqa: D
def visit_combine_aggregated_outputs_node( # noqa: D
self, node: CombineAggregatedOutputsNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
# The parent node of the CombineMetricsNode can be either ComputeMetricsNodes or CombineMetricsNodes
# The parent node of the CombineAggregatedOutputsNode can be either ComputeMetricsNodes or CombineAggregatedOutputsNodes

# Stores the result of running this optimizer on each parent branch separately.
optimized_parent_branches = []
Expand All @@ -248,7 +250,7 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> OptimizeBranch

assert (
result.base_output_node is not None
), f"Traversing the parents of a CombineMetricsNode should always produce a BaseOutput. Got: {result}"
), f"Traversing the parents of a CombineAggregatedOutputsNode should always produce a BaseOutput. Got: {result}"
optimized_parent_branches.append(result.base_output_node)

# Try to combine (using ComputeMetricsBranchCombiner) as many parent branches as possible in a
Expand All @@ -275,12 +277,14 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> OptimizeBranch
logger.log(level=self._log_level, msg=f"Got {len(combined_parent_branches)} branches after combination")
assert len(combined_parent_branches) > 0

# If we were able to reduce the parent branches of the CombineMetricsNode into a single one, there's no need
# for a CombineMetricsNode.
# If we were able to reduce the parent branches of the CombineAggregatedOutputsNode into a single one, there's no need
# for a CombineAggregatedOutputsNode.
if len(combined_parent_branches) == 1:
return OptimizeBranchResult(base_output_node=combined_parent_branches[0])

return OptimizeBranchResult(base_output_node=CombineMetricsNode(parent_nodes=combined_parent_branches))
return OptimizeBranchResult(
base_output_node=CombineAggregatedOutputsNode(parent_nodes=combined_parent_branches)
)

def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> OptimizeBranchResult: # noqa: D
self._log_visit_node_type(node)
Expand Down
11 changes: 3 additions & 8 deletions metricflow/model/semantics/metric_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,9 @@ def add_metric(self, metric: Metric) -> None:
)
self._metrics[metric_reference] = metric

def configured_input_measure_for_metric(self, metric_reference: MetricReference) -> Optional[MetricInputMeasure]:
"""Get input measure defined in the original metric config, if exists.

When SemanticModel is constructed, input measures from input metrics are added to the list of input measures
for a metric. Here, use rules about metric types to determine which input measures were defined in the config:
- Simple & cumulative metrics require one input measure, and can't take any input metrics.
- Derived & ratio metrics take no input measures, only input metrics.
"""
def configured_input_measure_for_metric( # noqa: D
self, metric_reference: MetricReference
) -> Optional[MetricInputMeasure]:
metric = self.get_metric(metric_reference=metric_reference)
if metric.type is MetricType.CUMULATIVE or metric.type is MetricType.SIMPLE:
assert len(metric.input_measures) == 1, "Simple and cumulative metrics should have one input measure."
Expand Down
105 changes: 29 additions & 76 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@

import logging
from collections import OrderedDict
from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import MetricModelReference, MetricReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.references import MetricModelReference

from metricflow.aggregation_properties import AggregationState
from metricflow.dag.id_generation import IdGeneratorRegistry
from metricflow.dataflow.dataflow_plan import (
AggregateMeasuresNode,
BaseOutput,
CombineMetricsNode,
CombineAggregatedOutputsNode,
ComputedMetricsOutput,
ComputeMetricsNode,
ConstrainTimeRangeNode,
Expand Down Expand Up @@ -46,12 +45,14 @@
AliasAggregatedMeasures,
ChangeAssociatedColumns,
ChangeMeasureAggregationState,
CreateSelectColumnForCombineOutputNode,
CreateSelectColumnsForInstances,
CreateSelectColumnsWithMeasuresAggregated,
FilterElements,
FilterLinkableInstancesWithLeadingLink,
RemoveMeasures,
RemoveMetrics,
UpdateMeasureFillNullsWith,
create_select_columns_for_instance_sets,
)
from metricflow.plan_conversion.select_column_gen import (
Expand Down Expand Up @@ -484,6 +485,10 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> SqlDataS
ChangeAssociatedColumns(self._column_association_resolver)
)

# Add fill null property to corresponding measure spec
aggregated_instance_set = aggregated_instance_set.transform(
UpdateMeasureFillNullsWith(metric_input_measure_specs=node.metric_input_measure_specs)
)
from_data_set_alias = self._next_unique_table_alias()

# Convert the instance set into a set of select column statements with updated aliases
Expand Down Expand Up @@ -781,68 +786,12 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet:
),
)

def _make_select_columns_for_multiple_metrics(
self,
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]],
aggregation_type: Optional[AggregationType],
) -> List[SqlSelectColumn]:
"""Creates select columns that get the given metric using the given table alias.

e.g.
def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNode) -> SqlDataSet:
"""Join aggregated output datasets together to return a single dataset containing all metrics/measures.

with table_alias_to_metric_instances = {"a": MetricSpec(element_name="bookings")}

->

a.bookings AS bookings
"""
select_columns = []
for table_alias, metric_instances in table_alias_to_metric_instances.items():
for metric_instance in metric_instances:
metric_spec = metric_instance.spec
metric_column_name = self._column_association_resolver.resolve_spec(metric_spec).column_name
column_reference_expression = SqlColumnReferenceExpression(
col_ref=SqlColumnReference(
table_alias=table_alias,
column_name=metric_column_name,
)
)
if aggregation_type:
select_expression: SqlExpressionNode = SqlFunctionExpression.build_expression_from_aggregation_type(
aggregation_type=aggregation_type, sql_column_expression=column_reference_expression
)
else:
select_expression = column_reference_expression

# At this point, the MetricSpec might have the alias in place of the element name, so we need to look
# back at where it was defined from to get the metric element name.
metric_reference = MetricReference(element_name=metric_instance.defined_from.metric_name)
input_measure = self._metric_lookup.configured_input_measure_for_metric(
metric_reference=metric_reference
)
if input_measure and input_measure.fill_nulls_with is not None:
select_expression = SqlAggregateFunctionExpression(
sql_function=SqlFunction.COALESCE,
sql_function_args=[
select_expression,
SqlStringExpression(str(input_measure.fill_nulls_with)),
],
)

select_columns.append(
SqlSelectColumn(
expr=select_expression,
column_alias=metric_column_name,
)
)
return select_columns

def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
"""Join computed metric datasets together to return a single dataset containing all metrics.

This node may exist in one of two situations: when metrics need to be combined in order to produce a single
dataset with all required inputs for a derived metric, or when metrics need to be combined in order to produce
a single dataset of output for downstream consumption by the end user.
This node may exist in one of two situations: when metrics/measures need to be combined in order to produce a single
dataset with all required inputs for a metric (ie., derived metric), or when metrics need to be combined in order to
produce a single dataset of output for downstream consumption by the end user.

The join key will be a coalesced set of all previously seen dimension values. For example:
FROM (
Expand Down Expand Up @@ -870,16 +819,16 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
"""
assert (
len(node.parent_nodes) > 1
), "Shouldn't have a CombineMetricsNode in the dataflow plan if there's only 1 parent."
), "Shouldn't have a CombineAggregatedOutputsNode in the dataflow plan if there's only 1 parent."

parent_data_sets: List[AnnotatedSqlDataSet] = []
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]] = OrderedDict()
table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict()

for parent_node in node.parent_nodes:
parent_sql_data_set = parent_node.accept(self)
table_alias = self._next_unique_table_alias()
parent_data_sets.append(AnnotatedSqlDataSet(data_set=parent_sql_data_set, alias=table_alias))
table_alias_to_metric_instances[table_alias] = parent_sql_data_set.instance_set.metric_instances
table_alias_to_instance_set[table_alias] = parent_sql_data_set.instance_set

# When we create the components of the join that combines metrics it will be one of INNER, FULL OUTER,
# or CROSS JOIN. Order doesn't matter for these join types, so we will use the first element in the FROM
Expand All @@ -905,7 +854,7 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
aliases_seen = [from_data_set.alias]
for join_data_set in join_data_sets:
joins_descriptions.append(
SqlQueryPlanJoinBuilder.make_combine_metrics_join_description(
SqlQueryPlanJoinBuilder.make_join_description_for_combining_datasets(
from_data_set=from_data_set,
join_data_set=join_data_set,
join_type=join_type,
Expand All @@ -919,20 +868,24 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
output_instance_set = InstanceSet.merge([x.data_set.instance_set for x in parent_data_sets])
output_instance_set = output_instance_set.transform(ChangeAssociatedColumns(self._column_association_resolver))

metric_aggregation_type = AggregationType.MAX
metric_select_column_set = SelectColumnSet(
metric_columns=self._make_select_columns_for_multiple_metrics(
table_alias_to_metric_instances=table_alias_to_metric_instances,
aggregation_type=metric_aggregation_type,
aggregated_select_columns = SelectColumnSet()
for table_alias, instance_set in table_alias_to_instance_set.items():
aggregated_select_columns = aggregated_select_columns.merge(
instance_set.transform(
CreateSelectColumnForCombineOutputNode(
table_alias=table_alias,
column_resolver=self._column_association_resolver,
metric_lookup=self._metric_lookup,
)
)
)
)
linkable_select_column_set = linkable_spec_set.transform(
CreateSelectCoalescedColumnsForLinkableSpecs(
column_association_resolver=self._column_association_resolver,
table_aliases=[x.alias for x in parent_data_sets],
)
)
combined_select_column_set = linkable_select_column_set.merge(metric_select_column_set)
combined_select_column_set = linkable_select_column_set.merge(aggregated_select_columns)

return SqlDataSet(
instance_set=output_instance_set,
Expand Down
Loading
Loading