Skip to content

Commit

Permalink
Update LinkableSpecName.from_name to parse custom grains (#1496)
Browse files Browse the repository at this point in the history
## Context

With the addition of custom granularities, when we parse the dundered
names via `LinkableSpecName.from_name`, we need to be able to know
whether the grain provided is a custom grain. This means we need to pass
that information through to this classmethod.

Resolves SL-2971
  • Loading branch information
WilliamDee authored Nov 5, 2024
1 parent 44ba096 commit 289b278
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import logging
from typing import Dict, List, Optional, Sequence, Set
from functools import cached_property
from typing import Dict, List, Optional, Sequence, Set, Tuple

from dbt_semantic_interfaces.protocols.dimension import Dimension
from dbt_semantic_interfaces.protocols.entity import Entity
Expand Down Expand Up @@ -75,6 +76,11 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa
self._measure_lookup = MeasureLookup(sorted_semantic_models, custom_granularities)
self._dimension_lookup = DimensionLookup(sorted_semantic_models)

@cached_property
def custom_granularity_names(self) -> Tuple[str, ...]:
"""Returns all the custom_granularity names."""
return tuple(self._custom_granularities.keys())

def get_dimension_references(self) -> Sequence[DimensionReference]:
"""Retrieve all dimension references from the collection of semantic models."""
return tuple(self._dimension_index.keys())
Expand Down Expand Up @@ -224,7 +230,9 @@ def _add_semantic_model(self, semantic_model: SemanticModel) -> None:
semantic_models_for_dimension = self._dimension_index.get(dim.reference, []) + [semantic_model]
self._dimension_index[dim.reference] = semantic_models_for_dimension

if not StructuredLinkableSpecName.from_name(dim.name).is_element_name:
if not StructuredLinkableSpecName.from_name(
qualified_name=dim.name, custom_granularity_names=self.custom_granularity_names
).is_element_name:
# TODO: [custom granularity] change this to an assertion once we're sure there aren't exceptions
logger.warning(
LazyFormat(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import logging
from typing import Optional, Tuple
from functools import lru_cache
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
Expand Down Expand Up @@ -34,7 +35,8 @@ def __init__(
self.date_part = date_part

@staticmethod
def from_name(qualified_name: str) -> StructuredLinkableSpecName:
@lru_cache
def from_name(qualified_name: str, custom_granularity_names: Sequence[str]) -> StructuredLinkableSpecName:
"""Construct from a name e.g. listing__ds__month."""
name_parts = qualified_name.split(DUNDER)

Expand All @@ -48,24 +50,30 @@ def from_name(qualified_name: str) -> StructuredLinkableSpecName:
"Dunder syntax not supported for querying date_part. Use `group_by` object syntax instead."
)

associated_granularity = None
# TODO: [custom granularity] Update parsing to account for custom granularities
associated_granularity: Optional[str] = None
for granularity in TimeGranularity:
if name_parts[-1] == granularity.value:
associated_granularity = granularity
associated_granularity = granularity.value
break

if associated_granularity is None:
for custom_grain in custom_granularity_names:
if name_parts[-1] == custom_grain:
associated_granularity = custom_grain
break

# Has a time granularity
if associated_granularity:
# e.g. "ds__month"
if len(name_parts) == 2:
return StructuredLinkableSpecName(
entity_link_names=(), element_name=name_parts[0], time_granularity_name=associated_granularity.value
entity_link_names=(), element_name=name_parts[0], time_granularity_name=associated_granularity
)
# e.g. "messages__ds__month"
return StructuredLinkableSpecName(
entity_link_names=tuple(name_parts[:-2]),
element_name=name_parts[-2],
time_granularity_name=associated_granularity.value,
time_granularity_name=associated_granularity,
)

# e.g. "messages__ds"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ def from_call_parameter_set( # noqa: D102
"This should have been caught by validations."
)
group_by = metric_call_parameter_set.group_by[0]
structured_name = StructuredLinkableSpecName.from_name(group_by.element_name)
# custom_granularity_names is empty because we are not parsing any dimensions here with grain
structured_name = StructuredLinkableSpecName.from_name(
qualified_name=group_by.element_name, custom_granularity_names=()
)
metric_subquery_entity_links = tuple(
EntityReference(entity_name)
for entity_name in (structured_name.entity_link_names + (structured_name.element_name,))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def query_resolver_input( # noqa: D102
if self.grain is not None:
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY)

# TODO: assert that the name does not include a time granularity marker
name_structure = StructuredLinkableSpecName.from_name(self.name.lower())
name_structure = StructuredLinkableSpecName.from_name(
qualified_name=self.name.lower(),
custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names,
)

return ResolverInputForGroupByItem(
input_obj=self,
Expand Down Expand Up @@ -97,7 +99,10 @@ def query_resolver_input(self, semantic_manifest_lookup: SemanticManifestLookup)
TODO: Refine these query input classes so that this kind of thing is either enforced in self-documenting
ways or removed from the codebase
"""
name_structure = StructuredLinkableSpecName.from_name(self.name.lower())
name_structure = StructuredLinkableSpecName.from_name(
qualified_name=self.name.lower(),
custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names,
)

return ResolverInputForGroupByItem(
input_obj=self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.call_parameter_sets import (
TimeDimensionCallParameterSet,
Expand Down Expand Up @@ -69,11 +69,13 @@ def __init__( # noqa
spec_resolution_lookup: FilterSpecResolutionLookUp,
where_filter_location: WhereFilterLocation,
rendered_spec_tracker: RenderedSpecTracker,
custom_granularity_names: Tuple[str, ...],
):
self._column_association_resolver = column_association_resolver
self._resolved_spec_lookup = spec_resolution_lookup
self._where_filter_location = where_filter_location
self._rendered_spec_tracker = rendered_spec_tracker
self._custom_granularity_names = custom_granularity_names

def create(
self,
Expand All @@ -90,7 +92,9 @@ def create(
)

time_granularity_name = time_granularity_name.lower() if time_granularity_name else None
structured_name = StructuredLinkableSpecName.from_name(time_dimension_name.lower())
structured_name = StructuredLinkableSpecName.from_name(
qualified_name=time_dimension_name.lower(), custom_granularity_names=self._custom_granularity_names
)

if (
structured_name.time_granularity_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilterIntersection
from dbt_semantic_interfaces.protocols import WhereFilter, WhereFilterIntersection

from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolutionLookUp,
Expand Down Expand Up @@ -36,9 +37,11 @@ def __init__( # noqa: D107
self,
column_association_resolver: ColumnAssociationResolver,
spec_resolution_lookup: FilterSpecResolutionLookUp,
semantic_model_lookup: SemanticModelLookup,
) -> None:
self._column_association_resolver = column_association_resolver
self._spec_resolution_lookup = spec_resolution_lookup
self._semantic_model_lookup = semantic_model_lookup

def create_from_where_filter( # noqa: D102
self,
Expand Down Expand Up @@ -73,6 +76,7 @@ def create_from_where_filter_intersection( # noqa: D102
spec_resolution_lookup=self._spec_resolution_lookup,
where_filter_location=filter_location,
rendered_spec_tracker=rendered_spec_tracker,
custom_granularity_names=self._semantic_model_lookup.custom_granularity_names,
)
entity_factory = WhereFilterEntityFactory(
column_association_resolver=self._column_association_resolver,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def test_dimension_in_filter( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter_intersection(
filter_location=EXAMPLE_FILTER_LOCATION,
filter_intersection=create_where_filter_intersection("{{ Dimension('listing__country_latest') }} = 'US'"),
Expand Down Expand Up @@ -196,6 +197,7 @@ def test_dimension_in_filter_with_grain( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter_intersection(
filter_location=EXAMPLE_FILTER_LOCATION,
filter_intersection=create_where_filter_intersection(
Expand Down Expand Up @@ -262,6 +264,7 @@ def test_time_dimension_in_filter( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter_intersection(
filter_location=EXAMPLE_FILTER_LOCATION,
filter_intersection=create_where_filter_intersection(
Expand Down Expand Up @@ -328,6 +331,7 @@ def test_time_dimension_with_grain_in_name( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter_intersection(
filter_location=EXAMPLE_FILTER_LOCATION,
filter_intersection=create_where_filter_intersection(
Expand Down Expand Up @@ -395,6 +399,7 @@ def test_date_part_in_filter( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter_intersection(
filter_location=EXAMPLE_FILTER_LOCATION,
filter_intersection=create_where_filter_intersection(
Expand Down Expand Up @@ -486,13 +491,15 @@ def resolved_spec_lookup(
def test_date_part_and_grain_in_filter( # noqa: D103
column_association_resolver: ColumnAssociationResolver,
resolved_spec_lookup: FilterSpecResolutionLookUp,
simple_semantic_manifest_lookup: SemanticManifestLookup,
where_sql: str,
) -> None:
where_filter = PydanticWhereFilter(where_sql_template=where_sql)

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
spec_resolution_lookup=resolved_spec_lookup,
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter(EXAMPLE_FILTER_LOCATION, where_filter)

assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'"
Expand Down Expand Up @@ -523,13 +530,15 @@ def test_date_part_and_grain_in_filter( # noqa: D103
def test_date_part_less_than_grain_in_filter( # noqa: D103
column_association_resolver: ColumnAssociationResolver,
resolved_spec_lookup: FilterSpecResolutionLookUp,
simple_semantic_manifest_lookup: SemanticManifestLookup,
where_sql: str,
) -> None:
where_filter = PydanticWhereFilter(where_sql_template=where_sql)

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
spec_resolution_lookup=resolved_spec_lookup,
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter(EXAMPLE_FILTER_LOCATION, where_filter)

assert where_filter_spec.where_sql == "metric_time__extract_day = '2020'"
Expand Down Expand Up @@ -587,6 +596,7 @@ def test_entity_in_filter( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter(filter_location=EXAMPLE_FILTER_LOCATION, where_filter=where_filter)

assert where_filter_spec.where_sql == "listing__user == 'example_user_id'"
Expand Down Expand Up @@ -646,6 +656,7 @@ def test_metric_in_filter( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter(filter_location=EXAMPLE_FILTER_LOCATION, where_filter=where_filter)

assert where_filter_spec.where_sql == "listing__bookings > 2"
Expand Down Expand Up @@ -715,6 +726,7 @@ def get_spec(dimension: str) -> WhereFilterSpec:
),
non_parsable_resolutions=(),
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter(filter_location, where_filter)

time_dimension_spec = get_spec("TimeDimension('metric_time', 'week', date_part_name='year')")
Expand Down
6 changes: 5 additions & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def _build_query_output_node(
filter_spec_factory = WhereSpecFactory(
column_association_resolver=self._column_association_resolver,
spec_resolution_lookup=query_spec.filter_spec_resolution_lookup,
semantic_model_lookup=self._semantic_model_lookup,
)

query_level_filter_specs = tuple(
Expand Down Expand Up @@ -409,7 +410,9 @@ def _build_conversion_metric_output_node(
queried_linkable_specs=queried_linkable_specs,
)
# TODO: [custom granularity] change this to an assertion once we're sure there aren't exceptions
if not StructuredLinkableSpecName.from_name(conversion_type_params.entity).is_element_name:
if not StructuredLinkableSpecName.from_name(
qualified_name=conversion_type_params.entity, custom_granularity_names=()
).is_element_name:
logger.warning(
LazyFormat(
lambda: f"Found additional annotations in type param entity name `{conversion_type_params.entity}`, which "
Expand Down Expand Up @@ -806,6 +809,7 @@ def _build_plan_for_distinct_values(
column_association_resolver=self._column_association_resolver,
spec_resolution_lookup=query_spec.filter_spec_resolution_lookup
or FilterSpecResolutionLookUp.empty_instance(),
semantic_model_lookup=self._semantic_model_lookup,
)

query_level_filter_specs = filter_spec_factory.create_from_where_filter_intersection(
Expand Down
1 change: 1 addition & 0 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ def simple_dimensions_for_metrics( # noqa: D102
else None
),
).qualified_name,
entity_links=(),
description="Event time for metrics.",
metadata=None,
type_params=PydanticDimensionTypeParams(
Expand Down
10 changes: 8 additions & 2 deletions metricflow/engine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Dimension:
qualified_name: str
description: Optional[str]
type: DimensionType
entity_links: Tuple[EntityReference, ...]
type_params: Optional[DimensionTypeParams]
metadata: Optional[Metadata]
is_partition: bool = False
Expand All @@ -87,7 +88,9 @@ class Dimension:

@classmethod
def from_pydantic(
cls, pydantic_dimension: SemanticManifestDimension, entity_links: Tuple[EntityReference, ...]
cls,
pydantic_dimension: SemanticManifestDimension,
entity_links: Tuple[EntityReference, ...],
) -> Dimension:
"""Build from pydantic Dimension and entity_key."""
qualified_name = DimensionSpec(element_name=pydantic_dimension.name, entity_links=entity_links).qualified_name
Expand All @@ -107,6 +110,7 @@ def from_pydantic(
is_partition=pydantic_dimension.is_partition,
expr=pydantic_dimension.expr,
label=pydantic_dimension.label,
entity_links=entity_links,
)

@property
Expand All @@ -120,7 +124,9 @@ def granularity_free_qualified_name(self) -> str:
Dimension set has de-duplicated TimeDimensions such that you never have more than one granularity
in your set for each TimeDimension.
"""
return StructuredLinkableSpecName.from_name(qualified_name=self.qualified_name).granularity_free_qualified_name
return StructuredLinkableSpecName(
entity_link_names=tuple(e.element_name for e in self.entity_links), element_name=self.name
).qualified_name


@dataclass(frozen=True)
Expand Down

0 comments on commit 289b278

Please sign in to comment.