Skip to content

Commit

Permalink
Move time spine source logic
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 11, 2023
1 parent 603a397 commit 4ca4cb8
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 81 deletions.
77 changes: 3 additions & 74 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
JoinOverTimeRangeNode,
JoinToBaseOutputNode,
JoinToTimeSpineNode,
MetricTimeDimensionTransformNode,
OrderByLimitNode,
ReadSqlSourceNode,
SemiAdditiveJoinNode,
Expand All @@ -45,16 +44,12 @@
from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer
from metricflow.dataflow.sql_table import SqlTable
from metricflow.dataset.dataset import DataSet
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.errors.errors import UnableToSatisfyQueryError
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.instances import InstanceSet, TimeDimensionInstance
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver
from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor
from metricflow.plan_conversion.time_spine import TIME_SPINE_DATA_SET_DESCRIPTION
from metricflow.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver, SingleColumnCorrelationKey
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import (
InstanceSpecSet,
LinkableInstanceSpec,
Expand All @@ -69,8 +64,7 @@
TimeDimensionSpec,
WhereFilterSpec,
)
from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression, SqlDateTruncExpression
from metricflow.sql.sql_plan import SqlJoinType, SqlSelectColumn, SqlSelectStatementNode, SqlTableFromClauseNode
from metricflow.sql.sql_plan import SqlJoinType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -480,70 +474,6 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) -
non_additive_dimension_spec=non_additive_dimension_spec,
)

# TODO: this should live somewhere else. Figure out where makes sense
def _create_metric_time_node_from_time_spine(self) -> MetricTimeDimensionTransformNode:
"""Build a ReadSqlSourceNode that represents reading from the time spine table."""
time_spine_source = self._time_spine_source
from_source_alias = IdGeneratorRegistry.for_class(self.__class__).create_id("time_spine_src")

# TODO: add date part to instances & select columns. Can we use the same logic as elsewhere??
time_spine_instances: List[TimeDimensionInstance] = []
select_columns: List[SqlSelectColumn] = []
for granularity in TimeGranularity:
if granularity.to_int() >= time_spine_source.time_column_granularity.to_int():
column_alias = StructuredLinkableSpecName(
entity_link_names=(),
element_name=time_spine_source.time_column_name,
time_granularity=granularity,
).qualified_name
time_spine_instance = TimeDimensionInstance(
defined_from=(),
associated_columns=(
ColumnAssociation(
column_name=column_alias,
single_column_correlation_key=SingleColumnCorrelationKey(),
),
),
spec=TimeDimensionSpec(
element_name=time_spine_source.time_column_name, entity_links=(), time_granularity=granularity
),
)
time_spine_instances.append(time_spine_instance)
select_column = SqlSelectColumn(
SqlDateTruncExpression(
time_granularity=granularity,
arg=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=from_source_alias,
column_name=time_spine_source.time_column_name,
),
),
),
column_alias=column_alias,
)
select_columns.append(select_column)

time_spine_instance_set = InstanceSet(time_dimension_instances=tuple(time_spine_instances))
time_spine_data_set = SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=tuple(select_columns),
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
from_source_alias=from_source_alias,
joins_descs=(),
group_bys=(),
order_bys=(),
),
)
# need this if we have above??
return MetricTimeDimensionTransformNode(
parent_node=ReadSqlSourceNode(data_set=time_spine_data_set),
aggregation_time_dimension_reference=TimeDimensionReference(
element_name=time_spine_source.time_column_name
),
)

def _find_dataflow_recipe(
self,
linkable_spec_set: LinkableSpecSet,
Expand Down Expand Up @@ -572,8 +502,7 @@ def _find_dataflow_recipe(
]
if requested_metric_time_specs:
# Add time_spine to potential source nodes for requested metric_time specs
time_spine_node = self._create_metric_time_node_from_time_spine()
potential_source_nodes = list(potential_source_nodes) + [time_spine_node]
potential_source_nodes = list(potential_source_nodes) + [self._time_spine_source.build_source_node()]

logger.info(f"There are {len(potential_source_nodes)} potential source nodes")

Expand Down
8 changes: 1 addition & 7 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from metricflow.aggregation_properties import AggregationState
from metricflow.dag.id_generation import IdGeneratorRegistry
from metricflow.dataflow.sql_table import SqlTable
from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet, SqlDataSet
from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet
from metricflow.instances import (
DimensionInstance,
EntityInstance,
Expand All @@ -26,7 +26,6 @@
)
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow.model.spec_converters import MeasureConverter
from metricflow.plan_conversion.time_spine import TimeSpineSource
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import (
DEFAULT_TIME_GRANULARITY,
Expand Down Expand Up @@ -479,8 +478,3 @@ def create_sql_source_data_set(self, semantic_model: SemanticModel) -> SemanticM
),
sql_select_node=select_statement_node,
)

# move logic here?
def create_data_set_from_time_spine(self, time_spine_source: TimeSpineSource) -> SqlDataSet:
"""Create a SQL source data set from time spine table."""
pass
73 changes: 73 additions & 0 deletions metricflow/plan_conversion/time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,21 @@

import logging
from dataclasses import dataclass
from typing import List

from dbt_semantic_interfaces.references import TimeDimensionReference
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.dag.id_generation import IdGeneratorRegistry
from metricflow.dataflow.dataflow_plan import MetricTimeDimensionTransformNode, ReadSqlSourceNode
from metricflow.dataflow.sql_table import SqlTable
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.instances import InstanceSet, TimeDimensionInstance
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.specs.column_assoc import ColumnAssociation, SingleColumnCorrelationKey
from metricflow.specs.specs import TimeDimensionSpec
from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression, SqlDateTruncExpression
from metricflow.sql.sql_plan import SqlSelectColumn, SqlSelectStatementNode, SqlTableFromClauseNode

logger = logging.getLogger(__name__)

Expand All @@ -27,3 +38,65 @@ class TimeSpineSource:
def spine_table(self) -> SqlTable:
"""Table containing all dates."""
return SqlTable(schema_name=self.schema_name, table_name=self.table_name)

def build_source_node(self) -> MetricTimeDimensionTransformNode:
"""Build data set for time spine."""
from_source_alias = IdGeneratorRegistry.for_class(self.__class__).create_id("time_spine_src")

# TODO: add date part to instances & select columns. Can we use the same logic as elsewhere??
# TODO: add test cases for date part
time_spine_instances: List[TimeDimensionInstance] = []
select_columns: List[SqlSelectColumn] = []
for granularity in TimeGranularity:
if granularity.to_int() >= self.time_column_granularity.to_int():
column_alias = StructuredLinkableSpecName(
entity_link_names=(),
element_name=self.time_column_name,
time_granularity=granularity,
).qualified_name
time_spine_instance = TimeDimensionInstance(
defined_from=(),
associated_columns=(
ColumnAssociation(
column_name=column_alias,
single_column_correlation_key=SingleColumnCorrelationKey(),
),
),
spec=TimeDimensionSpec(
element_name=self.time_column_name, entity_links=(), time_granularity=granularity
),
)
time_spine_instances.append(time_spine_instance)
select_column = SqlSelectColumn(
SqlDateTruncExpression(
time_granularity=granularity,
arg=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=from_source_alias,
column_name=self.time_column_name,
),
),
),
column_alias=column_alias,
)
select_columns.append(select_column)

time_spine_instance_set = InstanceSet(time_dimension_instances=tuple(time_spine_instances))

data_set = SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=tuple(select_columns),
from_source=SqlTableFromClauseNode(sql_table=self.spine_table),
from_source_alias=from_source_alias,
joins_descs=(),
group_bys=(),
order_bys=(),
),
)

return MetricTimeDimensionTransformNode(
parent_node=ReadSqlSourceNode(data_set=data_set),
aggregation_time_dimension_reference=TimeDimensionReference(element_name=self.time_column_name),
)

0 comments on commit 4ca4cb8

Please sign in to comment.