Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Incorrect Column Name Rendering For WhereConstraintNode #909

Merged
merged 7 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20231128-142315.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Fix Incorrect SQL Column Name Rendering for WhereConstraintNode
time: 2023-11-28T14:23:15.269611-08:00
custom:
Author: plypaul
Issue: "908"
10 changes: 7 additions & 3 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,12 @@ def visit_pass_elements_filter_node(self, node: FilterElementsNode) -> SqlDataSe

def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet:
"""Adds where clause to SQL statement from parent node."""
from_data_set: SqlDataSet = node.parent_node.accept(self)
output_instance_set = from_data_set.instance_set
parent_data_set: SqlDataSet = node.parent_node.accept(self)
# Since we're copying the instance set from the parent to conveniently generate the output instance set for this
# node, we'll need to change the column names.
output_instance_set = parent_data_set.instance_set.transform(
ChangeAssociatedColumns(self._column_association_resolver)
)
from_data_set_alias = self._next_unique_table_alias()

column_associations_in_where_sql: Sequence[ColumnAssociation] = CreateColumnAssociations(
Expand All @@ -771,7 +775,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet:
select_columns=output_instance_set.transform(
CreateSelectColumnsForInstances(from_data_set_alias, self._column_association_resolver)
).as_tuple(),
from_source=from_data_set.sql_select_node,
from_source=parent_data_set.sql_select_node,
from_source_alias=from_data_set_alias,
joins_descs=(),
group_bys=(),
Expand Down
29 changes: 28 additions & 1 deletion metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,34 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D


class ChangeAssociatedColumns(InstanceSetTransform[InstanceSet]):
"""Change the columns associated with instances to the one specified by the resolver."""
"""Change the columns associated with instances to the one specified by the resolver.

This is useful for conveniently generating output instances for a node that serve as a "pass-through". The output
instances can be a copy of the parent's instances, except that the column names need to be changed.

e.g. the parent may have a data set:
sql:
SELECT
is_lux AS is_lux_latext
...
instance:
DimensionInstance(column_name="is_lux", ...)

but for the current node, we want a data set like:

sql:
SELECT
is_lux_latest
...
FROM (
-- SQL from parent
is_lux AS is_lux_latest
...
)
...
instance:
DimensionInstance(column_name="is_lux_latest")
"""

def __init__(self, column_association_resolver: ColumnAssociationResolver) -> None: # noqa: D
self._column_association_resolver = column_association_resolver
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from dbt_semantic_interfaces.references import EntityReference

from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import DimensionSpec, MetricFlowQuerySpec
from metricflow.specs.where_filter_transform import WhereSpecFactory
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
from metricflow.test.plan_conversion.test_dataflow_to_sql_plan import convert_and_check


@pytest.mark.sql_engine_snapshot
def test_dimensions_requiring_join(
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
"""Tests querying 2 dimensions that require a join."""
dimension_specs = (
DimensionSpec(element_name="home_state_latest", entity_links=(EntityReference(element_name="user"),)),
DimensionSpec(element_name="is_lux_latest", entity_links=(EntityReference(element_name="listing"),)),
)
dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values(
query_spec=MetricFlowQuerySpec(dimension_specs=dimension_specs)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)


@pytest.mark.sql_engine_snapshot
def test_dimension_values_with_a_join_and_a_filter(
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
column_association_resolver: ColumnAssociationResolver,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
"""Tests querying 2 dimensions that require a join and a filter."""
dimension_specs = (
DimensionSpec(element_name="home_state_latest", entity_links=(EntityReference(element_name="user"),)),
DimensionSpec(element_name="is_lux_latest", entity_links=(EntityReference(element_name="listing"),)),
)
dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values(
query_spec=MetricFlowQuerySpec(
dimension_specs=dimension_specs,
where_constraint=(
WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(
PydanticWhereFilter(
where_sql_template="{{ Dimension('user__home_state_latest') }} = 'us'",
)
)
),
)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)
26 changes: 0 additions & 26 deletions metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,29 +1068,3 @@ def test_combine_output_node( # noqa: D
sql_client=sql_client,
node=combine_output_node,
)


@pytest.mark.sql_engine_snapshot
def test_dimensions_requiring_join(
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
) -> None:
"""Tests querying 2 dimensions that require a join."""
dimension_specs = (
DimensionSpec(element_name="home_state_latest", entity_links=(EntityReference(element_name="user"),)),
DimensionSpec(element_name="is_lux_latest", entity_links=(EntityReference(element_name="listing"),)),
)
dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values(
query_spec=MetricFlowQuerySpec(dimension_specs=dimension_specs)
)

convert_and_check(
request=request,
mf_test_session_state=mf_test_session_state,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_output_nodes[0].parent_node,
)
Loading
Loading