Skip to content

Commit

Permalink
Switch query parameter classes onto ExpandedTimeGranularity
Browse files Browse the repository at this point in the history
In preparation for supporting user-provided custom granularity names
at query time we need to update our query parameter processing
layer to use ExpandedTimeGranularities internally while accepting
string names from the original user input.

This change makes that update, although it is necessarily incomplete.
The TimeDimensionCallParameterSet, in particular, still uses an
enumeration-typed property for representing the user-requested grain,
but this interface is imported from dbt-semantic-interfaces and will
be updated in a later series of changes.

Follow-ups inside MetricFlow will enable lookups against custom granularity
values as we solidify our test layouts.
  • Loading branch information
tlento committed Sep 18, 2024
1 parent 9bd4b43 commit 387d2ec
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ParameterSetField,
)
from metricflow_semantics.specs.spec_set import InstanceSpecSet, InstanceSpecSetTransform, group_spec_by_type
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


class DunderNamingScheme(QueryItemNamingScheme):
Expand Down Expand Up @@ -81,7 +82,8 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
# At this point, len(input_str_parts) >= 2
for granularity in TimeGranularity:
if input_str_parts[-1] == granularity.value:
time_grain = granularity
# TODO: [custom granularity] add support for custom granularity names here
time_grain = ExpandedTimeGranularity.from_time_granularity(granularity)

# Has a time grain specified.
if time_grain is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern
from metricflow_semantics.specs.patterns.typed_patterns import DimensionPattern, TimeDimensionPattern
from metricflow_semantics.time.granularity import ExpandedTimeGranularity

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,14 +81,18 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
ParameterSetField.DATE_PART,
]

time_granularity = None
if time_dimension_call_parameter_set.time_granularity is not None:
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY)
time_granularity = ExpandedTimeGranularity.from_time_granularity(
time_dimension_call_parameter_set.time_granularity
)

return TimeDimensionPattern(
EntityLinkPatternParameterSet.from_parameters(
element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name,
entity_links=time_dimension_call_parameter_set.entity_path,
time_granularity=time_dimension_call_parameter_set.time_granularity,
time_granularity=time_granularity,
date_part=time_dimension_call_parameter_set.date_part,
fields_to_compare=tuple(fields_to_compare),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import TYPE_CHECKING, Optional, Protocol, Union, runtime_checkable

from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,8 +51,12 @@ def name(self) -> str:
raise NotImplementedError

@property
def grain(self) -> Optional[TimeGranularity]:
"""The time granularity."""
def grain(self) -> Optional[str]:
"""The name of the time granularity.
This may be the name of a custom granularity or the string value of an entry in the standard
TimeGranularity enum.
"""
raise NotImplementedError

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from typing import Any, List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from more_itertools import is_sorted
from typing_extensions import override

from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.time.granularity import ExpandedTimeGranularity

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,7 +50,7 @@ class EntityLinkPatternParameterSet:
# The entities used for joining semantic models.
entity_links: Optional[Tuple[EntityReference, ...]] = None
# Properties of time dimensions to match.
time_granularity: Optional[TimeGranularity] = None
time_granularity: Optional[ExpandedTimeGranularity] = None
date_part: Optional[DatePart] = None
metric_subquery_entity_links: Optional[Tuple[EntityReference, ...]] = None

Expand All @@ -59,7 +59,7 @@ def from_parameters( # noqa: D102
fields_to_compare: Sequence[ParameterSetField],
element_name: Optional[str] = None,
entity_links: Optional[Sequence[EntityReference]] = None,
time_granularity: Optional[TimeGranularity] = None,
time_granularity: Optional[ExpandedTimeGranularity] = None,
date_part: Optional[DatePart] = None,
metric_subquery_entity_links: Optional[Tuple[EntityReference, ...]] = None,
) -> EntityLinkPatternParameterSet:
Expand Down Expand Up @@ -112,22 +112,6 @@ def _match_entity_links(self, candidate_specs: Sequence[LinkableInstanceSpec]) -
shortest_entity_link_length = min(len(matching_spec.entity_links) for matching_spec in matching_specs)
return tuple(spec for spec in matching_specs if len(spec.entity_links) == shortest_entity_link_length)

def _match_time_granularities(
self, candidate_specs: Sequence[LinkableInstanceSpec]
) -> Sequence[LinkableInstanceSpec]:
"""Do a partial match on time granularities.
TODO: [custom granularity] Support custom granularities properly. This requires us to allow these pattern classes
to take in ExpandedTimeGranularity types, which should be viable. Once that is done, this method can be removed.
"""
matching_specs: Sequence[LinkableInstanceSpec] = tuple(
candidate_spec
for candidate_spec in group_specs_by_type(candidate_specs).time_dimension_specs
if candidate_spec.time_granularity.base_granularity == self.parameter_set.time_granularity
)

return matching_specs

@override
def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[LinkableInstanceSpec]:
filtered_candidate_specs = group_specs_by_type(candidate_specs).linkable_specs
Expand All @@ -136,13 +120,10 @@ def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[LinkableIns
# Entity links could be a partial match, so it's handled separately.
if ParameterSetField.ENTITY_LINKS in self.parameter_set.fields_to_compare:
filtered_candidate_specs = self._match_entity_links(filtered_candidate_specs)
# Time granularities are special, so they are also handled separately.
if ParameterSetField.TIME_GRANULARITY in self.parameter_set.fields_to_compare:
filtered_candidate_specs = self._match_time_granularities(filtered_candidate_specs)

other_keys_to_check = set(
field_to_compare.value for field_to_compare in self.parameter_set.fields_to_compare
).difference({ParameterSetField.ENTITY_LINKS.value, ParameterSetField.TIME_GRANULARITY.value})
).difference({ParameterSetField.ENTITY_LINKS.value})

matching_specs: List[LinkableInstanceSpec] = []
parameter_set_values = tuple(getattr(self.parameter_set, key_to_check) for key_to_check in other_keys_to_check)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ParameterSetField,
)
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


@dataclass(frozen=True)
Expand Down Expand Up @@ -78,15 +79,19 @@ def from_call_parameter_set(
ParameterSetField.DATE_PART,
]

time_granularity = None
if time_dimension_call_parameter_set.time_granularity is not None:
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY)
time_granularity = ExpandedTimeGranularity.from_time_granularity(
time_dimension_call_parameter_set.time_granularity
)

return TimeDimensionPattern(
parameter_set=EntityLinkPatternParameterSet.from_parameters(
fields_to_compare=tuple(fields_to_compare),
element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name,
entity_links=time_dimension_call_parameter_set.entity_path,
time_granularity=time_dimension_call_parameter_set.time_granularity,
time_granularity=time_granularity,
date_part=time_dimension_call_parameter_set.date_part,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
EntityLinkPatternParameterSet,
ParameterSetField,
)
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


@dataclass(frozen=True)
Expand All @@ -41,7 +42,7 @@ def _implements_protocol(self) -> TimeDimensionQueryParameter:
return self

name: str
grain: Optional[TimeGranularity] = None
grain: Optional[str] = None
date_part: Optional[DatePart] = None

def query_resolver_input( # noqa: D102
Expand All @@ -54,9 +55,17 @@ def query_resolver_input( # noqa: D102
ParameterSetField.ENTITY_LINKS,
ParameterSetField.DATE_PART,
]
time_granularity = None
if self.grain is not None:
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY)

# TODO: [custom granularity] support custom granularity lookups
assert ExpandedTimeGranularity.is_standard_granularity_name(self.grain), (
f"We got a non-standard granularity name, `{self.grain}`, but we have not yet "
"implemented support for custom granularities!"
)
time_granularity = ExpandedTimeGranularity.from_time_granularity(TimeGranularity(self.grain))

# TODO: assert that the name does not include a time granularity marker
name_structure = StructuredLinkableSpecName.from_name(self.name.lower())

return ResolverInputForGroupByItem(
Expand All @@ -67,7 +76,7 @@ def query_resolver_input( # noqa: D102
fields_to_compare=tuple(fields_to_compare),
element_name=name_structure.element_name,
entity_links=tuple(EntityReference(link_name) for link_name in name_structure.entity_link_names),
time_granularity=self.grain,
time_granularity=time_granularity,
date_part=self.date_part,
)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def test_date_part_parsing(
query_parser.parse_and_validate_query(
metric_names=["revenue"],
group_by=(
TimeDimensionParameter(name="metric_time", grain=TimeGranularity.YEAR, date_part=DatePart.MONTH),
TimeDimensionParameter(name="metric_time", grain=TimeGranularity.YEAR.value, date_part=DatePart.MONTH),
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_time_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: #
EntityLinkPatternParameterSet.from_parameters(
element_name=METRIC_TIME_ELEMENT_NAME,
entity_links=(),
time_granularity=TimeGranularity.WEEK,
time_granularity=ExpandedTimeGranularity.from_time_granularity(TimeGranularity.WEEK),
date_part=None,
fields_to_compare=(
ParameterSetField.ELEMENT_NAME,
Expand Down
2 changes: 0 additions & 2 deletions tests_metricflow/integration/test_configured_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ def test_case(
if date_part or grain:
if date_part:
kwargs["date_part"] = DatePart(date_part)
if grain:
kwargs["grain"] = TimeGranularity(grain)
group_by.append(TimeDimensionParameter(**kwargs))
else:
group_by.append(DimensionOrEntityParameter(**kwargs))
Expand Down

0 comments on commit 387d2ec

Please sign in to comment.