Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify dataflow to SQL logic for JoinOverTimeRangeNode #1540

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions metricflow/dataflow/nodes/join_over_time.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.protocols import MetricTimeWindow
from dbt_semantic_interfaces.type_enums import TimeGranularity
Expand All @@ -26,7 +26,7 @@ class JoinOverTimeRangeNode(DataflowPlanNode):
time_range_constraint: Time range to aggregate over.
"""

queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec]
queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...]
window: Optional[MetricTimeWindow]
grain_to_date: Optional[TimeGranularity]
time_range_constraint: Optional[TimeRangeConstraint]
Expand All @@ -38,7 +38,7 @@ def __post_init__(self) -> None: # noqa: D105
@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode,
queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec],
queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...],
window: Optional[MetricTimeWindow] = None,
grain_to_date: Optional[TimeGranularity] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
Expand Down
23 changes: 23 additions & 0 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import SemanticModelReference
Expand Down Expand Up @@ -160,3 +161,25 @@ def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensi
@override
def semantic_model_reference(self) -> Optional[SemanticModelReference]:
return None

def annotate(self, alias: str, metric_time_spec: TimeDimensionSpec) -> AnnotatedSqlDataSet:
"""Convert to an AnnotatedSqlDataSet with specified metadata."""
metric_time_column_name = self.column_association_for_time_dimension(metric_time_spec).column_name
return AnnotatedSqlDataSet(data_set=self, alias=alias, _metric_time_column_name=metric_time_column_name)


@dataclass(frozen=True)
class AnnotatedSqlDataSet:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class should be unchanged. I needed to move it to resolve circular imports.

"""Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan."""

data_set: SqlDataSet
alias: str
_metric_time_column_name: Optional[str] = None

@property
def metric_time_column_name(self) -> str:
"""Direct accessor for the optional metric time name, only safe to call when we know that value is set."""
assert (
self._metric_time_column_name
), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!"
return self._metric_time_column_name
84 changes: 35 additions & 49 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,70 +468,41 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet:
def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet:
"""Generate time range join SQL."""
table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict()
input_data_set = node.parent_node.accept(self)
input_data_set_alias = self._next_unique_table_alias()
parent_data_set = node.parent_node.accept(self)
parent_data_set_alias = self._next_unique_table_alias()

# Find requested agg_time_dimensions in parent instance set.
# Will use instance with the smallest base granularity in time spine join.
agg_time_dimension_instance_for_join: Optional[TimeDimensionInstance] = None
requested_agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...] = ()
for instance in input_data_set.instance_set.time_dimension_instances:
if instance.spec in node.queried_agg_time_dimension_specs:
requested_agg_time_dimension_instances += (instance,)
if not agg_time_dimension_instance_for_join or (
instance.spec.time_granularity.base_granularity.to_int()
< agg_time_dimension_instance_for_join.spec.time_granularity.base_granularity.to_int()
):
agg_time_dimension_instance_for_join = instance
assert (
agg_time_dimension_instance_for_join
), "Specified metric time spec not found in parent data set. This should have been caught by validations."
# For the purposes of this node, use base grains. Custom grains will be joined later in the dataflow plan.
agg_time_dimension_specs = tuple({spec.with_base_grain() for spec in node.queried_agg_time_dimension_specs})

# Assemble time_spine dataset with a column for each agg_time_dimension requested.
agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions(agg_time_dimension_specs)
time_spine_data_set_alias = self._next_unique_table_alias()

# Assemble time_spine dataset with requested agg time dimension instances selected.
time_spine_data_set = self._make_time_spine_data_set(
agg_time_dimension_instances=requested_agg_time_dimension_instances,
time_range_constraint=node.time_range_constraint,
agg_time_dimension_instances=agg_time_dimension_instances, time_range_constraint=node.time_range_constraint
)
table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set

# Build the join description.
join_spec = self._choose_instance_for_time_spine_join(agg_time_dimension_instances).spec
annotated_parent = parent_data_set.annotate(alias=parent_data_set_alias, metric_time_spec=join_spec)
annotated_time_spine = time_spine_data_set.annotate(alias=time_spine_data_set_alias, metric_time_spec=join_spec)
join_desc = SqlQueryPlanJoinBuilder.make_cumulative_metric_time_range_join_description(
node=node,
metric_data_set=AnnotatedSqlDataSet(
data_set=input_data_set,
alias=input_data_set_alias,
_metric_time_column_name=input_data_set.column_association_for_time_dimension(
agg_time_dimension_instance_for_join.spec
).column_name,
),
time_spine_data_set=AnnotatedSqlDataSet(
data_set=time_spine_data_set,
alias=time_spine_data_set_alias,
_metric_time_column_name=time_spine_data_set.column_association_for_time_dimension(
agg_time_dimension_instance_for_join.spec
).column_name,
),
node=node, metric_data_set=annotated_parent, time_spine_data_set=annotated_time_spine
)

# Remove instances of agg_time_dimension from input data set. They'll be replaced with time spine instances.
agg_time_dimension_specs = tuple(dim.spec for dim in requested_agg_time_dimension_instances)
modified_input_instance_set = input_data_set.instance_set.transform(
# Build select columns, replacing agg_time_dimensions from the parent node with columns from the time spine.
table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set
table_alias_to_instance_set[parent_data_set_alias] = parent_data_set.instance_set.transform(
FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=agg_time_dimension_specs))
)
table_alias_to_instance_set[input_data_set_alias] = modified_input_instance_set

# The output instances are the same as the input instances.
output_instance_set = ChangeAssociatedColumns(self._column_association_resolver).transform(
input_data_set.instance_set
select_columns = create_simple_select_columns_for_instance_sets(
column_resolver=self._column_association_resolver, table_alias_to_instance_set=table_alias_to_instance_set
)

return SqlDataSet(
instance_set=output_instance_set,
instance_set=parent_data_set.instance_set, # The output instances are the same as the input instances.
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=create_simple_select_columns_for_instance_sets(
self._column_association_resolver, table_alias_to_instance_set
),
select_columns=select_columns,
from_source=time_spine_data_set.checked_sql_select_node,
from_source_alias=time_spine_data_set_alias,
join_descs=(join_desc,),
Expand Down Expand Up @@ -1392,6 +1363,21 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe
),
)

def _choose_instance_for_time_spine_join(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper function is only used once at this point, but it will be used again for the JoinToTimeSpineNode farther up the stack.

self, agg_time_dimension_instances: Sequence[TimeDimensionInstance]
) -> TimeDimensionInstance:
"""Find the agg_time_dimension instance with the smallest grain to use for the time spine join."""
# We can't use a date part spec to join to the time spine, so filter those out.
agg_time_dimension_instances = [
instance for instance in agg_time_dimension_instances if not instance.spec.date_part
]
assert len(agg_time_dimension_instances) > 0, (
"No appropriate agg_time_dimension was found to join to the time spine. "
"This indicates that the dataflow plan was configured incorrectly."
)
agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but we probably should just make the TimeGranularity enum orderable in DSI at some point since this operation has come up a few times.

return agg_time_dimension_instances[0]

def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102
parent_data_set = node.parent_node.accept(self)
parent_alias = self._next_unique_table_alias()
Expand Down
1 change: 1 addition & 0 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102
)


# TODO: delete this class & all uses. It doesn't do anything.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is supposed to change the column names, but if the specs didn't change, then the column names shouldn't either, so it seems like it doesn't do anything. LMK if I'm overlooking something here!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This used to be needed but maybe things have changed. There were cases where nodes did not output data sets where the column name in the generated SQL did not match the defined format. Trying to remember what that was though, but there was a bug fix that required this transform.

class ChangeAssociatedColumns(InstanceSetTransform[InstanceSet]):
"""Change the columns associated with instances to the one specified by the resolver.
Expand Down
19 changes: 1 addition & 18 deletions metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.dataset.sql_dataset import AnnotatedSqlDataSet
from metricflow.plan_conversion.sql_expression_builders import make_coalesced_expr
from metricflow.sql.sql_exprs import (
SqlColumnReference,
Expand Down Expand Up @@ -46,23 +46,6 @@ class ColumnEqualityDescription:
treat_nulls_as_equal: bool = False


@dataclass(frozen=True)
class AnnotatedSqlDataSet:
"""Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan."""

data_set: SqlDataSet
alias: str
_metric_time_column_name: Optional[str] = None

@property
def metric_time_column_name(self) -> str:
"""Direct accessor for the optional metric time name, only safe to call when we know that value is set."""
assert (
self._metric_time_column_name
), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!"
return self._metric_time_column_name


class SqlQueryPlanJoinBuilder:
"""Helper class for constructing various join components in a SqlQueryPlan."""

Expand Down
Loading