-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fill nulls for multi-metric queries #850
Changes from 8 commits
4de5d57
5cd869e
e8e1b17
304b730
6abdf1c
affea74
b8bc8dc
aa32059
6545bc1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
kind: Features | ||
body: Fill nulls for multi-metric queries | ||
time: 2023-11-06T15:00:14.37926-08:00 | ||
custom: | ||
Author: courtneyholcomb | ||
Issue: "850" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import Dict, FrozenSet, List, Sequence | ||
from typing import Dict, FrozenSet, List, Optional, Sequence | ||
|
||
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted | ||
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilterIntersection | ||
from dbt_semantic_interfaces.implementations.metric import PydanticMetricTimeWindow | ||
from dbt_semantic_interfaces.protocols import WhereFilter | ||
from dbt_semantic_interfaces.protocols.metric import Metric, MetricType | ||
from dbt_semantic_interfaces.protocols.metric import Metric, MetricInputMeasure, MetricType | ||
from dbt_semantic_interfaces.protocols.semantic_manifest import SemanticManifest | ||
from dbt_semantic_interfaces.references import MetricReference | ||
|
||
|
@@ -105,6 +106,23 @@ def add_metric(self, metric: Metric) -> None: | |
) | ||
self._metrics[metric_reference] = metric | ||
|
||
def yaml_input_measure_for_metric(self, metric_reference: MetricReference) -> Optional[MetricInputMeasure]: | ||
"""Get input measure defined in the metric YAML, if exists. | ||
|
||
When SemanticModel is constructed, input measures from input metrics are added to the list of input measures | ||
for a metric. Here, use rules about metric types to determine which input measures were defined in the YAML: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then here we can replace |
||
- Simple & cumulative metrics require one input measure, and can't take any input metrics. | ||
- Derived & ratio metrics take no input measures, only input metrics. | ||
""" | ||
metric = self.get_metric(metric_reference=metric_reference) | ||
if metric.type is MetricType.CUMULATIVE or metric.type is MetricType.SIMPLE: | ||
assert len(metric.input_measures) == 1, "Simple and cumulative metrics should have one input measure." | ||
return metric.input_measures[0] | ||
elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED: | ||
return None | ||
else: | ||
assert_values_exhausted(metric.type) | ||
|
||
def measures_for_metric( | ||
self, | ||
metric_reference: MetricReference, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,11 +2,11 @@ | |
|
||
import logging | ||
from collections import OrderedDict | ||
from typing import List, Optional, Sequence, Union | ||
from typing import List, Optional, Sequence, Tuple, Union | ||
|
||
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted | ||
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType | ||
from dbt_semantic_interfaces.references import MetricModelReference | ||
from dbt_semantic_interfaces.references import MetricModelReference, MetricReference | ||
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType | ||
|
||
from metricflow.aggregation_properties import AggregationState | ||
|
@@ -706,7 +706,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: | |
metric_instances.append( | ||
MetricInstance( | ||
associated_columns=(output_column_association,), | ||
defined_from=(MetricModelReference(metric_name=metric_spec.element_name),), | ||
defined_from=MetricModelReference(metric_name=metric_spec.element_name), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LOL, thanks. We have so many of these kind of "let's allow a sequence of things and then have exactly one callsite that manually enforces there's only one" interfaces lying around... |
||
spec=metric_spec.alias_spec, | ||
) | ||
) | ||
|
@@ -862,24 +862,25 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet: | |
), | ||
) | ||
|
||
def _make_select_columns_for_metrics( | ||
def _make_select_columns_for_multiple_metrics( | ||
self, | ||
table_alias_to_metric_specs: OrderedDict[str, Sequence[MetricSpec]], | ||
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]], | ||
aggregation_type: Optional[AggregationType], | ||
) -> List[SqlSelectColumn]: | ||
"""Creates select columns that get the given metric using the given table alias. | ||
|
||
e.g. | ||
|
||
with table_alias_to_metric_specs = {"a": MetricSpec(element_name="bookings")} | ||
with table_alias_to_metric_instances = {"a": MetricSpec(element_name="bookings")} | ||
|
||
-> | ||
|
||
a.bookings AS bookings | ||
""" | ||
select_columns = [] | ||
for table_alias, metric_specs in table_alias_to_metric_specs.items(): | ||
for metric_spec in metric_specs: | ||
for table_alias, metric_instances in table_alias_to_metric_instances.items(): | ||
for metric_instance in metric_instances: | ||
metric_spec = metric_instance.spec | ||
metric_column_name = self._column_association_resolver.resolve_spec(metric_spec).column_name | ||
column_reference_expression = SqlColumnReferenceExpression( | ||
col_ref=SqlColumnReference( | ||
|
@@ -894,6 +895,19 @@ def _make_select_columns_for_metrics( | |
else: | ||
select_expression = column_reference_expression | ||
|
||
# At this point, the MetricSpec might have the alias in place of the element name, so we need to look | ||
# back at where it was defined from to get the metric element name. | ||
metric_reference = MetricReference(element_name=metric_instance.defined_from.metric_name) | ||
input_measure = self._metric_lookup.yaml_input_measure_for_metric(metric_reference=metric_reference) | ||
if input_measure and input_measure.fill_nulls_with is not None: | ||
select_expression = SqlAggregateFunctionExpression( | ||
sql_function=SqlFunction.COALESCE, | ||
sql_function_args=[ | ||
select_expression, | ||
SqlStringExpression(str(input_measure.fill_nulls_with)), | ||
], | ||
) | ||
|
||
select_columns.append( | ||
SqlSelectColumn( | ||
expr=select_expression, | ||
|
@@ -938,13 +952,13 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet: | |
), "Shouldn't have a CombineMetricsNode in the dataflow plan if there's only 1 parent." | ||
|
||
parent_data_sets: List[AnnotatedSqlDataSet] = [] | ||
table_alias_to_metric_specs: OrderedDict[str, Sequence[MetricSpec]] = OrderedDict() | ||
table_alias_to_metric_instances: OrderedDict[str, Tuple[MetricInstance, ...]] = OrderedDict() | ||
|
||
for parent_node in node.parent_nodes: | ||
parent_sql_data_set = parent_node.accept(self) | ||
table_alias = self._next_unique_table_alias() | ||
parent_data_sets.append(AnnotatedSqlDataSet(data_set=parent_sql_data_set, alias=table_alias)) | ||
table_alias_to_metric_specs[table_alias] = parent_sql_data_set.instance_set.spec_set.metric_specs | ||
table_alias_to_metric_instances[table_alias] = parent_sql_data_set.instance_set.metric_instances | ||
|
||
# When we create the components of the join that combines metrics it will be one of INNER, FULL OUTER, | ||
# or CROSS JOIN. Order doesn't matter for these join types, so we will use the first element in the FROM | ||
|
@@ -986,8 +1000,9 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet: | |
|
||
metric_aggregation_type = AggregationType.MAX | ||
metric_select_column_set = SelectColumnSet( | ||
metric_columns=self._make_select_columns_for_metrics( | ||
table_alias_to_metric_specs, aggregation_type=metric_aggregation_type | ||
metric_columns=self._make_select_columns_for_multiple_metrics( | ||
table_alias_to_metric_instances=table_alias_to_metric_instances, | ||
aggregation_type=metric_aggregation_type, | ||
) | ||
) | ||
linkable_select_column_set = linkable_spec_set.transform( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -436,7 +436,7 @@ def from_reference(reference: MetricReference) -> MetricSpec: | |
|
||
@property | ||
def alias_spec(self) -> MetricSpec: | ||
"""Returns a MetricSpec represneting the alias state.""" | ||
"""Returns a MetricSpec representing the alias state.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😸 |
||
return MetricSpec( | ||
element_name=self.alias or self.element_name, | ||
constraint=self.constraint, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from __future__ import annotations | ||
|
||
import datetime | ||
|
||
import pytest | ||
from _pytest.fixtures import FixtureRequest | ||
|
||
from metricflow.engine.metricflow_engine import MetricFlowQueryRequest | ||
from metricflow.protocols.sql_client import SqlClient | ||
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState | ||
from metricflow.test.integration.conftest import IntegrationTestHelpers | ||
from metricflow.test.snapshot_utils import assert_object_snapshot_equal | ||
|
||
|
||
@pytest.mark.sql_engine_snapshot | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These snapshot tests are great, thanks for doing this! |
||
def test_simple_fill_nulls_with_0_metric_time( # noqa: D | ||
request: FixtureRequest, | ||
mf_test_session_state: MetricFlowTestSessionState, | ||
sql_client: SqlClient, | ||
it_helpers: IntegrationTestHelpers, | ||
) -> None: | ||
query_result = it_helpers.mf_engine.query( | ||
MetricFlowQueryRequest.create_with_random_request_id( | ||
metric_names=["bookings_fill_nulls_with_0"], | ||
group_by_names=["metric_time"], | ||
order_by_names=["metric_time"], | ||
time_constraint_start=datetime.datetime(2019, 11, 27), | ||
time_constraint_end=datetime.datetime(2020, 1, 5), | ||
) | ||
) | ||
assert query_result.result_df is not None, "Unexpected empty result." | ||
|
||
assert_object_snapshot_equal( | ||
request=request, | ||
mf_test_session_state=mf_test_session_state, | ||
obj_id="query_output", | ||
obj=query_result.result_df.to_string(), | ||
sql_client=sql_client, | ||
) | ||
|
||
|
||
@pytest.mark.sql_engine_snapshot | ||
def test_simple_fill_nulls_with_0_month( # noqa: D | ||
request: FixtureRequest, | ||
mf_test_session_state: MetricFlowTestSessionState, | ||
sql_client: SqlClient, | ||
it_helpers: IntegrationTestHelpers, | ||
) -> None: | ||
query_result = it_helpers.mf_engine.query( | ||
MetricFlowQueryRequest.create_with_random_request_id( | ||
metric_names=["bookings_fill_nulls_with_0"], | ||
group_by_names=["metric_time__month"], | ||
order_by_names=["metric_time__month"], | ||
time_constraint_start=datetime.datetime(2019, 1, 1), | ||
time_constraint_end=datetime.datetime(2020, 12, 1), | ||
) | ||
) | ||
assert query_result.result_df is not None, "Unexpected empty result." | ||
|
||
assert_object_snapshot_equal( | ||
request=request, | ||
mf_test_session_state=mf_test_session_state, | ||
obj_id="query_output", | ||
obj=query_result.result_df.to_string(), | ||
sql_client=sql_client, | ||
) | ||
|
||
|
||
@pytest.mark.sql_engine_snapshot | ||
def test_simple_join_to_time_spine( # noqa: D | ||
request: FixtureRequest, | ||
mf_test_session_state: MetricFlowTestSessionState, | ||
sql_client: SqlClient, | ||
it_helpers: IntegrationTestHelpers, | ||
) -> None: | ||
query_result = it_helpers.mf_engine.query( | ||
MetricFlowQueryRequest.create_with_random_request_id( | ||
metric_names=["bookings_join_to_time_spine"], | ||
group_by_names=["metric_time"], | ||
time_constraint_start=datetime.datetime(2019, 11, 27), | ||
time_constraint_end=datetime.datetime(2020, 1, 5), | ||
order_by_names=["metric_time"], | ||
) | ||
) | ||
assert query_result.result_df is not None, "Unexpected empty result." | ||
|
||
assert_object_snapshot_equal( | ||
request=request, | ||
mf_test_session_state=mf_test_session_state, | ||
obj_id="query_output", | ||
obj=query_result.result_df.to_string(), | ||
sql_client=sql_client, | ||
) | ||
|
||
|
||
@pytest.mark.sql_engine_snapshot | ||
def test_fill_nulls_with_0_multi_metric_query( # noqa: D | ||
request: FixtureRequest, | ||
mf_test_session_state: MetricFlowTestSessionState, | ||
sql_client: SqlClient, | ||
it_helpers: IntegrationTestHelpers, | ||
) -> None: | ||
query_result = it_helpers.mf_engine.query( | ||
MetricFlowQueryRequest.create_with_random_request_id( | ||
metric_names=["bookings_fill_nulls_with_0", "views"], | ||
group_by_names=["metric_time"], | ||
order_by_names=["metric_time"], | ||
time_constraint_start=datetime.datetime(2019, 11, 27), | ||
time_constraint_end=datetime.datetime(2020, 1, 5), | ||
) | ||
) | ||
assert query_result.result_df is not None, "Unexpected empty result." | ||
|
||
assert_object_snapshot_equal( | ||
request=request, | ||
mf_test_session_state=mf_test_session_state, | ||
obj_id="query_output", | ||
obj=query_result.result_df.to_string(), | ||
sql_client=sql_client, | ||
) | ||
|
||
|
||
@pytest.mark.sql_engine_snapshot | ||
def test_fill_nulls_with_0_multi_metric_query_with_categorical_dimension( # noqa: D | ||
request: FixtureRequest, | ||
mf_test_session_state: MetricFlowTestSessionState, | ||
sql_client: SqlClient, | ||
it_helpers: IntegrationTestHelpers, | ||
) -> None: | ||
query_result = it_helpers.mf_engine.query( | ||
MetricFlowQueryRequest.create_with_random_request_id( | ||
metric_names=["bookings_fill_nulls_with_0_without_time_spine", "views"], | ||
group_by_names=["metric_time", "listing__is_lux_latest"], | ||
order_by_names=["metric_time", "listing__is_lux_latest"], | ||
) | ||
) | ||
assert query_result.result_df is not None, "Unexpected empty result." | ||
|
||
assert_object_snapshot_equal( | ||
request=request, | ||
mf_test_session_state=mf_test_session_state, | ||
obj_id="query_output", | ||
obj=query_result.result_df.to_string(), | ||
sql_client=sql_client, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Super nitty but can we rename this away from yaml? YAML makes me grumpy. More importantly, there's no guarantee that the input is defined in YAML (although in practice that is how it works today).
Maybe something like
direct_input_measure_for_metric
orconfigured_input_measure_for_metric
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated!