Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch query parameter classes onto ExpandedTimeGranularity #1404

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ParameterSetField,
)
from metricflow_semantics.specs.spec_set import InstanceSpecSet, InstanceSpecSetTransform, group_spec_by_type
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


class DunderNamingScheme(QueryItemNamingScheme):
Expand Down Expand Up @@ -81,7 +82,8 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
# At this point, len(input_str_parts) >= 2
for granularity in TimeGranularity:
if input_str_parts[-1] == granularity.value:
time_grain = granularity
# TODO: [custom granularity] add support for custom granularity names here
time_grain = ExpandedTimeGranularity.from_time_granularity(granularity)

# Has a time grain specified.
if time_grain is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern
from metricflow_semantics.specs.patterns.typed_patterns import DimensionPattern, TimeDimensionPattern
from metricflow_semantics.time.granularity import ExpandedTimeGranularity

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,14 +81,18 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
ParameterSetField.DATE_PART,
]

time_granularity: Optional[ExpandedTimeGranularity] = None
if time_dimension_call_parameter_set.time_granularity is not None:
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY)
time_granularity = ExpandedTimeGranularity.from_time_granularity(
time_dimension_call_parameter_set.time_granularity
)

return TimeDimensionPattern(
EntityLinkPatternParameterSet.from_parameters(
element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name,
entity_links=time_dimension_call_parameter_set.entity_path,
time_granularity=time_dimension_call_parameter_set.time_granularity,
time_granularity=time_granularity,
date_part=time_dimension_call_parameter_set.date_part,
fields_to_compare=tuple(fields_to_compare),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import TYPE_CHECKING, Optional, Protocol, Union, runtime_checkable

from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,8 +51,12 @@ def name(self) -> str:
raise NotImplementedError

@property
def grain(self) -> Optional[TimeGranularity]:
"""The time granularity."""
def grain(self) -> Optional[str]:
"""The name of the time granularity.

This may be the name of a custom granularity or the string value of an entry in the standard
TimeGranularity enum.
"""
raise NotImplementedError

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from typing import Any, List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from more_itertools import is_sorted
from typing_extensions import override

from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.time.granularity import ExpandedTimeGranularity

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,7 +50,7 @@ class EntityLinkPatternParameterSet:
# The entities used for joining semantic models.
entity_links: Optional[Tuple[EntityReference, ...]] = None
# Properties of time dimensions to match.
time_granularity: Optional[TimeGranularity] = None
time_granularity: Optional[ExpandedTimeGranularity] = None
date_part: Optional[DatePart] = None
metric_subquery_entity_links: Optional[Tuple[EntityReference, ...]] = None

Expand All @@ -59,7 +59,7 @@ def from_parameters( # noqa: D102
fields_to_compare: Sequence[ParameterSetField],
element_name: Optional[str] = None,
entity_links: Optional[Sequence[EntityReference]] = None,
time_granularity: Optional[TimeGranularity] = None,
time_granularity: Optional[ExpandedTimeGranularity] = None,
date_part: Optional[DatePart] = None,
metric_subquery_entity_links: Optional[Tuple[EntityReference, ...]] = None,
) -> EntityLinkPatternParameterSet:
Expand Down Expand Up @@ -112,22 +112,6 @@ def _match_entity_links(self, candidate_specs: Sequence[LinkableInstanceSpec]) -
shortest_entity_link_length = min(len(matching_spec.entity_links) for matching_spec in matching_specs)
return tuple(spec for spec in matching_specs if len(spec.entity_links) == shortest_entity_link_length)

def _match_time_granularities(
self, candidate_specs: Sequence[LinkableInstanceSpec]
) -> Sequence[LinkableInstanceSpec]:
"""Do a partial match on time granularities.

TODO: [custom granularity] Support custom granularities properly. This requires us to allow these pattern classes
to take in ExpandedTimeGranularity types, which should be viable. Once that is done, this method can be removed.
"""
matching_specs: Sequence[LinkableInstanceSpec] = tuple(
candidate_spec
for candidate_spec in group_specs_by_type(candidate_specs).time_dimension_specs
if candidate_spec.time_granularity.base_granularity == self.parameter_set.time_granularity
)

return matching_specs

@override
def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[LinkableInstanceSpec]:
filtered_candidate_specs = group_specs_by_type(candidate_specs).linkable_specs
Expand All @@ -136,13 +120,10 @@ def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[LinkableIns
# Entity links could be a partial match, so it's handled separately.
if ParameterSetField.ENTITY_LINKS in self.parameter_set.fields_to_compare:
filtered_candidate_specs = self._match_entity_links(filtered_candidate_specs)
# Time granularities are special, so they are also handled separately.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we just won't check this field anymore ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check it via the standard comparison, see the change below to the set.difference() call.

if ParameterSetField.TIME_GRANULARITY in self.parameter_set.fields_to_compare:
filtered_candidate_specs = self._match_time_granularities(filtered_candidate_specs)

other_keys_to_check = set(
field_to_compare.value for field_to_compare in self.parameter_set.fields_to_compare
).difference({ParameterSetField.ENTITY_LINKS.value, ParameterSetField.TIME_GRANULARITY.value})
).difference({ParameterSetField.ENTITY_LINKS.value})

matching_specs: List[LinkableInstanceSpec] = []
parameter_set_values = tuple(getattr(self.parameter_set, key_to_check) for key_to_check in other_keys_to_check)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ParameterSetField,
)
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


@dataclass(frozen=True)
Expand Down Expand Up @@ -78,15 +79,19 @@ def from_call_parameter_set(
ParameterSetField.DATE_PART,
]

time_granularity = None
if time_dimension_call_parameter_set.time_granularity is not None:
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY)
time_granularity = ExpandedTimeGranularity.from_time_granularity(
time_dimension_call_parameter_set.time_granularity
)

return TimeDimensionPattern(
parameter_set=EntityLinkPatternParameterSet.from_parameters(
fields_to_compare=tuple(fields_to_compare),
element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name,
entity_links=time_dimension_call_parameter_set.entity_path,
time_granularity=time_dimension_call_parameter_set.time_granularity,
time_granularity=time_granularity,
date_part=time_dimension_call_parameter_set.date_part,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
EntityLinkPatternParameterSet,
ParameterSetField,
)
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


@dataclass(frozen=True)
Expand All @@ -41,7 +42,7 @@ def _implements_protocol(self) -> TimeDimensionQueryParameter:
return self

name: str
grain: Optional[TimeGranularity] = None
grain: Optional[str] = None
date_part: Optional[DatePart] = None

def query_resolver_input( # noqa: D102
Expand All @@ -54,9 +55,17 @@ def query_resolver_input( # noqa: D102
ParameterSetField.ENTITY_LINKS,
ParameterSetField.DATE_PART,
]
time_granularity = None
if self.grain is not None:
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY)

# TODO: [custom granularity] support custom granularity lookups
assert ExpandedTimeGranularity.is_standard_granularity_name(self.grain), (
f"We got a non-standard granularity name, `{self.grain}`, but we have not yet "
"implemented support for custom granularities!"
)
time_granularity = ExpandedTimeGranularity.from_time_granularity(TimeGranularity(self.grain))

# TODO: assert that the name does not include a time granularity marker
name_structure = StructuredLinkableSpecName.from_name(self.name.lower())

return ResolverInputForGroupByItem(
Expand All @@ -67,7 +76,7 @@ def query_resolver_input( # noqa: D102
fields_to_compare=tuple(fields_to_compare),
element_name=name_structure.element_name,
entity_links=tuple(EntityReference(link_name) for link_name in name_structure.entity_link_names),
time_granularity=self.grain,
time_granularity=time_granularity,
date_part=self.date_part,
)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def test_date_part_parsing(
query_parser.parse_and_validate_query(
metric_names=["revenue"],
group_by=(
TimeDimensionParameter(name="metric_time", grain=TimeGranularity.YEAR, date_part=DatePart.MONTH),
TimeDimensionParameter(name="metric_time", grain=TimeGranularity.YEAR.value, date_part=DatePart.MONTH),
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_time_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: #
EntityLinkPatternParameterSet.from_parameters(
element_name=METRIC_TIME_ELEMENT_NAME,
entity_links=(),
time_granularity=TimeGranularity.WEEK,
time_granularity=ExpandedTimeGranularity.from_time_granularity(TimeGranularity.WEEK),
date_part=None,
fields_to_compare=(
ParameterSetField.ELEMENT_NAME,
Expand Down
2 changes: 0 additions & 2 deletions tests_metricflow/integration/test_configured_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ def test_case(
if date_part or grain:
if date_part:
kwargs["date_part"] = DatePart(date_part)
if grain:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did this not break anything 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix for a change to TimeDimensionParameter above.

We no longer need to convert the string input to a TimeGranularity in the TimeDimensionParameter initializer - it accepts string values - so without this change we were passing an incorrectly typed value in for that argument and weird runtime failures were cropping up.

kwargs["grain"] = TimeGranularity(grain)
group_by.append(TimeDimensionParameter(**kwargs))
else:
group_by.append(DimensionOrEntityParameter(**kwargs))
Expand Down
Loading