Skip to content

Commit

Permalink
Use LinkableElementSet instead of specs in group-by-item resolution…
Browse files Browse the repository at this point in the history
… classes.

Using `LinkableElementSet` allows retrieval of the semantic models that
are needed for computation.
  • Loading branch information
plypaul committed Apr 26, 2024
1 parent 0673c22 commit 9cb1650
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 103 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
from __future__ import annotations

import itertools
import logging
from dataclasses import dataclass
from typing import Iterable, Tuple
from typing import Sequence, Tuple

from dbt_semantic_interfaces.references import SemanticModelReference
from typing_extensions import override

from metricflow_semantics.model.semantic_model_derivation import SemanticModelDerivation
from metricflow_semantics.model.semantics.linkable_element_set import LinkableElementSet
from metricflow_semantics.query.group_by_item.path_prefixable import PathPrefixable
from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern
from metricflow_semantics.specs.spec_classes import InstanceSpecSet, LinkableInstanceSpec, LinkableSpecSet
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class GroupByItemCandidateSet(PathPrefixable):
class GroupByItemCandidateSet(PathPrefixable, SemanticModelDerivation):
"""The set of candidate specs that could match a given spec pattern.
This candidate set is refined as it is passed from the root node (representing a measure) to the leaf node
Expand All @@ -29,7 +35,7 @@ class GroupByItemCandidateSet(PathPrefixable):
error messages, you start analyzing from the leaf node.
"""

specs: Tuple[LinkableInstanceSpec, ...]
linkable_element_set: LinkableElementSet
measure_paths: Tuple[MetricFlowQueryResolutionPath, ...]
path_from_leaf_node: MetricFlowQueryResolutionPath

Expand All @@ -39,26 +45,40 @@ def __post_init__(self) -> None: # noqa: D105
len(self.specs) == 0 and len(self.measure_paths) == 0
)

@property
def specs(self) -> Sequence[LinkableInstanceSpec]: # noqa: D102
return self.linkable_element_set.specs

@staticmethod
def intersection(
path_from_leaf_node: MetricFlowQueryResolutionPath, candidate_sets: Iterable[GroupByItemCandidateSet]
path_from_leaf_node: MetricFlowQueryResolutionPath, candidate_sets: Sequence[GroupByItemCandidateSet]
) -> GroupByItemCandidateSet:
"""Create a new candidate set that is the intersection of the given candidate sets.
The intersection is defined as the specs common to all candidate sets. path_from_leaf_node is used to indicate
where the new candidate set was created.
"""
specs_as_sets = tuple(set(candidate_set.specs) for candidate_set in candidate_sets)
common_specs = set.intersection(*specs_as_sets) if specs_as_sets else set()
if len(common_specs) == 0:
if len(candidate_sets) == 0:
return GroupByItemCandidateSet.empty_instance()
elif len(candidate_sets) == 1:
return GroupByItemCandidateSet(
linkable_element_set=candidate_sets[0].linkable_element_set,
measure_paths=candidate_sets[0].measure_paths,
path_from_leaf_node=path_from_leaf_node,
)
linkable_element_set_candidates = tuple(candidate_set.linkable_element_set for candidate_set in candidate_sets)
intersection_result = LinkableElementSet.intersection_by_path_key(linkable_element_set_candidates)
if intersection_result.spec_count == 0:
return GroupByItemCandidateSet.empty_instance()

measure_paths = tuple(
itertools.chain.from_iterable(candidate_set.measure_paths for candidate_set in candidate_sets)
)

return GroupByItemCandidateSet(
specs=tuple(common_specs), measure_paths=measure_paths, path_from_leaf_node=path_from_leaf_node
linkable_element_set=intersection_result,
measure_paths=measure_paths,
path_from_leaf_node=path_from_leaf_node,
)

@property
Expand All @@ -72,31 +92,34 @@ def num_candidates(self) -> int: # noqa: D102
@staticmethod
def empty_instance() -> GroupByItemCandidateSet: # noqa: D102
return GroupByItemCandidateSet(
specs=(), measure_paths=(), path_from_leaf_node=MetricFlowQueryResolutionPath.empty_instance()
linkable_element_set=LinkableElementSet(),
measure_paths=(),
path_from_leaf_node=MetricFlowQueryResolutionPath.empty_instance(),
)

@property
def spec_set(self) -> LinkableSpecSet:
"""Return the candidates as a spec set."""
return LinkableSpecSet.from_specs(self.specs)

def filter_candidates_by_pattern(
self,
spec_pattern: SpecPattern,
) -> GroupByItemCandidateSet:
"""Return a new candidate set that only contains specs that match the given pattern."""
matching_specs = tuple(InstanceSpecSet.from_specs(spec_pattern.match(self.specs)).linkable_specs)
if len(matching_specs) == 0:
filtered_element_set = self.linkable_element_set.filter_by_spec_patterns((spec_pattern,))
if filtered_element_set.spec_count == 0:
return GroupByItemCandidateSet.empty_instance()

return GroupByItemCandidateSet(
specs=matching_specs, measure_paths=self.measure_paths, path_from_leaf_node=self.path_from_leaf_node
linkable_element_set=filtered_element_set,
measure_paths=self.measure_paths,
path_from_leaf_node=self.path_from_leaf_node,
)

@override
def with_path_prefix(self, path_prefix: MetricFlowQueryResolutionPath) -> GroupByItemCandidateSet:
return GroupByItemCandidateSet(
specs=self.specs,
linkable_element_set=self.linkable_element_set,
measure_paths=tuple(path.with_path_prefix(path_prefix) for path in self.measure_paths),
path_from_leaf_node=self.path_from_leaf_node.with_path_prefix(path_prefix),
)

@property
@override
def derived_from_semantic_models(self) -> Sequence[SemanticModelReference]:
return self.linkable_element_set.derived_from_semantic_models
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from metricflow_semantics.specs.patterns.base_time_grain import BaseTimeGrainPattern
from metricflow_semantics.specs.patterns.none_date_part import NoneDatePartPattern
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern
from metricflow_semantics.specs.spec_classes import InstanceSpecSet, LinkableInstanceSpec

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -156,13 +155,11 @@ def visit_measure_node(self, node: MeasureGroupByItemSourceNode) -> PushDownResu
"""Push the group-by-item specs that are available to the measure and match the source patterns to the child."""
with self._path_from_start_node_tracker.track_node_visit(node) as current_traversal_path:
logger.info(f"Handling {node.ui_description}")
specs_available_for_measure: Sequence[
LinkableInstanceSpec
] = self._semantic_manifest_lookup.metric_lookup.linkable_elements_for_measure(
items_available_for_measure = self._semantic_manifest_lookup.metric_lookup.linkable_elements_for_measure(
measure_reference=node.measure_reference,
with_any_of=self._with_any_property,
without_any_of=self._without_any_property,
).as_spec_set.as_tuple
)

# The following is needed to handle limitation of cumulative metrics. Filtering could be done at the measure
# node, but doing it here makes it a little easier to generate the error message.
Expand Down Expand Up @@ -192,32 +189,28 @@ def visit_measure_node(self, node: MeasureGroupByItemSourceNode) -> PushDownResu
else:
assert_values_exhausted(metric.type)

specs_available_for_measure_given_child_metric = specs_available_for_measure

for pattern_to_apply in patterns_to_apply:
specs_available_for_measure_given_child_metric = InstanceSpecSet.from_specs(
pattern_to_apply.match(specs_available_for_measure_given_child_metric)
).linkable_specs

matching_specs = specs_available_for_measure_given_child_metric

for source_spec_pattern in self._source_spec_patterns:
matching_specs = InstanceSpecSet.from_specs(source_spec_pattern.match(matching_specs)).linkable_specs
matching_items = items_available_for_measure.filter_by_spec_patterns(
patterns_to_apply + self._source_spec_patterns
)

logger.debug(
f"For {node.ui_description}:\n"
+ indent(
"After applying patterns:\n"
+ indent(mf_pformat(patterns_to_apply))
+ "\n"
+ "to inputs, matches are:\n"
+ indent(mf_pformat(matching_specs))
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
f"For {node.ui_description}:\n"
+ indent(
"After applying patterns:\n"
+ indent(mf_pformat(patterns_to_apply))
+ "\n"
+ "to inputs, matches are:\n"
+ indent(mf_pformat(matching_items.specs))
)
)
)

# The specified patterns don't match to any of the available group-by-items that can be queried for the
# measure.
if len(matching_specs) == 0:
if matching_items.spec_count == 0:
items_available_for_measure_given_child_metric = items_available_for_measure.filter_by_spec_patterns(
patterns_to_apply
)
return PushDownResult(
candidate_set=GroupByItemCandidateSet.empty_instance(),
issue_set=MetricFlowQueryResolutionIssueSet.from_issue(
Expand All @@ -227,7 +220,7 @@ def visit_measure_node(self, node: MeasureGroupByItemSourceNode) -> PushDownResu
input_suggestions=(
tuple(
self._suggestion_generator.input_suggestions(
specs_available_for_measure_given_child_metric
items_available_for_measure_given_child_metric.specs
)
)
if self._suggestion_generator is not None
Expand All @@ -240,7 +233,7 @@ def visit_measure_node(self, node: MeasureGroupByItemSourceNode) -> PushDownResu
return PushDownResult(
candidate_set=GroupByItemCandidateSet(
measure_paths=(current_traversal_path,),
specs=tuple(matching_specs),
linkable_element_set=matching_items,
path_from_leaf_node=current_traversal_path,
),
issue_set=MetricFlowQueryResolutionIssueSet(),
Expand Down Expand Up @@ -314,9 +307,10 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> PushDownRe
current_traversal_path=current_traversal_path,
)
logger.info(f"Handling {node.ui_description}")
logger.debug(
"candidates from parents:\n" + indent(mf_pformat(merged_result_from_parents.candidate_set.specs))
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Candidates from parents:\n" + indent(mf_pformat(merged_result_from_parents.candidate_set.specs))
)
if merged_result_from_parents.candidate_set.is_empty:
return merged_result_from_parents

Expand All @@ -340,43 +334,47 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> PushDownRe
else:
assert_values_exhausted(metric.type)

candidate_specs: Sequence[LinkableInstanceSpec] = merged_result_from_parents.candidate_set.specs
candidate_items = merged_result_from_parents.candidate_set.linkable_element_set
issue_sets_to_merge = [merged_result_from_parents.issue_set]

matched_specs = candidate_specs
for pattern_to_apply in patterns_to_apply:
matched_specs = InstanceSpecSet.from_specs(pattern_to_apply.match(matched_specs)).linkable_specs

logger.debug(
f"For {node.ui_description}:\n"
+ indent(
"After applying patterns:\n"
+ indent(mf_pformat(patterns_to_apply))
+ "\n"
+ "to inputs, outputs are:\n"
+ indent(mf_pformat(matched_specs))
matched_items = candidate_items.filter_by_spec_patterns(patterns_to_apply)

if logger.isEnabledFor(logging.DEBUG):
matched_specs = matched_items.specs
logger.debug(
f"For {node.ui_description}:\n"
+ indent(
"After applying patterns:\n"
+ indent(mf_pformat(patterns_to_apply))
+ "\n"
+ "to inputs, outputs are:\n"
+ indent(mf_pformat(matched_specs))
)
)
)

# There were candidates that were common from the ones passed from parents, but after applying the filters,
# none of the candidates were valid.
if len(matched_specs) == 0:
if matched_items.spec_count == 0:
issue_sets_to_merge.append(
MetricFlowQueryResolutionIssueSet.from_issue(
MetricExcludesDatePartIssue.from_parameters(
query_resolution_path=current_traversal_path,
candidate_specs=candidate_specs,
candidate_specs=candidate_items.specs,
parent_issues=(),
)
)
)

if matched_items.spec_count == 0:
return PushDownResult(
candidate_set=GroupByItemCandidateSet.empty_instance(),
issue_set=MetricFlowQueryResolutionIssueSet.merge_iterable(issue_sets_to_merge),
)

return PushDownResult(
candidate_set=GroupByItemCandidateSet(
specs=tuple(matched_specs),
measure_paths=(
merged_result_from_parents.candidate_set.measure_paths if len(matched_specs) > 0 else ()
),
linkable_element_set=matched_items,
measure_paths=merged_result_from_parents.candidate_set.measure_paths,
path_from_leaf_node=current_traversal_path,
),
issue_set=MetricFlowQueryResolutionIssueSet.merge_iterable(issue_sets_to_merge),
Expand All @@ -395,11 +393,10 @@ def visit_query_node(self, node: QueryGroupByItemResolutionNode) -> PushDownResu
},
current_traversal_path=current_traversal_path,
)

logger.info(f"Handling {node.ui_description}")
logger.debug(
"candidates from parents:\n" + indent(mf_pformat(merged_result_from_parents.candidate_set.specs))
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Candidates from parents:\n" + indent(mf_pformat(merged_result_from_parents.candidate_set.specs))
)

return merged_result_from_parents

Expand All @@ -410,14 +407,11 @@ def visit_no_metrics_query_node(self, node: NoMetricsGroupByItemSourceNode) -> P
logger.info(f"Handling {node.ui_description}")
# This is a case for distinct dimension values from semantic models.
candidate_elements = self._semantic_manifest_lookup.metric_lookup.linkable_elements_for_no_metrics_query()
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Candidate elements are:\n{mf_pformat(candidate_elements)}")
candidates_after_filtering = candidate_elements.filter_by_spec_patterns(self._source_spec_patterns)

matching_specs: Sequence[LinkableInstanceSpec] = tuple(
sorted(candidate_elements.as_spec_set.as_tuple, key=lambda x: x.qualified_name)
)
for pattern_to_apply in self._source_spec_patterns:
matching_specs = InstanceSpecSet.from_specs(pattern_to_apply.match(matching_specs)).linkable_specs

if len(matching_specs) == 0:
if candidates_after_filtering.spec_count == 0:
return PushDownResult(
candidate_set=GroupByItemCandidateSet.empty_instance(),
issue_set=MetricFlowQueryResolutionIssueSet.from_issue(
Expand All @@ -430,7 +424,7 @@ def visit_no_metrics_query_node(self, node: NoMetricsGroupByItemSourceNode) -> P

return PushDownResult(
candidate_set=GroupByItemCandidateSet(
specs=tuple(matching_specs),
linkable_element_set=candidates_after_filtering,
measure_paths=(current_traversal_path,),
path_from_leaf_node=current_traversal_path,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern

if TYPE_CHECKING:
from metricflow_semantics.model.semantics.linkable_element_set import LinkableElementSet
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -170,13 +171,30 @@ class FilterSpecResolution:

lookup_key: ResolvedSpecLookUpKey
where_filter_intersection: WhereFilterIntersection
resolved_spec: Optional[LinkableInstanceSpec]
resolved_linkable_element_set: Optional[LinkableElementSet]
spec_pattern: SpecPattern
issue_set: MetricFlowQueryResolutionIssueSet
# Used for error messages.
filter_location_path: MetricFlowQueryResolutionPath
object_builder_str: str

def __post_init__(self) -> None: # noqa: D105
if self.resolved_linkable_element_set is not None:
assert len(self.resolved_linkable_element_set.specs) <= 1

@property
def resolved_spec(self) -> Optional[LinkableInstanceSpec]: # noqa: D102
if self.resolved_linkable_element_set is None:
return None

specs = self.resolved_linkable_element_set.specs
if len(specs) == 0:
return None
elif len(specs) == 1:
return specs[0]
else:
raise RuntimeError(f"Found {len(specs)} in {self.resolved_linkable_element_set}")


CallParameterSet = Union[
DimensionCallParameterSet, TimeDimensionCallParameterSet, EntityCallParameterSet, MetricCallParameterSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _resolve_specs_for_where_filters(
call_parameter_set=group_by_item_in_where_filter.call_parameter_set,
),
filter_location_path=resolution_path,
resolved_spec=group_by_item_resolution.spec,
resolved_linkable_element_set=group_by_item_resolution.linkable_element_set,
where_filter_intersection=where_filter_intersection,
spec_pattern=group_by_item_in_where_filter.spec_pattern,
issue_set=group_by_item_resolution.issue_set.with_path_prefix(path_prefix),
Expand Down
Loading

0 comments on commit 9cb1650

Please sign in to comment.