From 42c064e6265a52d69a48bdc815901e4912e87205 Mon Sep 17 00:00:00 2001 From: tlento Date: Sun, 8 Sep 2024 20:56:20 -0700 Subject: [PATCH] Switch query parameter classes onto ExpandedTimeGranularity In preparation for supporting user-provided custom granularity names at query time we need to update our query parameter processing layer to use ExpandedTimeGranularities internally while accepting string names from the original user input. This change makes that update, although it is necessarily incomplete. The TimeDimensionCallParameterSet, in particular, still uses an enumeration-typed property for representing the user-requested grain, but this interface is imported from dbt-semantic-interfaces and will be updated in a later series of changes. Follow-ups inside MetricFlow will enable lookups against custom granularity values as we solidify our test layouts. --- .../naming/dunder_scheme.py | 4 ++- .../naming/object_builder_scheme.py | 7 ++++- .../protocols/query_parameter.py | 9 ++++--- .../specs/patterns/entity_link_pattern.py | 27 +++---------------- .../specs/patterns/typed_patterns.py | 7 ++++- .../specs/query_param_implementations.py | 15 ++++++++--- .../query/test_query_parser.py | 2 +- .../patterns/test_entity_link_pattern.py | 2 +- .../integration/test_configured_cases.py | 2 -- 9 files changed, 39 insertions(+), 36 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/naming/dunder_scheme.py b/metricflow-semantics/metricflow_semantics/naming/dunder_scheme.py index 0f0f239050..1c9e961a81 100644 --- a/metricflow-semantics/metricflow_semantics/naming/dunder_scheme.py +++ b/metricflow-semantics/metricflow_semantics/naming/dunder_scheme.py @@ -18,6 +18,7 @@ ParameterSetField, ) from metricflow_semantics.specs.spec_set import InstanceSpecSet, InstanceSpecSetTransform, group_spec_by_type +from metricflow_semantics.time.granularity import ExpandedTimeGranularity class DunderNamingScheme(QueryItemNamingScheme): @@ -81,7 +82,8 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes # At this point, len(input_str_parts) >= 2 for granularity in TimeGranularity: if input_str_parts[-1] == granularity.value: - time_grain = granularity + # TODO: [custom granularity] add support for custom granularity names here + time_grain = ExpandedTimeGranularity.from_time_granularity(granularity) # Has a time grain specified. if time_grain is not None: diff --git a/metricflow-semantics/metricflow_semantics/naming/object_builder_scheme.py b/metricflow-semantics/metricflow_semantics/naming/object_builder_scheme.py index c425332afc..83e49e45b4 100644 --- a/metricflow-semantics/metricflow_semantics/naming/object_builder_scheme.py +++ b/metricflow-semantics/metricflow_semantics/naming/object_builder_scheme.py @@ -23,6 +23,7 @@ ) from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern from metricflow_semantics.specs.patterns.typed_patterns import DimensionPattern, TimeDimensionPattern +from metricflow_semantics.time.granularity import ExpandedTimeGranularity logger = logging.getLogger(__name__) @@ -80,14 +81,18 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes ParameterSetField.DATE_PART, ] + time_granularity = None if time_dimension_call_parameter_set.time_granularity is not None: fields_to_compare.append(ParameterSetField.TIME_GRANULARITY) + time_granularity = ExpandedTimeGranularity.from_time_granularity( + time_dimension_call_parameter_set.time_granularity + ) return TimeDimensionPattern( EntityLinkPatternParameterSet.from_parameters( element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name, entity_links=time_dimension_call_parameter_set.entity_path, - time_granularity=time_dimension_call_parameter_set.time_granularity, + time_granularity=time_granularity, date_part=time_dimension_call_parameter_set.date_part, fields_to_compare=tuple(fields_to_compare), ) diff --git a/metricflow-semantics/metricflow_semantics/protocols/query_parameter.py b/metricflow-semantics/metricflow_semantics/protocols/query_parameter.py index 2b6969a252..e223966618 100644 --- a/metricflow-semantics/metricflow_semantics/protocols/query_parameter.py +++ b/metricflow-semantics/metricflow_semantics/protocols/query_parameter.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Optional, Protocol, Union, runtime_checkable -from dbt_semantic_interfaces.type_enums import TimeGranularity from dbt_semantic_interfaces.type_enums.date_part import DatePart if TYPE_CHECKING: @@ -52,8 +51,12 @@ def name(self) -> str: raise NotImplementedError @property - def grain(self) -> Optional[TimeGranularity]: - """The time granularity.""" + def grain(self) -> Optional[str]: + """The name of the time granularity. + + This may be the name of a custom granularity or the string value of an entry in the standard + TimeGranularity enum. + """ raise NotImplementedError @property diff --git a/metricflow-semantics/metricflow_semantics/specs/patterns/entity_link_pattern.py b/metricflow-semantics/metricflow_semantics/specs/patterns/entity_link_pattern.py index ce7f970ad9..1f7b793148 100644 --- a/metricflow-semantics/metricflow_semantics/specs/patterns/entity_link_pattern.py +++ b/metricflow-semantics/metricflow_semantics/specs/patterns/entity_link_pattern.py @@ -6,7 +6,6 @@ from typing import Any, List, Optional, Sequence, Tuple from dbt_semantic_interfaces.references import EntityReference -from dbt_semantic_interfaces.type_enums import TimeGranularity from dbt_semantic_interfaces.type_enums.date_part import DatePart from more_itertools import is_sorted from typing_extensions import override @@ -14,6 +13,7 @@ from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern from metricflow_semantics.specs.spec_set import group_specs_by_type +from metricflow_semantics.time.granularity import ExpandedTimeGranularity logger = logging.getLogger(__name__) @@ -50,7 +50,7 @@ class EntityLinkPatternParameterSet: # The entities used for joining semantic models. entity_links: Optional[Tuple[EntityReference, ...]] = None # Properties of time dimensions to match. - time_granularity: Optional[TimeGranularity] = None + time_granularity: Optional[ExpandedTimeGranularity] = None date_part: Optional[DatePart] = None metric_subquery_entity_links: Optional[Tuple[EntityReference, ...]] = None @@ -59,7 +59,7 @@ def from_parameters( # noqa: D102 fields_to_compare: Sequence[ParameterSetField], element_name: Optional[str] = None, entity_links: Optional[Sequence[EntityReference]] = None, - time_granularity: Optional[TimeGranularity] = None, + time_granularity: Optional[ExpandedTimeGranularity] = None, date_part: Optional[DatePart] = None, metric_subquery_entity_links: Optional[Tuple[EntityReference, ...]] = None, ) -> EntityLinkPatternParameterSet: @@ -112,22 +112,6 @@ def _match_entity_links(self, candidate_specs: Sequence[LinkableInstanceSpec]) - shortest_entity_link_length = min(len(matching_spec.entity_links) for matching_spec in matching_specs) return tuple(spec for spec in matching_specs if len(spec.entity_links) == shortest_entity_link_length) - def _match_time_granularities( - self, candidate_specs: Sequence[LinkableInstanceSpec] - ) -> Sequence[LinkableInstanceSpec]: - """Do a partial match on time granularities. - - TODO: [custom granularity] Support custom granularities properly. This requires us to allow these pattern classes - to take in ExpandedTimeGranularity types, which should be viable. Once that is done, this method can be removed. - """ - matching_specs: Sequence[LinkableInstanceSpec] = tuple( - candidate_spec - for candidate_spec in group_specs_by_type(candidate_specs).time_dimension_specs - if candidate_spec.time_granularity.base_granularity == self.parameter_set.time_granularity - ) - - return matching_specs - @override def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[LinkableInstanceSpec]: filtered_candidate_specs = group_specs_by_type(candidate_specs).linkable_specs @@ -136,13 +120,10 @@ def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[LinkableIns # Entity links could be a partial match, so it's handled separately. if ParameterSetField.ENTITY_LINKS in self.parameter_set.fields_to_compare: filtered_candidate_specs = self._match_entity_links(filtered_candidate_specs) - # Time granularities are special, so they are also handled separately. - if ParameterSetField.TIME_GRANULARITY in self.parameter_set.fields_to_compare: - filtered_candidate_specs = self._match_time_granularities(filtered_candidate_specs) other_keys_to_check = set( field_to_compare.value for field_to_compare in self.parameter_set.fields_to_compare - ).difference({ParameterSetField.ENTITY_LINKS.value, ParameterSetField.TIME_GRANULARITY.value}) + ).difference({ParameterSetField.ENTITY_LINKS.value}) matching_specs: List[LinkableInstanceSpec] = [] parameter_set_values = tuple(getattr(self.parameter_set, key_to_check) for key_to_check in other_keys_to_check) diff --git a/metricflow-semantics/metricflow_semantics/specs/patterns/typed_patterns.py b/metricflow-semantics/metricflow_semantics/specs/patterns/typed_patterns.py index 22d46f3170..8ddce95d32 100644 --- a/metricflow-semantics/metricflow_semantics/specs/patterns/typed_patterns.py +++ b/metricflow-semantics/metricflow_semantics/specs/patterns/typed_patterns.py @@ -20,6 +20,7 @@ ParameterSetField, ) from metricflow_semantics.specs.spec_set import group_specs_by_type +from metricflow_semantics.time.granularity import ExpandedTimeGranularity @dataclass(frozen=True) @@ -78,15 +79,19 @@ def from_call_parameter_set( ParameterSetField.DATE_PART, ] + time_granularity = None if time_dimension_call_parameter_set.time_granularity is not None: fields_to_compare.append(ParameterSetField.TIME_GRANULARITY) + time_granularity = ExpandedTimeGranularity.from_time_granularity( + time_dimension_call_parameter_set.time_granularity + ) return TimeDimensionPattern( parameter_set=EntityLinkPatternParameterSet.from_parameters( fields_to_compare=tuple(fields_to_compare), element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name, entity_links=time_dimension_call_parameter_set.entity_path, - time_granularity=time_dimension_call_parameter_set.time_granularity, + time_granularity=time_granularity, date_part=time_dimension_call_parameter_set.date_part, ) ) diff --git a/metricflow-semantics/metricflow_semantics/specs/query_param_implementations.py b/metricflow-semantics/metricflow_semantics/specs/query_param_implementations.py index 791b9cea34..8bab25e0a2 100644 --- a/metricflow-semantics/metricflow_semantics/specs/query_param_implementations.py +++ b/metricflow-semantics/metricflow_semantics/specs/query_param_implementations.py @@ -31,6 +31,7 @@ EntityLinkPatternParameterSet, ParameterSetField, ) +from metricflow_semantics.time.granularity import ExpandedTimeGranularity @dataclass(frozen=True) @@ -41,7 +42,7 @@ def _implements_protocol(self) -> TimeDimensionQueryParameter: return self name: str - grain: Optional[TimeGranularity] = None + grain: Optional[str] = None date_part: Optional[DatePart] = None def query_resolver_input( # noqa: D102 @@ -54,9 +55,17 @@ def query_resolver_input( # noqa: D102 ParameterSetField.ENTITY_LINKS, ParameterSetField.DATE_PART, ] + time_granularity = None if self.grain is not None: fields_to_compare.append(ParameterSetField.TIME_GRANULARITY) - + # TODO: [custom granularity] support custom granularity lookups + assert ExpandedTimeGranularity.is_standard_granularity_name(self.grain), ( + f"We got a non-standard granularity name, `{self.grain}`, but we have not yet " + "implemented support for custom granularities!" + ) + time_granularity = ExpandedTimeGranularity.from_time_granularity(TimeGranularity(self.grain)) + + # TODO: assert that the name does not include a time granularity marker name_structure = StructuredLinkableSpecName.from_name(self.name.lower()) return ResolverInputForGroupByItem( @@ -67,7 +76,7 @@ def query_resolver_input( # noqa: D102 fields_to_compare=tuple(fields_to_compare), element_name=name_structure.element_name, entity_links=tuple(EntityReference(link_name) for link_name in name_structure.entity_link_names), - time_granularity=self.grain, + time_granularity=time_granularity, date_part=self.date_part, ) ), diff --git a/metricflow-semantics/tests_metricflow_semantics/query/test_query_parser.py b/metricflow-semantics/tests_metricflow_semantics/query/test_query_parser.py index 0f30cb2537..50c9cc6853 100644 --- a/metricflow-semantics/tests_metricflow_semantics/query/test_query_parser.py +++ b/metricflow-semantics/tests_metricflow_semantics/query/test_query_parser.py @@ -560,7 +560,7 @@ def test_date_part_parsing( query_parser.parse_and_validate_query( metric_names=["revenue"], group_by=( - TimeDimensionParameter(name="metric_time", grain=TimeGranularity.YEAR, date_part=DatePart.MONTH), + TimeDimensionParameter(name="metric_time", grain=TimeGranularity.YEAR.value, date_part=DatePart.MONTH), ), ) diff --git a/metricflow-semantics/tests_metricflow_semantics/specs/patterns/test_entity_link_pattern.py b/metricflow-semantics/tests_metricflow_semantics/specs/patterns/test_entity_link_pattern.py index a917a575b8..b95298e6f6 100644 --- a/metricflow-semantics/tests_metricflow_semantics/specs/patterns/test_entity_link_pattern.py +++ b/metricflow-semantics/tests_metricflow_semantics/specs/patterns/test_entity_link_pattern.py @@ -145,7 +145,7 @@ def test_time_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: # EntityLinkPatternParameterSet.from_parameters( element_name=METRIC_TIME_ELEMENT_NAME, entity_links=(), - time_granularity=TimeGranularity.WEEK, + time_granularity=ExpandedTimeGranularity.from_time_granularity(TimeGranularity.WEEK), date_part=None, fields_to_compare=( ParameterSetField.ELEMENT_NAME, diff --git a/tests_metricflow/integration/test_configured_cases.py b/tests_metricflow/integration/test_configured_cases.py index 0fdce55958..b192c7acae 100644 --- a/tests_metricflow/integration/test_configured_cases.py +++ b/tests_metricflow/integration/test_configured_cases.py @@ -261,8 +261,6 @@ def test_case( if date_part or grain: if date_part: kwargs["date_part"] = DatePart(date_part) - if grain: - kwargs["grain"] = TimeGranularity(grain) group_by.append(TimeDimensionParameter(**kwargs)) else: group_by.append(DimensionOrEntityParameter(**kwargs))