diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element.py b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element.py index b4047702e2..33933b220d 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element.py @@ -20,13 +20,6 @@ from metricflow_semantics.model.linkable_element_property import LinkableElementProperty from metricflow_semantics.model.semantic_model_derivation import SemanticModelDerivation -from metricflow_semantics.specs.spec_classes import ( - DimensionSpec, - EntitySpec, - GroupByMetricSpec, - LinkableInstanceSpec, - TimeDimensionSpec, -) logger = logging.getLogger(__name__) @@ -85,42 +78,6 @@ def __post_init__(self) -> None: else: assert_values_exhausted(element_type) - @property - def spec(self) -> LinkableInstanceSpec: - """The corresponding spec object for this path key.""" - if self.element_type is LinkableElementType.DIMENSION: - return DimensionSpec( - element_name=self.element_name, - entity_links=self.entity_links, - ) - elif self.element_type is LinkableElementType.TIME_DIMENSION: - assert ( - self.time_granularity is not None - ), f"{self.time_granularity=} should not be None as per check in dataclass validation" - return TimeDimensionSpec( - element_name=self.element_name, - entity_links=self.entity_links, - time_granularity=self.time_granularity, - date_part=self.date_part, - ) - elif self.element_type is LinkableElementType.ENTITY: - return EntitySpec( - element_name=self.element_name, - entity_links=self.entity_links, - ) - elif self.element_type is LinkableElementType.METRIC: - assert self.metric_subquery_entity_links is not None, ( - "ElementPathKeys for metrics must have non-null metric_subquery_entity_links." - "This should have been checked in post_init." - ) - return GroupByMetricSpec( - element_name=self.element_name, - entity_links=self.entity_links, - metric_subquery_entity_links=self.metric_subquery_entity_links, - ) - else: - assert_values_exhausted(self.element_type) - @dataclass(frozen=True) class SemanticModelJoinPathElement: diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element_set.py b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element_set.py index 7789a71bcc..78487fbeeb 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element_set.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element_set.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from typing import Dict, FrozenSet, List, Sequence, Set, Tuple +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.references import SemanticModelReference from typing_extensions import override @@ -18,8 +19,12 @@ ) from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern from metricflow_semantics.specs.spec_classes import ( + DimensionSpec, + EntitySpec, + GroupByMetricSpec, InstanceSpec, LinkableInstanceSpec, + TimeDimensionSpec, ) @@ -308,14 +313,50 @@ def spec_count(self) -> int: @property def specs(self) -> Sequence[LinkableInstanceSpec]: """Converts the items in a `LinkableElementSet` to their corresponding spec objects.""" - return tuple( - path_key.spec - for path_key in ( - tuple(self.path_key_to_linkable_dimensions.keys()) - + tuple(self.path_key_to_linkable_entities.keys()) - + tuple(self.path_key_to_linkable_metrics.keys()) + specs: List[LinkableInstanceSpec] = [] + + for path_key in ( + tuple(self.path_key_to_linkable_dimensions.keys()) + + tuple(self.path_key_to_linkable_entities.keys()) + + tuple(self.path_key_to_linkable_metrics.keys()) + ): + specs.append(LinkableElementSet._path_key_to_spec(path_key)) + + return specs + + @staticmethod + def _path_key_to_spec(path_key: ElementPathKey) -> LinkableInstanceSpec: + """Helper method to convert ElementPathKey instances to LinkableInstanceSpecs. + + This is currently used in the context of switching between ElementPathKeys and LinkableInstanceSpecs + within a LinkableElementSet, so we leave it here for now. + """ + if path_key.element_type is LinkableElementType.DIMENSION: + return DimensionSpec( + element_name=path_key.element_name, + entity_links=path_key.entity_links, ) - ) + elif path_key.element_type is LinkableElementType.TIME_DIMENSION: + assert path_key.time_granularity is not None + return TimeDimensionSpec( + element_name=path_key.element_name, + entity_links=path_key.entity_links, + time_granularity=path_key.time_granularity, + date_part=path_key.date_part, + ) + elif path_key.element_type is LinkableElementType.ENTITY: + return EntitySpec( + element_name=path_key.element_name, + entity_links=path_key.entity_links, + ) + elif path_key.element_type is LinkableElementType.METRIC: + return GroupByMetricSpec( + element_name=path_key.element_name, + entity_links=path_key.entity_links, + metric_subquery_entity_links=path_key.metric_subquery_entity_links, + ) + else: + assert_values_exhausted(path_key.element_type) def filter_by_spec_patterns(self, spec_patterns: Sequence[SpecPattern]) -> LinkableElementSet: """Filter the elements in the set by the given spec patters. @@ -335,15 +376,15 @@ def filter_by_spec_patterns(self, spec_patterns: Sequence[SpecPattern]) -> Linka path_key_to_linkable_metrics: Dict[ElementPathKey, Tuple[LinkableMetric, ...]] = {} for path_key, linkable_dimensions in self.path_key_to_linkable_dimensions.items(): - if path_key.spec in specs_to_include: + if LinkableElementSet._path_key_to_spec(path_key) in specs_to_include: path_key_to_linkable_dimensions[path_key] = linkable_dimensions for path_key, linkable_entities in self.path_key_to_linkable_entities.items(): - if path_key.spec in specs_to_include: + if LinkableElementSet._path_key_to_spec(path_key) in specs_to_include: path_key_to_linkable_entities[path_key] = linkable_entities for path_key, linkable_metrics in self.path_key_to_linkable_metrics.items(): - if path_key.spec in specs_to_include: + if LinkableElementSet._path_key_to_spec(path_key) in specs_to_include: path_key_to_linkable_metrics[path_key] = linkable_metrics return LinkableElementSet( diff --git a/tests_metricflow/dataflow/builder/test_dataflow_plan_builder.py b/tests_metricflow/dataflow/builder/test_dataflow_plan_builder.py index 8821bc9a1c..9dca206645 100644 --- a/tests_metricflow/dataflow/builder/test_dataflow_plan_builder.py +++ b/tests_metricflow/dataflow/builder/test_dataflow_plan_builder.py @@ -13,6 +13,7 @@ from metricflow_semantics.errors.error_classes import UnableToSatisfyQueryError from metricflow_semantics.filters.time_constraint import TimeRangeConstraint from metricflow_semantics.model.linkable_element_property import LinkableElementProperty +from metricflow_semantics.model.semantics.linkable_element_set import LinkableElementSet from metricflow_semantics.query.query_parser import MetricFlowQueryParser from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec @@ -1338,7 +1339,7 @@ def test_all_available_single_join_metric_filters( MeasureReference("listings"), without_any_of={LinkableElementProperty.MULTI_HOP} ).path_key_to_linkable_metrics.values(): for linkable_metric in linkable_metric_tuple: - group_by_metric_spec = linkable_metric.path_key.spec + group_by_metric_spec = LinkableElementSet._path_key_to_spec(linkable_metric.path_key) assert isinstance(group_by_metric_spec, GroupByMetricSpec) entity_spec = group_by_metric_spec.metric_subquery_entity_spec if entity_spec.entity_links: # multi-hop for inner query