Skip to content

Commit

Permalink
/* PR_START p--query-resolution-perf 10 */ Add LinkableElementUnion
Browse files Browse the repository at this point in the history
… and use it in `WhereFilterSpec`.
  • Loading branch information
plypaul committed Jul 15, 2024
1 parent 24758ee commit dd4be3d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
from metricflow_semantics.model.linkable_element_property import LinkableElementProperty
from metricflow_semantics.model.semantic_model_derivation import SemanticModelDerivation
from metricflow_semantics.workarounds.reference import sorted_semantic_model_references
Expand Down Expand Up @@ -115,6 +116,12 @@ def semantic_model_origin(self) -> SemanticModelReference:
"""
raise NotImplementedError

@property
@abstractmethod
def as_union(self) -> LinkableElementUnion:
"""Return `self` in a union-type container for better serialization support."""
raise NotImplementedError


@dataclass(frozen=True)
class LinkableDimension(LinkableElement, SerializableDataclass):
Expand Down Expand Up @@ -177,6 +184,11 @@ def semantic_model_origin(self) -> SemanticModelReference:
else SemanticModelDerivation.VIRTUAL_SEMANTIC_MODEL_REFERENCE
)

@property
@override
def as_union(self) -> LinkableElementUnion:
return LinkableElementUnion(linkable_dimension=self)


@dataclass(frozen=True)
class LinkableEntity(LinkableElement, SerializableDataclass):
Expand Down Expand Up @@ -216,6 +228,11 @@ def derived_from_semantic_models(self) -> Sequence[SemanticModelReference]:
def semantic_model_origin(self) -> SemanticModelReference:
return self.defined_in_semantic_model

@property
@override
def as_union(self) -> LinkableElementUnion:
return LinkableElementUnion(linkable_entity=self)


# TODO: add to DSI
@dataclass(frozen=True)
Expand Down Expand Up @@ -306,6 +323,38 @@ def metric_subquery_entity_links(self) -> Tuple[EntityReference, ...]:
"""
return self.join_path.metric_subquery_entity_links

@property
@override
def as_union(self) -> LinkableElementUnion:
return LinkableElementUnion(linkable_metric=self)


@dataclass(frozen=True)
class LinkableElementUnion(SerializableDataclass):
"""A union type to use in classes that require a concrete implementation for serialization."""

linkable_dimension: Optional[LinkableDimension] = None
linkable_entity: Optional[LinkableEntity] = None
linkable_metric: Optional[LinkableMetric] = None

def __post_init__(self) -> None: # noqa: D105
assert_exactly_one_arg_set(
linkable_dimension=self.linkable_dimension,
linkable_entity=self.linkable_entity,
linkable_metric=self.linkable_metric,
)

@property
def linkable_element(self) -> LinkableElement: # noqa: D102
if self.linkable_dimension is not None:
return self.linkable_dimension
elif self.linkable_entity is not None:
return self.linkable_entity
elif self.linkable_metric is not None:
return self.linkable_metric

assert False, "All fields are None - this should have been caught in object initialization."


@dataclass(frozen=True)
class SemanticModelJoinPath(SemanticModelDerivation, SerializableDataclass):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple
from typing import Sequence, Tuple

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from typing_extensions import override

from metricflow_semantics.collection_helpers.dedupe import ordered_dedupe
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.model.semantics.linkable_element import LinkableElement
from metricflow_semantics.model.semantics.linkable_element import LinkableElement, LinkableElementUnion
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
Expand Down Expand Up @@ -46,9 +46,16 @@ class WhereFilterSpec(Mergeable, SerializableDataclass):
# quoted identifiers later.
where_sql: str
bind_parameters: SqlBindParameters
linkable_elements: Tuple[LinkableElement, ...]
linkable_element_unions: Tuple[LinkableElementUnion, ...]
linkable_spec_set: LinkableSpecSet

# def __init__(self, *args, **kwargs):
# pass

@property
def linkable_elements(self) -> Sequence[LinkableElement]: # noqa: D102
return tuple(linkable_element_union.linkable_element for linkable_element_union in self.linkable_element_unions)

@property
def linkable_specs(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102
return self.linkable_spec_set.as_tuple
Expand All @@ -67,7 +74,7 @@ def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102
where_sql=f"({self.where_sql}) AND ({other.where_sql})",
bind_parameters=self.bind_parameters.combine(other.bind_parameters),
linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set).dedupe(),
linkable_elements=ordered_dedupe(self.linkable_elements, other.linkable_elements),
linkable_element_unions=ordered_dedupe(self.linkable_element_unions, other.linkable_element_unions),
)

@classmethod
Expand All @@ -81,5 +88,5 @@ def empty_instance(cls) -> WhereFilterSpec:
where_sql="TRUE",
bind_parameters=SqlBindParameters(),
linkable_spec_set=LinkableSpecSet(),
linkable_elements=(),
linkable_element_unions=(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def create_from_where_filter_intersection( # noqa: D102
where_sql=where_sql,
bind_parameters=SqlBindParameters(),
linkable_spec_set=LinkableSpecSet.create_from_specs(rendered_specs),
linkable_elements=linkable_elements,
linkable_element_unions=tuple(linkable_element.as_union for linkable_element in linkable_elements),
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_branch_state_propagation(branch_state_tracker: PredicatePushdownBranchS
WhereFilterSpec(
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
),
Expand Down Expand Up @@ -116,13 +116,13 @@ def test_applied_filter_back_propagation(branch_state_tracker: PredicatePushdown
where_spec_x_is_true = WhereFilterSpec(
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
)
where_spec_y_is_null = WhereFilterSpec(
where_sql="y is null",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
)

Expand Down
4 changes: 2 additions & 2 deletions tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_filter_with_where_constraint_node(
),
),
),
linkable_elements=(
linkable_element_unions=(
LinkableDimension(
defined_in_semantic_model=SemanticModelReference("bookings_source"),
element_name="ds",
Expand All @@ -213,7 +213,7 @@ def test_filter_with_where_constraint_node(
join_path=SemanticModelJoinPath(
left_semantic_model_reference=SemanticModelReference("bookings_source"),
),
),
).as_union,
),
),
),
Expand Down

0 comments on commit dd4be3d

Please sign in to comment.