Skip to content

Commit

Permalink
Use LinkableSpecSet in WhereFilterSpec (#1335)
Browse files Browse the repository at this point in the history
In `WhereFilterSpec`, the field `linkable_specs` can't be used with
`SerializableDataclass` since it's an interface. This PR changes that
field to a `LinkableSpecSet`.
  • Loading branch information
plypaul authored Jul 17, 2024
1 parent abadb80 commit 5bc927b
Show file tree
Hide file tree
Showing 33 changed files with 1,497 additions and 1,033 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __post_init__(self) -> None:


@dataclass(frozen=True)
class SemanticModelJoinPathElement:
class SemanticModelJoinPathElement(SerializableDataclass):
"""Describes joining a semantic model by the given entity."""

semantic_model_reference: SemanticModelReference
Expand Down Expand Up @@ -297,7 +297,7 @@ def metric_subquery_entity_links(self) -> Tuple[EntityReference, ...]:


@dataclass(frozen=True)
class SemanticModelJoinPath(SemanticModelDerivation):
class SemanticModelJoinPath(SemanticModelDerivation, SerializableDataclass):
"""Describes a series of joins between the measure semantic model, and other semantic models by entity.
For example:
Expand Down Expand Up @@ -357,14 +357,14 @@ def derived_from_semantic_models(self) -> Sequence[SemanticModelReference]:


@dataclass(frozen=True)
class MetricSubqueryJoinPathElement:
class MetricSubqueryJoinPathElement(SerializableDataclass):
"""Describes joining from a semantic model to a metric subquery.
Args:
metric_reference: The metric that's aggregated in the subquery.
derived_from_semantic_models: The semantic models that the measure's input metrics are defined in.
join_on_entity: The entity that the metric is grouped by in the subquery. This will be updated in V2 to allow a list
of entitites & dimensions.
of entities & dimensions.
entity_links: Sequence of entities joined to get from a metric source to the `join_on_entity`. Should not include
the `join_on_entity`.
metric_to_entity_join_path: Describes the join path used in the subquery to join the metric to the `join_on_entity`.
Expand All @@ -384,7 +384,7 @@ def __post_init__(self) -> None: # noqa: D105


@dataclass(frozen=True)
class SemanticModelToMetricSubqueryJoinPath:
class SemanticModelToMetricSubqueryJoinPath(SerializableDataclass):
"""Describes how to join from a semantic model to a metric subquery.
Starts with semantic model join path, if needed. Always ends with metric subquery join path.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.model.semantics.linkable_element import LinkableElement
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters


Expand Down Expand Up @@ -45,8 +46,12 @@ class WhereFilterSpec(Mergeable, SerializableDataclass):
# quoted identifiers later.
where_sql: str
bind_parameters: SqlBindParameters
linkable_specs: Tuple[LinkableInstanceSpec, ...]
linkable_elements: Tuple[LinkableElement, ...]
linkable_spec_set: LinkableSpecSet

@property
def linkable_specs(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102
return self.linkable_spec_set.as_tuple

def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102
if self == WhereFilterSpec.empty_instance():
Expand All @@ -61,7 +66,7 @@ def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102
return WhereFilterSpec(
where_sql=f"({self.where_sql}) AND ({other.where_sql})",
bind_parameters=self.bind_parameters.combine(other.bind_parameters),
linkable_specs=ordered_dedupe(self.linkable_specs, other.linkable_specs),
linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set).dedupe(),
linkable_elements=ordered_dedupe(self.linkable_elements, other.linkable_elements),
)

Expand All @@ -75,6 +80,6 @@ def empty_instance(cls) -> WhereFilterSpec:
return WhereFilterSpec(
where_sql="TRUE",
bind_parameters=SqlBindParameters(),
linkable_specs=(),
linkable_spec_set=LinkableSpecSet(),
linkable_elements=(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FilterSpecResolutionLookUp,
)
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.rendered_spec_tracker import RenderedSpecTracker
from metricflow_semantics.specs.where_filter.where_filter_dimension import WhereFilterDimensionFactory
from metricflow_semantics.specs.where_filter.where_filter_entity import WhereFilterEntityFactory
Expand Down Expand Up @@ -108,7 +109,7 @@ def create_from_where_filter_intersection( # noqa: D102
WhereFilterSpec(
where_sql=where_sql,
bind_parameters=SqlBindParameters(),
linkable_specs=rendered_specs,
linkable_spec_set=LinkableSpecSet.create_from_specs(rendered_specs),
linkable_elements=linkable_elements,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.query.query_parser import MetricFlowQueryParser
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
Expand Down Expand Up @@ -76,7 +77,10 @@ def test_branch_state_propagation(branch_state_tracker: PredicatePushdownBranchS
original_pushdown_state=base_state,
where_filter_specs=(
WhereFilterSpec(
where_sql="x is true", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=()
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_spec_set=LinkableSpecSet(),
),
),
)
Expand Down Expand Up @@ -110,10 +114,16 @@ def test_applied_filter_back_propagation(branch_state_tracker: PredicatePushdown
"""
base_state = branch_state_tracker.last_pushdown_state
where_spec_x_is_true = WhereFilterSpec(
where_sql="x is true", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=()
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_spec_set=LinkableSpecSet(),
)
where_spec_y_is_null = WhereFilterSpec(
where_sql="y is null", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=()
where_sql="y is null",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_spec_set=LinkableSpecSet(),
)

where_state = PredicatePushdownState.with_where_filter_specs(
Expand Down
13 changes: 8 additions & 5 deletions tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import LinklessEntitySpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.measure_spec import MeasureSpec, MetricInputMeasureSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
from metricflow_semantics.specs.non_additive_dimension_spec import NonAdditiveDimensionSpec
Expand Down Expand Up @@ -191,11 +192,13 @@ def test_filter_with_where_constraint_node(
WhereFilterSpec(
where_sql="booking__ds__day = '2020-01-01'",
bind_parameters=SqlBindParameters(),
linkable_specs=(
TimeDimensionSpec(
element_name="ds",
entity_links=(EntityReference(element_name="booking"),),
time_granularity=TimeGranularity.DAY,
linkable_spec_set=LinkableSpecSet(
time_dimension_specs=(
TimeDimensionSpec(
element_name="ds",
entity_links=(EntityReference(element_name="booking"),),
time_granularity=TimeGranularity.DAY,
),
),
),
linkable_elements=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@
<!-- WhereFilterSpec( -->
<!-- where_sql="listing__country_latest = 'us'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=(EntityReference(element_name='listing'),), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_elements=( -->
<!-- LinkableDimension( -->
<!-- defined_in_semantic_model=SemanticModelReference( -->
Expand All @@ -52,6 +46,18 @@
<!-- properties=frozenset('LOCAL',), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- dimension_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=( -->
<!-- EntityReference( -->
<!-- element_name='listing', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<ReadSqlSourceNode>
<!-- description = "Read From SemanticModelDataSet('listings_latest')" -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@
<!-- WhereFilterSpec( -->
<!-- where_sql="listing__country_latest = 'us'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=(EntityReference(element_name='listing'),), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_elements=( -->
<!-- LinkableDimension( -->
<!-- defined_in_semantic_model=SemanticModelReference( -->
Expand All @@ -57,6 +51,18 @@
<!-- properties=frozenset('LOCAL',), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- dimension_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=( -->
<!-- EntityReference( -->
<!-- element_name='listing', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<JoinOnEntitiesNode>
<!-- description = 'Join Standard Outputs' -->
Expand Down
Loading

0 comments on commit 5bc927b

Please sign in to comment.