diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/base_node.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/base_node.py index ac3fef5c31..db6f65e646 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/base_node.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/base_node.py @@ -1,8 +1,13 @@ from __future__ import annotations +import itertools from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generic, Sequence, Tuple +from typing_extensions import override + +from metricflow_semantics.collection_helpers.merger import Mergeable from metricflow_semantics.dag.mf_dag import DagNode, NodeId from metricflow_semantics.visitor import Visitable, VisitorOutputT @@ -46,6 +51,22 @@ def ui_description(self) -> str: def parent_nodes(self) -> Sequence[GroupByItemResolutionNode]: # noqa: D102 raise NotImplementedError + @abstractmethod + def _self_set(self) -> GroupByItemResolutionNodeSet: + """Return a `GroupByItemResolutionNodeInclusiveAncestorSet` only containing self. + + Use to simplify implementation of `inclusive_ancestors` + """ + raise NotImplementedError + + def inclusive_ancestors(self) -> GroupByItemResolutionNodeSet: + """Return a set containing itself and all its ancestors.""" + return GroupByItemResolutionNodeSet.merge_iterable( + itertools.chain( + [self._self_set()], (parent_node.inclusive_ancestors() for parent_node in self.parent_nodes) + ) + ) + class GroupByItemResolutionNodeVisitor(Generic[VisitorOutputT], ABC): """Visitor for traversing GroupByItemResolutionNodes.""" @@ -65,3 +86,27 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> VisitorOut @abstractmethod def visit_query_node(self, node: QueryGroupByItemResolutionNode) -> VisitorOutputT: # noqa: D102 raise NotImplementedError + + +@dataclass(frozen=True) +class GroupByItemResolutionNodeSet(Mergeable): + """Set containing nodes in a group-by-item resolution DAG.""" + + measure_nodes: Tuple[MeasureGroupByItemSourceNode, ...] = () + no_metrics_query_nodes: Tuple[NoMetricsGroupByItemSourceNode, ...] = () + metric_nodes: Tuple[MetricGroupByItemResolutionNode, ...] = () + query_nodes: Tuple[QueryGroupByItemResolutionNode, ...] = () + + @override + def merge(self, other: GroupByItemResolutionNodeSet) -> GroupByItemResolutionNodeSet: + return GroupByItemResolutionNodeSet( + measure_nodes=self.measure_nodes + other.measure_nodes, + no_metrics_query_nodes=self.no_metrics_query_nodes + other.no_metrics_query_nodes, + metric_nodes=self.metric_nodes + other.metric_nodes, + query_nodes=self.query_nodes + other.query_nodes, + ) + + @classmethod + @override + def empty_instance(cls) -> GroupByItemResolutionNodeSet: + return GroupByItemResolutionNodeSet() diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/measure_source_node.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/measure_source_node.py index 34502b4adb..5a2dd593ae 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/measure_source_node.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/measure_source_node.py @@ -9,6 +9,7 @@ from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import ( GroupByItemResolutionNode, + GroupByItemResolutionNodeSet, GroupByItemResolutionNodeVisitor, ) from metricflow_semantics.visitor import VisitorOutputT @@ -78,3 +79,7 @@ def child_metric_reference(self) -> MetricReference: @override def ui_description(self) -> str: return f"Measure({repr(self.measure_reference.element_name)})" + + @override + def _self_set(self) -> GroupByItemResolutionNodeSet: + return GroupByItemResolutionNodeSet(measure_nodes=(self,)) diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/metric_resolution_node.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/metric_resolution_node.py index 914cc69946..4c3bd9bba5 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/metric_resolution_node.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/metric_resolution_node.py @@ -10,6 +10,7 @@ from metricflow_semantics.query.group_by_item.resolution_dag.input_metric_location import InputMetricDefinitionLocation from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import ( GroupByItemResolutionNode, + GroupByItemResolutionNodeSet, GroupByItemResolutionNodeVisitor, ) from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.measure_source_node import ( @@ -86,3 +87,7 @@ def ui_description(self) -> str: f"Metric({repr(self._metric_reference.element_name)}, " f"input_metric_index={self._metric_input_location.input_metric_list_index})" ) + + @override + def _self_set(self) -> GroupByItemResolutionNodeSet: + return GroupByItemResolutionNodeSet(metric_nodes=(self,)) diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/no_metrics_query_source_node.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/no_metrics_query_source_node.py index d1ba72470a..015f592b1d 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/no_metrics_query_source_node.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/no_metrics_query_source_node.py @@ -7,6 +7,7 @@ from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import ( GroupByItemResolutionNode, + GroupByItemResolutionNodeSet, GroupByItemResolutionNodeVisitor, ) from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.metric_resolution_node import ( @@ -44,3 +45,7 @@ def id_prefix(cls) -> IdPrefix: @override def ui_description(self) -> str: return f"{self.__class__.__name__}()" + + @override + def _self_set(self) -> GroupByItemResolutionNodeSet: + return GroupByItemResolutionNodeSet(no_metrics_query_nodes=(self,)) diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/query_resolution_node.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/query_resolution_node.py index 684da6a2d6..f4b8593608 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/query_resolution_node.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/resolution_dag/resolution_nodes/query_resolution_node.py @@ -10,6 +10,7 @@ from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import ( GroupByItemResolutionNode, + GroupByItemResolutionNodeSet, GroupByItemResolutionNodeVisitor, ) from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.metric_resolution_node import ( @@ -92,3 +93,7 @@ def where_filter_intersection(self) -> WhereFilterIntersection: # noqa: D102 @override def ui_description(self) -> str: return f"Query({repr([metric_reference.element_name for metric_reference in self._metrics_in_query])})" + + @override + def _self_set(self) -> GroupByItemResolutionNodeSet: + return GroupByItemResolutionNodeSet(query_nodes=(self,)) diff --git a/metricflow-semantics/metricflow_semantics/query/query_resolver.py b/metricflow-semantics/metricflow_semantics/query/query_resolver.py index 1a41d3418c..a1517b5462 100644 --- a/metricflow-semantics/metricflow_semantics/query/query_resolver.py +++ b/metricflow-semantics/metricflow_semantics/query/query_resolver.py @@ -1,14 +1,16 @@ from __future__ import annotations +import itertools import logging from dataclasses import dataclass -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Set, Tuple -from dbt_semantic_interfaces.references import MetricReference +from dbt_semantic_interfaces.references import MeasureReference, MetricReference, SemanticModelReference -from metricflow_semantics.mf_logging.pretty_print import mf_pformat +from metricflow_semantics.mf_logging.pretty_print import mf_pformat, mf_pformat_many from metricflow_semantics.mf_logging.runtime import log_runtime from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup +from metricflow_semantics.model.semantic_model_derivation import SemanticModelDerivation from metricflow_semantics.model.semantics.linkable_element_set import LinkableElementSet from metricflow_semantics.naming.metric_scheme import MetricNamingScheme from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_pattern_factory import ( @@ -61,6 +63,7 @@ OrderBySpec, ) from metricflow_semantics.specs.spec_set import group_specs_by_type +from metricflow_semantics.workarounds.reference import sorted_semantic_model_references logger = logging.getLogger(__name__) @@ -520,6 +523,46 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met queried_semantic_models=(), ) + model_reference_set = set(resolve_group_by_item_result.linkable_element_set.derived_from_semantic_models) + for filter_spec_resolution in filter_spec_lookup.spec_resolutions: + model_reference_set.update( + set(filter_spec_resolution.resolved_linkable_element_set.derived_from_semantic_models) + ) + + # Collect all semantic models referenced by the query. + semantic_models_in_group_by_items = set( + resolve_group_by_item_result.linkable_element_set.derived_from_semantic_models + ) + semantic_models_in_filters = set( + itertools.chain.from_iterable( + filter_spec_resolution.resolved_linkable_element_set.derived_from_semantic_models + for filter_spec_resolution in filter_spec_lookup.spec_resolutions + ) + ) + measure_semantic_models = self._get_models_for_measures(resolution_dag) + + queried_semantic_models = set.union( + semantic_models_in_group_by_items, semantic_models_in_filters, measure_semantic_models + ) + queried_semantic_models -= {SemanticModelDerivation.VIRTUAL_SEMANTIC_MODEL_REFERENCE} + + # Sanity check to make sure that all queried semantic models are in the model. + models_not_in_manifest = queried_semantic_models - { + semantic_model.reference for semantic_model in self._manifest_lookup.semantic_manifest.semantic_models + } + + # There are no known cases where this should happen, but adding this check just in case there's a bug where + # a measure alias is used incorrectly. + if len(models_not_in_manifest) > 0: + logger.error( + mf_pformat_many( + "Semantic references that aren't in the manifest were found in the set used in " + "a query. This is a bug, and to avoid potential issues, they will be filtered out.", + {"models_not_in_manifest": models_not_in_manifest}, + ) + ) + queried_semantic_models -= models_not_in_manifest + return MetricFlowQueryResolution( query_spec=MetricFlowQuerySpec( metric_specs=metric_specs, @@ -535,7 +578,35 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met resolution_dag=resolution_dag, filter_spec_lookup=filter_spec_lookup, input_to_issue_set=issue_set_mapping, - queried_semantic_models=tuple( - resolve_group_by_item_result.linkable_element_set.derived_from_semantic_models - ), + queried_semantic_models=sorted_semantic_model_references(queried_semantic_models), ) + + def _get_models_for_measures(self, resolution_dag: GroupByItemResolutionDag) -> Set[SemanticModelReference]: + """Return the semantic model references for the measures used in the query.""" + resolution_dag_node_set = resolution_dag.sink_node.inclusive_ancestors() + + measure_references: Set[MeasureReference] = set() + + # Collect measures for metrics through the associated measure nodes. + for measure_node in resolution_dag_node_set.measure_nodes: + measure_references.add(measure_node.measure_reference) + + # For conversion metrics, get the measures through the metric since those measures aren't in the DAG. + for metric_node in resolution_dag_node_set.metric_nodes: + metric = self._manifest_lookup.metric_lookup.get_metric(metric_node.metric_reference) + conversion_type_params = metric.type_params.conversion_type_params + if conversion_type_params is None: + continue + + # The base measure should be in a DAG, but just in case. + measure_references.add(conversion_type_params.base_measure.measure_reference) + measure_references.add(conversion_type_params.conversion_measure.measure_reference) + + model_references: Set[SemanticModelReference] = set() + for measure_reference in measure_references: + measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure( + measure_reference + ) + model_references.add(measure_semantic_model.reference) + + return model_references diff --git a/metricflow-semantics/tests_metricflow_semantics/query/test_conversion_metrics.py b/metricflow-semantics/tests_metricflow_semantics/query/test_conversion_metrics.py new file mode 100644 index 0000000000..fbe6176856 --- /dev/null +++ b/metricflow-semantics/tests_metricflow_semantics/query/test_conversion_metrics.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import pytest +from _pytest.fixtures import FixtureRequest +from dbt_semantic_interfaces.implementations.semantic_manifest import PydanticSemanticManifest +from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup +from metricflow_semantics.query.query_parser import MetricFlowQueryParser +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration +from metricflow_semantics.test_helpers.snapshot_helpers import assert_object_snapshot_equal + + +@pytest.fixture(scope="module") +def query_parser(simple_semantic_manifest: PydanticSemanticManifest) -> MetricFlowQueryParser: # noqa: D103 + return MetricFlowQueryParser(SemanticManifestLookup(simple_semantic_manifest)) + + +def test_conversion_rate_with_constant_properties( # noqa: D103 + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + query_parser: MetricFlowQueryParser, +) -> None: + result = query_parser.parse_and_validate_query( + metric_names=("visit_buy_conversion_rate_by_session",), + group_by_names=("visit__referrer_id", "user__home_state_latest", "metric_time"), + ) + + assert_object_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + obj=result, + ) diff --git a/metricflow-semantics/tests_metricflow_semantics/snapshots/test_conversion_metrics.py/ParseQueryResult/test_conversion_rate_with_constant_properties__result.txt b/metricflow-semantics/tests_metricflow_semantics/snapshots/test_conversion_metrics.py/ParseQueryResult/test_conversion_rate_with_constant_properties__result.txt new file mode 100644 index 0000000000..7eec0e4d6c --- /dev/null +++ b/metricflow-semantics/tests_metricflow_semantics/snapshots/test_conversion_metrics.py/ParseQueryResult/test_conversion_rate_with_constant_properties__result.txt @@ -0,0 +1,24 @@ +ParseQueryResult( + query_spec=MetricFlowQuerySpec( + metric_specs=(MetricSpec(element_name='visit_buy_conversion_rate_by_session'),), + dimension_specs=( + DimensionSpec( + element_name='referrer_id', + entity_links=(EntityReference(element_name='visit'),), + ), + DimensionSpec( + element_name='home_state_latest', + entity_links=(EntityReference(element_name='user'),), + ), + ), + time_dimension_specs=(TimeDimensionSpec(element_name='metric_time', time_granularity=DAY),), + filter_intersection=PydanticWhereFilterIntersection(), + filter_spec_resolution_lookup=FilterSpecResolutionLookUp(), + min_max_only=False, + ), + queried_semantic_models=( + SemanticModelReference(semantic_model_name='buys_source'), + SemanticModelReference(semantic_model_name='users_latest'), + SemanticModelReference(semantic_model_name='visits_source'), + ), +)