Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LinkableSpecSet in WhereFilterSpec #1335

Merged
merged 3 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.model.semantics.linkable_element import LinkableElement
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters


Expand Down Expand Up @@ -45,8 +46,12 @@ class WhereFilterSpec(Mergeable, SerializableDataclass):
# quoted identifiers later.
where_sql: str
bind_parameters: SqlBindParameters
linkable_specs: Tuple[LinkableInstanceSpec, ...]
linkable_elements: Tuple[LinkableElement, ...]
linkable_spec_set: LinkableSpecSet

@property
def linkable_specs(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102
return self.linkable_spec_set.as_tuple

def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102
if self == WhereFilterSpec.empty_instance():
Expand All @@ -61,7 +66,7 @@ def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102
return WhereFilterSpec(
where_sql=f"({self.where_sql}) AND ({other.where_sql})",
bind_parameters=self.bind_parameters.combine(other.bind_parameters),
linkable_specs=ordered_dedupe(self.linkable_specs, other.linkable_specs),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the ordering matter from the previous code? Just checking since it looks like it was removed in the new code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .dedupe() call is actually ordered, so it should be the same.

linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set).dedupe(),
linkable_elements=ordered_dedupe(self.linkable_elements, other.linkable_elements),
)

Expand All @@ -75,6 +80,6 @@ def empty_instance(cls) -> WhereFilterSpec:
return WhereFilterSpec(
where_sql="TRUE",
bind_parameters=SqlBindParameters(),
linkable_specs=(),
linkable_spec_set=LinkableSpecSet(),
linkable_elements=(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FilterSpecResolutionLookUp,
)
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.rendered_spec_tracker import RenderedSpecTracker
from metricflow_semantics.specs.where_filter.where_filter_dimension import WhereFilterDimensionFactory
from metricflow_semantics.specs.where_filter.where_filter_entity import WhereFilterEntityFactory
Expand Down Expand Up @@ -108,7 +109,7 @@ def create_from_where_filter_intersection( # noqa: D102
WhereFilterSpec(
where_sql=where_sql,
bind_parameters=SqlBindParameters(),
linkable_specs=rendered_specs,
linkable_spec_set=LinkableSpecSet.create_from_specs(rendered_specs),
linkable_elements=linkable_elements,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.query.query_parser import MetricFlowQueryParser
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
Expand Down Expand Up @@ -76,7 +77,10 @@ def test_branch_state_propagation(branch_state_tracker: PredicatePushdownBranchS
original_pushdown_state=base_state,
where_filter_specs=(
WhereFilterSpec(
where_sql="x is true", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=()
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_spec_set=LinkableSpecSet(),
),
),
)
Expand Down Expand Up @@ -110,10 +114,16 @@ def test_applied_filter_back_propagation(branch_state_tracker: PredicatePushdown
"""
base_state = branch_state_tracker.last_pushdown_state
where_spec_x_is_true = WhereFilterSpec(
where_sql="x is true", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=()
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_spec_set=LinkableSpecSet(),
)
where_spec_y_is_null = WhereFilterSpec(
where_sql="y is null", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=()
where_sql="y is null",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_spec_set=LinkableSpecSet(),
)

where_state = PredicatePushdownState.with_where_filter_specs(
Expand Down
13 changes: 8 additions & 5 deletions tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import LinklessEntitySpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.measure_spec import MeasureSpec, MetricInputMeasureSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
from metricflow_semantics.specs.non_additive_dimension_spec import NonAdditiveDimensionSpec
Expand Down Expand Up @@ -191,11 +192,13 @@ def test_filter_with_where_constraint_node(
WhereFilterSpec(
where_sql="booking__ds__day = '2020-01-01'",
bind_parameters=SqlBindParameters(),
linkable_specs=(
TimeDimensionSpec(
element_name="ds",
entity_links=(EntityReference(element_name="booking"),),
time_granularity=TimeGranularity.DAY,
linkable_spec_set=LinkableSpecSet(
time_dimension_specs=(
TimeDimensionSpec(
element_name="ds",
entity_links=(EntityReference(element_name="booking"),),
time_granularity=TimeGranularity.DAY,
),
),
),
linkable_elements=(
Expand Down