Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LinkableElementUnion and use it in WhereFilterSpec #1337

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
from metricflow_semantics.model.linkable_element_property import LinkableElementProperty
from metricflow_semantics.model.semantic_model_derivation import SemanticModelDerivation
from metricflow_semantics.workarounds.reference import sorted_semantic_model_references
Expand Down Expand Up @@ -130,6 +131,12 @@ def semantic_model_origin(self) -> SemanticModelReference:
"""
raise NotImplementedError

@property
@abstractmethod
def as_union(self) -> LinkableElementUnion:
"""Return `self` in a union-type container for better serialization support."""
raise NotImplementedError


@dataclass(frozen=True)
class LinkableDimension(LinkableElement, SerializableDataclass):
Expand Down Expand Up @@ -213,6 +220,11 @@ def semantic_model_origin(self) -> SemanticModelReference:
else SemanticModelDerivation.VIRTUAL_SEMANTIC_MODEL_REFERENCE
)

@property
@override
def as_union(self) -> LinkableElementUnion:
return LinkableElementUnion(linkable_dimension=self)


@dataclass(frozen=True)
class LinkableEntity(LinkableElement, SerializableDataclass):
Expand Down Expand Up @@ -267,6 +279,11 @@ def derived_from_semantic_models(self) -> Sequence[SemanticModelReference]:
def semantic_model_origin(self) -> SemanticModelReference:
return self.defined_in_semantic_model

@property
@override
def as_union(self) -> LinkableElementUnion:
return LinkableElementUnion(linkable_entity=self)


@dataclass(frozen=True)
class LinkableMetric(LinkableElement, SerializableDataclass):
Expand Down Expand Up @@ -355,6 +372,38 @@ def metric_subquery_entity_links(self) -> Tuple[EntityReference, ...]:
"""
return self.join_path.metric_subquery_entity_links

@property
@override
def as_union(self) -> LinkableElementUnion:
return LinkableElementUnion(linkable_metric=self)


@dataclass(frozen=True)
class LinkableElementUnion(SerializableDataclass):
"""A union type to use in classes that require a concrete implementation for serialization."""

linkable_dimension: Optional[LinkableDimension] = None
linkable_entity: Optional[LinkableEntity] = None
linkable_metric: Optional[LinkableMetric] = None

def __post_init__(self) -> None: # noqa: D105
assert_exactly_one_arg_set(
linkable_dimension=self.linkable_dimension,
linkable_entity=self.linkable_entity,
linkable_metric=self.linkable_metric,
)

@property
def linkable_element(self) -> LinkableElement: # noqa: D102
if self.linkable_dimension is not None:
return self.linkable_dimension
elif self.linkable_entity is not None:
return self.linkable_entity
elif self.linkable_metric is not None:
return self.linkable_metric

assert False, "All fields are None - this should have been caught in object initialization."


@dataclass(frozen=True)
class SemanticModelJoinPath(SemanticModelDerivation, SerializableDataclass):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple
from typing import Sequence, Tuple

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from typing_extensions import override

from metricflow_semantics.collection_helpers.dedupe import ordered_dedupe
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.model.semantics.linkable_element import LinkableElement
from metricflow_semantics.model.semantics.linkable_element import LinkableElement, LinkableElementUnion
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
Expand Down Expand Up @@ -46,9 +46,13 @@ class WhereFilterSpec(Mergeable, SerializableDataclass):
# quoted identifiers later.
where_sql: str
bind_parameters: SqlBindParameters
linkable_elements: Tuple[LinkableElement, ...]
linkable_element_unions: Tuple[LinkableElementUnion, ...]
linkable_spec_set: LinkableSpecSet

@property
def linkable_elements(self) -> Sequence[LinkableElement]: # noqa: D102
return tuple(linkable_element_union.linkable_element for linkable_element_union in self.linkable_element_unions)

@property
def linkable_specs(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102
return self.linkable_spec_set.as_tuple
Expand All @@ -67,7 +71,7 @@ def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102
where_sql=f"({self.where_sql}) AND ({other.where_sql})",
bind_parameters=self.bind_parameters.combine(other.bind_parameters),
linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set).dedupe(),
linkable_elements=ordered_dedupe(self.linkable_elements, other.linkable_elements),
linkable_element_unions=ordered_dedupe(self.linkable_element_unions, other.linkable_element_unions),
)

@classmethod
Expand All @@ -81,5 +85,5 @@ def empty_instance(cls) -> WhereFilterSpec:
where_sql="TRUE",
bind_parameters=SqlBindParameters(),
linkable_spec_set=LinkableSpecSet(),
linkable_elements=(),
linkable_element_unions=(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def create_from_where_filter_intersection( # noqa: D102
where_sql=where_sql,
bind_parameters=SqlBindParameters(),
linkable_spec_set=LinkableSpecSet.create_from_specs(rendered_specs),
linkable_elements=linkable_elements,
linkable_element_unions=tuple(linkable_element.as_union for linkable_element in linkable_elements),
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_branch_state_propagation(branch_state_tracker: PredicatePushdownBranchS
WhereFilterSpec(
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
),
Expand Down Expand Up @@ -116,13 +116,13 @@ def test_applied_filter_back_propagation(branch_state_tracker: PredicatePushdown
where_spec_x_is_true = WhereFilterSpec(
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
)
where_spec_y_is_null = WhereFilterSpec(
where_sql="y is null",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
)

Expand Down
4 changes: 2 additions & 2 deletions tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_filter_with_where_constraint_node(
),
),
),
linkable_elements=(
linkable_element_unions=(
LinkableDimension.create(
defined_in_semantic_model=SemanticModelReference("bookings_source"),
element_name="ds",
Expand All @@ -213,7 +213,7 @@ def test_filter_with_where_constraint_node(
join_path=SemanticModelJoinPath(
left_semantic_model_reference=SemanticModelReference("bookings_source"),
),
),
).as_union,
),
),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,45 @@
<WhereConstraintNode>
<!-- description = 'Constrain Output with WHERE' -->
<!-- node_id = NodeId(id_str='wcc_0') -->
<!-- where_condition = -->
<!-- WhereFilterSpec( -->
<!-- where_sql="listing__country_latest = 'us'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_elements=( -->
<!-- LinkableDimension( -->
<!-- properties=(LOCAL,), -->
<!-- defined_in_semantic_model=SemanticModelReference( -->
<!-- semantic_model_name='listings_latest', -->
<!-- ), -->
<!-- element_name='country_latest', -->
<!-- dimension_type=CATEGORICAL, -->
<!-- entity_links=(EntityReference(element_name='listing'),), -->
<!-- join_path=SemanticModelJoinPath( -->
<!-- left_semantic_model_reference=SemanticModelReference( -->
<!-- semantic_model_name='listings_latest', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- dimension_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=( -->
<!-- EntityReference( -->
<!-- element_name='listing', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<!-- where_condition = -->
<!-- WhereFilterSpec( -->
<!-- where_sql="listing__country_latest = 'us'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_element_unions=( -->
<!-- LinkableElementUnion( -->
<!-- linkable_dimension=LinkableDimension( -->
<!-- properties=(LOCAL,), -->
<!-- defined_in_semantic_model=SemanticModelReference( -->
<!-- semantic_model_name='listings_latest', -->
<!-- ), -->
<!-- element_name='country_latest', -->
<!-- dimension_type=CATEGORICAL, -->
<!-- entity_links=( -->
<!-- EntityReference( -->
<!-- element_name='listing', -->
<!-- ), -->
<!-- ), -->
<!-- join_path=SemanticModelJoinPath( -->
<!-- left_semantic_model_reference=SemanticModelReference( -->
<!-- semantic_model_name='listings_latest', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- dimension_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=( -->
<!-- EntityReference( -->
<!-- element_name='listing', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<ReadSqlSourceNode>
<!-- description = "Read From SemanticModelDataSet('listings_latest')" -->
<!-- node_id = NodeId(id_str='rss_28018') -->
Expand Down
Loading
Loading