Skip to content

Commit

Permalink
Resolve candidate filters for suggestions dynamically
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jul 11, 2024
1 parent 80880f2 commit 3c16e16
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def _resolve_specs_for_where_filters(
input_str=group_by_item_in_where_filter.object_builder_str,
spec_pattern=group_by_item_in_where_filter.spec_pattern,
resolution_node=current_node,
filter_location=filter_location,
)
# The paths in the issue set are generated relative to the current node. For error messaging, it seems more
# helpful for those paths to be relative to the query. To do, we have to add nodes from the resolution path.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
PushDownResult,
_PushDownGroupByItemCandidatesVisitor,
)
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow_semantics.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag, ResolutionDagSinkNode
from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.query.issues.group_by_item_resolver.ambiguous_group_by_item import AmbiguousGroupByItemIssue
from metricflow_semantics.query.issues.issues_base import (
MetricFlowQueryResolutionIssueSet,
)
from metricflow_semantics.query.suggestion_generator import QueryItemSuggestionGenerator
from metricflow_semantics.query.suggestion_generator import QueryItemSuggestionGenerator, QueryPartForSuggestions
from metricflow_semantics.specs.patterns.minimum_time_grain import MinimumTimeGrainPattern
from metricflow_semantics.specs.patterns.no_group_by_metric import NoGroupByMetricPattern
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern
Expand Down Expand Up @@ -135,6 +136,7 @@ def resolve_matching_item_for_filters(
input_str: str,
spec_pattern: SpecPattern,
resolution_node: ResolutionDagSinkNode,
filter_location: WhereFilterLocation,
) -> GroupByItemResolution:
"""Returns the spec that matches the spec_pattern associated with the filter in the given node.
Expand All @@ -147,7 +149,9 @@ def resolve_matching_item_for_filters(
suggestion_generator = QueryItemSuggestionGenerator(
input_naming_scheme=ObjectBuilderNamingScheme(),
input_str=input_str,
candidate_filters=QueryItemSuggestionGenerator.FILTER_ITEM_CANDIDATE_FILTERS,
query_part=QueryPartForSuggestions.WHERE_FILTER,
metric_lookup=self._manifest_lookup.metric_lookup,
queried_metrics=filter_location.metric_references,
)

push_down_visitor = _PushDownGroupByItemCandidatesVisitor(
Expand Down
23 changes: 12 additions & 11 deletions metricflow-semantics/metricflow_semantics/query/query_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@
ResolverInputForQueryLevelWhereFilterIntersection,
ResolverInputForWhereFilterIntersection,
)
from metricflow_semantics.query.suggestion_generator import QueryItemSuggestionGenerator
from metricflow_semantics.query.suggestion_generator import QueryItemSuggestionGenerator, QueryPartForSuggestions
from metricflow_semantics.query.validation_rules.query_validator import PostResolutionQueryValidator
from metricflow_semantics.specs.patterns.match_list_pattern import MatchListSpecPattern
from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec
from metricflow_semantics.specs.spec_classes import (
InstanceSpec,
Expand Down Expand Up @@ -149,21 +148,20 @@ def _resolve_has_metric_or_group_by_inputs(
)
return ResolveMetricOrGroupByItemsResult(input_to_issue_set_mapping=InputToIssueSetMapping.empty_instance())

@staticmethod
def _resolve_group_by_item_input(
self,
group_by_item_input: ResolverInputForGroupByItem,
group_by_item_resolver: GroupByItemResolver,
valid_group_by_item_specs_for_querying: Sequence[LinkableInstanceSpec],
queried_metrics: Sequence[MetricReference],
) -> GroupByItemResolution:
suggestion_generator = QueryItemSuggestionGenerator(
input_naming_scheme=group_by_item_input.input_obj_naming_scheme,
input_str=str(group_by_item_input.input_obj),
candidate_filters=QueryItemSuggestionGenerator.GROUP_BY_ITEM_CANDIDATE_FILTERS
+ (
MatchListSpecPattern(
listed_specs=valid_group_by_item_specs_for_querying,
),
),
query_part=QueryPartForSuggestions.GROUP_BY,
metric_lookup=self._manifest_lookup.metric_lookup,
queried_metrics=queried_metrics,
valid_group_by_item_specs_for_querying=valid_group_by_item_specs_for_querying,
)
return group_by_item_resolver.resolve_matching_item_for_querying(
spec_pattern=group_by_item_input.spec_pattern,
Expand All @@ -190,7 +188,9 @@ def _resolve_metric_inputs(
suggestion_generator = QueryItemSuggestionGenerator(
input_naming_scheme=MetricNamingScheme(),
input_str=str(metric_input.input_obj),
candidate_filters=(),
query_part=QueryPartForSuggestions.METRIC,
metric_lookup=self._manifest_lookup.metric_lookup,
queried_metrics=tuple(metric_input.spec_pattern.metric_reference for metric_input in metric_inputs),
)
metric_suggestions = suggestion_generator.input_suggestions(candidate_specs=available_metric_specs)
input_to_issue_set_mapping_items.append(
Expand Down Expand Up @@ -238,10 +238,11 @@ def _resolve_group_by_items_result(
group_by_item_specs: List[LinkableInstanceSpec] = []
linkable_element_sets: List[LinkableElementSet] = []
for group_by_item_input in group_by_item_inputs:
resolution = MetricFlowQueryResolver._resolve_group_by_item_input(
resolution = self._resolve_group_by_item_input(
group_by_item_resolver=group_by_item_resolver,
group_by_item_input=group_by_item_input,
valid_group_by_item_specs_for_querying=valid_group_by_item_specs_for_querying,
queried_metrics=metric_references,
)
if resolution.issue_set.has_issues:
input_to_issue_set_mapping_items.append(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,34 @@
from __future__ import annotations

import logging
from typing import Sequence, Tuple
from enum import Enum
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.references import MetricReference

from metricflow_semantics.model.semantics.metric_lookup import MetricLookup
from metricflow_semantics.naming.naming_scheme import QueryItemNamingScheme
from metricflow_semantics.query.similarity import top_fuzzy_matches
from metricflow_semantics.specs.patterns.minimum_time_grain import MinimumTimeGrainPattern
from metricflow_semantics.specs.patterns.match_list_pattern import MatchListSpecPattern
from metricflow_semantics.specs.patterns.metric_time_default_granularity import MetricTimeDefaultGranularityPattern
from metricflow_semantics.specs.patterns.min_time_grain import MinimumTimeGrainPattern
from metricflow_semantics.specs.patterns.no_group_by_metric import NoGroupByMetricPattern
from metricflow_semantics.specs.patterns.none_date_part import NoneDatePartPattern
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern
from metricflow_semantics.specs.spec_classes import InstanceSpec
from metricflow_semantics.specs.spec_classes import InstanceSpec, LinkableInstanceSpec

logger = logging.getLogger(__name__)


class QueryPartForSuggestions(Enum):
"""Indicates which type of query parameter is being suggested."""

WHERE_FILTER = "where_filter"
GROUP_BY = "group_by"
METRIC = "metric"


class QueryItemSuggestionGenerator:
"""Returns specs that partially match a spec pattern created from user input. Used for suggestions in errors.
Expand All @@ -22,29 +37,67 @@ class QueryItemSuggestionGenerator:
a candidate filter is not needed as any available spec at a resolution node can be used.
"""

# Adding these filters so that we don't get multiple suggestions that are similar, but with different
# grains. Some additional thought is needed to tweak this as the base grain may not be the best suggestion.
FILTER_ITEM_CANDIDATE_FILTERS: Tuple[SpecPattern, ...] = (MinimumTimeGrainPattern(), NoneDatePartPattern())
GROUP_BY_ITEM_CANDIDATE_FILTERS: Tuple[SpecPattern, ...] = (
MinimumTimeGrainPattern(),
NoneDatePartPattern(),
NoGroupByMetricPattern(),
)

def __init__( # noqa: D107
self, input_naming_scheme: QueryItemNamingScheme, input_str: str, candidate_filters: Sequence[SpecPattern]
self,
input_naming_scheme: QueryItemNamingScheme,
input_str: str,
query_part: QueryPartForSuggestions,
metric_lookup: MetricLookup,
queried_metrics: Sequence[MetricReference],
valid_group_by_item_specs_for_querying: Optional[Sequence[LinkableInstanceSpec]] = None,
) -> None:
self._input_naming_scheme = input_naming_scheme
self._input_str = input_str
self._candidate_filters = candidate_filters
self._query_part = query_part
self._metric_lookup = metric_lookup
self._queried_metrics = queried_metrics
self._valid_group_by_item_specs_for_querying = valid_group_by_item_specs_for_querying

if self._query_part is QueryPartForSuggestions.GROUP_BY and valid_group_by_item_specs_for_querying is None:
raise ValueError(
"QueryItemSuggestionGenerator requires valid_group_by_item_specs_for_querying param when used on group by items."
)

@property
def candidate_filters(self) -> Tuple[SpecPattern, ...]:
"""Filters to apply before determining suggestions.
These ensure we don't get multiple suggestions that are similar, but with different grains or date_parts.
"""
default_filters = (
NoneDatePartPattern(),
# MetricTimeDefaultGranularityPattern must come before MinimumTimeGrainPattern to ensure we don't remove the
# default grain from candiate set prematurely.
MetricTimeDefaultGranularityPattern(
metric_lookup=self._metric_lookup, queried_metrics=self._queried_metrics
),
MinimumTimeGrainPattern(),
)
if self._query_part is QueryPartForSuggestions.WHERE_FILTER:
return default_filters
elif self._query_part is QueryPartForSuggestions.GROUP_BY:
assert self._valid_group_by_item_specs_for_querying, (
"Group by suggestions require valid_group_by_item_specs_for_querying param."
"This should have been validated on init."
)
return default_filters + (
NoGroupByMetricPattern(),
MatchListSpecPattern(
listed_specs=self._valid_group_by_item_specs_for_querying,
),
)
elif self._query_part is QueryPartForSuggestions.METRIC:
return ()
else:
assert_values_exhausted(self._query_part)

def input_suggestions(
self,
candidate_specs: Sequence[InstanceSpec],
max_suggestions: int = 6,
) -> Sequence[str]:
"""Return the best specs that match the given pattern from candidate_specs and match the candidate_filer."""
for candidate_filter in self._candidate_filters:
"""Return the best specs that match the given pattern from candidate_specs and match the candidate_filter."""
for candidate_filter in self.candidate_filters:
candidate_specs = candidate_filter.match(candidate_specs)

# Use edit distance to figure out the closest matches, so convert the specs to strings.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.naming.naming_scheme import QueryItemNamingScheme
from metricflow_semantics.naming.object_builder_scheme import ObjectBuilderNamingScheme
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow_semantics.query.group_by_item.group_by_item_resolver import GroupByItemResolver
from metricflow_semantics.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
Expand Down Expand Up @@ -42,6 +43,7 @@ def test_ambiguous_metric_time_in_query_filter( # noqa: D103
input_str=input_str,
spec_pattern=spec_pattern,
resolution_node=resolution_dag.sink_node,
filter_location=WhereFilterLocation(metric_references=()),
)

assert_object_snapshot_equal(
Expand Down

0 comments on commit 3c16e16

Please sign in to comment.