Skip to content

Commit

Permalink
Allow CombineMetricsNode to combine the outputs from aggregated measures
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamDee committed Nov 9, 2023
1 parent 1410c46 commit b9db730
Show file tree
Hide file tree
Showing 5 changed files with 1,235 additions and 1 deletion.
33 changes: 32 additions & 1 deletion metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.instances import (
InstanceSet,
MeasureInstance,
MetricInstance,
TimeDimensionInstance,
)
Expand Down Expand Up @@ -787,6 +788,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet:
def _make_select_columns_for_multiple_metrics(
self,
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]],
table_alias_to_measure_instances: OrderedDict[str, Tuple[MeasureInstance, ...]],
aggregation_type: Optional[AggregationType],
) -> List[SqlSelectColumn]:
"""Creates select columns that get the given metric using the given table alias.
Expand Down Expand Up @@ -838,6 +840,28 @@ def _make_select_columns_for_multiple_metrics(
column_alias=metric_column_name,
)
)
for table_alias, measure_instances in table_alias_to_measure_instances.items():
for measure_instance in measure_instances:
measure_spec = measure_instance.spec
measure_column_name = self._column_association_resolver.resolve_spec(measure_spec).column_name
column_reference_expression = SqlColumnReferenceExpression(
col_ref=SqlColumnReference(
table_alias=table_alias,
column_name=measure_column_name,
)
)
if aggregation_type:
select_expression: SqlExpressionNode = SqlFunctionExpression.build_expression_from_aggregation_type(
aggregation_type=aggregation_type, sql_column_expression=column_reference_expression
)
else:
select_expression = column_reference_expression
select_columns.append(
SqlSelectColumn(
expr=select_expression,
column_alias=measure_column_name,
)
)
return select_columns

def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
Expand Down Expand Up @@ -877,12 +901,18 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:

parent_data_sets: List[AnnotatedSqlDataSet] = []
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]] = OrderedDict()
table_alias_to_measure_instances: OrderedDict[str, Tuple[MeasureInstance, ...]] = OrderedDict()

for parent_node in node.parent_nodes:
parent_sql_data_set = parent_node.accept(self)
table_alias = self._next_unique_table_alias()
parent_data_sets.append(AnnotatedSqlDataSet(data_set=parent_sql_data_set, alias=table_alias))
table_alias_to_metric_instances[table_alias] = parent_sql_data_set.instance_set.metric_instances
if parent_sql_data_set.instance_set.metric_instances:
table_alias_to_metric_instances[table_alias] = parent_sql_data_set.instance_set.metric_instances
elif parent_sql_data_set.instance_set.measure_instances:
table_alias_to_measure_instances[table_alias] = parent_sql_data_set.instance_set.measure_instances
else:
raise RuntimeError("Attempting to combine output nodes without any measures/metrics upstream")

# When we create the components of the join that combines metrics it will be one of INNER, FULL OUTER,
# or CROSS JOIN. Order doesn't matter for these join types, so we will use the first element in the FROM
Expand Down Expand Up @@ -926,6 +956,7 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
metric_select_column_set = SelectColumnSet(
metric_columns=self._make_select_columns_for_multiple_metrics(
table_alias_to_metric_instances=table_alias_to_metric_instances,
table_alias_to_measure_instances=table_alias_to_measure_instances,
aggregation_type=metric_aggregation_type,
)
)
Expand Down
64 changes: 64 additions & 0 deletions metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow.dataflow.dataflow_plan import (
AggregateMeasuresNode,
BaseOutput,
CombineMetricsNode,
ComputeMetricsNode,
ConstrainTimeRangeNode,
DataflowPlan,
Expand Down Expand Up @@ -1003,3 +1004,66 @@ def test_compute_metrics_node_ratio_from_multiple_semantic_models(
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_combine_output_node( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
consistent_id_object_repository: ConsistentIdObjectRepository,
sql_client: SqlClient,
) -> None:
"""Tests converting a dataflow plan to a SQL query plan where there is a leaf measure aggregation node.
Covers SUM, AVERAGE, SUM_BOOLEAN (transformed to SUM upstream), and COUNT_DISTINCT agg types
"""
sum_spec = MeasureSpec(
element_name="bookings",
)
sum_boolean_spec = MeasureSpec(
element_name="instant_bookings",
)
count_distinct_spec = MeasureSpec(
element_name="bookers",
)
dimension_spec = DimensionSpec(
element_name="is_instant",
entity_links=(),
)
measure_source_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"]

# Build compute measures node
measure_specs: List[MeasureSpec] = [sum_spec]
filtered_measure_node = FilterElementsNode(
parent_node=measure_source_node,
include_specs=InstanceSpecSet(measure_specs=tuple(measure_specs), dimension_specs=(dimension_spec,)),
)
aggregated_measure_node = AggregateMeasuresNode(
parent_node=filtered_measure_node,
metric_input_measure_specs=tuple(MetricInputMeasureSpec(measure_spec=x) for x in measure_specs),
)
metric_spec = MetricSpec(element_name="bookings")
compute_metrics_node = ComputeMetricsNode(parent_node=aggregated_measure_node, metric_specs=[metric_spec])

# Build agg measures node
measure_specs_2 = [sum_boolean_spec, count_distinct_spec]
filtered_measure_node_2 = FilterElementsNode(
parent_node=measure_source_node,
include_specs=InstanceSpecSet(measure_specs=tuple(measure_specs_2), dimension_specs=(dimension_spec,)),
)
aggregated_measure_node_2 = AggregateMeasuresNode(
parent_node=filtered_measure_node_2,
metric_input_measure_specs=tuple(MetricInputMeasureSpec(measure_spec=x) for x in measure_specs_2),
)

# Combine metrics node
combine_metrics_node = CombineMetricsNode([compute_metrics_node, aggregated_measure_node_2])

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=combine_metrics_node,
)
Loading

0 comments on commit b9db730

Please sign in to comment.