From 4ca4cb8bede146ebfd45ca9c6e846f588cee6cfd Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Fri, 10 Nov 2023 16:51:57 -0800 Subject: [PATCH] Move time spine source logic --- .../dataflow/builder/dataflow_plan_builder.py | 77 +------------------ metricflow/dataset/convert_semantic_model.py | 8 +- metricflow/plan_conversion/time_spine.py | 73 ++++++++++++++++++ ...1.xml => test_metric_time_only__dfp_0.xml} | 0 4 files changed, 77 insertions(+), 81 deletions(-) rename metricflow/test/snapshots/test_dataflow_plan_builder.py/DataflowPlan/{test_metric_time_only__dfp_1.xml => test_metric_time_only__dfp_0.xml} (100%) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index e1ae00ee61..8e0b4c8a97 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -32,7 +32,6 @@ JoinOverTimeRangeNode, JoinToBaseOutputNode, JoinToTimeSpineNode, - MetricTimeDimensionTransformNode, OrderByLimitNode, ReadSqlSourceNode, SemiAdditiveJoinNode, @@ -45,16 +44,12 @@ from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer from metricflow.dataflow.sql_table import SqlTable from metricflow.dataset.dataset import DataSet -from metricflow.dataset.sql_dataset import SqlDataSet from metricflow.errors.errors import UnableToSatisfyQueryError from metricflow.filters.time_constraint import TimeRangeConstraint -from metricflow.instances import InstanceSet, TimeDimensionInstance from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup -from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor -from metricflow.plan_conversion.time_spine import TIME_SPINE_DATA_SET_DESCRIPTION -from metricflow.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver, SingleColumnCorrelationKey +from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.specs import ( InstanceSpecSet, LinkableInstanceSpec, @@ -69,8 +64,7 @@ TimeDimensionSpec, WhereFilterSpec, ) -from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression, SqlDateTruncExpression -from metricflow.sql.sql_plan import SqlJoinType, SqlSelectColumn, SqlSelectStatementNode, SqlTableFromClauseNode +from metricflow.sql.sql_plan import SqlJoinType logger = logging.getLogger(__name__) @@ -480,70 +474,6 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) - non_additive_dimension_spec=non_additive_dimension_spec, ) - # TODO: this should live somewhere else. Figure out where makes sense - def _create_metric_time_node_from_time_spine(self) -> MetricTimeDimensionTransformNode: - """Build a ReadSqlSourceNode that represents reading from the time spine table.""" - time_spine_source = self._time_spine_source - from_source_alias = IdGeneratorRegistry.for_class(self.__class__).create_id("time_spine_src") - - # TODO: add date part to instances & select columns. Can we use the same logic as elsewhere?? - time_spine_instances: List[TimeDimensionInstance] = [] - select_columns: List[SqlSelectColumn] = [] - for granularity in TimeGranularity: - if granularity.to_int() >= time_spine_source.time_column_granularity.to_int(): - column_alias = StructuredLinkableSpecName( - entity_link_names=(), - element_name=time_spine_source.time_column_name, - time_granularity=granularity, - ).qualified_name - time_spine_instance = TimeDimensionInstance( - defined_from=(), - associated_columns=( - ColumnAssociation( - column_name=column_alias, - single_column_correlation_key=SingleColumnCorrelationKey(), - ), - ), - spec=TimeDimensionSpec( - element_name=time_spine_source.time_column_name, entity_links=(), time_granularity=granularity - ), - ) - time_spine_instances.append(time_spine_instance) - select_column = SqlSelectColumn( - SqlDateTruncExpression( - time_granularity=granularity, - arg=SqlColumnReferenceExpression( - SqlColumnReference( - table_alias=from_source_alias, - column_name=time_spine_source.time_column_name, - ), - ), - ), - column_alias=column_alias, - ) - select_columns.append(select_column) - - time_spine_instance_set = InstanceSet(time_dimension_instances=tuple(time_spine_instances)) - time_spine_data_set = SqlDataSet( - instance_set=time_spine_instance_set, - sql_select_node=SqlSelectStatementNode( - description=TIME_SPINE_DATA_SET_DESCRIPTION, - select_columns=tuple(select_columns), - from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table), - from_source_alias=from_source_alias, - joins_descs=(), - group_bys=(), - order_bys=(), - ), - ) - # need this if we have above?? - return MetricTimeDimensionTransformNode( - parent_node=ReadSqlSourceNode(data_set=time_spine_data_set), - aggregation_time_dimension_reference=TimeDimensionReference( - element_name=time_spine_source.time_column_name - ), - ) - def _find_dataflow_recipe( self, linkable_spec_set: LinkableSpecSet, @@ -572,8 +502,7 @@ def _find_dataflow_recipe( ] if requested_metric_time_specs: # Add time_spine to potential source nodes for requested metric_time specs - time_spine_node = self._create_metric_time_node_from_time_spine() - potential_source_nodes = list(potential_source_nodes) + [time_spine_node] + potential_source_nodes = list(potential_source_nodes) + [self._time_spine_source.build_source_node()] logger.info(f"There are {len(potential_source_nodes)} potential source nodes") diff --git a/metricflow/dataset/convert_semantic_model.py b/metricflow/dataset/convert_semantic_model.py index 5f713c7bf6..cc29f45bca 100644 --- a/metricflow/dataset/convert_semantic_model.py +++ b/metricflow/dataset/convert_semantic_model.py @@ -16,7 +16,7 @@ from metricflow.aggregation_properties import AggregationState from metricflow.dag.id_generation import IdGeneratorRegistry from metricflow.dataflow.sql_table import SqlTable -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet, SqlDataSet +from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.instances import ( DimensionInstance, EntityInstance, @@ -26,7 +26,6 @@ ) from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup from metricflow.model.spec_converters import MeasureConverter -from metricflow.plan_conversion.time_spine import TimeSpineSource from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.specs import ( DEFAULT_TIME_GRANULARITY, @@ -479,8 +478,3 @@ def create_sql_source_data_set(self, semantic_model: SemanticModel) -> SemanticM ), sql_select_node=select_statement_node, ) - - # move logic here? - def create_data_set_from_time_spine(self, time_spine_source: TimeSpineSource) -> SqlDataSet: - """Create a SQL source data set from time spine table.""" - pass diff --git a/metricflow/plan_conversion/time_spine.py b/metricflow/plan_conversion/time_spine.py index 71a5d773d8..aeb0a6adf1 100644 --- a/metricflow/plan_conversion/time_spine.py +++ b/metricflow/plan_conversion/time_spine.py @@ -2,10 +2,21 @@ import logging from dataclasses import dataclass +from typing import List +from dbt_semantic_interfaces.references import TimeDimensionReference from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from metricflow.dag.id_generation import IdGeneratorRegistry +from metricflow.dataflow.dataflow_plan import MetricTimeDimensionTransformNode, ReadSqlSourceNode from metricflow.dataflow.sql_table import SqlTable +from metricflow.dataset.sql_dataset import SqlDataSet +from metricflow.instances import InstanceSet, TimeDimensionInstance +from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName +from metricflow.specs.column_assoc import ColumnAssociation, SingleColumnCorrelationKey +from metricflow.specs.specs import TimeDimensionSpec +from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression, SqlDateTruncExpression +from metricflow.sql.sql_plan import SqlSelectColumn, SqlSelectStatementNode, SqlTableFromClauseNode logger = logging.getLogger(__name__) @@ -27,3 +38,65 @@ class TimeSpineSource: def spine_table(self) -> SqlTable: """Table containing all dates.""" return SqlTable(schema_name=self.schema_name, table_name=self.table_name) + + def build_source_node(self) -> MetricTimeDimensionTransformNode: + """Build data set for time spine.""" + from_source_alias = IdGeneratorRegistry.for_class(self.__class__).create_id("time_spine_src") + + # TODO: add date part to instances & select columns. Can we use the same logic as elsewhere?? + # TODO: add test cases for date part + time_spine_instances: List[TimeDimensionInstance] = [] + select_columns: List[SqlSelectColumn] = [] + for granularity in TimeGranularity: + if granularity.to_int() >= self.time_column_granularity.to_int(): + column_alias = StructuredLinkableSpecName( + entity_link_names=(), + element_name=self.time_column_name, + time_granularity=granularity, + ).qualified_name + time_spine_instance = TimeDimensionInstance( + defined_from=(), + associated_columns=( + ColumnAssociation( + column_name=column_alias, + single_column_correlation_key=SingleColumnCorrelationKey(), + ), + ), + spec=TimeDimensionSpec( + element_name=self.time_column_name, entity_links=(), time_granularity=granularity + ), + ) + time_spine_instances.append(time_spine_instance) + select_column = SqlSelectColumn( + SqlDateTruncExpression( + time_granularity=granularity, + arg=SqlColumnReferenceExpression( + SqlColumnReference( + table_alias=from_source_alias, + column_name=self.time_column_name, + ), + ), + ), + column_alias=column_alias, + ) + select_columns.append(select_column) + + time_spine_instance_set = InstanceSet(time_dimension_instances=tuple(time_spine_instances)) + + data_set = SqlDataSet( + instance_set=time_spine_instance_set, + sql_select_node=SqlSelectStatementNode( + description=TIME_SPINE_DATA_SET_DESCRIPTION, + select_columns=tuple(select_columns), + from_source=SqlTableFromClauseNode(sql_table=self.spine_table), + from_source_alias=from_source_alias, + joins_descs=(), + group_bys=(), + order_bys=(), + ), + ) + + return MetricTimeDimensionTransformNode( + parent_node=ReadSqlSourceNode(data_set=data_set), + aggregation_time_dimension_reference=TimeDimensionReference(element_name=self.time_column_name), + ) diff --git a/metricflow/test/snapshots/test_dataflow_plan_builder.py/DataflowPlan/test_metric_time_only__dfp_1.xml b/metricflow/test/snapshots/test_dataflow_plan_builder.py/DataflowPlan/test_metric_time_only__dfp_0.xml similarity index 100% rename from metricflow/test/snapshots/test_dataflow_plan_builder.py/DataflowPlan/test_metric_time_only__dfp_1.xml rename to metricflow/test/snapshots/test_dataflow_plan_builder.py/DataflowPlan/test_metric_time_only__dfp_0.xml