Skip to content

Commit

Permalink
Support multiple time spine sources
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jul 25, 2024
1 parent cb3003e commit 987a6ff
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 61 deletions.
4 changes: 2 additions & 2 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ def _find_dataflow_recipe(
desired_linkable_specs=linkable_specs,
nodes=candidate_nodes_for_right_side_of_join,
metric_time_dimension_reference=self._metric_time_dimension_reference,
time_spine_node=self._source_node_set.time_spine_node,
time_spine_nodes=self._source_node_set.time_spine_nodes_tuple,
)
logger.info(
f"After removing unnecessary nodes, there are {len(candidate_nodes_for_right_side_of_join)} candidate "
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def _find_dataflow_recipe(
semantic_model_lookup=self._semantic_model_lookup,
nodes_available_for_joins=self._sort_by_suitability(candidate_nodes_for_right_side_of_join),
node_data_set_resolver=self._node_data_set_resolver,
time_spine_node=self._source_node_set.time_spine_node,
time_spine_nodes=self._source_node_set.time_spine_nodes_tuple,
)

# Dict from the node that contains the source node to the evaluation results.
Expand Down
6 changes: 3 additions & 3 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def __init__(
semantic_model_lookup: SemanticModelLookup,
nodes_available_for_joins: Sequence[DataflowPlanNode],
node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver,
time_spine_node: MetricTimeDimensionTransformNode,
time_spine_nodes: Sequence[MetricTimeDimensionTransformNode],
) -> None:
"""Initializer.
Expand All @@ -185,7 +185,7 @@ def __init__(
self._node_data_set_resolver = node_data_set_resolver
self._partition_resolver = PartitionJoinResolver(self._semantic_model_lookup)
self._join_evaluator = SemanticModelJoinEvaluator(self._semantic_model_lookup)
self._time_spine_node = time_spine_node
self._time_spine_nodes = time_spine_nodes

def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
self,
Expand All @@ -201,7 +201,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
left_node_spec_set = left_node_instance_set.spec_set
for right_node in self._nodes_available_for_joins:
# If right node is time spine source node, use cross join.
if right_node == self._time_spine_node:
if right_node in self._time_spine_nodes:
needed_metric_time_specs = group_specs_by_type(needed_linkable_specs).metric_time_specs
candidates_for_join.append(
JoinLinkableInstancesRecipe(
Expand Down
40 changes: 25 additions & 15 deletions metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Sequence, Tuple
from typing import Dict, List, Sequence, Tuple

from dbt_semantic_interfaces.references import TimeDimensionReference
from dbt_semantic_interfaces.type_enums import TimeGranularity
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.query.query_parser import MetricFlowQueryParser
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
Expand Down Expand Up @@ -36,14 +37,16 @@ class SourceNodeSet:
# below. See usage in `DataflowPlanBuilder`.
source_nodes_for_group_by_item_queries: Tuple[DataflowPlanNode, ...]

# Provides the time spine.
time_spine_node: MetricTimeDimensionTransformNode
# Provides the time spines.
time_spine_nodes: Dict[TimeGranularity, MetricTimeDimensionTransformNode]

@property
def all_nodes(self) -> Sequence[DataflowPlanNode]: # noqa: D102
return (
self.source_nodes_for_metric_queries + self.source_nodes_for_group_by_item_queries + (self.time_spine_node,)
)
return self.source_nodes_for_metric_queries + self.source_nodes_for_group_by_item_queries

@property
def time_spine_nodes_tuple(self) -> Tuple[MetricTimeDimensionTransformNode, ...]: # noqa: D102
return tuple(self.time_spine_nodes.values())


class SourceNodeBuilder:
Expand All @@ -56,13 +59,19 @@ def __init__( # noqa: D107
) -> None:
self._semantic_manifest_lookup = semantic_manifest_lookup
data_set_converter = SemanticModelToDataSetConverter(column_association_resolver)
time_spine_source = TimeSpineSource.create_from_manifest(semantic_manifest_lookup.semantic_manifest)
time_spine_data_set = data_set_converter.build_time_spine_source_data_set(time_spine_source)
time_dim_reference = TimeDimensionReference(element_name=time_spine_source.time_column_name)
self._time_spine_source_node = MetricTimeDimensionTransformNode.create(
parent_node=ReadSqlSourceNode.create(data_set=time_spine_data_set),
aggregation_time_dimension_reference=time_dim_reference,
)
self._time_spine_source_nodes = {
granularity: MetricTimeDimensionTransformNode.create(
parent_node=ReadSqlSourceNode.create(
data_set=data_set_converter.build_time_spine_source_data_set(time_spine_source)
),
aggregation_time_dimension_reference=TimeDimensionReference(
element_name=time_spine_source.time_column_name
),
)
for granularity, time_spine_source in TimeSpineSource.create_from_manifest(
semantic_manifest_lookup.semantic_manifest
).items()
}
self._query_parser = MetricFlowQueryParser(semantic_manifest_lookup)

def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> SourceNodeSet:
Expand Down Expand Up @@ -93,8 +102,9 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> So
source_nodes_for_metric_queries.append(metric_time_transform_node)

return SourceNodeSet(
time_spine_node=self._time_spine_source_node,
source_nodes_for_group_by_item_queries=tuple(group_by_item_source_nodes) + (self._time_spine_source_node,),
time_spine_nodes=self._time_spine_source_nodes,
source_nodes_for_group_by_item_queries=tuple(group_by_item_source_nodes)
+ tuple(self._time_spine_source_nodes.values()),
source_nodes_for_metric_queries=tuple(source_nodes_for_metric_queries),
)

Expand Down
2 changes: 1 addition & 1 deletion metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def __init__(
DunderColumnAssociationResolver(semantic_manifest_lookup)
)
self._time_source = time_source
self._time_spine_source = TimeSpineSource.create_from_manifest(semantic_manifest_lookup.semantic_manifest)
self._time_spine_sources = TimeSpineSource.create_from_manifest(semantic_manifest_lookup.semantic_manifest)
self._source_data_sets: List[SemanticModelDataSet] = []
converter = SemanticModelToDataSetConverter(column_association_resolver=self._column_association_resolver)
for semantic_model in sorted(
Expand Down
35 changes: 28 additions & 7 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __init__(
self._semantic_manifest_lookup = semantic_manifest_lookup
self._metric_lookup = semantic_manifest_lookup.metric_lookup
self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup
self._time_spine_source = TimeSpineSource.create_from_manifest(semantic_manifest_lookup.semantic_manifest)
self._time_spine_sources = TimeSpineSource.create_from_manifest(semantic_manifest_lookup.semantic_manifest)

@property
def column_association_resolver(self) -> ColumnAssociationResolver: # noqa: D102
Expand Down Expand Up @@ -222,24 +222,47 @@ def _next_unique_table_alias(self) -> str:
"""Return the next unique table alias to use in generating queries."""
return SequentialIdGenerator.create_next_id(StaticIdPrefix.SUB_QUERY).str_value

def _choose_time_spine_source(
self, agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...]
) -> TimeSpineSource:
"""Determine which time spine source to use when building time spine dataset.
Will choose the time spine with the largest granularity that can be used to get the smallest granularity requested
in the agg time dimension instances. Example:
- Time spines available: SECOND, MINUTE, DAY
- Agg time dimension granularity needed for request: HOUR, DAY
--> Selected time spine: MINUTE
"""
assert (
agg_time_dimension_instances
), "Building time spine dataset requires agg_time_dimension_instances, but none were found."
smallest_agg_time_grain = min(dim.spec.time_granularity for dim in agg_time_dimension_instances)
compatible_time_spine_grains = [
grain for grain in self._time_spine_sources.keys() if grain.to_int() <= smallest_agg_time_grain.to_int()
]
if not compatible_time_spine_grains:
raise RuntimeError(
# TODO: update docs link when new docs are available
f"No compatible time spine found. This query requires a time spine with granularity {smallest_agg_time_grain} or smaller. See docs to configure a new time spine: https://docs.getdbt.com/docs/build/metricflow-time-spine"
)
return self._time_spine_sources[max(compatible_time_spine_grains)]

def _make_time_spine_data_set(
self,
agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...],
time_spine_source: TimeSpineSource,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> SqlDataSet:
"""Make a time spine data set, which contains all date/time values like '2020-01-01', '2020-01-02'...
"""Returns a dataset with a datetime column for each agg_time_dimension granularity requested.
Returns a dataset with a column selected for each agg_time_dimension requested.
Column alias will use 'metric_time' or the agg_time_dimension name depending on which the user requested.
"""
time_spine_instance_set = InstanceSet(time_dimension_instances=agg_time_dimension_instances)
time_spine_table_alias = self._next_unique_table_alias()

time_spine_source = self._choose_time_spine_source(agg_time_dimension_instances)
column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_table_alias, column_name=time_spine_source.time_column_name
)

select_columns: Tuple[SqlSelectColumn, ...] = ()
apply_group_by = False
for agg_time_dimension_instance in agg_time_dimension_instances:
Expand Down Expand Up @@ -314,7 +337,6 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
# Assemble time_spine dataset with requested agg time dimension instances selected.
time_spine_data_set = self._make_time_spine_data_set(
agg_time_dimension_instances=requested_agg_time_dimension_instances,
time_spine_source=self._time_spine_source,
time_range_constraint=node.time_range_constraint,
)
table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set
Expand Down Expand Up @@ -1234,7 +1256,6 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
time_spine_alias = self._next_unique_table_alias()
time_spine_dataset = self._make_time_spine_data_set(
agg_time_dimension_instances=(agg_time_dimension_instance_for_join,),
time_spine_source=self._time_spine_source,
time_range_constraint=node.time_range_constraint,
)

Expand Down
4 changes: 2 additions & 2 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def remove_unnecessary_nodes(
desired_linkable_specs: Sequence[LinkableInstanceSpec],
nodes: Sequence[DataflowPlanNode],
metric_time_dimension_reference: TimeDimensionReference,
time_spine_node: MetricTimeDimensionTransformNode,
time_spine_nodes: Sequence[MetricTimeDimensionTransformNode],
) -> List[DataflowPlanNode]:
"""Filters out many of the nodes that can't possibly be useful for joins to obtain the desired linkable specs.
Expand Down Expand Up @@ -661,7 +661,7 @@ def remove_unnecessary_nodes(
continue

# Used for group-by-item-values queries.
if node == time_spine_node:
if node in time_spine_nodes:
logger.debug(f"Including {node} since it matches `time_spine_node`")
relevant_nodes.append(node)
continue
Expand Down
59 changes: 37 additions & 22 deletions metricflow/plan_conversion/time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import logging
from dataclasses import dataclass
from typing import Optional
from typing import Dict, Optional

from dbt_semantic_interfaces.protocols import SemanticManifest
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from metricflow_semantics.mf_logging.pretty_print import mf_pformat
from metricflow_semantics.specs.time_dimension_spec import DEFAULT_TIME_GRANULARITY

from metricflow.sql.sql_table import SqlTable
Expand Down Expand Up @@ -34,26 +33,42 @@ def spine_table(self) -> SqlTable:
return SqlTable(schema_name=self.schema_name, table_name=self.table_name, db_name=self.db_name)

@staticmethod
def create_from_manifest(semantic_manifest: SemanticManifest) -> TimeSpineSource:
def create_from_manifest(semantic_manifest: SemanticManifest) -> Dict[TimeGranularity, TimeSpineSource]:
"""Creates a time spine source based on what's in the manifest."""
time_spine_table_configurations = semantic_manifest.project_configuration.time_spine_table_configurations

if not (
len(time_spine_table_configurations) == 1
and time_spine_table_configurations[0].grain == DEFAULT_TIME_GRANULARITY
):
raise NotImplementedError(
f"Only a single time spine table configuration with {DEFAULT_TIME_GRANULARITY} is currently "
f"supported. Got:\n"
f"{mf_pformat(time_spine_table_configurations)}"
time_spine_sources = {
time_spine.primary_column.time_granularity: TimeSpineSource(
schema_name=time_spine.node_relation.schema_name,
table_name=time_spine.node_relation.relation_name, # is relation name the table name? double check
db_name=time_spine.node_relation.database,
time_column_name=time_spine.primary_column.name,
time_column_granularity=time_spine.primary_column.time_granularity,
)
for time_spine in semantic_manifest.project_configuration.time_spines
}

time_spine_table_configuration = time_spine_table_configurations[0]
time_spine_table = SqlTable.from_string(time_spine_table_configuration.location)
return TimeSpineSource(
schema_name=time_spine_table.schema_name,
table_name=time_spine_table.table_name,
db_name=time_spine_table.db_name,
time_column_name=time_spine_table_configuration.column_name,
time_column_granularity=time_spine_table_configuration.grain,
)
# For backward compatibility: if legacy time spine config exists in the manifest, add that time spine here for
# backward compatibility. Ignore it if there is a new time spine config with the same granularity.
legacy_time_spines = semantic_manifest.project_configuration.time_spine_table_configurations
for legacy_time_spine in legacy_time_spines:
if not time_spine_sources.get(legacy_time_spine.grain):
time_spine_table = SqlTable.from_string(legacy_time_spine.location)
time_spine_sources[legacy_time_spine.grain] = TimeSpineSource(
schema_name=time_spine_table.schema_name,
table_name=time_spine_table.table_name,
db_name=time_spine_table.db_name,
time_column_name=legacy_time_spine.column_name,
time_column_granularity=legacy_time_spine.grain,
)

# Sanity check: this should have been validated during manifest parsing.
if not time_spine_sources:
raise RuntimeError(
"At least one time spine must be configured to use the semantic layer, but none were found."
)

return time_spine_sources


# DSI validations to add:
# - Check that there is only one time spine for each granularity option
# - Check that there is a time spine defined at minimum DAY
3 changes: 0 additions & 3 deletions metricflow/validation/data_warehouse_model_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from metricflow.dataset.dataset_classes import DataSet
from metricflow.engine.metricflow_engine import MetricFlowEngine, MetricFlowExplainResult, MetricFlowQueryRequest
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.plan_conversion.time_spine import TimeSpineSource
from metricflow.protocols.sql_client import SqlClient


Expand All @@ -59,7 +58,6 @@ class QueryRenderingTools:
semantic_manifest_lookup: SemanticManifestLookup
source_node_builder: SourceNodeBuilder
converter: SemanticModelToDataSetConverter
time_spine_source: TimeSpineSource
plan_converter: DataflowToSqlQueryPlanConverter

def __init__(self, manifest: SemanticManifest) -> None: # noqa: D107
Expand All @@ -68,7 +66,6 @@ def __init__(self, manifest: SemanticManifest) -> None: # noqa: D107
column_association_resolver=DunderColumnAssociationResolver(self.semantic_manifest_lookup),
semantic_manifest_lookup=self.semantic_manifest_lookup,
)
self.time_spine_source = TimeSpineSource.create_from_manifest(manifest)
self.converter = SemanticModelToDataSetConverter(
column_association_resolver=DunderColumnAssociationResolver(
semantic_manifest_lookup=self.semantic_manifest_lookup
Expand Down
8 changes: 4 additions & 4 deletions tests_metricflow/dataflow/builder/test_node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def node_evaluator(
].semantic_manifest_lookup.semantic_model_lookup,
nodes_available_for_joins=tuple(mf_engine_fixture.read_node_mapping.values()),
node_data_set_resolver=node_data_set_resolver,
time_spine_node=mf_engine_fixture.source_node_set.time_spine_node,
time_spine_nodes=mf_engine_fixture.source_node_set.time_spine_nodes_tuple,
)


Expand All @@ -72,7 +72,7 @@ def make_multihop_node_evaluator(
desired_linkable_specs=desired_linkable_specs,
nodes=source_node_set.source_nodes_for_metric_queries,
metric_time_dimension_reference=DataSet.metric_time_dimension_reference(),
time_spine_node=source_node_set.time_spine_node,
time_spine_nodes=source_node_set.time_spine_nodes_tuple,
)

nodes_available_for_joins = list(
Expand All @@ -87,7 +87,7 @@ def make_multihop_node_evaluator(
semantic_model_lookup=semantic_manifest_lookup_with_multihop_links.semantic_model_lookup,
nodes_available_for_joins=nodes_available_for_joins,
node_data_set_resolver=node_data_set_resolver,
time_spine_node=source_node_set.time_spine_node,
time_spine_nodes=source_node_set.time_spine_nodes_tuple,
)


Expand Down Expand Up @@ -518,7 +518,7 @@ def test_node_evaluator_with_scd_target(
# Use all nodes in the simple model as candidates for joins.
nodes_available_for_joins=tuple(mf_engine_fixture.read_node_mapping.values()),
node_data_set_resolver=node_data_set_resolver,
time_spine_node=mf_engine_fixture.source_node_set.time_spine_node,
time_spine_nodes=mf_engine_fixture.source_node_set.time_spine_nodes_tuple,
)

evaluation = node_evaluator.evaluate_node(
Expand Down
Loading

0 comments on commit 987a6ff

Please sign in to comment.