Skip to content

Commit

Permalink
/* PR_START 03 */ Update logic for queried semantic models in query r…
Browse files Browse the repository at this point in the history
…esolver.
  • Loading branch information
plypaul committed May 15, 2024
1 parent b865f97 commit 7a8c634
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations

import itertools
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Sequence, Tuple

from typing_extensions import override

from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.mf_dag import DagNode, NodeId
from metricflow_semantics.visitor import Visitable, VisitorOutputT

Expand Down Expand Up @@ -46,6 +51,22 @@ def ui_description(self) -> str:
def parent_nodes(self) -> Sequence[GroupByItemResolutionNode]: # noqa: D102
raise NotImplementedError

@abstractmethod
def _self_set(self) -> GroupByItemResolutionNodeSet:
"""Return a `GroupByItemResolutionNodeInclusiveAncestorSet` only containing self.
Use to simplify implementation of `inclusive_ancestors`
"""
raise NotImplementedError

def inclusive_ancestors(self) -> GroupByItemResolutionNodeSet:
"""Return a set containing itself and all its ancestors."""
return GroupByItemResolutionNodeSet.merge_iterable(
itertools.chain(
[self._self_set()], (parent_node.inclusive_ancestors() for parent_node in self.parent_nodes)
)
)


class GroupByItemResolutionNodeVisitor(Generic[VisitorOutputT], ABC):
"""Visitor for traversing GroupByItemResolutionNodes."""
Expand All @@ -65,3 +86,27 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> VisitorOut
@abstractmethod
def visit_query_node(self, node: QueryGroupByItemResolutionNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError


@dataclass(frozen=True)
class GroupByItemResolutionNodeSet(Mergeable):
"""Set containing nodes in a group-by-item resolution DAG."""

measure_nodes: Tuple[MeasureGroupByItemSourceNode, ...] = ()
no_metrics_query_nodes: Tuple[NoMetricsGroupByItemSourceNode, ...] = ()
metric_nodes: Tuple[MetricGroupByItemResolutionNode, ...] = ()
query_nodes: Tuple[QueryGroupByItemResolutionNode, ...] = ()

@override
def merge(self, other: GroupByItemResolutionNodeSet) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet(
measure_nodes=self.measure_nodes + other.measure_nodes,
no_metrics_query_nodes=self.no_metrics_query_nodes + other.no_metrics_query_nodes,
metric_nodes=self.metric_nodes + other.metric_nodes,
query_nodes=self.query_nodes + other.query_nodes,
)

@classmethod
@override
def empty_instance(cls) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet()
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
GroupByItemResolutionNode,
GroupByItemResolutionNodeSet,
GroupByItemResolutionNodeVisitor,
)
from metricflow_semantics.visitor import VisitorOutputT
Expand Down Expand Up @@ -78,3 +79,7 @@ def child_metric_reference(self) -> MetricReference:
@override
def ui_description(self) -> str:
return f"Measure({repr(self.measure_reference.element_name)})"

@override
def _self_set(self) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet(measure_nodes=(self,))
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metricflow_semantics.query.group_by_item.resolution_dag.input_metric_location import InputMetricDefinitionLocation
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
GroupByItemResolutionNode,
GroupByItemResolutionNodeSet,
GroupByItemResolutionNodeVisitor,
)
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.measure_source_node import (
Expand Down Expand Up @@ -86,3 +87,7 @@ def ui_description(self) -> str:
f"Metric({repr(self._metric_reference.element_name)}, "
f"input_metric_index={self._metric_input_location.input_metric_list_index})"
)

@override
def _self_set(self) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet(metric_nodes=(self,))
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
GroupByItemResolutionNode,
GroupByItemResolutionNodeSet,
GroupByItemResolutionNodeVisitor,
)
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.metric_resolution_node import (
Expand Down Expand Up @@ -44,3 +45,7 @@ def id_prefix(cls) -> IdPrefix:
@override
def ui_description(self) -> str:
return f"{self.__class__.__name__}()"

@override
def _self_set(self) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet(no_metrics_query_nodes=(self,))
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
GroupByItemResolutionNode,
GroupByItemResolutionNodeSet,
GroupByItemResolutionNodeVisitor,
)
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.metric_resolution_node import (
Expand Down Expand Up @@ -92,3 +93,7 @@ def where_filter_intersection(self) -> WhereFilterIntersection: # noqa: D102
@override
def ui_description(self) -> str:
return f"Query({repr([metric_reference.element_name for metric_reference in self._metrics_in_query])})"

@override
def _self_set(self) -> GroupByItemResolutionNodeSet:
return GroupByItemResolutionNodeSet(query_nodes=(self,))
82 changes: 76 additions & 6 deletions metricflow-semantics/metricflow_semantics/query/query_resolver.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

import itertools
import logging
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Set, Tuple

from dbt_semantic_interfaces.references import MetricReference
from dbt_semantic_interfaces.references import MeasureReference, MetricReference, SemanticModelReference

from metricflow_semantics.mf_logging.pretty_print import mf_pformat
from metricflow_semantics.mf_logging.pretty_print import mf_pformat, mf_pformat_many
from metricflow_semantics.mf_logging.runtime import log_runtime
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.model.semantic_model_derivation import SemanticModelDerivation
from metricflow_semantics.model.semantics.linkable_element_set import LinkableElementSet
from metricflow_semantics.naming.metric_scheme import MetricNamingScheme
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_pattern_factory import (
Expand Down Expand Up @@ -61,6 +63,7 @@
OrderBySpec,
)
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.workarounds.reference import sorted_semantic_model_references

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -520,6 +523,46 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met
queried_semantic_models=(),
)

model_reference_set = set(resolve_group_by_item_result.linkable_element_set.derived_from_semantic_models)
for filter_spec_resolution in filter_spec_lookup.spec_resolutions:
model_reference_set.update(
set(filter_spec_resolution.resolved_linkable_element_set.derived_from_semantic_models)
)

# Collect all semantic models referenced by the query.
semantic_models_in_group_by_items = set(
resolve_group_by_item_result.linkable_element_set.derived_from_semantic_models
)
semantic_models_in_filters = set(
itertools.chain.from_iterable(
filter_spec_resolution.resolved_linkable_element_set.derived_from_semantic_models
for filter_spec_resolution in filter_spec_lookup.spec_resolutions
)
)
measure_semantic_models = self._get_models_for_measures(resolution_dag)

queried_semantic_models = set.union(
semantic_models_in_group_by_items, semantic_models_in_filters, measure_semantic_models
)
queried_semantic_models -= {SemanticModelDerivation.VIRTUAL_SEMANTIC_MODEL_REFERENCE}

# Sanity check to make sure that all queried semantic models are in the model.
models_not_in_manifest = queried_semantic_models - {
semantic_model.reference for semantic_model in self._manifest_lookup.semantic_manifest.semantic_models
}

# There are no known cases where this should happen, but adding this check just in case there's a bug where
# a measure alias is used incorrectly.
if len(models_not_in_manifest) > 0:
logger.error(
mf_pformat_many(
"Semantic references that aren't in the manifest were found in the set used in "
"a query. This is a bug, and to avoid potential issues, they will be filtered out.",
{"models_not_in_manifest": models_not_in_manifest},
)
)
queried_semantic_models -= models_not_in_manifest

return MetricFlowQueryResolution(
query_spec=MetricFlowQuerySpec(
metric_specs=metric_specs,
Expand All @@ -535,7 +578,34 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met
resolution_dag=resolution_dag,
filter_spec_lookup=filter_spec_lookup,
input_to_issue_set=issue_set_mapping,
queried_semantic_models=tuple(
resolve_group_by_item_result.linkable_element_set.derived_from_semantic_models
),
queried_semantic_models=sorted_semantic_model_references(queried_semantic_models),
)

def _get_models_for_measures(self, resolution_dag: GroupByItemResolutionDag) -> Set[SemanticModelReference]:
"""Return the semantic model references for the measures used in the query."""
resolution_dag_node_set = resolution_dag.sink_node.inclusive_ancestors()

measure_references: Set[MeasureReference] = set()

# Collect measures for metrics through the associated measure nodes.
for measure_node in resolution_dag_node_set.measure_nodes:
measure_references.add(measure_node.measure_reference)

# For conversion metrics, get the measures through the metric since those measures aren't in the DAG.
for metric_node in resolution_dag_node_set.metric_nodes:
metric = self._manifest_lookup.metric_lookup.get_metric(metric_node.metric_reference)
conversion_type_params = metric.type_params.conversion_type_params
if conversion_type_params is None:
continue

measure_references.add(conversion_type_params.base_measure.measure_reference)
measure_references.add(conversion_type_params.conversion_measure.measure_reference)

model_references: Set[SemanticModelReference] = set()
for measure_reference in measure_references:
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure(
measure_reference
)
model_references.add(measure_semantic_model.reference)

return model_references

0 comments on commit 7a8c634

Please sign in to comment.