Skip to content

Commit

Permalink
Add LinkableElementUnion and use it in WhereFilterSpec (#1337)
Browse files Browse the repository at this point in the history
As a part of making `WhereFilterSpec` serializable, this updates the
field `WhereFilterSpec.linkable_elements` to be a union type that is
supported by the serializer.
  • Loading branch information
plypaul authored Jul 17, 2024
1 parent d6c2dea commit 17a7fb3
Show file tree
Hide file tree
Showing 31 changed files with 3,023 additions and 2,622 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
from metricflow_semantics.model.linkable_element_property import LinkableElementProperty
from metricflow_semantics.model.semantic_model_derivation import SemanticModelDerivation
from metricflow_semantics.workarounds.reference import sorted_semantic_model_references
Expand Down Expand Up @@ -130,6 +131,12 @@ def semantic_model_origin(self) -> SemanticModelReference:
"""
raise NotImplementedError

@property
@abstractmethod
def as_union(self) -> LinkableElementUnion:
"""Return `self` in a union-type container for better serialization support."""
raise NotImplementedError


@dataclass(frozen=True)
class LinkableDimension(LinkableElement, SerializableDataclass):
Expand Down Expand Up @@ -213,6 +220,11 @@ def semantic_model_origin(self) -> SemanticModelReference:
else SemanticModelDerivation.VIRTUAL_SEMANTIC_MODEL_REFERENCE
)

@property
@override
def as_union(self) -> LinkableElementUnion:
return LinkableElementUnion(linkable_dimension=self)


@dataclass(frozen=True)
class LinkableEntity(LinkableElement, SerializableDataclass):
Expand Down Expand Up @@ -267,6 +279,11 @@ def derived_from_semantic_models(self) -> Sequence[SemanticModelReference]:
def semantic_model_origin(self) -> SemanticModelReference:
return self.defined_in_semantic_model

@property
@override
def as_union(self) -> LinkableElementUnion:
return LinkableElementUnion(linkable_entity=self)


@dataclass(frozen=True)
class LinkableMetric(LinkableElement, SerializableDataclass):
Expand Down Expand Up @@ -355,6 +372,38 @@ def metric_subquery_entity_links(self) -> Tuple[EntityReference, ...]:
"""
return self.join_path.metric_subquery_entity_links

@property
@override
def as_union(self) -> LinkableElementUnion:
return LinkableElementUnion(linkable_metric=self)


@dataclass(frozen=True)
class LinkableElementUnion(SerializableDataclass):
"""A union type to use in classes that require a concrete implementation for serialization."""

linkable_dimension: Optional[LinkableDimension] = None
linkable_entity: Optional[LinkableEntity] = None
linkable_metric: Optional[LinkableMetric] = None

def __post_init__(self) -> None: # noqa: D105
assert_exactly_one_arg_set(
linkable_dimension=self.linkable_dimension,
linkable_entity=self.linkable_entity,
linkable_metric=self.linkable_metric,
)

@property
def linkable_element(self) -> LinkableElement: # noqa: D102
if self.linkable_dimension is not None:
return self.linkable_dimension
elif self.linkable_entity is not None:
return self.linkable_entity
elif self.linkable_metric is not None:
return self.linkable_metric

assert False, "All fields are None - this should have been caught in object initialization."


@dataclass(frozen=True)
class SemanticModelJoinPath(SemanticModelDerivation, SerializableDataclass):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple
from typing import Sequence, Tuple

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from typing_extensions import override

from metricflow_semantics.collection_helpers.dedupe import ordered_dedupe
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.model.semantics.linkable_element import LinkableElement
from metricflow_semantics.model.semantics.linkable_element import LinkableElement, LinkableElementUnion
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
Expand Down Expand Up @@ -46,9 +46,13 @@ class WhereFilterSpec(Mergeable, SerializableDataclass):
# quoted identifiers later.
where_sql: str
bind_parameters: SqlBindParameters
linkable_elements: Tuple[LinkableElement, ...]
linkable_element_unions: Tuple[LinkableElementUnion, ...]
linkable_spec_set: LinkableSpecSet

@property
def linkable_elements(self) -> Sequence[LinkableElement]: # noqa: D102
return tuple(linkable_element_union.linkable_element for linkable_element_union in self.linkable_element_unions)

@property
def linkable_specs(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102
return self.linkable_spec_set.as_tuple
Expand All @@ -67,7 +71,7 @@ def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102
where_sql=f"({self.where_sql}) AND ({other.where_sql})",
bind_parameters=self.bind_parameters.combine(other.bind_parameters),
linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set).dedupe(),
linkable_elements=ordered_dedupe(self.linkable_elements, other.linkable_elements),
linkable_element_unions=ordered_dedupe(self.linkable_element_unions, other.linkable_element_unions),
)

@classmethod
Expand All @@ -81,5 +85,5 @@ def empty_instance(cls) -> WhereFilterSpec:
where_sql="TRUE",
bind_parameters=SqlBindParameters(),
linkable_spec_set=LinkableSpecSet(),
linkable_elements=(),
linkable_element_unions=(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def create_from_where_filter_intersection( # noqa: D102
where_sql=where_sql,
bind_parameters=SqlBindParameters(),
linkable_spec_set=LinkableSpecSet.create_from_specs(rendered_specs),
linkable_elements=linkable_elements,
linkable_element_unions=tuple(linkable_element.as_union for linkable_element in linkable_elements),
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_branch_state_propagation(branch_state_tracker: PredicatePushdownBranchS
WhereFilterSpec(
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
),
Expand Down Expand Up @@ -116,13 +116,13 @@ def test_applied_filter_back_propagation(branch_state_tracker: PredicatePushdown
where_spec_x_is_true = WhereFilterSpec(
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
)
where_spec_y_is_null = WhereFilterSpec(
where_sql="y is null",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
)

Expand Down
4 changes: 2 additions & 2 deletions tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_filter_with_where_constraint_node(
),
),
),
linkable_elements=(
linkable_element_unions=(
LinkableDimension.create(
defined_in_semantic_model=SemanticModelReference("bookings_source"),
element_name="ds",
Expand All @@ -213,7 +213,7 @@ def test_filter_with_where_constraint_node(
join_path=SemanticModelJoinPath(
left_semantic_model_reference=SemanticModelReference("bookings_source"),
),
),
).as_union,
),
),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,45 @@
<WhereConstraintNode>
<!-- description = 'Constrain Output with WHERE' -->
<!-- node_id = NodeId(id_str='wcc_0') -->
<!-- where_condition = -->
<!-- WhereFilterSpec( -->
<!-- where_sql="listing__country_latest = 'us'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_elements=( -->
<!-- LinkableDimension( -->
<!-- properties=(LOCAL,), -->
<!-- defined_in_semantic_model=SemanticModelReference( -->
<!-- semantic_model_name='listings_latest', -->
<!-- ), -->
<!-- element_name='country_latest', -->
<!-- dimension_type=CATEGORICAL, -->
<!-- entity_links=(EntityReference(element_name='listing'),), -->
<!-- join_path=SemanticModelJoinPath( -->
<!-- left_semantic_model_reference=SemanticModelReference( -->
<!-- semantic_model_name='listings_latest', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- dimension_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=( -->
<!-- EntityReference( -->
<!-- element_name='listing', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<!-- where_condition = -->
<!-- WhereFilterSpec( -->
<!-- where_sql="listing__country_latest = 'us'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_element_unions=( -->
<!-- LinkableElementUnion( -->
<!-- linkable_dimension=LinkableDimension( -->
<!-- properties=(LOCAL,), -->
<!-- defined_in_semantic_model=SemanticModelReference( -->
<!-- semantic_model_name='listings_latest', -->
<!-- ), -->
<!-- element_name='country_latest', -->
<!-- dimension_type=CATEGORICAL, -->
<!-- entity_links=( -->
<!-- EntityReference( -->
<!-- element_name='listing', -->
<!-- ), -->
<!-- ), -->
<!-- join_path=SemanticModelJoinPath( -->
<!-- left_semantic_model_reference=SemanticModelReference( -->
<!-- semantic_model_name='listings_latest', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- dimension_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=( -->
<!-- EntityReference( -->
<!-- element_name='listing', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<ReadSqlSourceNode>
<!-- description = "Read From SemanticModelDataSet('listings_latest')" -->
<!-- node_id = NodeId(id_str='rss_28018') -->
Expand Down
Loading

0 comments on commit 17a7fb3

Please sign in to comment.