diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py b/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py index 057bc349a7..f122f6f5c2 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import Dict, List, Optional, Sequence, Set +from functools import cached_property +from typing import Dict, List, Optional, Sequence, Set, Tuple from dbt_semantic_interfaces.protocols.dimension import Dimension from dbt_semantic_interfaces.protocols.entity import Entity @@ -75,6 +76,11 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa self._measure_lookup = MeasureLookup(sorted_semantic_models, custom_granularities) self._dimension_lookup = DimensionLookup(sorted_semantic_models) + @cached_property + def custom_granularity_names(self) -> Tuple[str, ...]: + """Returns all the custom_granularity names.""" + return tuple(self._custom_granularities.keys()) + def get_dimension_references(self) -> Sequence[DimensionReference]: """Retrieve all dimension references from the collection of semantic models.""" return tuple(self._dimension_index.keys()) @@ -224,7 +230,9 @@ def _add_semantic_model(self, semantic_model: SemanticModel) -> None: semantic_models_for_dimension = self._dimension_index.get(dim.reference, []) + [semantic_model] self._dimension_index[dim.reference] = semantic_models_for_dimension - if not StructuredLinkableSpecName.from_name(dim.name).is_element_name: + if not StructuredLinkableSpecName.from_name( + qualified_name=dim.name, custom_granularity_names=self.custom_granularity_names + ).is_element_name: # TODO: [custom granularity] change this to an assertion once we're sure there aren't exceptions logger.warning( LazyFormat( diff --git a/metricflow-semantics/metricflow_semantics/naming/linkable_spec_name.py b/metricflow-semantics/metricflow_semantics/naming/linkable_spec_name.py index aa0185115a..ecca286ccb 100644 --- a/metricflow-semantics/metricflow_semantics/naming/linkable_spec_name.py +++ b/metricflow-semantics/metricflow_semantics/naming/linkable_spec_name.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import Optional, Tuple +from functools import lru_cache +from typing import Optional, Sequence, Tuple from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity @@ -34,7 +35,8 @@ def __init__( self.date_part = date_part @staticmethod - def from_name(qualified_name: str) -> StructuredLinkableSpecName: + @lru_cache + def from_name(qualified_name: str, custom_granularity_names: Sequence[str]) -> StructuredLinkableSpecName: """Construct from a name e.g. listing__ds__month.""" name_parts = qualified_name.split(DUNDER) @@ -48,24 +50,30 @@ def from_name(qualified_name: str) -> StructuredLinkableSpecName: "Dunder syntax not supported for querying date_part. Use `group_by` object syntax instead." ) - associated_granularity = None - # TODO: [custom granularity] Update parsing to account for custom granularities + associated_granularity: Optional[str] = None for granularity in TimeGranularity: if name_parts[-1] == granularity.value: - associated_granularity = granularity + associated_granularity = granularity.value + break + + if associated_granularity is None: + for custom_grain in custom_granularity_names: + if name_parts[-1] == custom_grain: + associated_granularity = custom_grain + break # Has a time granularity if associated_granularity: # e.g. "ds__month" if len(name_parts) == 2: return StructuredLinkableSpecName( - entity_link_names=(), element_name=name_parts[0], time_granularity_name=associated_granularity.value + entity_link_names=(), element_name=name_parts[0], time_granularity_name=associated_granularity ) # e.g. "messages__ds__month" return StructuredLinkableSpecName( entity_link_names=tuple(name_parts[:-2]), element_name=name_parts[-2], - time_granularity_name=associated_granularity.value, + time_granularity_name=associated_granularity, ) # e.g. "messages__ds" diff --git a/metricflow-semantics/metricflow_semantics/specs/patterns/typed_patterns.py b/metricflow-semantics/metricflow_semantics/specs/patterns/typed_patterns.py index ffafc7561e..a5664f9f9e 100644 --- a/metricflow-semantics/metricflow_semantics/specs/patterns/typed_patterns.py +++ b/metricflow-semantics/metricflow_semantics/specs/patterns/typed_patterns.py @@ -160,7 +160,10 @@ def from_call_parameter_set( # noqa: D102 "This should have been caught by validations." ) group_by = metric_call_parameter_set.group_by[0] - structured_name = StructuredLinkableSpecName.from_name(group_by.element_name) + # custom_granularity_names is empty because we are not parsing any dimensions here with grain + structured_name = StructuredLinkableSpecName.from_name( + qualified_name=group_by.element_name, custom_granularity_names=() + ) metric_subquery_entity_links = tuple( EntityReference(entity_name) for entity_name in (structured_name.entity_link_names + (structured_name.element_name,)) diff --git a/metricflow-semantics/metricflow_semantics/specs/query_param_implementations.py b/metricflow-semantics/metricflow_semantics/specs/query_param_implementations.py index 37a4f352e3..9715ddbf00 100644 --- a/metricflow-semantics/metricflow_semantics/specs/query_param_implementations.py +++ b/metricflow-semantics/metricflow_semantics/specs/query_param_implementations.py @@ -56,8 +56,10 @@ def query_resolver_input( # noqa: D102 if self.grain is not None: fields_to_compare.append(ParameterSetField.TIME_GRANULARITY) - # TODO: assert that the name does not include a time granularity marker - name_structure = StructuredLinkableSpecName.from_name(self.name.lower()) + name_structure = StructuredLinkableSpecName.from_name( + qualified_name=self.name.lower(), + custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names, + ) return ResolverInputForGroupByItem( input_obj=self, @@ -97,7 +99,10 @@ def query_resolver_input(self, semantic_manifest_lookup: SemanticManifestLookup) TODO: Refine these query input classes so that this kind of thing is either enforced in self-documenting ways or removed from the codebase """ - name_structure = StructuredLinkableSpecName.from_name(self.name.lower()) + name_structure = StructuredLinkableSpecName.from_name( + qualified_name=self.name.lower(), + custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names, + ) return ResolverInputForGroupByItem( input_obj=self, diff --git a/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_time_dimension.py b/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_time_dimension.py index 8f5e41e336..607f0bf123 100644 --- a/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_time_dimension.py +++ b/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_time_dimension.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple from dbt_semantic_interfaces.call_parameter_sets import ( TimeDimensionCallParameterSet, @@ -69,11 +69,13 @@ def __init__( # noqa spec_resolution_lookup: FilterSpecResolutionLookUp, where_filter_location: WhereFilterLocation, rendered_spec_tracker: RenderedSpecTracker, + custom_granularity_names: Tuple[str, ...], ): self._column_association_resolver = column_association_resolver self._resolved_spec_lookup = spec_resolution_lookup self._where_filter_location = where_filter_location self._rendered_spec_tracker = rendered_spec_tracker + self._custom_granularity_names = custom_granularity_names def create( self, @@ -90,7 +92,9 @@ def create( ) time_granularity_name = time_granularity_name.lower() if time_granularity_name else None - structured_name = StructuredLinkableSpecName.from_name(time_dimension_name.lower()) + structured_name = StructuredLinkableSpecName.from_name( + qualified_name=time_dimension_name.lower(), custom_granularity_names=self._custom_granularity_names + ) if ( structured_name.time_granularity_name diff --git a/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_transform.py b/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_transform.py index 2a9f303628..47ef06476d 100644 --- a/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_transform.py +++ b/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_transform.py @@ -8,6 +8,7 @@ from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilterIntersection from dbt_semantic_interfaces.protocols import WhereFilter, WhereFilterIntersection +from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_spec_lookup import ( FilterSpecResolutionLookUp, @@ -36,9 +37,11 @@ def __init__( # noqa: D107 self, column_association_resolver: ColumnAssociationResolver, spec_resolution_lookup: FilterSpecResolutionLookUp, + semantic_model_lookup: SemanticModelLookup, ) -> None: self._column_association_resolver = column_association_resolver self._spec_resolution_lookup = spec_resolution_lookup + self._semantic_model_lookup = semantic_model_lookup def create_from_where_filter( # noqa: D102 self, @@ -73,6 +76,7 @@ def create_from_where_filter_intersection( # noqa: D102 spec_resolution_lookup=self._spec_resolution_lookup, where_filter_location=filter_location, rendered_spec_tracker=rendered_spec_tracker, + custom_granularity_names=self._semantic_model_lookup.custom_granularity_names, ) entity_factory = WhereFilterEntityFactory( column_association_resolver=self._column_association_resolver, diff --git a/metricflow-semantics/tests_metricflow_semantics/model/test_where_filter_spec.py b/metricflow-semantics/tests_metricflow_semantics/model/test_where_filter_spec.py index 7d40ad21ae..f991e91fc3 100644 --- a/metricflow-semantics/tests_metricflow_semantics/model/test_where_filter_spec.py +++ b/metricflow-semantics/tests_metricflow_semantics/model/test_where_filter_spec.py @@ -136,6 +136,7 @@ def test_dimension_in_filter( # noqa: D103 ), semantic_manifest_lookup=simple_semantic_manifest_lookup, ), + semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup, ).create_from_where_filter_intersection( filter_location=EXAMPLE_FILTER_LOCATION, filter_intersection=create_where_filter_intersection("{{ Dimension('listing__country_latest') }} = 'US'"), @@ -196,6 +197,7 @@ def test_dimension_in_filter_with_grain( # noqa: D103 ), semantic_manifest_lookup=simple_semantic_manifest_lookup, ), + semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup, ).create_from_where_filter_intersection( filter_location=EXAMPLE_FILTER_LOCATION, filter_intersection=create_where_filter_intersection( @@ -262,6 +264,7 @@ def test_time_dimension_in_filter( # noqa: D103 ), semantic_manifest_lookup=simple_semantic_manifest_lookup, ), + semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup, ).create_from_where_filter_intersection( filter_location=EXAMPLE_FILTER_LOCATION, filter_intersection=create_where_filter_intersection( @@ -328,6 +331,7 @@ def test_time_dimension_with_grain_in_name( # noqa: D103 ), semantic_manifest_lookup=simple_semantic_manifest_lookup, ), + semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup, ).create_from_where_filter_intersection( filter_location=EXAMPLE_FILTER_LOCATION, filter_intersection=create_where_filter_intersection( @@ -395,6 +399,7 @@ def test_date_part_in_filter( # noqa: D103 ), semantic_manifest_lookup=simple_semantic_manifest_lookup, ), + semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup, ).create_from_where_filter_intersection( filter_location=EXAMPLE_FILTER_LOCATION, filter_intersection=create_where_filter_intersection( @@ -486,6 +491,7 @@ def resolved_spec_lookup( def test_date_part_and_grain_in_filter( # noqa: D103 column_association_resolver: ColumnAssociationResolver, resolved_spec_lookup: FilterSpecResolutionLookUp, + simple_semantic_manifest_lookup: SemanticManifestLookup, where_sql: str, ) -> None: where_filter = PydanticWhereFilter(where_sql_template=where_sql) @@ -493,6 +499,7 @@ def test_date_part_and_grain_in_filter( # noqa: D103 where_filter_spec = WhereSpecFactory( column_association_resolver=column_association_resolver, spec_resolution_lookup=resolved_spec_lookup, + semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup, ).create_from_where_filter(EXAMPLE_FILTER_LOCATION, where_filter) assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'" @@ -523,6 +530,7 @@ def test_date_part_and_grain_in_filter( # noqa: D103 def test_date_part_less_than_grain_in_filter( # noqa: D103 column_association_resolver: ColumnAssociationResolver, resolved_spec_lookup: FilterSpecResolutionLookUp, + simple_semantic_manifest_lookup: SemanticManifestLookup, where_sql: str, ) -> None: where_filter = PydanticWhereFilter(where_sql_template=where_sql) @@ -530,6 +538,7 @@ def test_date_part_less_than_grain_in_filter( # noqa: D103 where_filter_spec = WhereSpecFactory( column_association_resolver=column_association_resolver, spec_resolution_lookup=resolved_spec_lookup, + semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup, ).create_from_where_filter(EXAMPLE_FILTER_LOCATION, where_filter) assert where_filter_spec.where_sql == "metric_time__extract_day = '2020'" @@ -587,6 +596,7 @@ def test_entity_in_filter( # noqa: D103 ), semantic_manifest_lookup=simple_semantic_manifest_lookup, ), + semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup, ).create_from_where_filter(filter_location=EXAMPLE_FILTER_LOCATION, where_filter=where_filter) assert where_filter_spec.where_sql == "listing__user == 'example_user_id'" @@ -646,6 +656,7 @@ def test_metric_in_filter( # noqa: D103 ), semantic_manifest_lookup=simple_semantic_manifest_lookup, ), + semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup, ).create_from_where_filter(filter_location=EXAMPLE_FILTER_LOCATION, where_filter=where_filter) assert where_filter_spec.where_sql == "listing__bookings > 2" @@ -715,6 +726,7 @@ def get_spec(dimension: str) -> WhereFilterSpec: ), non_parsable_resolutions=(), ), + semantic_model_lookup=simple_semantic_manifest_lookup.semantic_model_lookup, ).create_from_where_filter(filter_location, where_filter) time_dimension_spec = get_spec("TimeDimension('metric_time', 'week', date_part_name='year')") diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 507a566525..00469eb755 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -165,6 +165,7 @@ def _build_query_output_node( filter_spec_factory = WhereSpecFactory( column_association_resolver=self._column_association_resolver, spec_resolution_lookup=query_spec.filter_spec_resolution_lookup, + semantic_model_lookup=self._semantic_model_lookup, ) query_level_filter_specs = tuple( @@ -409,7 +410,9 @@ def _build_conversion_metric_output_node( queried_linkable_specs=queried_linkable_specs, ) # TODO: [custom granularity] change this to an assertion once we're sure there aren't exceptions - if not StructuredLinkableSpecName.from_name(conversion_type_params.entity).is_element_name: + if not StructuredLinkableSpecName.from_name( + qualified_name=conversion_type_params.entity, custom_granularity_names=() + ).is_element_name: logger.warning( LazyFormat( lambda: f"Found additional annotations in type param entity name `{conversion_type_params.entity}`, which " @@ -806,6 +809,7 @@ def _build_plan_for_distinct_values( column_association_resolver=self._column_association_resolver, spec_resolution_lookup=query_spec.filter_spec_resolution_lookup or FilterSpecResolutionLookUp.empty_instance(), + semantic_model_lookup=self._semantic_model_lookup, ) query_level_filter_specs = filter_spec_factory.create_from_where_filter_intersection( diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index 2c4f33ce08..d5916a11d1 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -629,6 +629,7 @@ def simple_dimensions_for_metrics( # noqa: D102 else None ), ).qualified_name, + entity_links=(), description="Event time for metrics.", metadata=None, type_params=PydanticDimensionTypeParams( diff --git a/metricflow/engine/models.py b/metricflow/engine/models.py index 25961e26ae..60424ccf91 100644 --- a/metricflow/engine/models.py +++ b/metricflow/engine/models.py @@ -79,6 +79,7 @@ class Dimension: qualified_name: str description: Optional[str] type: DimensionType + entity_links: Tuple[EntityReference, ...] type_params: Optional[DimensionTypeParams] metadata: Optional[Metadata] is_partition: bool = False @@ -87,7 +88,9 @@ class Dimension: @classmethod def from_pydantic( - cls, pydantic_dimension: SemanticManifestDimension, entity_links: Tuple[EntityReference, ...] + cls, + pydantic_dimension: SemanticManifestDimension, + entity_links: Tuple[EntityReference, ...], ) -> Dimension: """Build from pydantic Dimension and entity_key.""" qualified_name = DimensionSpec(element_name=pydantic_dimension.name, entity_links=entity_links).qualified_name @@ -107,6 +110,7 @@ def from_pydantic( is_partition=pydantic_dimension.is_partition, expr=pydantic_dimension.expr, label=pydantic_dimension.label, + entity_links=entity_links, ) @property @@ -120,7 +124,9 @@ def granularity_free_qualified_name(self) -> str: Dimension set has de-duplicated TimeDimensions such that you never have more than one granularity in your set for each TimeDimension. """ - return StructuredLinkableSpecName.from_name(qualified_name=self.qualified_name).granularity_free_qualified_name + return StructuredLinkableSpecName( + entity_link_names=tuple(e.element_name for e in self.entity_links), element_name=self.name + ).qualified_name @dataclass(frozen=True)