Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JoinToTimeSpineNode bug fix #1541

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20241121-073923.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Always treat metric_time and the agg_time_dimension the same in the JoinToTimeSpineNode.
time: 2024-11-21T07:39:23.698194-08:00
custom:
Author: courtneyholcomb
Issue: "1541"
6 changes: 6 additions & 0 deletions metricflow-semantics/metricflow_semantics/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ def with_entity_prefix(
spec=transformed_spec,
)

def with_new_defined_from(self, defined_from: Sequence[SemanticModelElementReference]) -> TimeDimensionInstance:
"""Returns a new instance with the defined_from field replaced."""
return TimeDimensionInstance(
associated_columns=self.associated_columns, defined_from=tuple(defined_from), spec=self.spec
)


@dataclass(frozen=True)
class EntityInstance(LinkableInstance[EntitySpec], SemanticModelElementInstance): # noqa: D101
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,16 @@ def with_entity_prefix(self, entity_prefix: EntityReference) -> TimeDimensionSpe
date_part=self.date_part,
aggregation_state=self.aggregation_state,
)

@staticmethod
def with_base_grains(time_dimension_specs: Sequence[TimeDimensionSpec]) -> Sequence[TimeDimensionSpec]:
"""Return the list of time dimension specs, replacing any custom grains with base grains.

Dedupes new specs, but preserves the initial order.
"""
base_grain_specs: List[TimeDimensionSpec] = []
for spec in time_dimension_specs:
base_grain_spec = spec.with_base_grain()
if base_grain_spec not in base_grain_specs:
base_grain_specs.append(base_grain_spec)
return base_grain_specs
31 changes: 22 additions & 9 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
MetricTimeWindow,
MetricType,
)
from dbt_semantic_interfaces.references import MetricReference, TimeDimensionReference
from dbt_semantic_interfaces.references import MeasureReference, MetricReference, TimeDimensionReference
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from dbt_semantic_interfaces.validations.unique_valid_name import MetricFlowReservedKeywords
from metricflow_semantics.dag.id_prefix import StaticIdPrefix
Expand Down Expand Up @@ -653,9 +653,6 @@ def _build_derived_metric_output_node(
queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_metric(
metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup
)
assert (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error is handled in the __post_init__ for JoinToTimeSpineNode, so I removed the duplication.

queried_agg_time_dimension_specs
), "Joining to time spine requires querying with metric_time or the appropriate agg_time_dimension."
output_node = JoinToTimeSpineNode.create(
parent_node=output_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
Expand Down Expand Up @@ -684,6 +681,19 @@ def _build_derived_metric_output_node(
)
return output_node

def _get_base_agg_time_dimensions(
self, queried_linkable_specs: LinkableSpecSet, measure_reference: MeasureReference
) -> Sequence[TimeDimensionSpec]:
"""Get queried agg_time_dimensios with their base grains, deduped.

Custom grains are joined right before measure aggregation and after all other pre-aggregation joins,
so only base grains are needed prior to that point.
"""
queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_measure(
measure_reference=measure_reference, semantic_model_lookup=self._semantic_model_lookup
)
return TimeDimensionSpec.with_base_grains(queried_agg_time_dimension_specs)

def _build_any_metric_output_node(self, parameter_set: BuildAnyMetricOutputNodeParameterSet) -> DataflowPlanNode:
"""Builds a node to compute a metric of any type."""
result = self._cache.get_build_any_metric_output_node_result(parameter_set)
Expand Down Expand Up @@ -1603,17 +1613,17 @@ def _build_aggregated_measure_from_measure_source_node(
f"Unable to join all items in request. Measure: {measure_spec.element_name}; Specs to join: {required_linkable_specs}"
)

queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_measure(
measure_reference=measure_spec.reference, semantic_model_lookup=self._semantic_model_lookup
base_agg_time_dimension_specs = self._get_base_agg_time_dimensions(
queried_linkable_specs=queried_linkable_specs, measure_reference=measure_spec.reference
)

# If a cumulative metric is queried with metric_time / agg_time_dimension, join over time range.
# Otherwise, the measure will be aggregated over all time.
unaggregated_measure_node: DataflowPlanNode = measure_recipe.source_node
if cumulative and queried_agg_time_dimension_specs:
if cumulative and base_agg_time_dimension_specs:
unaggregated_measure_node = JoinOverTimeRangeNode.create(
parent_node=unaggregated_measure_node,
queried_agg_time_dimension_specs=tuple(queried_agg_time_dimension_specs),
queried_agg_time_dimension_specs=tuple(base_agg_time_dimension_specs),
window=cumulative_window,
grain_to_date=cumulative_grain_to_date,
# Note: we use the original constraint here because the JoinOverTimeRangeNode will eventually get
Expand All @@ -1635,7 +1645,7 @@ def _build_aggregated_measure_from_measure_source_node(
# in join rendering
unaggregated_measure_node = JoinToTimeSpineNode.create(
parent_node=unaggregated_measure_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
requested_agg_time_dimension_specs=base_agg_time_dimension_specs,
use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
offset_window=before_aggregation_time_spine_join_description.offset_window,
Expand Down Expand Up @@ -1705,6 +1715,9 @@ def _build_aggregated_measure_from_measure_source_node(

# TODO: split this node into TimeSpineSourceNode and JoinToTimeSpineNode - then can use standard nodes here
# like JoinToCustomGranularityNode, WhereConstraintNode, etc.
queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_measure(
measure_reference=measure_spec.reference, semantic_model_lookup=self._semantic_model_lookup
)
output_node: DataflowPlanNode = JoinToTimeSpineNode.create(
parent_node=aggregate_measures_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
Expand Down
83 changes: 32 additions & 51 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import EntityReference, MetricModelReference, SemanticModelElementReference
from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.conversion_calculation_type import ConversionCalculationType
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation
Expand Down Expand Up @@ -471,11 +470,10 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
parent_data_set = node.parent_node.accept(self)
parent_data_set_alias = self._next_unique_table_alias()

# For the purposes of this node, use base grains. Custom grains will be joined later in the dataflow plan.
agg_time_dimension_specs = tuple({spec.with_base_grain() for spec in node.queried_agg_time_dimension_specs})

# Assemble time_spine dataset with a column for each agg_time_dimension requested.
agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions(agg_time_dimension_specs)
agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions(
node.queried_agg_time_dimension_specs
)
time_spine_data_set_alias = self._next_unique_table_alias()
time_spine_data_set = self._make_time_spine_data_set(
agg_time_dimension_instances=agg_time_dimension_instances, time_range_constraint=node.time_range_constraint
Expand All @@ -492,7 +490,7 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
# Build select columns, replacing agg_time_dimensions from the parent node with columns from the time spine.
table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set
table_alias_to_instance_set[parent_data_set_alias] = parent_data_set.instance_set.transform(
FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=agg_time_dimension_specs))
FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=node.queried_agg_time_dimension_specs))
)
select_columns = create_simple_select_columns_for_instance_sets(
column_resolver=self._column_association_resolver, table_alias_to_instance_set=table_alias_to_instance_set
Expand Down Expand Up @@ -1382,33 +1380,30 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
parent_data_set = node.parent_node.accept(self)
parent_alias = self._next_unique_table_alias()

if node.use_custom_agg_time_dimension:
agg_time_dimension = node.requested_agg_time_dimension_specs[0]
agg_time_element_name = agg_time_dimension.element_name
agg_time_entity_links: Tuple[EntityReference, ...] = agg_time_dimension.entity_links
else:
agg_time_element_name = METRIC_TIME_ELEMENT_NAME
agg_time_entity_links = ()
agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions(
node.requested_agg_time_dimension_specs
)

# Find the time dimension instances in the parent data set that match the one we want to join with.
agg_time_dimension_instances: List[TimeDimensionInstance] = []
for instance in parent_data_set.instance_set.time_dimension_instances:
if (
instance.spec.date_part is None # Ensure we don't join using an instance with date part
and instance.spec.element_name == agg_time_element_name
and instance.spec.entity_links == agg_time_entity_links
):
agg_time_dimension_instances.append(instance)
# Select the dimension for the join from the parent node because it may not have been included in the request.
# Default to using metric_time for the join if it was requested, otherwise use the agg_time_dimension.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the one place we still have some differentiation between metric_time and agg_time. If you request both, we have to choose one of them to use as the join column, and we default to using metric_time.

included_metric_time_instances = [
instance for instance in agg_time_dimension_instances if instance.spec.is_metric_time
]
if included_metric_time_instances:
join_on_time_dimension_sample = included_metric_time_instances[0].spec
else:
join_on_time_dimension_sample = agg_time_dimension_instances[0].spec

# Choose the instance with the smallest base granularity available.
agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int())
assert len(agg_time_dimension_instances) > 0, (
"Couldn't find requested agg_time_dimension in parent data set. The dataflow plan may have been "
"configured incorrectly."
agg_time_dimension_instance_for_join = self._choose_instance_for_time_spine_join(
[
instance
for instance in parent_data_set.instance_set.time_dimension_instances
if instance.spec.element_name == join_on_time_dimension_sample.element_name
and instance.spec.entity_links == join_on_time_dimension_sample.entity_links
]
)
agg_time_dimension_instance_for_join = agg_time_dimension_instances[0]

# Build time spine data set using the requested agg_time_dimension name.
# Build time spine data set with just the agg_time_dimension instance needed for the join.
time_spine_alias = self._next_unique_table_alias()
time_spine_dataset = self._make_time_spine_data_set(
agg_time_dimension_instances=(agg_time_dimension_instance_for_join,),
Expand All @@ -1432,24 +1427,14 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
time_dimensions_to_select_from_parent: Tuple[TimeDimensionInstance, ...] = ()
time_dimensions_to_select_from_time_spine: Tuple[TimeDimensionInstance, ...] = ()
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances:
if (
time_dimension_instance.spec.element_name == agg_time_element_name
and time_dimension_instance.spec.entity_links == agg_time_entity_links
):
if time_dimension_instance in agg_time_dimension_instances:
time_dimensions_to_select_from_time_spine += (time_dimension_instance,)
else:
time_dimensions_to_select_from_parent += (time_dimension_instance,)
parent_instance_set = InstanceSet(
measure_instances=parent_data_set.instance_set.measure_instances,
dimension_instances=parent_data_set.instance_set.dimension_instances,
time_dimension_instances=tuple(
time_dimension_instance
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances
if not (
time_dimension_instance.spec.element_name == agg_time_element_name
and time_dimension_instance.spec.entity_links == agg_time_entity_links
)
),
time_dimension_instances=time_dimensions_to_select_from_parent,
entity_instances=parent_data_set.instance_set.entity_instances,
metric_instances=parent_data_set.instance_set.metric_instances,
metadata_instances=parent_data_set.instance_set.metadata_instances,
Expand Down Expand Up @@ -1481,8 +1466,8 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
)

# Add requested granularities (if different from time_spine) and date_parts to time spine column.
for time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = time_dimension_instance.spec
for parent_time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = parent_time_dimension_instance.spec
if (
time_dimension_spec.time_granularity.base_granularity.to_int()
< original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int()
Expand Down Expand Up @@ -1521,13 +1506,9 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
# Apply date_part to time spine column select expression.
if time_dimension_spec.date_part:
select_expr = SqlExtractExpression.create(date_part=time_dimension_spec.date_part, arg=select_expr)
time_dim_spec = original_time_spine_dim_instance.spec.with_grain_and_date_part(
time_granularity=time_dimension_spec.time_granularity, date_part=time_dimension_spec.date_part
)
time_spine_dim_instance = TimeDimensionInstance(
defined_from=original_time_spine_dim_instance.defined_from,
associated_columns=(self._column_association_resolver.resolve_spec(time_dim_spec),),
spec=time_dim_spec,

time_spine_dim_instance = parent_time_dimension_instance.with_new_defined_from(
original_time_spine_dim_instance.defined_from
)
time_spine_dim_instances.append(time_spine_dim_instance)
time_spine_select_columns.append(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,34 @@ sql_engine: BigQuery
---
-- Re-aggregate Metric via Group By
SELECT
subq_11.metric_time__week
, subq_11.booking__ds__month
subq_11.booking__ds__month
, subq_11.metric_time__week
, subq_11.every_two_days_bookers_fill_nulls_with_0
FROM (
-- Window Function for Metric Re-aggregation
SELECT
subq_10.metric_time__week
, subq_10.booking__ds__month
subq_10.booking__ds__month
, subq_10.metric_time__week
, FIRST_VALUE(subq_10.every_two_days_bookers_fill_nulls_with_0) OVER (
PARTITION BY
subq_10.metric_time__week
, subq_10.booking__ds__month
subq_10.booking__ds__month
, subq_10.metric_time__week
ORDER BY subq_10.metric_time__day
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS every_two_days_bookers_fill_nulls_with_0
FROM (
-- Compute Metrics via Expressions
SELECT
subq_9.metric_time__day
subq_9.booking__ds__month
, subq_9.metric_time__day
, subq_9.metric_time__week
, subq_9.booking__ds__month
, COALESCE(subq_9.bookers, 0) AS every_two_days_bookers_fill_nulls_with_0
FROM (
-- Join to Time Spine Dataset
SELECT
subq_7.metric_time__day AS metric_time__day
DATETIME_TRUNC(subq_7.metric_time__day, month) AS booking__ds__month
, subq_7.metric_time__day AS metric_time__day
, DATETIME_TRUNC(subq_7.metric_time__day, isoweek) AS metric_time__week
, subq_6.booking__ds__month AS booking__ds__month
, subq_6.bookers AS bookers
FROM (
-- Time Spine
Expand Down Expand Up @@ -380,6 +380,6 @@ FROM (
) subq_10
) subq_11
GROUP BY
metric_time__week
, booking__ds__month
booking__ds__month
, metric_time__week
, every_two_days_bookers_fill_nulls_with_0
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,28 @@ sql_engine: BigQuery
---
-- Re-aggregate Metric via Group By
SELECT
metric_time__week
, booking__ds__month
booking__ds__month
, metric_time__week
, every_two_days_bookers_fill_nulls_with_0
FROM (
-- Compute Metrics via Expressions
-- Window Function for Metric Re-aggregation
SELECT
metric_time__week
, booking__ds__month
booking__ds__month
, metric_time__week
, FIRST_VALUE(COALESCE(bookers, 0)) OVER (
PARTITION BY
metric_time__week
, booking__ds__month
booking__ds__month
, metric_time__week
ORDER BY metric_time__day
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS every_two_days_bookers_fill_nulls_with_0
FROM (
-- Join to Time Spine Dataset
SELECT
subq_20.ds AS metric_time__day
DATETIME_TRUNC(subq_20.ds, month) AS booking__ds__month
, subq_20.ds AS metric_time__day
, DATETIME_TRUNC(subq_20.ds, isoweek) AS metric_time__week
, subq_18.booking__ds__month AS booking__ds__month
, subq_18.bookers AS bookers
FROM ***************************.mf_time_spine subq_20
LEFT OUTER JOIN (
Expand Down Expand Up @@ -60,6 +60,6 @@ FROM (
) subq_21
) subq_23
GROUP BY
metric_time__week
, booking__ds__month
booking__ds__month
, metric_time__week
, every_two_days_bookers_fill_nulls_with_0
Loading
Loading