Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update fields in LinkableElement-related classes to use tuples #1336

Merged
merged 3 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import FrozenSet
from typing import Any, FrozenSet


class LinkableElementProperty(Enum):
Expand Down Expand Up @@ -31,3 +31,10 @@ class LinkableElementProperty(Enum):
@staticmethod
def all_properties() -> FrozenSet[LinkableElementProperty]: # noqa: D102
return frozenset({linkable_element_property for linkable_element_property in LinkableElementProperty})

def __lt__(self, other: Any) -> bool: # type: ignore[misc]
"""When ordering, order by the enum name."""
if not isinstance(other, LinkableElementProperty):
return NotImplemented

return self.name < other.name
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import collections
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import FrozenSet, Optional, Sequence, Tuple
from functools import cached_property
from typing import FrozenSet, Iterable, Optional, Sequence, Tuple

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
Expand Down Expand Up @@ -98,6 +100,19 @@ class SemanticModelJoinPathElement(SerializableDataclass):
class LinkableElement(SemanticModelDerivation, SerializableDataclass, ABC):
"""An entity / dimension that may have been joined by entities."""

properties: Tuple[LinkableElementProperty, ...]

def __post_init__(self) -> None: # noqa: D105
if len(self.property_set) != len(self.properties):
duplicate_properties = [item for item, count in collections.Counter(self.properties).items() if count > 1]
assert False, f"Found duplicate properties {duplicate_properties} in {self.properties}"

assert self.properties == tuple(sorted(self.properties)), f"Properties are not sorted: {self.properties}"

@cached_property
def property_set(self) -> FrozenSet[LinkableElementProperty]: # noqa: D102
return frozenset(self.properties)

@property
@abstractmethod
def element_type(self) -> LinkableElementType:
Expand Down Expand Up @@ -126,10 +141,31 @@ class LinkableDimension(LinkableElement, SerializableDataclass):
dimension_type: DimensionType
entity_links: Tuple[EntityReference, ...]
join_path: SemanticModelJoinPath
properties: FrozenSet[LinkableElementProperty]
time_granularity: Optional[TimeGranularity]
date_part: Optional[DatePart]

@staticmethod
def create( # noqa: D102
properties: Iterable[LinkableElementProperty],
defined_in_semantic_model: Optional[SemanticModelReference],
element_name: str,
dimension_type: DimensionType,
entity_links: Tuple[EntityReference, ...],
join_path: SemanticModelJoinPath,
time_granularity: Optional[TimeGranularity],
date_part: Optional[DatePart],
) -> LinkableDimension:
return LinkableDimension(
properties=tuple(sorted(set(properties))),
defined_in_semantic_model=defined_in_semantic_model,
element_name=element_name,
dimension_type=dimension_type,
entity_links=entity_links,
join_path=join_path,
time_granularity=time_granularity,
date_part=date_part,
)

@property
@override
def element_type(self) -> LinkableElementType:
Expand Down Expand Up @@ -185,10 +221,25 @@ class LinkableEntity(LinkableElement, SerializableDataclass):
# The semantic model where this entity was defined.
defined_in_semantic_model: SemanticModelReference
element_name: str
properties: FrozenSet[LinkableElementProperty]
entity_links: Tuple[EntityReference, ...]
join_path: SemanticModelJoinPath

@staticmethod
def create( # noqa: D102
properties: Iterable[LinkableElementProperty],
defined_in_semantic_model: SemanticModelReference,
element_name: str,
entity_links: Tuple[EntityReference, ...],
join_path: SemanticModelJoinPath,
) -> LinkableEntity:
return LinkableEntity(
properties=tuple(sorted(set(properties))),
defined_in_semantic_model=defined_in_semantic_model,
element_name=element_name,
entity_links=entity_links,
join_path=join_path,
)

@property
@override
def element_type(self) -> LinkableElementType:
Expand Down Expand Up @@ -221,15 +272,24 @@ def semantic_model_origin(self) -> SemanticModelReference:
class LinkableMetric(LinkableElement, SerializableDataclass):
"""Describes how a metric can be realized by joining based on entity links."""

properties: FrozenSet[LinkableElementProperty]
join_path: SemanticModelToMetricSubqueryJoinPath

@staticmethod
def create( # noqa: D102
properties: Iterable[LinkableElementProperty], join_path: SemanticModelToMetricSubqueryJoinPath
) -> LinkableMetric:
return LinkableMetric(
properties=tuple(sorted(set(properties))),
join_path=join_path,
)

def __post_init__(self) -> None:
"""Ensure expected LinkableElementProperties have been set.

LinkableMetrics always require a join to a metric subquery.
"""
assert {LinkableElementProperty.METRIC, LinkableElementProperty.JOINED}.issubset(self.properties)
super().__post_init__()
assert {LinkableElementProperty.METRIC, LinkableElementProperty.JOINED}.issubset(self.property_set)

@property
@override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,11 @@ def filter(
filtered_linkable_dimensions = tuple(
linkable_dimension
for linkable_dimension in linkable_dimensions
if len(linkable_dimension.properties.intersection(with_any_of)) > 0
and len(linkable_dimension.properties.intersection(without_any_of)) == 0
if len(linkable_dimension.property_set.intersection(with_any_of)) > 0
and len(linkable_dimension.property_set.intersection(without_any_of)) == 0
and (
len(without_all_of) == 0
or linkable_dimension.properties.intersection(without_all_of) != without_all_of
or linkable_dimension.property_set.intersection(without_all_of) != without_all_of
)
)
if len(filtered_linkable_dimensions) > 0:
Expand All @@ -251,11 +251,11 @@ def filter(
filtered_linkable_entities = tuple(
linkable_entity
for linkable_entity in linkable_entities
if len(linkable_entity.properties.intersection(with_any_of)) > 0
and len(linkable_entity.properties.intersection(without_any_of)) == 0
if len(linkable_entity.property_set.intersection(with_any_of)) > 0
and len(linkable_entity.property_set.intersection(without_any_of)) == 0
and (
len(without_all_of) == 0
or linkable_entity.properties.intersection(without_all_of) != without_all_of
or linkable_entity.property_set.intersection(without_all_of) != without_all_of
)
)
if len(filtered_linkable_entities) > 0:
Expand All @@ -265,11 +265,11 @@ def filter(
filtered_linkable_metrics = tuple(
linkable_metric
for linkable_metric in linkable_metrics
if len(linkable_metric.properties.intersection(with_any_of)) > 0
and len(linkable_metric.properties.intersection(without_any_of)) == 0
if len(linkable_metric.property_set.intersection(with_any_of)) > 0
and len(linkable_metric.property_set.intersection(without_any_of)) == 0
and (
len(without_all_of) == 0
or linkable_metric.properties.intersection(without_all_of) != without_all_of
or linkable_metric.property_set.intersection(without_all_of) != without_all_of
)
)
if len(filtered_linkable_metrics) > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,23 @@ def _generate_linkable_time_dimensions(
properties.add(LinkableElementProperty.DERIVED_TIME_GRANULARITY)

linkable_dimensions.append(
LinkableDimension(
LinkableDimension.create(
defined_in_semantic_model=semantic_model_origin,
element_name=dimension.reference.element_name,
dimension_type=DimensionType.TIME,
entity_links=entity_links,
join_path=join_path,
time_granularity=time_granularity,
date_part=None,
properties=frozenset(properties),
properties=tuple(sorted(properties)),
)
)

# Add the time dimension aggregated to a different date part.
for date_part in DatePart:
if time_granularity.to_int() <= date_part.to_int():
linkable_dimensions.append(
LinkableDimension(
LinkableDimension.create(
defined_in_semantic_model=semantic_model_origin,
element_name=dimension.reference.element_name,
dimension_type=DimensionType.TIME,
Expand Down Expand Up @@ -305,7 +305,7 @@ def get_joinable_metrics_for_semantic_model(
if join_path_has_path_links and entity_reference in using_join_path.entity_links:
continue
for metric_subquery_join_path_element in self._joinable_metrics_for_entities[entity_reference]:
linkable_metric = LinkableMetric(
linkable_metric = LinkableMetric.create(
properties=properties,
join_path=SemanticModelToMetricSubqueryJoinPath(
metric_subquery_join_path_element=metric_subquery_join_path_element,
Expand All @@ -328,7 +328,7 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link
linkable_entities = []
for entity in semantic_model.entities:
linkable_entities.append(
LinkableEntity(
LinkableEntity.create(
defined_in_semantic_model=semantic_model.reference,
element_name=entity.reference.element_name,
entity_links=(),
Expand All @@ -343,7 +343,7 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link
if entity_link == entity.reference:
continue
linkable_entities.append(
LinkableEntity(
LinkableEntity.create(
defined_in_semantic_model=semantic_model.reference,
element_name=entity.reference.element_name,
entity_links=(entity_link,),
Expand All @@ -360,7 +360,7 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link
dimension_type = dimension.type
if dimension_type is DimensionType.CATEGORICAL:
linkable_dimensions.append(
LinkableDimension(
LinkableDimension.create(
defined_in_semantic_model=semantic_model.reference,
element_name=dimension.reference.element_name,
dimension_type=DimensionType.CATEGORICAL,
Expand Down Expand Up @@ -492,7 +492,7 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference
date_part=date_part,
)
path_key_to_linkable_dimensions[path_key].append(
LinkableDimension(
LinkableDimension.create(
defined_in_semantic_model=measure_semantic_model.reference if measure_semantic_model else None,
element_name=MetricFlowReservedKeywords.METRIC_TIME.value,
dimension_type=DimensionType.TIME,
Expand Down Expand Up @@ -719,7 +719,7 @@ def create_linkable_element_set_from_join_path(
dimension_type = dimension.type
if dimension_type == DimensionType.CATEGORICAL:
linkable_dimensions.append(
LinkableDimension(
LinkableDimension.create(
defined_in_semantic_model=semantic_model.reference,
element_name=dimension.reference.element_name,
dimension_type=DimensionType.CATEGORICAL,
Expand Down Expand Up @@ -747,7 +747,7 @@ def create_linkable_element_set_from_join_path(
# Avoid creating "booking_id__booking_id"
if entity.reference != join_path.last_entity_link:
linkable_entities.append(
LinkableEntity(
LinkableEntity.create(
defined_in_semantic_model=semantic_model.reference,
element_name=entity.reference.element_name,
entity_links=join_path.entity_links,
Expand Down
Loading
Loading