diff --git a/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py b/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py index 00b60ad92d..43d55d3b6b 100644 --- a/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py +++ b/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses import itertools import typing from dataclasses import dataclass @@ -13,13 +14,15 @@ from metricflow_semantics.specs.dimension_spec import DimensionSpec from metricflow_semantics.specs.entity_spec import EntitySpec from metricflow_semantics.specs.group_by_metric_spec import GroupByMetricSpec -from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec -from metricflow_semantics.specs.spec_set import InstanceSpecSet +from metricflow_semantics.specs.instance_spec import InstanceSpecVisitor, LinkableInstanceSpec from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec if typing.TYPE_CHECKING: from metricflow_semantics.model.semantics.metric_lookup import MetricLookup from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup + from metricflow_semantics.specs.measure_spec import MeasureSpec + from metricflow_semantics.specs.metadata_spec import MetadataSpec + from metricflow_semantics.specs.metric_spec import MetricSpec @dataclass(frozen=True) @@ -90,8 +93,8 @@ def merge(self, other: LinkableSpecSet) -> LinkableSpecSet: group_by_metric_specs=self.group_by_metric_specs + other.group_by_metric_specs, ) - @override @classmethod + @override def empty_instance(cls) -> LinkableSpecSet: return LinkableSpecSet() @@ -136,14 +139,57 @@ def __len__(self) -> int: # noqa: D105 return len(self.dimension_specs) + len(self.time_dimension_specs) + len(self.entity_specs) @staticmethod - def create_from_spec_set(spec_set: InstanceSpecSet) -> LinkableSpecSet: # noqa: D102 - return LinkableSpecSet( - dimension_specs=spec_set.dimension_specs, - time_dimension_specs=spec_set.time_dimension_specs, - entity_specs=spec_set.entity_specs, - group_by_metric_specs=spec_set.group_by_metric_specs, - ) + def create_from_specs(specs: Sequence[LinkableInstanceSpec]) -> LinkableSpecSet: # noqa: D102 + return _group_specs_by_type(specs) - @staticmethod - def create_from_specs(specs: Sequence[InstanceSpec]) -> LinkableSpecSet: # noqa: D102 - return LinkableSpecSet.create_from_spec_set(InstanceSpecSet.create_from_specs(specs)) + +@dataclass +class _GroupSpecByTypeVisitor(InstanceSpecVisitor[None]): + """Groups a spec by type into an `InstanceSpecSet`.""" + + dimension_specs: List[DimensionSpec] = dataclasses.field(default_factory=list) + entity_specs: List[EntitySpec] = dataclasses.field(default_factory=list) + time_dimension_specs: List[TimeDimensionSpec] = dataclasses.field(default_factory=list) + group_by_metric_specs: List[GroupByMetricSpec] = dataclasses.field(default_factory=list) + + @override + def visit_measure_spec(self, measure_spec: MeasureSpec) -> None: + pass + + @override + def visit_dimension_spec(self, dimension_spec: DimensionSpec) -> None: + self.dimension_specs.append(dimension_spec) + + @override + def visit_time_dimension_spec(self, time_dimension_spec: TimeDimensionSpec) -> None: + self.time_dimension_specs.append(time_dimension_spec) + + @override + def visit_entity_spec(self, entity_spec: EntitySpec) -> None: + self.entity_specs.append(entity_spec) + + @override + def visit_group_by_metric_spec(self, group_by_metric_spec: GroupByMetricSpec) -> None: + self.group_by_metric_specs.append(group_by_metric_spec) + + @override + def visit_metric_spec(self, metric_spec: MetricSpec) -> None: + pass + + @override + def visit_metadata_spec(self, metadata_spec: MetadataSpec) -> None: + pass + + +def _group_specs_by_type(specs: Sequence[LinkableInstanceSpec]) -> LinkableSpecSet: + """Groups a sequence of specs by type.""" + grouper = _GroupSpecByTypeVisitor() + for spec in specs: + spec.accept(grouper) + + return LinkableSpecSet( + dimension_specs=tuple(grouper.dimension_specs), + entity_specs=tuple(grouper.entity_specs), + time_dimension_specs=tuple(grouper.time_dimension_specs), + group_by_metric_specs=tuple(grouper.group_by_metric_specs), + ) diff --git a/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_available_group_by_items.py b/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_available_group_by_items.py index ad35fa2748..74ba861dd6 100644 --- a/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_available_group_by_items.py +++ b/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_available_group_by_items.py @@ -9,7 +9,6 @@ from metricflow_semantics.query.group_by_item.group_by_item_resolver import GroupByItemResolver from metricflow_semantics.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet -from metricflow_semantics.specs.spec_set import group_specs_by_type from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow_semantics.test_helpers.snapshot_helpers import assert_linkable_spec_set_snapshot_equal @@ -37,5 +36,5 @@ def test_available_group_by_items( # noqa: D103 request=request, mf_test_configuration=mf_test_configuration, set_id="set0", - spec_set=LinkableSpecSet.create_from_spec_set(group_specs_by_type(result.specs)), + spec_set=LinkableSpecSet.create_from_specs(result.specs), )