Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LinkableSpecName.from_name to parse custom grains #1496

Merged
merged 5 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import logging
from typing import Dict, List, Optional, Sequence, Set
from functools import cached_property
from typing import Dict, List, Optional, Sequence, Set, Tuple

from dbt_semantic_interfaces.protocols.dimension import Dimension
from dbt_semantic_interfaces.protocols.entity import Entity
Expand Down Expand Up @@ -75,6 +76,11 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa
self._measure_lookup = MeasureLookup(sorted_semantic_models, custom_granularities)
self._dimension_lookup = DimensionLookup(sorted_semantic_models)

@cached_property
def custom_granularity_names(self) -> Tuple[str, ...]:
"""Returns all the custom_granularity names."""
return tuple(self._custom_granularities.keys())

def get_dimension_references(self) -> Sequence[DimensionReference]:
"""Retrieve all dimension references from the collection of semantic models."""
return tuple(self._dimension_index.keys())
Expand Down Expand Up @@ -224,7 +230,9 @@ def _add_semantic_model(self, semantic_model: SemanticModel) -> None:
semantic_models_for_dimension = self._dimension_index.get(dim.reference, []) + [semantic_model]
self._dimension_index[dim.reference] = semantic_models_for_dimension

if not StructuredLinkableSpecName.from_name(dim.name).is_element_name:
if not StructuredLinkableSpecName.from_name(
qualified_name=dim.name, custom_granularity_names=self.custom_granularity_names
).is_element_name:
# TODO: [custom granularity] change this to an assertion once we're sure there aren't exceptions
logger.warning(
LazyFormat(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import logging
from typing import Optional, Tuple
from functools import lru_cache
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
Expand Down Expand Up @@ -34,7 +35,8 @@ def __init__(
self.date_part = date_part

@staticmethod
def from_name(qualified_name: str) -> StructuredLinkableSpecName:
@lru_cache
def from_name(qualified_name: str, custom_granularity_names: Sequence[str]) -> StructuredLinkableSpecName:
WilliamDee marked this conversation as resolved.
Show resolved Hide resolved
"""Construct from a name e.g. listing__ds__month."""
name_parts = qualified_name.split(DUNDER)

Expand All @@ -48,24 +50,30 @@ def from_name(qualified_name: str) -> StructuredLinkableSpecName:
"Dunder syntax not supported for querying date_part. Use `group_by` object syntax instead."
)

associated_granularity = None
# TODO: [custom granularity] Update parsing to account for custom granularities
associated_granularity: Optional[str] = None
for granularity in TimeGranularity:
if name_parts[-1] == granularity.value:
associated_granularity = granularity
associated_granularity = granularity.value
break

if associated_granularity is None:
for custom_grain in custom_granularity_names:
if name_parts[-1] == custom_grain:
associated_granularity = custom_grain
break

# Has a time granularity
if associated_granularity:
# e.g. "ds__month"
if len(name_parts) == 2:
return StructuredLinkableSpecName(
entity_link_names=(), element_name=name_parts[0], time_granularity_name=associated_granularity.value
entity_link_names=(), element_name=name_parts[0], time_granularity_name=associated_granularity
)
# e.g. "messages__ds__month"
return StructuredLinkableSpecName(
entity_link_names=tuple(name_parts[:-2]),
element_name=name_parts[-2],
time_granularity_name=associated_granularity.value,
time_granularity_name=associated_granularity,
)

# e.g. "messages__ds"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ def from_call_parameter_set( # noqa: D102
"This should have been caught by validations."
)
group_by = metric_call_parameter_set.group_by[0]
structured_name = StructuredLinkableSpecName.from_name(group_by.element_name)
# custom_granularity_names is empty because we are not parsing any dimensions here with grain
structured_name = StructuredLinkableSpecName.from_name(
qualified_name=group_by.element_name, custom_granularity_names=()
)
metric_subquery_entity_links = tuple(
EntityReference(entity_name)
for entity_name in (structured_name.entity_link_names + (structured_name.element_name,))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def query_resolver_input( # noqa: D102
if self.grain is not None:
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY)

# TODO: assert that the name does not include a time granularity marker
name_structure = StructuredLinkableSpecName.from_name(self.name.lower())
name_structure = StructuredLinkableSpecName.from_name(
qualified_name=self.name.lower(),
custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names,
)

return ResolverInputForGroupByItem(
input_obj=self,
Expand Down Expand Up @@ -97,7 +99,10 @@ def query_resolver_input(self, semantic_manifest_lookup: SemanticManifestLookup)
TODO: Refine these query input classes so that this kind of thing is either enforced in self-documenting
ways or removed from the codebase
"""
name_structure = StructuredLinkableSpecName.from_name(self.name.lower())
name_structure = StructuredLinkableSpecName.from_name(
qualified_name=self.name.lower(),
custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names,
)

return ResolverInputForGroupByItem(
input_obj=self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.call_parameter_sets import (
TimeDimensionCallParameterSet,
Expand Down Expand Up @@ -69,11 +69,13 @@ def __init__( # noqa
spec_resolution_lookup: FilterSpecResolutionLookUp,
where_filter_location: WhereFilterLocation,
rendered_spec_tracker: RenderedSpecTracker,
custom_granularity_names: Tuple[str, ...],
):
self._column_association_resolver = column_association_resolver
self._resolved_spec_lookup = spec_resolution_lookup
self._where_filter_location = where_filter_location
self._rendered_spec_tracker = rendered_spec_tracker
self._custom_granularity_names = custom_granularity_names

def create(
self,
Expand All @@ -90,7 +92,9 @@ def create(
)

time_granularity_name = time_granularity_name.lower() if time_granularity_name else None
structured_name = StructuredLinkableSpecName.from_name(time_dimension_name.lower())
structured_name = StructuredLinkableSpecName.from_name(
qualified_name=time_dimension_name.lower(), custom_granularity_names=self._custom_granularity_names
)

if (
structured_name.time_granularity_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilterIntersection
from dbt_semantic_interfaces.protocols import WhereFilter, WhereFilterIntersection

from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolutionLookUp,
Expand Down Expand Up @@ -36,9 +37,11 @@ def __init__( # noqa: D107
self,
column_association_resolver: ColumnAssociationResolver,
spec_resolution_lookup: FilterSpecResolutionLookUp,
semantic_model_lookup: SemanticModelLookup,
) -> None:
self._column_association_resolver = column_association_resolver
self._spec_resolution_lookup = spec_resolution_lookup
self._semantic_model_lookup = semantic_model_lookup

def create_from_where_filter( # noqa: D102
self,
Expand Down Expand Up @@ -73,6 +76,7 @@ def create_from_where_filter_intersection( # noqa: D102
spec_resolution_lookup=self._spec_resolution_lookup,
where_filter_location=filter_location,
rendered_spec_tracker=rendered_spec_tracker,
custom_granularity_names=self._semantic_model_lookup.custom_granularity_names,
)
entity_factory = WhereFilterEntityFactory(
column_association_resolver=self._column_association_resolver,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def test_dimension_in_filter( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter_intersection(
filter_location=EXAMPLE_FILTER_LOCATION,
filter_intersection=create_where_filter_intersection("{{ Dimension('listing__country_latest') }} = 'US'"),
Expand Down Expand Up @@ -196,6 +197,7 @@ def test_dimension_in_filter_with_grain( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter_intersection(
filter_location=EXAMPLE_FILTER_LOCATION,
filter_intersection=create_where_filter_intersection(
Expand Down Expand Up @@ -262,6 +264,7 @@ def test_time_dimension_in_filter( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter_intersection(
filter_location=EXAMPLE_FILTER_LOCATION,
filter_intersection=create_where_filter_intersection(
Expand Down Expand Up @@ -328,6 +331,7 @@ def test_time_dimension_with_grain_in_name( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter_intersection(
filter_location=EXAMPLE_FILTER_LOCATION,
filter_intersection=create_where_filter_intersection(
Expand Down Expand Up @@ -395,6 +399,7 @@ def test_date_part_in_filter( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter_intersection(
filter_location=EXAMPLE_FILTER_LOCATION,
filter_intersection=create_where_filter_intersection(
Expand Down Expand Up @@ -486,13 +491,15 @@ def resolved_spec_lookup(
def test_date_part_and_grain_in_filter( # noqa: D103
column_association_resolver: ColumnAssociationResolver,
resolved_spec_lookup: FilterSpecResolutionLookUp,
simple_semantic_manifest_lookup: SemanticManifestLookup,
where_sql: str,
) -> None:
where_filter = PydanticWhereFilter(where_sql_template=where_sql)

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
spec_resolution_lookup=resolved_spec_lookup,
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter(EXAMPLE_FILTER_LOCATION, where_filter)

assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'"
Expand Down Expand Up @@ -523,13 +530,15 @@ def test_date_part_and_grain_in_filter( # noqa: D103
def test_date_part_less_than_grain_in_filter( # noqa: D103
column_association_resolver: ColumnAssociationResolver,
resolved_spec_lookup: FilterSpecResolutionLookUp,
simple_semantic_manifest_lookup: SemanticManifestLookup,
where_sql: str,
) -> None:
where_filter = PydanticWhereFilter(where_sql_template=where_sql)

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
spec_resolution_lookup=resolved_spec_lookup,
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter(EXAMPLE_FILTER_LOCATION, where_filter)

assert where_filter_spec.where_sql == "metric_time__extract_day = '2020'"
Expand Down Expand Up @@ -587,6 +596,7 @@ def test_entity_in_filter( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter(filter_location=EXAMPLE_FILTER_LOCATION, where_filter=where_filter)

assert where_filter_spec.where_sql == "listing__user == 'example_user_id'"
Expand Down Expand Up @@ -646,6 +656,7 @@ def test_metric_in_filter( # noqa: D103
),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter(filter_location=EXAMPLE_FILTER_LOCATION, where_filter=where_filter)

assert where_filter_spec.where_sql == "listing__bookings > 2"
Expand Down Expand Up @@ -715,6 +726,7 @@ def get_spec(dimension: str) -> WhereFilterSpec:
),
non_parsable_resolutions=(),
),
semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup,
).create_from_where_filter(filter_location, where_filter)

time_dimension_spec = get_spec("TimeDimension('metric_time', 'week', date_part_name='year')")
Expand Down
6 changes: 5 additions & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def _build_query_output_node(
filter_spec_factory = WhereSpecFactory(
column_association_resolver=self._column_association_resolver,
spec_resolution_lookup=query_spec.filter_spec_resolution_lookup,
semantic_model_lookup=self._semantic_model_lookup,
)

query_level_filter_specs = tuple(
Expand Down Expand Up @@ -409,7 +410,9 @@ def _build_conversion_metric_output_node(
queried_linkable_specs=queried_linkable_specs,
)
# TODO: [custom granularity] change this to an assertion once we're sure there aren't exceptions
if not StructuredLinkableSpecName.from_name(conversion_type_params.entity).is_element_name:
if not StructuredLinkableSpecName.from_name(
qualified_name=conversion_type_params.entity, custom_granularity_names=()
).is_element_name:
logger.warning(
LazyFormat(
lambda: f"Found additional annotations in type param entity name `{conversion_type_params.entity}`, which "
Expand Down Expand Up @@ -806,6 +809,7 @@ def _build_plan_for_distinct_values(
column_association_resolver=self._column_association_resolver,
spec_resolution_lookup=query_spec.filter_spec_resolution_lookup
or FilterSpecResolutionLookUp.empty_instance(),
semantic_model_lookup=self._semantic_model_lookup,
)

query_level_filter_specs = filter_spec_factory.create_from_where_filter_intersection(
Expand Down
1 change: 1 addition & 0 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ def simple_dimensions_for_metrics( # noqa: D102
else None
),
).qualified_name,
entity_links=(),
description="Event time for metrics.",
metadata=None,
type_params=PydanticDimensionTypeParams(
Expand Down
10 changes: 8 additions & 2 deletions metricflow/engine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Dimension:
qualified_name: str
description: Optional[str]
type: DimensionType
entity_links: Tuple[EntityReference, ...]
type_params: Optional[DimensionTypeParams]
metadata: Optional[Metadata]
is_partition: bool = False
Expand All @@ -87,7 +88,9 @@ class Dimension:

@classmethod
def from_pydantic(
cls, pydantic_dimension: SemanticManifestDimension, entity_links: Tuple[EntityReference, ...]
cls,
pydantic_dimension: SemanticManifestDimension,
entity_links: Tuple[EntityReference, ...],
) -> Dimension:
"""Build from pydantic Dimension and entity_key."""
qualified_name = DimensionSpec(element_name=pydantic_dimension.name, entity_links=entity_links).qualified_name
Expand All @@ -107,6 +110,7 @@ def from_pydantic(
is_partition=pydantic_dimension.is_partition,
expr=pydantic_dimension.expr,
label=pydantic_dimension.label,
entity_links=entity_links,
)

@property
Expand All @@ -120,7 +124,9 @@ def granularity_free_qualified_name(self) -> str:
Dimension set has de-duplicated TimeDimensions such that you never have more than one granularity
in your set for each TimeDimension.
"""
return StructuredLinkableSpecName.from_name(qualified_name=self.qualified_name).granularity_free_qualified_name
return StructuredLinkableSpecName(
entity_link_names=tuple(e.element_name for e in self.entity_links), element_name=self.name
).qualified_name


@dataclass(frozen=True)
Expand Down
Loading