Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LinkableSpecSet in WhereFilterSpec #1335

Merged
merged 3 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __post_init__(self) -> None:


@dataclass(frozen=True)
class SemanticModelJoinPathElement:
class SemanticModelJoinPathElement(SerializableDataclass):
"""Describes joining a semantic model by the given entity."""

semantic_model_reference: SemanticModelReference
Expand Down Expand Up @@ -308,7 +308,7 @@ def metric_subquery_entity_links(self) -> Tuple[EntityReference, ...]:


@dataclass(frozen=True)
class SemanticModelJoinPath(SemanticModelDerivation):
class SemanticModelJoinPath(SemanticModelDerivation, SerializableDataclass):
"""Describes a series of joins between the measure semantic model, and other semantic models by entity.

For example:
Expand Down Expand Up @@ -368,14 +368,14 @@ def derived_from_semantic_models(self) -> Sequence[SemanticModelReference]:


@dataclass(frozen=True)
class MetricSubqueryJoinPathElement:
class MetricSubqueryJoinPathElement(SerializableDataclass):
"""Describes joining from a semantic model to a metric subquery.

Args:
metric_reference: The metric that's aggregated in the subquery.
derived_from_semantic_models: The semantic models that the measure's input metrics are defined in.
join_on_entity: The entity that the metric is grouped by in the subquery. This will be updated in V2 to allow a list
of entitites & dimensions.
of entities & dimensions.
entity_links: Sequence of entities joined to get from a metric source to the `join_on_entity`. Should not include
the `join_on_entity`.
metric_to_entity_join_path: Describes the join path used in the subquery to join the metric to the `join_on_entity`.
Expand All @@ -395,7 +395,7 @@ def __post_init__(self) -> None: # noqa: D105


@dataclass(frozen=True)
class SemanticModelToMetricSubqueryJoinPath:
class SemanticModelToMetricSubqueryJoinPath(SerializableDataclass):
"""Describes how to join from a semantic model to a metric subquery.

Starts with semantic model join path, if needed. Always ends with metric subquery join path.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.model.semantics.linkable_element import LinkableElement
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters


Expand Down Expand Up @@ -45,8 +46,12 @@ class WhereFilterSpec(Mergeable, SerializableDataclass):
# quoted identifiers later.
where_sql: str
bind_parameters: SqlBindParameters
linkable_specs: Tuple[LinkableInstanceSpec, ...]
linkable_elements: Tuple[LinkableElement, ...]
linkable_spec_set: LinkableSpecSet

@property
def linkable_specs(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102
return self.linkable_spec_set.as_tuple

def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102
if self == WhereFilterSpec.empty_instance():
Expand All @@ -61,7 +66,7 @@ def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102
return WhereFilterSpec(
where_sql=f"({self.where_sql}) AND ({other.where_sql})",
bind_parameters=self.bind_parameters.combine(other.bind_parameters),
linkable_specs=ordered_dedupe(self.linkable_specs, other.linkable_specs),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the ordering matter from the previous code? Just checking since it looks like it was removed in the new code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .dedupe() call is actually ordered, so it should be the same.

linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set).dedupe(),
linkable_elements=ordered_dedupe(self.linkable_elements, other.linkable_elements),
)

Expand All @@ -75,6 +80,6 @@ def empty_instance(cls) -> WhereFilterSpec:
return WhereFilterSpec(
where_sql="TRUE",
bind_parameters=SqlBindParameters(),
linkable_specs=(),
linkable_spec_set=LinkableSpecSet(),
linkable_elements=(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FilterSpecResolutionLookUp,
)
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.rendered_spec_tracker import RenderedSpecTracker
from metricflow_semantics.specs.where_filter.where_filter_dimension import WhereFilterDimensionFactory
from metricflow_semantics.specs.where_filter.where_filter_entity import WhereFilterEntityFactory
Expand Down Expand Up @@ -108,7 +109,7 @@ def create_from_where_filter_intersection( # noqa: D102
WhereFilterSpec(
where_sql=where_sql,
bind_parameters=SqlBindParameters(),
linkable_specs=rendered_specs,
linkable_spec_set=LinkableSpecSet.create_from_specs(rendered_specs),
linkable_elements=linkable_elements,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.query.query_parser import MetricFlowQueryParser
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
Expand Down Expand Up @@ -76,7 +77,10 @@ def test_branch_state_propagation(branch_state_tracker: PredicatePushdownBranchS
original_pushdown_state=base_state,
where_filter_specs=(
WhereFilterSpec(
where_sql="x is true", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=()
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_spec_set=LinkableSpecSet(),
),
),
)
Expand Down Expand Up @@ -110,10 +114,16 @@ def test_applied_filter_back_propagation(branch_state_tracker: PredicatePushdown
"""
base_state = branch_state_tracker.last_pushdown_state
where_spec_x_is_true = WhereFilterSpec(
where_sql="x is true", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=()
where_sql="x is true",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_spec_set=LinkableSpecSet(),
)
where_spec_y_is_null = WhereFilterSpec(
where_sql="y is null", bind_parameters=SqlBindParameters(), linkable_elements=(), linkable_specs=()
where_sql="y is null",
bind_parameters=SqlBindParameters(),
linkable_elements=(),
linkable_spec_set=LinkableSpecSet(),
)

where_state = PredicatePushdownState.with_where_filter_specs(
Expand Down
13 changes: 8 additions & 5 deletions tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import LinklessEntitySpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.measure_spec import MeasureSpec, MetricInputMeasureSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
from metricflow_semantics.specs.non_additive_dimension_spec import NonAdditiveDimensionSpec
Expand Down Expand Up @@ -191,11 +192,13 @@ def test_filter_with_where_constraint_node(
WhereFilterSpec(
where_sql="booking__ds__day = '2020-01-01'",
bind_parameters=SqlBindParameters(),
linkable_specs=(
TimeDimensionSpec(
element_name="ds",
entity_links=(EntityReference(element_name="booking"),),
time_granularity=TimeGranularity.DAY,
linkable_spec_set=LinkableSpecSet(
time_dimension_specs=(
TimeDimensionSpec(
element_name="ds",
entity_links=(EntityReference(element_name="booking"),),
time_granularity=TimeGranularity.DAY,
),
),
),
linkable_elements=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@
<!-- WhereFilterSpec( -->
<!-- where_sql="listing__country_latest = 'us'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=(EntityReference(element_name='listing'),), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_elements=( -->
<!-- LinkableDimension( -->
<!-- defined_in_semantic_model=SemanticModelReference( -->
Expand All @@ -52,6 +46,18 @@
<!-- properties=frozenset('LOCAL',), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- dimension_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=( -->
<!-- EntityReference( -->
<!-- element_name='listing', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<ReadSqlSourceNode>
<!-- description = "Read From SemanticModelDataSet('listings_latest')" -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@
<!-- WhereFilterSpec( -->
<!-- where_sql="listing__country_latest = 'us'", -->
<!-- bind_parameters=SqlBindParameters(), -->
<!-- linkable_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=(EntityReference(element_name='listing'),), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_elements=( -->
<!-- LinkableDimension( -->
<!-- defined_in_semantic_model=SemanticModelReference( -->
Expand All @@ -57,6 +51,18 @@
<!-- properties=frozenset('LOCAL',), -->
<!-- ), -->
<!-- ), -->
<!-- linkable_spec_set=LinkableSpecSet( -->
<!-- dimension_specs=( -->
<!-- DimensionSpec( -->
<!-- element_name='country_latest', -->
<!-- entity_links=( -->
<!-- EntityReference( -->
<!-- element_name='listing', -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ), -->
<!-- ) -->
<JoinOnEntitiesNode>
<!-- description = 'Join Standard Outputs' -->
Expand Down
Loading
Loading