Skip to content

Commit

Permalink
Add SQL rendering logic for custom granularities
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 5, 2024
1 parent d8ee7c8 commit 910f21a
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import Dict, Optional, Sequence

from dbt_semantic_interfaces.implementations.time_spine import PydanticTimeSpineCustomGranularityColumn
from dbt_semantic_interfaces.protocols import SemanticManifest
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

Expand All @@ -29,7 +30,7 @@ class TimeSpineSource:
# The time granularity of the base column.
base_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY
db_name: Optional[str] = None
custom_granularities: Sequence[str] = ()
custom_granularities: Sequence[PydanticTimeSpineCustomGranularityColumn] = ()

@property
def spine_table(self) -> SqlTable:
Expand All @@ -48,7 +49,14 @@ def build_standard_time_spine_sources(
db_name=time_spine.node_relation.database,
base_column=time_spine.primary_column.name,
base_granularity=time_spine.primary_column.time_granularity,
custom_granularities=[column.name for column in time_spine.custom_granularities],
custom_granularities=tuple(
[
PydanticTimeSpineCustomGranularityColumn(
name=custom_granularity.name, column_name=custom_granularity.column_name
)
for custom_granularity in time_spine.custom_granularities
]
),
)
for time_spine in semantic_manifest.project_configuration.time_spines
}
Expand All @@ -74,3 +82,12 @@ def build_standard_time_spine_sources(
)

return time_spine_sources

@staticmethod
def build_custom_time_spine_sources(time_spine_sources: Sequence[TimeSpineSource]) -> Dict[str, TimeSpineSource]:
"""Creates a set of time spine sources with custom granularities based on what's in the manifest."""
return {
custom_granularity.name: time_spine_source
for time_spine_source in time_spine_sources
for custom_granularity in time_spine_source.custom_granularities
}
6 changes: 6 additions & 0 deletions metricflow/dataflow/nodes/join_to_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ class JoinToCustomGranularityNode(DataflowPlanNode, ABC):

time_dimension_spec: TimeDimensionSpec

def __post_init__(self) -> None: # noqa: D105
assert (
self.time_dimension_spec.time_granularity.is_custom_granularity
), "Time granularity for time dimension spec in JoinToCustomGranularityNode must be qualified as custom granularity."
f" Instead, found {self.time_dimension_spec.time_granularity.name}. This indicates internal misconfiguration."

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec
Expand Down
131 changes: 116 additions & 15 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def __init__(
self._time_spine_sources = TimeSpineSource.build_standard_time_spine_sources(
semantic_manifest_lookup.semantic_manifest
)
self._custom_granularity_time_spine_sources = TimeSpineSource.build_custom_time_spine_sources(
tuple(self._time_spine_sources.values())
)

@property
def column_association_resolver(self) -> ColumnAssociationResolver: # noqa: D102
Expand Down Expand Up @@ -237,7 +240,7 @@ def _next_unique_table_alias(self) -> str:
"""Return the next unique table alias to use in generating queries."""
return SequentialIdGenerator.create_next_id(StaticIdPrefix.SUB_QUERY).str_value

def _choose_time_spine_source(
def _choose_time_spine_source_for_standard_granularity(
self, agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...]
) -> TimeSpineSource:
"""Determine which time spine source to use when building time spine dataset.
Expand All @@ -254,9 +257,10 @@ def _choose_time_spine_source(
assert (
agg_time_dimension_instances
), "Building time spine dataset requires agg_time_dimension_instances, but none were found."
smallest_agg_time_grain = min(
dim.spec.time_granularity.base_granularity for dim in agg_time_dimension_instances
)
smallest_agg_time_grain = sorted(
agg_time_dimension_instances,
key=lambda x: x.spec.time_granularity.base_granularity.to_int(),
)[0].spec.time_granularity.base_granularity
compatible_time_spine_grains = [
grain for grain in self._time_spine_sources.keys() if grain.to_int() <= smallest_agg_time_grain.to_int()
]
Expand All @@ -268,6 +272,48 @@ def _choose_time_spine_source(
)
return self._time_spine_sources[max(compatible_time_spine_grains)]

def _get_time_spine_for_custom_granularity(self, custom_granularity: str) -> TimeSpineSource:
time_spine_source = self._custom_granularity_time_spine_sources.get(custom_granularity)
assert time_spine_source, (
f"Custom granularity {custom_granularity} does not not exist in time spine sources. "
f"Available custom granularities: {list(self._custom_granularity_time_spine_sources.keys())}"
)
return time_spine_source

def _get_custom_granularity_column_name(self, custom_granularity: str) -> str:
time_spine_source = self._get_time_spine_for_custom_granularity(custom_granularity)
for custom_granularity in time_spine_source.custom_granularities:
if custom_granularity.name == custom_granularity:
return custom_granularity.column_name if custom_granularity.column_name else custom_granularity.name

raise RuntimeError(
f"Custom granularity {custom_granularity} not found. This indicates internal misconfiguration."
)

def _make_custom_granularity_dataset(self, time_dimension_instance: TimeDimensionInstance) -> SqlDataSet:
time_spine_instance_set = InstanceSet(time_dimension_instances=(time_dimension_instance,))
time_spine_table_alias = self._next_unique_table_alias()
assert (
time_dimension_instance.spec.time_granularity.is_custom_granularity
), "_make_custom_granularity_dataset() should only be called for custom granularities."

custom_granularity_name = time_dimension_instance.spec.time_granularity.name
time_spine_source = self._get_time_spine_for_custom_granularity(custom_granularity_name)
custom_granularity_column_name = self._get_custom_granularity_column_name(custom_granularity_name)
column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_table_alias, column_name=custom_granularity_column_name
)
column_alias = self.column_association_resolver.resolve_spec(time_dimension_instance.spec).column_name
return SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode.create(
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=(SqlSelectColumn(expr=column_expr, column_alias=column_alias),),
from_source=SqlTableFromClauseNode.create(sql_table=time_spine_source.spine_table),
from_source_alias=time_spine_table_alias,
),
)

def _make_time_spine_data_set(
self,
agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...],
Expand All @@ -280,7 +326,7 @@ def _make_time_spine_data_set(
time_spine_instance_set = InstanceSet(time_dimension_instances=agg_time_dimension_instances)
time_spine_table_alias = self._next_unique_table_alias()

time_spine_source = self._choose_time_spine_source(agg_time_dimension_instances)
time_spine_source = self._choose_time_spine_source_for_standard_granularity(agg_time_dimension_instances)
column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_table_alias, column_name=time_spine_source.base_column
)
Expand All @@ -290,9 +336,7 @@ def _make_time_spine_data_set(
column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_instance.spec).column_name
# If the requested granularity is the same as the granularity of the spine, do a direct select.
# TODO: also handle date part.
# TODO: [custom granularity] add support for custom granularities to make_time_spine_data_set
agg_time_grain = agg_time_dimension_instance.spec.time_granularity
assert not agg_time_grain.is_custom_granularity, "Custom time granularities are not yet supported!"
if agg_time_grain.base_granularity == time_spine_source.base_granularity:
select_columns += (SqlSelectColumn(expr=column_expr, column_alias=column_alias),)
# If any columns have a different granularity, apply a DATE_TRUNC() and aggregate via group_by.
Expand Down Expand Up @@ -334,6 +378,7 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet:
instance_set=node.data_set.instance_set,
)

# TODO: write tests for custom granularities that hit this node
def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet:
"""Generate time range join SQL."""
table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict()
Expand Down Expand Up @@ -1257,6 +1302,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe
),
)

# TODO: write tests for custom granularities that hit this node
def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102
parent_data_set = node.parent_node.accept(self)
parent_alias = self._next_unique_table_alias()
Expand All @@ -1280,10 +1326,6 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
agg_time_dimension_instances.append(instance)

# Choose the instance with the smallest standard granularity available.
# TODO: [custom granularity] Update to account for custom granularity instances
assert all(
[not instance.spec.time_granularity.is_custom_granularity for instance in agg_time_dimension_instances]
), "Custom granularities are not yet supported!"
agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int())
assert len(agg_time_dimension_instances) > 0, (
"Couldn't find requested agg_time_dimension in parent data set. The dataflow plan may have been "
Expand Down Expand Up @@ -1369,9 +1411,6 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
# Add requested granularities (if different from time_spine) and date_parts to time spine column.
for time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = time_dimension_instance.spec

# TODO: this will break when we start supporting smaller grain than DAY unless the time spine table is
# updated to use the smallest available grain.
if (
time_dimension_spec.time_granularity.base_granularity.to_int()
< original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int()
Expand Down Expand Up @@ -1437,7 +1476,69 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
)

def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> SqlDataSet: # noqa: D102
raise NotImplementedError # TODO in later commit
parent_data_set = node.parent_node.accept(self)
parent_alias = self._next_unique_table_alias()
parent_time_dimension_instance: Optional[TimeDimensionInstance] = None
for instance in parent_data_set.instance_set.time_dimension_instances:
if instance.spec == node.time_dimension_spec:
parent_time_dimension_instance = instance
break
assert parent_time_dimension_instance, (
"JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. "
"This indicates internal misconfiguration."
)

# Build time spine dataset.
time_spine_dataset = self._make_custom_granularity_dataset(parent_time_dimension_instance)
assert (
time_spine_dataset.instance_set.time_dimension_instances
), "No time dimensions found in time spine dataset. This indicates internal misconfiguration."
time_spine_instance = time_spine_dataset.instance_set.time_dimension_instances[0]

# Build join expression.
time_spine_source = self._get_time_spine_for_custom_granularity(node.time_dimension_spec.time_granularity.name)
left_expr_for_join: SqlExpressionNode = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=parent_alias, column_name=parent_time_dimension_instance.associated_column.column_name
)
left_expr_for_join = (
left_expr_for_join
if parent_time_dimension_instance.spec.time_granularity == time_spine_source.base_granularity
else SqlDateTruncExpression.create(
time_granularity=time_spine_source.base_granularity, arg=left_expr_for_join
)
)
time_spine_alias = self._next_unique_table_alias()
join_description = SqlJoinDescription(
right_source=time_spine_dataset.checked_sql_select_node,
right_source_alias=time_spine_alias,
on_condition=SqlComparisonExpression.create(
left_expr=left_expr_for_join,
comparison=SqlComparison.EQUALS,
right_expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=time_spine_source.base_column
),
),
join_type=SqlJoinType.LEFT_OUTER,
)

# Build output dataset, replacing the custom time dimension from the parent dataset with the one from the time spine.
parent_instance_set = parent_data_set.instance_set.transform(
FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=(parent_time_dimension_instance.spec,)))
)
time_spine_instance_set = InstanceSet(time_dimension_instances=(time_spine_instance,))
return SqlDataSet(
instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]),
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=create_select_columns_for_instance_sets(
self._column_association_resolver,
OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_instance_set}),
),
from_source=parent_data_set.checked_sql_select_node,
from_source_alias=parent_alias,
join_descs=(join_description,),
),
)

def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D102
parent_data_set = node.parent_node.accept(self)
Expand Down

0 comments on commit 910f21a

Please sign in to comment.