Skip to content

Commit

Permalink
Dataflow plan for metric_time alone
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 10, 2023
1 parent b5e6110 commit f0837c6
Show file tree
Hide file tree
Showing 17 changed files with 575 additions and 48 deletions.
113 changes: 96 additions & 17 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
JoinOverTimeRangeNode,
JoinToBaseOutputNode,
JoinToTimeSpineNode,
MetricTimeDimensionTransformNode,
OrderByLimitNode,
ReadSqlSourceNode,
SemiAdditiveJoinNode,
Expand All @@ -44,12 +45,16 @@
from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer
from metricflow.dataflow.sql_table import SqlTable
from metricflow.dataset.dataset import DataSet
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.errors.errors import UnableToSatisfyQueryError
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.instances import InstanceSet, TimeDimensionInstance
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver
from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.plan_conversion.time_spine import TIME_SPINE_DATA_SET_DESCRIPTION
from metricflow.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver, SingleColumnCorrelationKey
from metricflow.specs.specs import (
InstanceSpecSet,
LinkableInstanceSpec,
Expand All @@ -64,7 +69,8 @@
TimeDimensionSpec,
WhereFilterSpec,
)
from metricflow.sql.sql_plan import SqlJoinType
from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression, SqlDateTruncExpression
from metricflow.sql.sql_plan import SqlJoinType, SqlSelectColumn, SqlSelectStatementNode, SqlTableFromClauseNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -148,6 +154,7 @@ def __init__( # noqa: D
) -> None:
self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup
self._metric_lookup = semantic_manifest_lookup.metric_lookup
self._time_spine_source = semantic_manifest_lookup.time_spine_source
self._metric_time_dimension_reference = DataSet.metric_time_dimension_reference()
self._source_nodes = source_nodes
self._read_nodes = read_nodes
Expand Down Expand Up @@ -408,18 +415,18 @@ def _select_source_nodes_with_measures(

def _select_read_nodes_with_linkable_specs(
self, linkable_specs: LinkableSpecSet, read_nodes: Sequence[ReadSqlSourceNode]
) -> Dict[BaseOutput, Set[LinkableInstanceSpec]]:
) -> List[ReadSqlSourceNode]:
"""Find source nodes with requested linkable specs and no measures."""
nodes_to_linkable_specs: Dict[BaseOutput, Set[LinkableInstanceSpec]] = {}
linkable_specs_set = set(linkable_specs.as_tuple)
selected_nodes: List[ReadSqlSourceNode] = []
requested_linkable_specs_set = set(linkable_specs.as_tuple)
for read_node in read_nodes:
output_spec_set = self._node_data_set_resolver.get_output_data_set(read_node).instance_set.spec_set
linkable_specs_in_node = set(output_spec_set.linkable_specs)
requested_linkable_specs_in_node = linkable_specs_set.intersection(linkable_specs_in_node)
all_linkable_specs_in_node = set(output_spec_set.linkable_specs)
requested_linkable_specs_in_node = requested_linkable_specs_set.intersection(all_linkable_specs_in_node)
if requested_linkable_specs_in_node:
nodes_to_linkable_specs[read_node] = requested_linkable_specs_in_node
selected_nodes.append(read_node)

return nodes_to_linkable_specs
return selected_nodes

def _find_non_additive_dimension_in_linkable_specs(
self,
Expand Down Expand Up @@ -473,31 +480,103 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) -
non_additive_dimension_spec=non_additive_dimension_spec,
)

# TODO: this should live somewhere else. Figure out where makes sense
def _create_metric_time_node_from_time_spine(self) -> MetricTimeDimensionTransformNode:
"""Build a ReadSqlSourceNode that represents reading from the time spine table."""
time_spine_source = self._time_spine_source
from_source_alias = IdGeneratorRegistry.for_class(self.__class__).create_id("time_spine_src")

# TODO: add date part to instances & select columns. Can we use the same logic as elsewhere??
time_spine_instances: List[TimeDimensionInstance] = []
select_columns: List[SqlSelectColumn] = []
for granularity in TimeGranularity:
if granularity.to_int() >= time_spine_source.time_column_granularity.to_int():
column_alias = StructuredLinkableSpecName(
entity_link_names=(),
element_name=time_spine_source.time_column_name,
time_granularity=granularity,
).qualified_name
time_spine_instance = TimeDimensionInstance(
defined_from=(),
associated_columns=(
ColumnAssociation(
column_name=column_alias,
single_column_correlation_key=SingleColumnCorrelationKey(),
),
),
spec=TimeDimensionSpec(
element_name=time_spine_source.time_column_name, entity_links=(), time_granularity=granularity
),
)
time_spine_instances.append(time_spine_instance)
select_column = SqlSelectColumn(
SqlDateTruncExpression(
time_granularity=granularity,
arg=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=from_source_alias,
column_name=time_spine_source.time_column_name,
),
),
),
column_alias=column_alias,
)
select_columns.append(select_column)

time_spine_instance_set = InstanceSet(time_dimension_instances=tuple(time_spine_instances))
time_spine_data_set = SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=tuple(select_columns),
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
from_source_alias=from_source_alias,
joins_descs=(),
group_bys=(),
order_bys=(),
),
)
# need this if we have above??
return MetricTimeDimensionTransformNode(
parent_node=ReadSqlSourceNode(data_set=time_spine_data_set),
aggregation_time_dimension_reference=TimeDimensionReference(
element_name=time_spine_source.time_column_name
),
)

def _find_dataflow_recipe(
self,
linkable_spec_set: LinkableSpecSet,
measure_spec_properties: Optional[MeasureSpecProperties] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> Optional[DataflowRecipe]:
linkable_specs = linkable_spec_set.as_tuple
potential_source_nodes: Sequence[BaseOutput]
if measure_spec_properties:
source_nodes = self._source_nodes
potential_source_nodes: Sequence[BaseOutput] = self._select_source_nodes_with_measures(
potential_source_nodes = self._select_source_nodes_with_measures(
measure_specs=set(measure_spec_properties.measure_specs), source_nodes=source_nodes
)
else:
# Only read nodes can be source nodes for queries without measures
source_nodes = list(self._read_nodes)
source_nodes_to_linkable_specs = self._select_read_nodes_with_linkable_specs(
linkable_specs=linkable_spec_set, read_nodes=source_nodes
)
# Add time_spine to potential source nodes w/ metric_time as linkable spec
# Maybe only do this if requested
potential_source_nodes = list(source_nodes_to_linkable_specs.keys())
potential_source_nodes = self._select_read_nodes_with_linkable_specs(
linkable_specs=linkable_spec_set, read_nodes=self._read_nodes
)
# `metric_time` does not exist if there is no metric in the query.
# In that case, we'll use the time spine table to represent `metric_time` values.
requested_metric_time_specs = [
time_dimension_spec
for time_dimension_spec in linkable_spec_set.time_dimension_specs
if time_dimension_spec.element_name == self._metric_time_dimension_reference.element_name
]
if requested_metric_time_specs:
# Add time_spine to potential source nodes for requested metric_time specs
time_spine_node = self._create_metric_time_node_from_time_spine()
potential_source_nodes = list(potential_source_nodes) + [time_spine_node]

logger.info(f"There are {len(potential_source_nodes)} potential source nodes")

logger.info(f"Starting search with {len(source_nodes)} source nodes")
start_time = time.time()

node_processor = PreJoinNodeProcessor(
Expand Down
12 changes: 8 additions & 4 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
LinkableInstanceSpec,
LinklessEntitySpec,
)
from metricflow.test.time.metric_time_dimension import MTD

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -322,11 +323,11 @@ def evaluate_node(

data_set_linkable_specs = candidate_spec_set.linkable_specs

# These are linkable specs in the same data set as the measure. Those are considered "local".
local_linkable_specs = []
# These are linkable specs in the start node data set. Those are considered "local".
local_linkable_specs: List[LinkableInstanceSpec] = []

# These are linkable specs that aren't in the data set, but they might be able to be joined in.
possibly_joinable_linkable_specs = []
possibly_joinable_linkable_specs: List[LinkableInstanceSpec] = []

# Group required_linkable_specs into local / un-joinable / or possibly joinable.
unjoinable_linkable_specs = []
Expand Down Expand Up @@ -364,7 +365,10 @@ def evaluate_node(
"There are no more candidate nodes that can be joined, but not all linkable specs have "
"been acquired."
)
unjoinable_linkable_specs.extend(possibly_joinable_linkable_specs)
if all(spec.element_name == MTD for spec in possibly_joinable_linkable_specs):
pass
else:
unjoinable_linkable_specs.extend(possibly_joinable_linkable_specs)
break

# Join the best candidate to realize the linkable specs
Expand Down
8 changes: 7 additions & 1 deletion metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from metricflow.aggregation_properties import AggregationState
from metricflow.dag.id_generation import IdGeneratorRegistry
from metricflow.dataflow.sql_table import SqlTable
from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet
from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet, SqlDataSet
from metricflow.instances import (
DimensionInstance,
EntityInstance,
Expand All @@ -26,6 +26,7 @@
)
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow.model.spec_converters import MeasureConverter
from metricflow.plan_conversion.time_spine import TimeSpineSource
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import (
DEFAULT_TIME_GRANULARITY,
Expand Down Expand Up @@ -478,3 +479,8 @@ def create_sql_source_data_set(self, semantic_model: SemanticModel) -> SemanticM
),
sql_select_node=select_statement_node,
)

# move logic here?
def create_data_set_from_time_spine(self, time_spine_source: TimeSpineSource) -> SqlDataSet:
"""Create a SQL source data set from time spine table."""
pass
2 changes: 2 additions & 0 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self, instance_set: InstanceSet, sql_select_node: SqlSelectStatemen
self._sql_select_node = sql_select_node
super().__init__(instance_set=instance_set)

# TODO: add optional __repr__ to display the table name or pass in custom name (time spine)

@property
def sql_select_node(self) -> SqlSelectStatementNode:
"""Return a SELECT node that can be used to read data from the given SQL table or SQL query."""
Expand Down
23 changes: 9 additions & 14 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
ColumnEqualityDescription,
SqlQueryPlanJoinBuilder,
)
from metricflow.plan_conversion.time_spine import TimeSpineSource
from metricflow.plan_conversion.time_spine import TIME_SPINE_DATA_SET_DESCRIPTION, TimeSpineSource
from metricflow.protocols.sql_client import SqlEngine
from metricflow.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver, SingleColumnCorrelationKey
from metricflow.specs.specs import (
Expand Down Expand Up @@ -185,19 +185,15 @@ def _make_time_spine_data_set(
spec=metric_time_dimension_instance.spec,
),
)
time_spine_instance_set = InstanceSet(
time_dimension_instances=time_spine_instance,
)
description = "Date Spine"
time_spine_instance_set = InstanceSet(time_dimension_instances=time_spine_instance)
time_spine_table_alias = self._next_unique_table_alias()

# If the requested granularity is the same as the granularity of the spine, do a direct select.
if metric_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity:
return SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=description,
# This creates select expressions for all columns referenced in the instance set.
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression(
Expand Down Expand Up @@ -242,8 +238,7 @@ def _make_time_spine_data_set(
return SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=description,
# This creates select expressions for all columns referenced in the instance set.
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=select_columns,
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
from_source_alias=time_spine_table_alias,
Expand Down Expand Up @@ -1035,11 +1030,11 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr
if aggregation_time_dimension_for_measure == node.aggregation_time_dimension_reference:
output_measure_instances.append(measure_instance)

if len(output_measure_instances) == 0:
raise RuntimeError(
f"No measure instances in the input source match the aggregation time dimension "
f"{node.aggregation_time_dimension_reference}. Check if the dataflow plan was constructed correctly."
)
# if len(output_measure_instances) == 0:
# raise RuntimeError(
# f"No measure instances in the input source match the aggregation time dimension "
# f"{node.aggregation_time_dimension_reference}. Check if the dataflow plan was constructed correctly."
# )

# Find time dimension instances that refer to the same dimension as the one specified in the node.
matching_time_dimension_instances = []
Expand Down
2 changes: 2 additions & 0 deletions metricflow/plan_conversion/time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

logger = logging.getLogger(__name__)

TIME_SPINE_DATA_SET_DESCRIPTION = "Date Spine"


@dataclass(frozen=True)
class TimeSpineSource:
Expand Down
Loading

0 comments on commit f0837c6

Please sign in to comment.