Skip to content

Commit

Permalink
added ability to combine aggregated outputs nodes into CombineMetrics…
Browse files Browse the repository at this point in the history
…Node
  • Loading branch information
WilliamDee committed Nov 15, 2023
1 parent 40e5adb commit a858854
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 71 deletions.
84 changes: 16 additions & 68 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import logging
from collections import OrderedDict
from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import MetricModelReference, MetricReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.references import MetricModelReference

from metricflow.aggregation_properties import AggregationState
from metricflow.dag.id_generation import IdGeneratorRegistry
Expand Down Expand Up @@ -46,6 +45,7 @@
AliasAggregatedMeasures,
ChangeAssociatedColumns,
ChangeMeasureAggregationState,
CreateSelectColumnForJoinOutputNode,
CreateSelectColumnsForInstances,
CreateSelectColumnsWithMeasuresAggregated,
FilterElements,
Expand Down Expand Up @@ -789,62 +789,6 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet:
),
)

def _make_select_columns_for_multiple_metrics(
self,
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]],
aggregation_type: Optional[AggregationType],
) -> List[SqlSelectColumn]:
"""Creates select columns that get the given metric using the given table alias.
e.g.
with table_alias_to_metric_instances = {"a": MetricSpec(element_name="bookings")}
->
a.bookings AS bookings
"""
select_columns = []
for table_alias, metric_instances in table_alias_to_metric_instances.items():
for metric_instance in metric_instances:
metric_spec = metric_instance.spec
metric_column_name = self._column_association_resolver.resolve_spec(metric_spec).column_name
column_reference_expression = SqlColumnReferenceExpression(
col_ref=SqlColumnReference(
table_alias=table_alias,
column_name=metric_column_name,
)
)
if aggregation_type:
select_expression: SqlExpressionNode = SqlFunctionExpression.build_expression_from_aggregation_type(
aggregation_type=aggregation_type, sql_column_expression=column_reference_expression
)
else:
select_expression = column_reference_expression

# At this point, the MetricSpec might have the alias in place of the element name, so we need to look
# back at where it was defined from to get the metric element name.
metric_reference = MetricReference(element_name=metric_instance.defined_from.metric_name)
input_measure = self._metric_lookup.configured_input_measure_for_metric(
metric_reference=metric_reference
)
if input_measure and input_measure.fill_nulls_with is not None:
select_expression = SqlAggregateFunctionExpression(
sql_function=SqlFunction.COALESCE,
sql_function_args=[
select_expression,
SqlStringExpression(str(input_measure.fill_nulls_with)),
],
)

select_columns.append(
SqlSelectColumn(
expr=select_expression,
column_alias=metric_column_name,
)
)
return select_columns

def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
"""Join computed metric datasets together to return a single dataset containing all metrics.
Expand Down Expand Up @@ -881,13 +825,13 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
), "Shouldn't have a CombineMetricsNode in the dataflow plan if there's only 1 parent."

parent_data_sets: List[AnnotatedSqlDataSet] = []
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]] = OrderedDict()
table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict()

for parent_node in node.parent_nodes:
parent_sql_data_set = parent_node.accept(self)
table_alias = self._next_unique_table_alias()
parent_data_sets.append(AnnotatedSqlDataSet(data_set=parent_sql_data_set, alias=table_alias))
table_alias_to_metric_instances[table_alias] = parent_sql_data_set.instance_set.metric_instances
table_alias_to_instance_set[table_alias] = parent_sql_data_set.instance_set

# When we create the components of the join that combines metrics it will be one of INNER, FULL OUTER,
# or CROSS JOIN. Order doesn't matter for these join types, so we will use the first element in the FROM
Expand Down Expand Up @@ -927,20 +871,24 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet:
output_instance_set = InstanceSet.merge([x.data_set.instance_set for x in parent_data_sets])
output_instance_set = output_instance_set.transform(ChangeAssociatedColumns(self._column_association_resolver))

metric_aggregation_type = AggregationType.MAX
metric_select_column_set = SelectColumnSet(
metric_columns=self._make_select_columns_for_multiple_metrics(
table_alias_to_metric_instances=table_alias_to_metric_instances,
aggregation_type=metric_aggregation_type,
aggregated_select_columns = SelectColumnSet()
for table_alias, instance_set in table_alias_to_instance_set.items():
aggregated_select_columns = aggregated_select_columns.merge(
instance_set.transform(
CreateSelectColumnForJoinOutputNode(
table_alias=table_alias,
column_resolver=self._column_association_resolver,
metric_lookup=self._metric_lookup,
)
)
)
)
linkable_select_column_set = linkable_spec_set.transform(
CreateSelectCoalescedColumnsForLinkableSpecs(
column_association_resolver=self._column_association_resolver,
table_aliases=[x.alias for x in parent_data_sets],
)
)
combined_select_column_set = linkable_select_column_set.merge(metric_select_column_set)
combined_select_column_set = linkable_select_column_set.merge(aggregated_select_columns)

return SqlDataSet(
instance_set=output_instance_set,
Expand Down
88 changes: 85 additions & 3 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from itertools import chain
from typing import Dict, List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import SemanticModelReference
from dbt_semantic_interfaces.references import MetricReference, SemanticModelReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from more_itertools import bucket
Expand All @@ -28,7 +29,7 @@
TimeDimensionInstance,
)
from metricflow.plan_conversion.select_column_gen import SelectColumnSet
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.protocols.semantics import MetricAccessor, SemanticModelAccessor
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import (
DimensionSpec,
Expand All @@ -43,11 +44,17 @@
TimeDimensionSpec,
)
from metricflow.sql.sql_exprs import (
SqlAggregateFunctionExpression,
SqlColumnReference,
SqlColumnReferenceExpression,
SqlExpressionNode,
SqlFunction,
SqlFunctionExpression,
SqlStringExpression,
)
from metricflow.sql.sql_plan import (
SqlSelectColumn,
)
from metricflow.sql.sql_plan import SqlSelectColumn

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -762,6 +769,81 @@ def transform(self, instance_set: InstanceSet) -> Tuple[SqlColumnReferenceExpres
)


class CreateSelectColumnForJoinOutputNode(InstanceSetTransform[SelectColumnSet]):
"""Create select column expressions for the instance for joining outputs.
It assumes that the column names of the instances are represented by the supplied column association resolver and
come from the given table alias.
"""

def __init__( # noqa: D
self,
table_alias: str,
column_resolver: ColumnAssociationResolver,
metric_lookup: MetricAccessor,
) -> None:
self._table_alias = table_alias
self._column_resolver = column_resolver
self._metric_lookup = metric_lookup

def _create_select_column(self, spec: InstanceSpec, fill_nulls_with: Optional[int] = None) -> SqlSelectColumn:
"""Creates the select column for the given spec and the fill value."""
column_name = self._column_resolver.resolve_spec(spec).column_name
column_reference_expression = SqlColumnReferenceExpression(
col_ref=SqlColumnReference(
table_alias=self._table_alias,
column_name=column_name,
)
)
select_expression: SqlExpressionNode = SqlFunctionExpression.build_expression_from_aggregation_type(
aggregation_type=AggregationType.MAX, sql_column_expression=column_reference_expression
)
if fill_nulls_with is not None:
select_expression = SqlAggregateFunctionExpression(
sql_function=SqlFunction.COALESCE,
sql_function_args=[
select_expression,
SqlStringExpression(str(fill_nulls_with)),
],
)
return SqlSelectColumn(
expr=select_expression,
column_alias=column_name,
)

def _create_select_columns_for_metrics(
self, metric_instances: Tuple[MetricInstance, ...]
) -> List[SqlSelectColumn]: # noqa: D
select_columns: List[SqlSelectColumn] = []
for metric_instance in metric_instances:
metric_reference = MetricReference(element_name=metric_instance.defined_from.metric_name)
input_measure = self._metric_lookup.configured_input_measure_for_metric(metric_reference=metric_reference)
fill_nulls_with: Optional[int] = None
if input_measure and input_measure.fill_nulls_with is not None:
fill_nulls_with = input_measure.fill_nulls_with
select_columns.append(
self._create_select_column(spec=metric_instance.spec, fill_nulls_with=fill_nulls_with)
)
return select_columns

def _create_select_columns_for_measures( # noqa: D
self, measure_instances: Tuple[MeasureInstance, ...]
) -> List[SqlSelectColumn]:
select_columns: List[SqlSelectColumn] = []
for measure_instance in measure_instances:
measure_spec = measure_instance.spec
select_columns.append(
self._create_select_column(spec=measure_spec, fill_nulls_with=measure_spec.fill_nulls_with)
)
return select_columns

def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D
return SelectColumnSet(
metric_columns=self._create_select_columns_for_metrics(instance_set.metric_instances),
measure_columns=self._create_select_columns_for_measures(instance_set.measure_instances),
)


class ChangeAssociatedColumns(InstanceSetTransform[InstanceSet]):
"""Change the columns associated with instances to the one specified by the resolver."""

Expand Down

0 comments on commit a858854

Please sign in to comment.