Skip to content

Commit

Permalink
Implement Mergeable for LinkableSpecSet.
Browse files Browse the repository at this point in the history
LinkableSpecSet already has a merge() call, so this updates the class to
imlement the Mergeable interface for consistency.
  • Loading branch information
plypaul committed Nov 14, 2023
1 parent 6282aac commit ff6ae22
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 31 deletions.
14 changes: 6 additions & 8 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,17 +674,15 @@ def _build_aggregated_measure_from_measure_source_node(

# Extraneous linkable specs are specs that are used in this phase that should not show up in the final result
# unless it was already a requested spec in the query
extraneous_linkable_specs = LinkableSpecSet()
linkable_spec_sets_to_merge: List[LinkableSpecSet] = []
if where_constraint:
extraneous_linkable_specs = LinkableSpecSet.merge(
(extraneous_linkable_specs, where_constraint.linkable_spec_set)
)
linkable_spec_sets_to_merge.append(where_constraint.linkable_spec_set)
if non_additive_dimension_spec:
extraneous_linkable_specs = LinkableSpecSet.merge(
(extraneous_linkable_specs, non_additive_dimension_spec.linkable_specs)
)
linkable_spec_sets_to_merge.append(non_additive_dimension_spec.linkable_specs)

extraneous_linkable_specs = LinkableSpecSet.merge_iterable(linkable_spec_sets_to_merge).dedupe()
required_linkable_specs = queried_linkable_specs.merge(extraneous_linkable_specs).dedupe()

required_linkable_specs = LinkableSpecSet.merge((queried_linkable_specs, extraneous_linkable_specs))
logger.info(
f"Looking for a recipe to get:\n"
f"{pformat_big_objects(measure_specs=[measure_spec], required_linkable_set=required_linkable_specs)}"
Expand Down
56 changes: 33 additions & 23 deletions metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from hashlib import sha1
from typing import Any, Generic, List, Optional, Sequence, Tuple, TypeVar
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from dbt_semantic_interfaces.implementations.metric import PydanticMetricTimeWindow
Expand Down Expand Up @@ -493,7 +493,7 @@ class FilterSpec(SerializableDataclass): # noqa: D


@dataclass(frozen=True)
class LinkableSpecSet(SerializableDataclass):
class LinkableSpecSet(Mergeable, SerializableDataclass):
"""Groups linkable specs."""

dimension_specs: Tuple[DimensionSpec, ...] = ()
Expand All @@ -504,28 +504,38 @@ class LinkableSpecSet(SerializableDataclass):
def as_tuple(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D
return tuple(itertools.chain(self.dimension_specs, self.time_dimension_specs, self.entity_specs))

@staticmethod
def merge(spec_sets: Sequence[LinkableSpecSet]) -> LinkableSpecSet:
"""Merges and dedupes the linkable specs."""
dimension_specs: List[DimensionSpec] = []
time_dimension_specs: List[TimeDimensionSpec] = []
entity_specs: List[EntitySpec] = []

for spec_set in spec_sets:
for dimension_spec in spec_set.dimension_specs:
if dimension_spec not in dimension_specs:
dimension_specs.append(dimension_spec)
for time_dimension_spec in spec_set.time_dimension_specs:
if time_dimension_spec not in time_dimension_specs:
time_dimension_specs.append(time_dimension_spec)
for entity_spec in spec_set.entity_specs:
if entity_spec not in entity_specs:
entity_specs.append(entity_spec)
@override
def merge(self, other: LinkableSpecSet) -> LinkableSpecSet:
return LinkableSpecSet(
dimension_specs=self.dimension_specs + other.dimension_specs,
time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs,
entity_specs=self.entity_specs + other.entity_specs,
)

@override
@classmethod
def empty_instance(cls) -> LinkableSpecSet:
return LinkableSpecSet()

def dedupe(self) -> LinkableSpecSet: # noqa: D
# Use dictionaries to dedupe as it preserves insertion order.

dimension_spec_dict: Dict[DimensionSpec, None] = {}
for dimension_spec in self.dimension_specs:
dimension_spec_dict[dimension_spec] = None

time_dimension_spec_dict: Dict[TimeDimensionSpec, None] = {}
for time_dimension_spec in self.time_dimension_specs:
time_dimension_spec_dict[time_dimension_spec] = None

entity_spec_dict: Dict[EntitySpec, None] = {}
for entity_spec in self.entity_specs:
entity_spec_dict[entity_spec] = None

return LinkableSpecSet(
dimension_specs=tuple(dimension_specs),
time_dimension_specs=tuple(time_dimension_specs),
entity_specs=tuple(entity_specs),
dimension_specs=tuple(dimension_spec_dict.keys()),
time_dimension_specs=tuple(time_dimension_spec_dict.keys()),
entity_specs=tuple(entity_spec_dict.keys()),
)

def is_subset_of(self, other_set: LinkableSpecSet) -> bool: # noqa: D
Expand Down Expand Up @@ -724,5 +734,5 @@ def combine(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D
return WhereFilterSpec(
where_sql=f"({self.where_sql}) AND ({other.where_sql})",
bind_parameters=self.bind_parameters.combine(other.bind_parameters),
linkable_spec_set=LinkableSpecSet.merge([self.linkable_spec_set, other.linkable_spec_set]),
linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set),
)

0 comments on commit ff6ae22

Please sign in to comment.