Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 8, 2024
1 parent 35fef1a commit 4ffe95b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 128 deletions.
9 changes: 7 additions & 2 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,9 @@ def _build_derived_metric_output_node(
), "Joining to time spine requires querying with metric_time or the appropriate agg_time_dimension."
output_node = JoinToTimeSpineNode.create(
parent_node=output_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
requested_agg_time_dimension_specs=[
spec.with_base_grain() for spec in queried_agg_time_dimension_specs
],
use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
offset_window=metric_spec.offset_window,
Expand Down Expand Up @@ -1618,7 +1620,9 @@ def _build_aggregated_measure_from_measure_source_node(
# in join rendering
unaggregated_measure_node = JoinToTimeSpineNode.create(
parent_node=unaggregated_measure_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
requested_agg_time_dimension_specs=[
spec.with_base_grain() for spec in queried_agg_time_dimension_specs
],
use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
offset_window=before_aggregation_time_spine_join_description.offset_window,
Expand Down Expand Up @@ -1688,6 +1692,7 @@ def _build_aggregated_measure_from_measure_source_node(

# TODO: split this node into TimeSpineSourceNode and JoinToTimeSpineNode - then can use standard nodes here
# like JoinToCustomGranularityNode, WhereConstraintNode, etc.
print("queried_agg_time_dimension_specs:", queried_agg_time_dimension_specs)
output_node: DataflowPlanNode = JoinToTimeSpineNode.create(
parent_node=aggregate_measures_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
Expand Down
189 changes: 63 additions & 126 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import datetime as dt
import logging
from collections import OrderedDict
from typing import List, Optional, Sequence, Set, Tuple, Union
from typing import List, Optional, Sequence, Set, Tuple

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import EntityReference, MetricModelReference, SemanticModelElementReference
from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.conversion_calculation_type import ConversionCalculationType
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation
Expand Down Expand Up @@ -278,29 +277,36 @@ def _make_time_spine_data_set(
)
select_columns: Tuple[SqlSelectColumn, ...] = ()
apply_group_by = True
for agg_time_dimension_spec in required_time_spine_specs:
column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_spec).column_name
for time_spine_spec in required_time_spine_specs:
if time_spine_spec.time_granularity.base_granularity.to_int() < time_spine_source.base_granularity.to_int():
raise RuntimeError(
f"Can't join to time spine for a time dimension with a smaller granularity than that of the time "
f"spine column. Got {time_spine_spec.time_granularity} for time dimension, "
f"{time_spine_source.base_granularity} for time spine."
)
column_alias = self.column_association_resolver.resolve_spec(time_spine_spec).column_name
# If the requested granularity is the same as the granularity of the spine, do a direct select.
agg_time_grain = agg_time_dimension_spec.time_granularity
if (
agg_time_grain.base_granularity == time_spine_source.base_granularity
and not agg_time_grain.is_custom_granularity
):
expr: SqlExpressionNode = base_column_expr
apply_group_by = False
agg_time_grain = time_spine_spec.time_granularity
if time_spine_spec.date_part:
# For any requested date parts, apply an EXTRACT expression to the base column.
expr: SqlExpressionNode = SqlExtractExpression.create(
date_part=time_spine_spec.date_part, arg=base_column_expr
)
elif agg_time_grain.is_custom_granularity:
# If any dimensions require a custom granularity, select the appropriate column.
for custom_granularity in time_spine_source.custom_granularities:
expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_table_alias, column_name=custom_granularity.parsed_column_name
)
elif agg_time_grain.base_granularity == time_spine_source.base_granularity:
expr = base_column_expr
apply_group_by = False
else:
# If any dimensions require a different standard granularity, apply a DATE_TRUNC() to the base column.
expr = SqlDateTruncExpression.create(
time_granularity=agg_time_grain.base_granularity, arg=base_column_expr
)
select_columns += (SqlSelectColumn(expr=expr, column_alias=column_alias),)
# TODO: also handle date part.

output_instance_set = InstanceSet(
time_dimension_instances=tuple(
Expand Down Expand Up @@ -1167,9 +1173,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr
spec=metric_time_dimension_spec,
)
)
output_column_to_input_column[
metric_time_dimension_column_association.column_name
] = matching_time_dimension_instance.associated_column.column_name
output_column_to_input_column[metric_time_dimension_column_association.column_name] = (
matching_time_dimension_instance.associated_column.column_name
)

output_instance_set = InstanceSet(
measure_instances=tuple(output_measure_instances),
Expand Down Expand Up @@ -1318,36 +1324,26 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
parent_data_set = node.parent_node.accept(self)
parent_alias = self._next_unique_table_alias()

if node.use_custom_agg_time_dimension:
agg_time_dimension = node.requested_agg_time_dimension_specs[0]
agg_time_element_name = agg_time_dimension.element_name
agg_time_entity_links: Tuple[EntityReference, ...] = agg_time_dimension.entity_links
else:
agg_time_element_name = METRIC_TIME_ELEMENT_NAME
agg_time_entity_links = ()

# Find the time dimension instances in the parent data set that match the one we want to join with.
agg_time_dimension_instances: List[TimeDimensionInstance] = []
for instance in parent_data_set.instance_set.time_dimension_instances:
if (
instance.spec.date_part is None # Ensure we don't join using an instance with date part
and instance.spec.element_name == agg_time_element_name
and instance.spec.entity_links == agg_time_entity_links
):
agg_time_dimension_instances.append(instance)

# Choose the instance with the smallest base granularity available.
agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int())
assert len(agg_time_dimension_instances) > 0, (
"Couldn't find requested agg_time_dimension in parent data set. The dataflow plan may have been "
# Find the agg time dimension instance with the smallest base granularity available.
parent_agg_time_dimension_instances = tuple(
instance
for instance in parent_data_set.instance_set.time_dimension_instances
if instance.spec in node.requested_agg_time_dimension_specs
)
assert len(parent_agg_time_dimension_instances) > 0, (
"Couldn't find requested agg_time_dimensions in parent data set. The dataflow plan may have been "
"configured incorrectly."
)
agg_time_dimension_instance_for_join = agg_time_dimension_instances[0]
time_dimension_instance_for_join = sorted(
parent_agg_time_dimension_instances,
key=lambda instance: instance.spec.time_granularity.base_granularity.to_int(),
)[0]

# Build time spine data set using the requested agg_time_dimension name.
# Build time spine data set using the agg time dimension with the smallest granularity requested.
time_spine_alias = self._next_unique_table_alias()
print("parent_agg_time_dimension_instances: ", parent_agg_time_dimension_instances)
time_spine_dataset = self._make_time_spine_data_set(
agg_time_dimension_instances=(agg_time_dimension_instance_for_join,),
agg_time_dimension_instances=parent_agg_time_dimension_instances,
time_range_constraint=node.time_range_constraint,
time_spine_where_constraints=node.time_spine_filters or (),
)
Expand All @@ -1357,7 +1353,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
node=node,
time_spine_alias=time_spine_alias,
agg_time_dimension_column_name=self.column_association_resolver.resolve_spec(
agg_time_dimension_instance_for_join.spec
time_dimension_instance_for_join.spec
).column_name,
parent_sql_select_node=parent_data_set.checked_sql_select_node,
parent_alias=parent_alias,
Expand All @@ -1368,24 +1364,14 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
time_dimensions_to_select_from_parent: Tuple[TimeDimensionInstance, ...] = ()
time_dimensions_to_select_from_time_spine: Tuple[TimeDimensionInstance, ...] = ()
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances:
if (
time_dimension_instance.spec.element_name == agg_time_element_name
and time_dimension_instance.spec.entity_links == agg_time_entity_links
):
if time_dimension_instance in parent_agg_time_dimension_instances:
time_dimensions_to_select_from_time_spine += (time_dimension_instance,)
else:
time_dimensions_to_select_from_parent += (time_dimension_instance,)
parent_instance_set = InstanceSet(
measure_instances=parent_data_set.instance_set.measure_instances,
dimension_instances=parent_data_set.instance_set.dimension_instances,
time_dimension_instances=tuple(
time_dimension_instance
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances
if not (
time_dimension_instance.spec.element_name == agg_time_element_name
and time_dimension_instance.spec.entity_links == agg_time_entity_links
)
),
time_dimension_instances=time_dimensions_to_select_from_parent,
entity_instances=parent_data_set.instance_set.entity_instances,
metric_instances=parent_data_set.instance_set.metric_instances,
metadata_instances=parent_data_set.instance_set.metadata_instances,
Expand All @@ -1394,96 +1380,47 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set})
)

# Select matching instance from time spine data set (using base grain - custom grain will be joined in a later node).
original_time_spine_dim_instance: Optional[TimeDimensionInstance] = None
for time_dimension_instance in time_spine_dataset.instance_set.time_dimension_instances:
if time_dimension_instance.spec == agg_time_dimension_instance_for_join.spec:
original_time_spine_dim_instance = time_dimension_instance
break
assert original_time_spine_dim_instance, (
"Couldn't find requested agg_time_dimension_instance_for_join in time spine data set, which "
f"indicates it may have been configured incorrectly. Expected: {agg_time_dimension_instance_for_join.spec};"
f" Got: {[instance.spec for instance in time_spine_dataset.instance_set.time_dimension_instances]}"
)
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
] = SqlColumnReferenceExpression.create(
SqlColumnReference(
table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name
)
base_time_spine_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=time_dimension_instance_for_join.associated_column.column_name
)

time_spine_select_columns = []
time_spine_dim_instances = []
where_filter: Optional[SqlExpressionNode] = None

# TODO: consolidate these comments
# If offset_to_grain is used, will need to filter down to rows that match selected granularities.
# Does not apply if one of the granularities selected matches the time spine column granularity.
# Does not apply if one of the granularities selected matches the time spine base granularity.
# Filter down to one row per granularity period requested in the group by. Any other granularities
# included here will be filtered out in later nodes so should not be included in where filter.
where_filter: Optional[SqlExpressionNode] = None
need_where_filter = (
node.offset_to_grain
and original_time_spine_dim_instance.spec not in node.requested_agg_time_dimension_specs
and time_dimension_instance_for_join.spec not in node.requested_agg_time_dimension_specs
)

# Add requested granularities (if different from time_spine) and date_parts to time spine column.
for time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = time_dimension_instance.spec
if (
time_dimension_spec.time_granularity.base_granularity.to_int()
< original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int()
):
raise RuntimeError(
f"Can't join to time spine for a time dimension with a smaller granularity than that of the time "
f"spine column. Got {time_dimension_spec.time_granularity} for time dimension, "
f"{original_time_spine_dim_instance.spec.time_granularity} for time spine."
)

# Apply grain to time spine select expression, unless grain already matches original time spine column.
should_skip_date_trunc = (
time_dimension_spec.time_granularity == original_time_spine_dim_instance.spec.time_granularity
or time_dimension_spec.time_granularity.is_custom_granularity
)
select_expr: SqlExpressionNode = (
time_spine_column_select_expr
if should_skip_date_trunc
else SqlDateTruncExpression.create(
time_granularity=time_dimension_spec.time_granularity.base_granularity,
arg=time_spine_column_select_expr,
)
time_spine_columns: Tuple[SqlSelectColumn, ...] = ()
for select_column in time_spine_dataset.checked_sql_select_node.select_columns:
time_spine_columns += (
SqlSelectColumn(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=select_column.column_alias
),
column_alias=select_column.column_alias,
),
)
# Filter down to one row per granularity period requested in the group by. Any other granularities
# included here will be filtered out in later nodes so should not be included in where filter.
if need_where_filter and time_dimension_spec in node.requested_agg_time_dimension_specs:
if need_where_filter:
left_expr = select_column.expr
new_where_filter = SqlComparisonExpression.create(
left_expr=select_expr, comparison=SqlComparison.EQUALS, right_expr=time_spine_column_select_expr
left_expr=left_expr, comparison=SqlComparison.EQUALS, right_expr=base_time_spine_column_expr
)
where_filter = (
SqlLogicalExpression.create(operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter))
if where_filter
else new_where_filter
)

# Apply date_part to time spine column select expression.
if time_dimension_spec.date_part:
select_expr = SqlExtractExpression.create(date_part=time_dimension_spec.date_part, arg=select_expr)
time_dim_spec = original_time_spine_dim_instance.spec.with_grain_and_date_part(
time_granularity=time_dimension_spec.time_granularity, date_part=time_dimension_spec.date_part
)
time_spine_dim_instance = TimeDimensionInstance(
defined_from=original_time_spine_dim_instance.defined_from,
associated_columns=(self._column_association_resolver.resolve_spec(time_dim_spec),),
spec=time_dim_spec,
)
time_spine_dim_instances.append(time_spine_dim_instance)
time_spine_select_columns.append(
SqlSelectColumn(expr=select_expr, column_alias=time_spine_dim_instance.associated_column.column_name)
)
time_spine_instance_set = InstanceSet(time_dimension_instances=tuple(time_spine_dim_instances))

print("parent_select_columns: ", parent_select_columns)
return SqlDataSet(
instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]),
instance_set=InstanceSet.merge([time_spine_dataset.instance_set, parent_instance_set]),
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=tuple(time_spine_select_columns) + parent_select_columns,
select_columns=time_spine_columns + parent_select_columns,
from_source=time_spine_dataset.checked_sql_select_node,
from_source_alias=time_spine_alias,
join_descs=(join_description,),
Expand Down

0 comments on commit 4ffe95b

Please sign in to comment.