Skip to content

Commit

Permalink
Refactor _make_time_spine_data_set() for readability & simplicity
Browse files Browse the repository at this point in the history
Also supports an upcoming change to allow multiple agg_time_dimensions in this function.
  • Loading branch information
courtneyholcomb committed Jun 12, 2024
1 parent 0796a98 commit 7fc74c0
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 78 deletions.
120 changes: 42 additions & 78 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,89 +223,59 @@ def _next_unique_table_alias(self) -> str:
def _make_time_spine_data_set(
self,
agg_time_dimension_instance: TimeDimensionInstance,
agg_time_dimension_column_name: str,
time_spine_source: TimeSpineSource,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> SqlDataSet:
"""Make a time spine data set, which contains all date values like '2020-01-01', '2020-01-02'...
"""Make a time spine data set, which contains all date/time values like '2020-01-01', '2020-01-02'...
This is useful in computing cumulative metrics. This will need to be updated to support granularities finer than a
day.
Returns a data set with a column for the agg_time_dimension requested.
Column alias will use 'metric_time' or the agg_time_dimension name depending on which the user requested.
"""
time_spine_instance = TimeDimensionInstance(
defined_from=agg_time_dimension_instance.defined_from,
associated_columns=(ColumnAssociation(agg_time_dimension_column_name),),
spec=agg_time_dimension_instance.spec,
)

time_spine_instance_set = InstanceSet(time_dimension_instances=(time_spine_instance,))
time_spine_instance_set = InstanceSet(time_dimension_instances=(agg_time_dimension_instance,))
time_spine_table_alias = self._next_unique_table_alias()

# If the requested granularity is the same as the granularity of the spine, do a direct select.
column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_table_alias, column_name=time_spine_source.time_column_name
)

select_columns: Tuple[SqlSelectColumn, ...] = ()
apply_group_by = False
column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_instance.spec).column_name
# If the requested granularity matches that of the time spine, do a direct select.
# TODO: also handle date part.
if agg_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity:
return SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_table_alias,
column_name=time_spine_source.time_column_name,
),
),
column_alias=agg_time_dimension_column_name,
),
),
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
from_source_alias=time_spine_table_alias,
where=(
_make_time_range_comparison_expr(
table_alias=time_spine_table_alias,
column_alias=time_spine_source.time_column_name,
time_range_constraint=time_range_constraint,
)
if time_range_constraint
else None
),
),
)
# If the granularity is different, apply a DATE_TRUNC() and aggregate.
select_columns += (SqlSelectColumn(expr=column_expr, column_alias=column_alias),)
# Otherwise, apply a DATE_TRUNC() and aggregate via group_by.
else:
select_columns = (
select_columns += (
SqlSelectColumn(
expr=SqlDateTruncExpression(
time_granularity=agg_time_dimension_instance.spec.time_granularity,
arg=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_table_alias,
column_name=time_spine_source.time_column_name,
),
),
time_granularity=agg_time_dimension_instance.spec.time_granularity, arg=column_expr
),
column_alias=agg_time_dimension_column_name,
column_alias=column_alias,
),
)
return SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=select_columns,
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
from_source_alias=time_spine_table_alias,
group_bys=select_columns,
where=(
_make_time_range_comparison_expr(
table_alias=time_spine_table_alias,
column_alias=time_spine_source.time_column_name,
time_range_constraint=time_range_constraint,
)
if time_range_constraint
else None
),
apply_group_by = True

return SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=select_columns,
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
from_source_alias=time_spine_table_alias,
group_bys=select_columns if apply_group_by else (),
where=(
_make_time_range_comparison_expr(
table_alias=time_spine_table_alias,
column_alias=time_spine_source.time_column_name,
time_range_constraint=time_range_constraint,
)
if time_range_constraint
else None
),
)
),
)

def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet:
"""Generate the SQL to read from the source."""
Expand All @@ -332,15 +302,10 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat

time_spine_data_set_alias = self._next_unique_table_alias()

agg_time_dimension_column_name = self.column_association_resolver.resolve_spec(
agg_time_dimension_instance.spec
).column_name

# Assemble time_spine dataset with metric_time_dimension to join.
# Granularity of time_spine column should match granularity of metric_time column from parent dataset.
time_spine_data_set = self._make_time_spine_data_set(
agg_time_dimension_instance=agg_time_dimension_instance,
agg_time_dimension_column_name=agg_time_dimension_column_name,
time_spine_source=self._time_spine_source,
time_range_constraint=node.time_range_constraint,
)
Expand Down Expand Up @@ -1275,22 +1240,21 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
agg_time_dimension_instance_for_join = agg_time_dimension_instances[0]

# Build time spine data set using the requested agg_time_dimension name.
agg_time_dimension_column_name = self.column_association_resolver.resolve_spec(
agg_time_dimension_instance_for_join.spec
).column_name
time_spine_alias = self._next_unique_table_alias()
time_spine_dataset = self._make_time_spine_data_set(
agg_time_dimension_instance=agg_time_dimension_instance_for_join,
agg_time_dimension_column_name=agg_time_dimension_column_name,
time_spine_source=self._time_spine_source,
time_range_constraint=node.time_range_constraint,
)

# Build join expression.

join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description(
node=node,
time_spine_alias=time_spine_alias,
agg_time_dimension_column_name=agg_time_dimension_column_name,
agg_time_dimension_column_name=self.column_association_resolver.resolve_spec(
agg_time_dimension_instance_for_join.spec
).column_name,
parent_sql_select_node=parent_data_set.checked_sql_select_node,
parent_alias=parent_alias,
)
Expand Down
4 changes: 4 additions & 0 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.col_ref == other.col_ref

@staticmethod
def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumnReferenceExpression: # noqa: D102
return SqlColumnReferenceExpression(SqlColumnReference(table_alias=table_alias, column_name=column_name))


class SqlColumnAliasReferenceExpression(SqlExpressionNode):
"""An expression that evaluates to the alias of a column, but is not qualified with a table alias.
Expand Down

0 comments on commit 7fc74c0

Please sign in to comment.