Skip to content

Commit

Permalink
Make pattern classes dataclasses (#1346)
Browse files Browse the repository at this point in the history
This PR makes pattern classes dataclass for easier comparison for
equality. The comparison is helpful when memoizing / caching function
calls.
  • Loading branch information
plypaul authored Sep 6, 2024
1 parent aa0fb35 commit 27218b2
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _resolve_group_by_item_input(
input_str=str(group_by_item_input.input_obj),
candidate_filters=QueryItemSuggestionGenerator.GROUP_BY_ITEM_CANDIDATE_FILTERS
+ (
MatchListSpecPattern(
MatchListSpecPattern.create(
listed_specs=valid_group_by_item_specs_for_querying,
),
),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from __future__ import annotations

from typing import Sequence
from dataclasses import dataclass
from typing import Sequence, Tuple

from typing_extensions import override

from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern


@dataclass(frozen=True)
class MatchListSpecPattern(SpecPattern):
"""A spec pattern that matches based on a configured list of specs.
This is useful for filtering possible group-by-items to ones valid for a query.
"""

def __init__(self, listed_specs: Sequence[InstanceSpec]) -> None: # noqa: D107
self._listed_specs = set(listed_specs)
listed_specs: Tuple[InstanceSpec, ...]

@staticmethod
def create(listed_specs: Sequence[InstanceSpec]) -> MatchListSpecPattern: # noqa: D102
return MatchListSpecPattern(tuple(listed_specs))

@override
def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[InstanceSpec]:
return tuple(spec for spec in candidate_specs if spec in self._listed_specs)
return tuple(spec for spec in candidate_specs if spec in self.listed_specs)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Set, Tuple

from dbt_semantic_interfaces.type_enums import TimeGranularity
Expand All @@ -18,6 +19,7 @@
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


@dataclass(frozen=True)
class MetricTimeDefaultGranularityPattern(SpecPattern):
"""A pattern that matches metric_time specs if they have the default granularity for the requested metrics.
Expand All @@ -42,21 +44,19 @@ class MetricTimeDefaultGranularityPattern(SpecPattern):
]
"""

def __init__(self, max_metric_default_time_granularity: Optional[TimeGranularity]) -> None: # noqa: D107
self._max_metric_default_time_granularity = max_metric_default_time_granularity
max_metric_default_time_granularity: Optional[TimeGranularity]

@override
def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[InstanceSpec]:
spec_set = group_specs_by_type(candidate_specs)

# If there are no metric_time specs in the query, skip this filter.
if not spec_set.metric_time_specs:
return candidate_specs

# If there are metrics in the query, use max metric default. For no-metric queries, use standard default.
# TODO: [custom granularity] allow custom granularities to be used as defaults if appropriate
default_granularity = ExpandedTimeGranularity.from_time_granularity(
self._max_metric_default_time_granularity or DEFAULT_TIME_GRANULARITY
self.max_metric_default_time_granularity or DEFAULT_TIME_GRANULARITY
)

spec_key_to_grains: Dict[TimeDimensionSpecComparisonKey, Set[ExpandedTimeGranularity]] = defaultdict(set)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence

from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
Expand All @@ -11,6 +12,7 @@
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec


@dataclass(frozen=True)
class MetricTimePattern(SpecPattern):
"""Pattern that matches to only metric_time specs.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Sequence, Set

from dbt_semantic_interfaces.type_enums import TimeGranularity
Expand All @@ -17,6 +18,7 @@
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


@dataclass(frozen=True)
class MinimumTimeGrainPattern(SpecPattern):
"""A pattern that matches linkable specs, but for time dimension specs, only the one with the finest base grain.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Sequence

from typing_extensions import override
Expand All @@ -9,6 +10,7 @@
from metricflow_semantics.specs.spec_set import group_specs_by_type


@dataclass(frozen=True)
class NoGroupByMetricPattern(SpecPattern):
"""Matches to linkable specs, but only if they're not group by metrics.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Sequence

from typing_extensions import override
Expand All @@ -9,6 +10,7 @@
from metricflow_semantics.specs.spec_set import group_specs_by_type


@dataclass(frozen=True)
class NoneDatePartPattern(SpecPattern):
"""Matches to linkable specs, but for time dimension specs, only matches to ones without date_part.
Expand Down

0 comments on commit 27218b2

Please sign in to comment.