Skip to content

Commit

Permalink
Streamline LinkableSpecSet.create_from_specs() to avoid circular im…
Browse files Browse the repository at this point in the history
…ports.
  • Loading branch information
plypaul committed Jul 15, 2024
1 parent 607d78b commit e9ee66c
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import dataclasses
import itertools
import typing
from dataclasses import dataclass
Expand All @@ -13,13 +14,15 @@
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.group_by_metric_spec import GroupByMetricSpec
from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.specs.instance_spec import InstanceSpecVisitor, LinkableInstanceSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec

if typing.TYPE_CHECKING:
from metricflow_semantics.model.semantics.metric_lookup import MetricLookup
from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow_semantics.specs.measure_spec import MeasureSpec
from metricflow_semantics.specs.metadata_spec import MetadataSpec
from metricflow_semantics.specs.metric_spec import MetricSpec


@dataclass(frozen=True)
Expand Down Expand Up @@ -90,8 +93,8 @@ def merge(self, other: LinkableSpecSet) -> LinkableSpecSet:
group_by_metric_specs=self.group_by_metric_specs + other.group_by_metric_specs,
)

@override
@classmethod
@override
def empty_instance(cls) -> LinkableSpecSet:
return LinkableSpecSet()

Expand Down Expand Up @@ -136,14 +139,57 @@ def __len__(self) -> int: # noqa: D105
return len(self.dimension_specs) + len(self.time_dimension_specs) + len(self.entity_specs)

@staticmethod
def create_from_spec_set(spec_set: InstanceSpecSet) -> LinkableSpecSet: # noqa: D102
return LinkableSpecSet(
dimension_specs=spec_set.dimension_specs,
time_dimension_specs=spec_set.time_dimension_specs,
entity_specs=spec_set.entity_specs,
group_by_metric_specs=spec_set.group_by_metric_specs,
)
def create_from_specs(specs: Sequence[LinkableInstanceSpec]) -> LinkableSpecSet: # noqa: D102
return _group_specs_by_type(specs)

@staticmethod
def create_from_specs(specs: Sequence[InstanceSpec]) -> LinkableSpecSet: # noqa: D102
return LinkableSpecSet.create_from_spec_set(InstanceSpecSet.create_from_specs(specs))

@dataclass
class _GroupSpecByTypeVisitor(InstanceSpecVisitor[None]):
"""Groups a spec by type into an `InstanceSpecSet`."""

dimension_specs: List[DimensionSpec] = dataclasses.field(default_factory=list)
entity_specs: List[EntitySpec] = dataclasses.field(default_factory=list)
time_dimension_specs: List[TimeDimensionSpec] = dataclasses.field(default_factory=list)
group_by_metric_specs: List[GroupByMetricSpec] = dataclasses.field(default_factory=list)

@override
def visit_measure_spec(self, measure_spec: MeasureSpec) -> None:
pass

@override
def visit_dimension_spec(self, dimension_spec: DimensionSpec) -> None:
self.dimension_specs.append(dimension_spec)

@override
def visit_time_dimension_spec(self, time_dimension_spec: TimeDimensionSpec) -> None:
self.time_dimension_specs.append(time_dimension_spec)

@override
def visit_entity_spec(self, entity_spec: EntitySpec) -> None:
self.entity_specs.append(entity_spec)

@override
def visit_group_by_metric_spec(self, group_by_metric_spec: GroupByMetricSpec) -> None:
self.group_by_metric_specs.append(group_by_metric_spec)

@override
def visit_metric_spec(self, metric_spec: MetricSpec) -> None:
pass

@override
def visit_metadata_spec(self, metadata_spec: MetadataSpec) -> None:
pass


def _group_specs_by_type(specs: Sequence[LinkableInstanceSpec]) -> LinkableSpecSet:
"""Groups a sequence of specs by type."""
grouper = _GroupSpecByTypeVisitor()
for spec in specs:
spec.accept(grouper)

return LinkableSpecSet(
dimension_specs=tuple(grouper.dimension_specs),
entity_specs=tuple(grouper.entity_specs),
time_dimension_specs=tuple(grouper.time_dimension_specs),
group_by_metric_specs=tuple(grouper.group_by_metric_specs),
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from metricflow_semantics.query.group_by_item.group_by_item_resolver import GroupByItemResolver
from metricflow_semantics.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.test_helpers.snapshot_helpers import assert_linkable_spec_set_snapshot_equal

Expand Down Expand Up @@ -37,5 +36,5 @@ def test_available_group_by_items( # noqa: D103
request=request,
mf_test_configuration=mf_test_configuration,
set_id="set0",
spec_set=LinkableSpecSet.create_from_spec_set(group_specs_by_type(result.specs)),
spec_set=LinkableSpecSet.create_from_specs(result.specs),
)

0 comments on commit e9ee66c

Please sign in to comment.