Skip to content

Commit

Permalink
Update MetricFlowQueryResolution.queried_semantic_models for Conver…
Browse files Browse the repository at this point in the history
…sion Metrics (#1204)

### Description

I noticed that the `queried_semantic_models` field was not correct for
conversion metrics and also for queries without any group-by-items. This
PR addresses those cases.

<!--- 
  Before requesting review, please make sure you have:
1. read [the contributing
guide](https://github.com/dbt-labs/metricflow/blob/main/CONTRIBUTING.md),
2. signed the
[CLA](https://docs.getdbt.com/docs/contributor-license-agreements)
3. run `changie new` to [create a changelog
entry](https://github.com/dbt-labs/metricflow/blob/main/CONTRIBUTING.md#adding-a-changelog-entry)
-->
  • Loading branch information
plypaul authored May 15, 2024
1 parent 30d6dc6 commit 03d7650
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations

import itertools
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Sequence, Tuple

from typing_extensions import override

from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.mf_dag import DagNode, NodeId
from metricflow_semantics.visitor import Visitable, VisitorOutputT

Expand Down Expand Up @@ -46,6 +51,22 @@ def ui_description(self) -> str:
def parent_nodes(self) -> Sequence[GroupByItemResolutionNode]: # noqa: D102
raise NotImplementedError

@abstractmethod
def _self_set(self) -> GroupByItemResolutionNodeSet:
"""Return a `GroupByItemResolutionNodeInclusiveAncestorSet` only containing self.
Use to simplify implementation of `inclusive_ancestors`
"""
raise NotImplementedError

def inclusive_ancestors(self) -> GroupByItemResolutionNodeSet:
"""Return a set containing itself and all its ancestors."""
return GroupByItemResolutionNodeSet.merge_iterable(
itertools.chain(
[self._self_set()], (parent_node.inclusive_ancestors() for parent_node in self.parent_nodes)
)
)


class GroupByItemResolutionNodeVisitor(Generic[VisitorOutputT], ABC):
"""Visitor for traversing GroupByItemResolutionNodes."""
Expand All @@ -65,3 +86,27 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> VisitorOut
@abstractmethod
def visit_query_node(self, node: QueryGroupByItemResolutionNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError


@dataclass(frozen=True)
class GroupByItemResolutionNodeSet(Mergeable):
"""Set containing nodes in a group-by-item resolution DAG."""

measure_nodes: Tuple[MeasureGroupByItemSourceNode, ...] = ()
no_metrics_query_nodes: Tuple[NoMetricsGroupByItemSourceNode, ...] = ()
metric_nodes: Tuple[MetricGroupByItemResolutionNode, ...] = ()
query_nodes: Tuple[QueryGroupByItemResolutionNode, ...] = ()

@override
def merge(self, other: GroupByItemResolutionNodeSet) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet(
measure_nodes=self.measure_nodes + other.measure_nodes,
no_metrics_query_nodes=self.no_metrics_query_nodes + other.no_metrics_query_nodes,
metric_nodes=self.metric_nodes + other.metric_nodes,
query_nodes=self.query_nodes + other.query_nodes,
)

@classmethod
@override
def empty_instance(cls) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet()
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
GroupByItemResolutionNode,
GroupByItemResolutionNodeSet,
GroupByItemResolutionNodeVisitor,
)
from metricflow_semantics.visitor import VisitorOutputT
Expand Down Expand Up @@ -78,3 +79,7 @@ def child_metric_reference(self) -> MetricReference:
@override
def ui_description(self) -> str:
return f"Measure({repr(self.measure_reference.element_name)})"

@override
def _self_set(self) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet(measure_nodes=(self,))
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metricflow_semantics.query.group_by_item.resolution_dag.input_metric_location import InputMetricDefinitionLocation
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
GroupByItemResolutionNode,
GroupByItemResolutionNodeSet,
GroupByItemResolutionNodeVisitor,
)
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.measure_source_node import (
Expand Down Expand Up @@ -86,3 +87,7 @@ def ui_description(self) -> str:
f"Metric({repr(self._metric_reference.element_name)}, "
f"input_metric_index={self._metric_input_location.input_metric_list_index})"
)

@override
def _self_set(self) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet(metric_nodes=(self,))
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
GroupByItemResolutionNode,
GroupByItemResolutionNodeSet,
GroupByItemResolutionNodeVisitor,
)
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.metric_resolution_node import (
Expand Down Expand Up @@ -44,3 +45,7 @@ def id_prefix(cls) -> IdPrefix:
@override
def ui_description(self) -> str:
return f"{self.__class__.__name__}()"

@override
def _self_set(self) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet(no_metrics_query_nodes=(self,))
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
GroupByItemResolutionNode,
GroupByItemResolutionNodeSet,
GroupByItemResolutionNodeVisitor,
)
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.metric_resolution_node import (
Expand Down Expand Up @@ -92,3 +93,7 @@ def where_filter_intersection(self) -> WhereFilterIntersection: # noqa: D102
@override
def ui_description(self) -> str:
return f"Query({repr([metric_reference.element_name for metric_reference in self._metrics_in_query])})"

@override
def _self_set(self) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet(query_nodes=(self,))
83 changes: 77 additions & 6 deletions metricflow-semantics/metricflow_semantics/query/query_resolver.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

import itertools
import logging
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Set, Tuple

from dbt_semantic_interfaces.references import MetricReference
from dbt_semantic_interfaces.references import MeasureReference, MetricReference, SemanticModelReference

from metricflow_semantics.mf_logging.pretty_print import mf_pformat
from metricflow_semantics.mf_logging.pretty_print import mf_pformat, mf_pformat_many
from metricflow_semantics.mf_logging.runtime import log_runtime
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.model.semantic_model_derivation import SemanticModelDerivation
from metricflow_semantics.model.semantics.linkable_element_set import LinkableElementSet
from metricflow_semantics.naming.metric_scheme import MetricNamingScheme
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_pattern_factory import (
Expand Down Expand Up @@ -61,6 +63,7 @@
OrderBySpec,
)
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.workarounds.reference import sorted_semantic_model_references

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -520,6 +523,46 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met
queried_semantic_models=(),
)

model_reference_set = set(resolve_group_by_item_result.linkable_element_set.derived_from_semantic_models)
for filter_spec_resolution in filter_spec_lookup.spec_resolutions:
model_reference_set.update(
set(filter_spec_resolution.resolved_linkable_element_set.derived_from_semantic_models)
)

# Collect all semantic models referenced by the query.
semantic_models_in_group_by_items = set(
resolve_group_by_item_result.linkable_element_set.derived_from_semantic_models
)
semantic_models_in_filters = set(
itertools.chain.from_iterable(
filter_spec_resolution.resolved_linkable_element_set.derived_from_semantic_models
for filter_spec_resolution in filter_spec_lookup.spec_resolutions
)
)
measure_semantic_models = self._get_models_for_measures(resolution_dag)

queried_semantic_models = set.union(
semantic_models_in_group_by_items, semantic_models_in_filters, measure_semantic_models
)
queried_semantic_models -= {SemanticModelDerivation.VIRTUAL_SEMANTIC_MODEL_REFERENCE}

# Sanity check to make sure that all queried semantic models are in the model.
models_not_in_manifest = queried_semantic_models - {
semantic_model.reference for semantic_model in self._manifest_lookup.semantic_manifest.semantic_models
}

# There are no known cases where this should happen, but adding this check just in case there's a bug where
# a measure alias is used incorrectly.
if len(models_not_in_manifest) > 0:
logger.error(
mf_pformat_many(
"Semantic references that aren't in the manifest were found in the set used in "
"a query. This is a bug, and to avoid potential issues, they will be filtered out.",
{"models_not_in_manifest": models_not_in_manifest},
)
)
queried_semantic_models -= models_not_in_manifest

return MetricFlowQueryResolution(
query_spec=MetricFlowQuerySpec(
metric_specs=metric_specs,
Expand All @@ -535,7 +578,35 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met
resolution_dag=resolution_dag,
filter_spec_lookup=filter_spec_lookup,
input_to_issue_set=issue_set_mapping,
queried_semantic_models=tuple(
resolve_group_by_item_result.linkable_element_set.derived_from_semantic_models
),
queried_semantic_models=sorted_semantic_model_references(queried_semantic_models),
)

def _get_models_for_measures(self, resolution_dag: GroupByItemResolutionDag) -> Set[SemanticModelReference]:
"""Return the semantic model references for the measures used in the query."""
resolution_dag_node_set = resolution_dag.sink_node.inclusive_ancestors()

measure_references: Set[MeasureReference] = set()

# Collect measures for metrics through the associated measure nodes.
for measure_node in resolution_dag_node_set.measure_nodes:
measure_references.add(measure_node.measure_reference)

# For conversion metrics, get the measures through the metric since those measures aren't in the DAG.
for metric_node in resolution_dag_node_set.metric_nodes:
metric = self._manifest_lookup.metric_lookup.get_metric(metric_node.metric_reference)
conversion_type_params = metric.type_params.conversion_type_params
if conversion_type_params is None:
continue

# The base measure should be in a DAG, but just in case.
measure_references.add(conversion_type_params.base_measure.measure_reference)
measure_references.add(conversion_type_params.conversion_measure.measure_reference)

model_references: Set[SemanticModelReference] = set()
for measure_reference in measure_references:
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure(
measure_reference
)
model_references.add(measure_semantic_model.reference)

return model_references
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.implementations.semantic_manifest import PydanticSemanticManifest
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.query.query_parser import MetricFlowQueryParser
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.test_helpers.snapshot_helpers import assert_object_snapshot_equal


@pytest.fixture(scope="module")
def query_parser(simple_semantic_manifest: PydanticSemanticManifest) -> MetricFlowQueryParser: # noqa: D103
return MetricFlowQueryParser(SemanticManifestLookup(simple_semantic_manifest))


def test_conversion_rate_with_constant_properties( # noqa: D103
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
query_parser: MetricFlowQueryParser,
) -> None:
result = query_parser.parse_and_validate_query(
metric_names=("visit_buy_conversion_rate_by_session",),
group_by_names=("visit__referrer_id", "user__home_state_latest", "metric_time"),
)

assert_object_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
obj=result,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
ParseQueryResult(
query_spec=MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name='visit_buy_conversion_rate_by_session'),),
dimension_specs=(
DimensionSpec(
element_name='referrer_id',
entity_links=(EntityReference(element_name='visit'),),
),
DimensionSpec(
element_name='home_state_latest',
entity_links=(EntityReference(element_name='user'),),
),
),
time_dimension_specs=(TimeDimensionSpec(element_name='metric_time', time_granularity=DAY),),
filter_intersection=PydanticWhereFilterIntersection(),
filter_spec_resolution_lookup=FilterSpecResolutionLookUp(),
min_max_only=False,
),
queried_semantic_models=(
SemanticModelReference(semantic_model_name='buys_source'),
SemanticModelReference(semantic_model_name='users_latest'),
SemanticModelReference(semantic_model_name='visits_source'),
),
)

0 comments on commit 03d7650

Please sign in to comment.