diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 896dad428..87d6490bb 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import logging import time from typing import Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, Union @@ -92,6 +93,7 @@ from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode +from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode from metricflow.dataflow.nodes.where_filter import WhereConstraintNode from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode @@ -648,14 +650,19 @@ def _build_derived_metric_output_node( metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup ) if metric_spec.has_time_offset and queried_agg_time_dimension_specs: + # TODO: move this to a helper method + time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs) output_node = JoinToTimeSpineNode.create( parent_node=output_node, + time_spine_node=time_spine_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, + join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0], offset_window=metric_spec.offset_window, offset_to_grain=metric_spec.offset_to_grain, join_type=SqlJoinType.INNER, ) + # TODO: fix bug here where filter specs are being included in when aggregating. if len(metric_spec.filter_spec_set.all_filter_specs) > 0 or predicate_pushdown_state.time_range_constraint: # FilterElementsNode will only be needed if there are where filter specs that were selected in the group by. specs_in_filters = set( @@ -1616,15 +1623,22 @@ def _build_aggregated_measure_from_measure_source_node( # If querying an offset metric, join to time spine before aggregation. if before_aggregation_time_spine_join_description and base_queried_agg_time_dimension_specs: + # TODO: move all of this to a helper function assert before_aggregation_time_spine_join_description.join_type is SqlJoinType.INNER, ( f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a " f"new use case." ) - # This also uses the original time range constraint due to the application of the time window intervals - # in join rendering + + join_on_time_dimension_spec = self._determine_time_spine_join_spec( + measure_properties=measure_properties, required_time_spine_specs=base_queried_agg_time_dimension_specs + ) + required_time_spine_specs = (join_on_time_dimension_spec,) + base_queried_agg_time_dimension_specs + time_spine_node = self._build_time_spine_node(required_time_spine_specs) unaggregated_measure_node = JoinToTimeSpineNode.create( parent_node=unaggregated_measure_node, + time_spine_node=time_spine_node, requested_agg_time_dimension_specs=base_queried_agg_time_dimension_specs, + join_on_time_dimension_spec=join_on_time_dimension_spec, offset_window=before_aggregation_time_spine_join_description.offset_window, offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain, join_type=before_aggregation_time_spine_join_description.join_type, @@ -1667,10 +1681,13 @@ def _build_aggregated_measure_from_measure_source_node( measure_reference=measure_spec.reference, semantic_model_lookup=self._semantic_model_lookup ) if after_aggregation_time_spine_join_description and queried_agg_time_dimension_specs: + # TODO: move all of this to a helper function assert after_aggregation_time_spine_join_description.join_type is SqlJoinType.LEFT_OUTER, ( f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if " f"there's a new use case." ) + time_spine_required_specs = copy.deepcopy(queried_agg_time_dimension_specs) + # Find filters that contain only metric_time or agg_time_dimension. They will be applied to the time spine table. agg_time_only_filters: List[WhereFilterSpec] = [] non_agg_time_filters: List[WhereFilterSpec] = [] @@ -1680,24 +1697,23 @@ def _build_aggregated_measure_from_measure_source_node( ) if set(included_agg_time_specs) == set(filter_spec.linkable_spec_set.as_tuple): agg_time_only_filters.append(filter_spec) - if filter_spec.linkable_spec_set.time_dimension_specs_with_custom_grain: - raise ValueError( - "Using custom granularity in filters for `join_to_timespine` metrics is not yet fully supported. " - "This feature is coming soon!" - ) + for agg_time_spec in included_agg_time_specs: + if agg_time_spec not in time_spine_required_specs: + time_spine_required_specs.append(agg_time_spec) else: non_agg_time_filters.append(filter_spec) - # TODO: split this node into TimeSpineSourceNode and JoinToTimeSpineNode - then can use standard nodes here - # like JoinToCustomGranularityNode, WhereConstraintNode, etc. + time_spine_node = self._build_time_spine_node( + queried_time_spine_specs=queried_agg_time_dimension_specs, + time_range_constraint=predicate_pushdown_state.time_range_constraint, + where_filter_specs=agg_time_only_filters, + ) output_node: DataflowPlanNode = JoinToTimeSpineNode.create( parent_node=aggregate_measures_node, + time_spine_node=time_spine_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, + join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0], join_type=after_aggregation_time_spine_join_description.join_type, - time_range_constraint=predicate_pushdown_state.time_range_constraint, - offset_window=after_aggregation_time_spine_join_description.offset_window, - offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain, - time_spine_filters=agg_time_only_filters, ) # Since new rows might have been added due to time spine join, re-apply constraints here. Only re-apply filters @@ -1824,3 +1840,68 @@ def _choose_time_spine_metric_time_node( def _choose_time_spine_read_node(self, time_spine_source: TimeSpineSource) -> ReadSqlSourceNode: """Return the MetricTimeDimensionTransform time spine node needed to satisfy the specs.""" return self._source_node_set.time_spine_read_nodes[time_spine_source.base_granularity] + + def _build_time_spine_node( + self, + queried_time_spine_specs: Sequence[TimeDimensionSpec], + where_filter_specs: Sequence[WhereFilterSpec] = (), + time_range_constraint: Optional[TimeRangeConstraint] = None, + ) -> DataflowPlanNode: + """Return the time spine node needed to satisfy the specs.""" + required_time_spine_spec_set = self.__get_required_linkable_specs( + queried_linkable_specs=LinkableSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)), + filter_specs=where_filter_specs, + ) + required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs + + # TODO: support multiple time spines here. Build node on the one with the smallest base grain. + # Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine. + time_spine_source = self._choose_time_spine_source(required_time_spine_specs) + time_spine_node = TransformTimeDimensionsNode.create( + parent_node=self._choose_time_spine_read_node(time_spine_source), + requested_time_dimension_specs=required_time_spine_specs, + ) + + # If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping. + should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in { + spec.time_granularity for spec in queried_time_spine_specs + } + + return self._build_pre_aggregation_plan( + source_node=time_spine_node, + filter_to_specs=InstanceSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)), + time_range_constraint=time_range_constraint, + where_filter_specs=where_filter_specs, + distinct=should_dedupe, + ) + + def _sort_by_base_granularity(self, time_dimension_specs: Sequence[TimeDimensionSpec]) -> List[TimeDimensionSpec]: + """Sort the time dimensions by their base granularity. + + Specs with date part will come after specs without it. Standard grains will come before custom. + """ + return sorted( + time_dimension_specs, + key=lambda spec: ( + spec.date_part is not None, + spec.time_granularity.is_custom_granularity, + spec.time_granularity.base_granularity.to_int(), + ), + ) + + def _determine_time_spine_join_spec( + self, measure_properties: MeasureSpecProperties, required_time_spine_specs: Tuple[TimeDimensionSpec, ...] + ) -> TimeDimensionSpec: + """Determine the spec to join on for a time spine join. + + Defaults to metric_time if it is included in the request, else the agg_time_dimension. + Will use the smallest available grain for the meeasure. + """ + join_spec_grain = ExpandedTimeGranularity.from_time_granularity(measure_properties.agg_time_dimension_grain) + join_on_time_dimension_spec = DataSet.metric_time_dimension_spec(time_granularity=join_spec_grain) + if not LinkableSpecSet(time_dimension_specs=required_time_spine_specs).contains_metric_time: + sample_agg_time_dimension_spec = required_time_spine_specs[0] + join_on_time_dimension_spec = sample_agg_time_dimension_spec.with_grain_and_date_part( + time_granularity=join_spec_grain, date_part=None + ) + return join_on_time_dimension_spec diff --git a/metricflow/dataflow/nodes/join_to_time_spine.py b/metricflow/dataflow/nodes/join_to_time_spine.py index fd8650381..b33a3ff6d 100644 --- a/metricflow/dataflow/nodes/join_to_time_spine.py +++ b/metricflow/dataflow/nodes/join_to_time_spine.py @@ -8,9 +8,7 @@ from dbt_semantic_interfaces.type_enums import TimeGranularity from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty -from metricflow_semantics.filters.time_constraint import TimeRangeConstraint from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec -from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.visitor import VisitorOutputT @@ -25,17 +23,17 @@ class JoinToTimeSpineNode(DataflowPlanNode, ABC): Attributes: requested_agg_time_dimension_specs: Time dimensions requested in the query. join_type: Join type to use when joining to time spine. - time_range_constraint: Time range to constrain the time spine to. + join_on_time_dimension_spec: The time dimension to use in the join ON condition. offset_window: Time window to offset the parent dataset by when joining to time spine. offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine. """ + time_spine_node: DataflowPlanNode requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec] + join_on_time_dimension_spec: TimeDimensionSpec join_type: SqlJoinType - time_range_constraint: Optional[TimeRangeConstraint] offset_window: Optional[MetricTimeWindow] offset_to_grain: Optional[TimeGranularity] - time_spine_filters: Optional[Sequence[WhereFilterSpec]] = None def __post_init__(self) -> None: # noqa: D105 super().__post_init__() @@ -51,21 +49,21 @@ def __post_init__(self) -> None: # noqa: D105 @staticmethod def create( # noqa: D102 parent_node: DataflowPlanNode, + time_spine_node: DataflowPlanNode, requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec], + join_on_time_dimension_spec: TimeDimensionSpec, join_type: SqlJoinType, - time_range_constraint: Optional[TimeRangeConstraint] = None, offset_window: Optional[MetricTimeWindow] = None, offset_to_grain: Optional[TimeGranularity] = None, - time_spine_filters: Optional[Sequence[WhereFilterSpec]] = None, ) -> JoinToTimeSpineNode: return JoinToTimeSpineNode( parent_nodes=(parent_node,), + time_spine_node=time_spine_node, requested_agg_time_dimension_specs=tuple(requested_agg_time_dimension_specs), + join_on_time_dimension_spec=join_on_time_dimension_spec, join_type=join_type, - time_range_constraint=time_range_constraint, offset_window=offset_window, offset_to_grain=offset_to_grain, - time_spine_filters=time_spine_filters, ) @classmethod @@ -83,20 +81,13 @@ def description(self) -> str: # noqa: D102 def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 props = tuple(super().displayed_properties) + ( DisplayedProperty("requested_agg_time_dimension_specs", self.requested_agg_time_dimension_specs), + DisplayedProperty("join_on_time_dimension_spec", self.join_on_time_dimension_spec), DisplayedProperty("join_type", self.join_type), ) if self.offset_window: props += (DisplayedProperty("offset_window", self.offset_window),) if self.offset_to_grain: props += (DisplayedProperty("offset_to_grain", self.offset_to_grain),) - if self.time_range_constraint: - props += (DisplayedProperty("time_range_constraint", self.time_range_constraint),) - if self.time_spine_filters: - props += ( - DisplayedProperty( - "time_spine_filters", [time_spine_filter.where_sql for time_spine_filter in self.time_spine_filters] - ), - ) return props @property @@ -106,22 +97,21 @@ def parent_node(self) -> DataflowPlanNode: # noqa: D102 def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 return ( isinstance(other_node, self.__class__) - and other_node.time_range_constraint == self.time_range_constraint and other_node.offset_window == self.offset_window and other_node.offset_to_grain == self.offset_to_grain and other_node.requested_agg_time_dimension_specs == self.requested_agg_time_dimension_specs + and other_node.join_on_time_dimension_spec == self.join_on_time_dimension_spec and other_node.join_type == self.join_type - and other_node.time_spine_filters == self.time_spine_filters ) def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> JoinToTimeSpineNode: # noqa: D102 assert len(new_parent_nodes) == 1 return JoinToTimeSpineNode.create( parent_node=new_parent_nodes[0], + time_spine_node=self.time_spine_node, requested_agg_time_dimension_specs=self.requested_agg_time_dimension_specs, - time_range_constraint=self.time_range_constraint, offset_window=self.offset_window, offset_to_grain=self.offset_to_grain, join_type=self.join_type, - time_spine_filters=self.time_spine_filters, + join_on_time_dimension_spec=self.join_on_time_dimension_spec, ) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index a4d01496b..dccec33e0 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -313,6 +313,7 @@ def _next_unique_table_alias(self) -> str: """Return the next unique table alias to use in generating queries.""" return SequentialIdGenerator.create_next_id(StaticIdPrefix.SUB_QUERY).str_value + # TODO: replace this with a dataflow plan node for cumulative metrics def _make_time_spine_data_set( self, agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...], @@ -1205,9 +1206,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr spec=metric_time_dimension_spec, ) ) - output_column_to_input_column[ - metric_time_dimension_column_association.column_name - ] = matching_time_dimension_instance.associated_column.column_name + output_column_to_input_column[metric_time_dimension_column_association.column_name] = ( + matching_time_dimension_instance.associated_column.column_name + ) output_instance_set = InstanceSet( measure_instances=tuple(output_measure_instances), @@ -1372,97 +1373,72 @@ def _choose_instance_for_time_spine_join( def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102 parent_data_set = node.parent_node.accept(self) parent_alias = self._next_unique_table_alias() - - agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions( - node.requested_agg_time_dimension_specs - ) - - # Select the dimension for the join from the parent node because it may not have been included in the request. - # Default to using metric_time for the join if it was requested, otherwise use the agg_time_dimension. - included_metric_time_instances = [ - instance for instance in agg_time_dimension_instances if instance.spec.is_metric_time - ] - if included_metric_time_instances: - join_on_time_dimension_sample = included_metric_time_instances[0].spec - else: - join_on_time_dimension_sample = agg_time_dimension_instances[0].spec - agg_time_dimension_instance_for_join = self._choose_instance_for_time_spine_join( - [ - instance - for instance in parent_data_set.instance_set.time_dimension_instances - if instance.spec.element_name == join_on_time_dimension_sample.element_name - and instance.spec.entity_links == join_on_time_dimension_sample.entity_links - ] - ) - if agg_time_dimension_instance_for_join not in agg_time_dimension_instances: - agg_time_dimension_instances = (agg_time_dimension_instance_for_join,) + agg_time_dimension_instances - - # Build time spine data set with just the agg_time_dimension instance needed for the join. + time_spine_data_set = node.time_spine_node.accept(self) time_spine_alias = self._next_unique_table_alias() - time_spine_dataset = self._make_time_spine_data_set( - agg_time_dimension_instances=agg_time_dimension_instances, - time_range_constraint=node.time_range_constraint, - time_spine_where_constraints=node.time_spine_filters or (), - ) + + required_agg_time_dimension_specs = tuple(node.requested_agg_time_dimension_specs) + if node.join_on_time_dimension_spec not in node.requested_agg_time_dimension_specs: + required_agg_time_dimension_specs += (node.join_on_time_dimension_spec,) # Build join expression. + join_column_name = self._column_association_resolver.resolve_spec(node.join_on_time_dimension_spec).column_name join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description( node=node, time_spine_alias=time_spine_alias, - agg_time_dimension_column_name=self._column_association_resolver.resolve_spec( - agg_time_dimension_instance_for_join.spec - ).column_name, + agg_time_dimension_column_name=join_column_name, parent_sql_select_node=parent_data_set.checked_sql_select_node, parent_alias=parent_alias, ) - # Remove time spine instances from parent instance set. - time_spine_instances = time_spine_dataset.instance_set - time_spine_specs = time_spine_instances.spec_set - parent_instance_set = parent_data_set.instance_set.transform(FilterElements(exclude_specs=time_spine_specs)) + # Build combined instance set. + time_spine_required_spec_set = InstanceSpecSet(time_dimension_specs=required_agg_time_dimension_specs) + parent_instance_set = parent_data_set.instance_set.transform( + FilterElements(exclude_specs=time_spine_required_spec_set) + ) + time_spine_instance_set = time_spine_data_set.instance_set.transform( + FilterElements(include_specs=time_spine_required_spec_set) + ) + output_instance_set = InstanceSet.merge([parent_instance_set, time_spine_instance_set]) - # Build select columns + # Build new simple select columns. select_columns = create_simple_select_columns_for_instance_sets( self._column_association_resolver, - OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_dataset.instance_set}), + OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_instance_set}), ) # If offset_to_grain is used, will need to filter down to rows that match selected granularities. # Does not apply if one of the granularities selected matches the time spine column granularity. where_filter: Optional[SqlExpressionNode] = None need_where_filter = ( - node.offset_to_grain - and agg_time_dimension_instance_for_join.spec not in node.requested_agg_time_dimension_specs + node.offset_to_grain and node.join_on_time_dimension_spec not in node.requested_agg_time_dimension_specs ) + + # Filter down to one row per granularity period requested in the group by. Any other granularities + # included here will be filtered out before aggregation and so should not be included in where filter. if need_where_filter: join_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( - table_alias=time_spine_alias, - column_name=agg_time_dimension_instance_for_join.associated_column.column_name, + table_alias=time_spine_alias, column_name=join_column_name ) - for time_spine_instance in time_spine_instances.as_tuple: - # Filter down to one row per granularity period requested in the group by. Any other granularities - # included here will be filtered out in later nodes so should not be included in where filter. - if need_where_filter and time_spine_instance.spec in node.requested_agg_time_dimension_specs: - column_to_filter_expr = SqlColumnReferenceExpression.from_table_and_column_names( - table_alias=time_spine_alias, column_name=time_spine_instance.associated_column.column_name - ) - new_where_filter = SqlComparisonExpression.create( - left_expr=column_to_filter_expr, comparison=SqlComparison.EQUALS, right_expr=join_column_expr - ) - where_filter = ( - SqlLogicalExpression.create( - operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter) - ) - if where_filter - else new_where_filter - ) + for requested_spec in node.requested_agg_time_dimension_specs: + column_name = self._column_association_resolver.resolve_spec(requested_spec).column_name + column_to_filter_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=time_spine_alias, column_name=column_name + ) + new_where_filter = SqlComparisonExpression.create( + left_expr=column_to_filter_expr, comparison=SqlComparison.EQUALS, right_expr=join_column_expr + ) + where_filter = ( + SqlLogicalExpression.create(operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter)) + if where_filter + else new_where_filter + ) return SqlDataSet( - instance_set=InstanceSet.merge([time_spine_dataset.instance_set, parent_instance_set]), + instance_set=output_instance_set, sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=select_columns, - from_source=time_spine_dataset.checked_sql_select_node, + from_source=time_spine_data_set.checked_sql_select_node, from_source_alias=time_spine_alias, join_descs=(join_description,), where=where_filter,