Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Querying with DATE PART #772

Merged
merged 14 commits into from
Sep 19, 2023
  •  
  •  
  •  
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230911-190924.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
body: Enable DATE PART aggregation for time dimensions
time: 2023-09-11T19:09:24.960342-07:00
custom:
Author: courtneyholcomb
Issue: "770"
1 change: 1 addition & 0 deletions metricflow/dag/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SQL_EXPR_IS_NULL_PREFIX = "isn"
SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt"
SQL_EXPR_DATE_TRUNC = "dt"
SQL_EXPR_EXTRACT = "ex"
SQL_EXPR_RATIO_COMPUTATION = "rc"
SQL_EXPR_BETWEEN_PREFIX = "betw"
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc"
Expand Down
48 changes: 32 additions & 16 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SqlColumnReferenceExpression,
SqlDateTruncExpression,
SqlExpressionNode,
SqlExtractExpression,
SqlStringExpression,
)
from metricflow.sql.sql_plan import (
Expand All @@ -46,6 +47,7 @@
SqlSelectStatementNode,
SqlTableFromClauseNode,
)
from metricflow.time.date_part import DatePart

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,12 +104,14 @@ def _create_time_dimension_instance(
time_dimension: Dimension,
entity_links: Tuple[EntityReference, ...],
time_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY,
date_part: Optional[DatePart] = None,
) -> TimeDimensionInstance:
"""Create a time dimension instance from the dimension object from a semantic model in the model."""
time_dimension_spec = TimeDimensionSpec(
element_name=time_dimension.reference.element_name,
entity_links=entity_links,
time_granularity=time_granularity,
date_part=date_part,
)

return TimeDimensionInstance(
Expand Down Expand Up @@ -219,6 +223,11 @@ def _convert_dimensions(
select_columns = []

for dimension in dimensions or []:
dimension_select_expr = SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
)
if dimension.type == DimensionType.CATEGORICAL:
dimension_instance = self._create_dimension_instance(
semantic_model_name=semantic_model_name,
Expand All @@ -228,11 +237,7 @@ def _convert_dimensions(
dimension_instances.append(dimension_instance)
select_columns.append(
SqlSelectColumn(
expr=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
expr=dimension_select_expr,
column_alias=dimension_instance.associated_column.column_name,
)
)
Expand All @@ -251,11 +256,7 @@ def _convert_dimensions(
time_dimension_instances.append(time_dimension_instance)
select_columns.append(
SqlSelectColumn(
expr=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
expr=dimension_select_expr,
column_alias=time_dimension_instance.associated_column.column_name,
)
)
Expand All @@ -274,16 +275,31 @@ def _convert_dimensions(
select_columns.append(
SqlSelectColumn(
expr=SqlDateTruncExpression(
time_granularity=time_granularity,
arg=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
time_granularity=time_granularity, arg=dimension_select_expr
),
column_alias=time_dimension_instance.associated_column.column_name,
)
)

# Add all date part options for easy query resolution
for date_part in DatePart:
if date_part.to_int() >= defined_time_granularity.to_int():
time_dimension_instance = self._create_time_dimension_instance(
semantic_model_name=semantic_model_name,
time_dimension=dimension,
entity_links=entity_links,
time_granularity=defined_time_granularity,
date_part=date_part,
)
time_dimension_instances.append(time_dimension_instance)

select_columns.append(
SqlSelectColumn(
expr=SqlExtractExpression(date_part=date_part, arg=dimension_select_expr),
column_alias=time_dimension_instance.associated_column.column_name,
)
)

else:
assert False, f"Unhandled dimension type: {dimension.type}"

Expand Down
8 changes: 6 additions & 2 deletions metricflow/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import logging
from typing import Sequence
from typing import Optional, Sequence

from dbt_semantic_interfaces.references import TimeDimensionReference
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from dbt_semantic_interfaces.validations.unique_valid_name import MetricFlowReservedKeywords

from metricflow.instances import InstanceSet, TimeDimensionInstance
from metricflow.specs.specs import TimeDimensionSpec
from metricflow.time.date_part import DatePart

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,12 +49,15 @@ def metric_time_dimension_name() -> str:
return DataSet.metric_time_dimension_reference().element_name

@staticmethod
def metric_time_dimension_spec(time_granularity: TimeGranularity) -> TimeDimensionSpec:
def metric_time_dimension_spec(
time_granularity: TimeGranularity, date_part: Optional[DatePart] = None
) -> TimeDimensionSpec:
"""Spec that corresponds to DataSet.metric_time_dimension_reference."""
return TimeDimensionSpec(
element_name=DataSet.metric_time_dimension_reference().element_name,
entity_links=(),
time_granularity=time_granularity,
date_part=date_part,
)

def __repr__(self) -> str: # noqa: D
Expand Down
1 change: 1 addition & 0 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
query_spec = self._query_parser.parse_and_validate_query(
metric_names=mf_query_request.metric_names,
group_by_names=mf_query_request.group_by_names,
group_by=mf_query_request.group_by,
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
limit=mf_query_request.limit,
time_constraint_start=mf_query_request.time_constraint_start,
time_constraint_end=mf_query_request.time_constraint_end,
Expand Down
32 changes: 26 additions & 6 deletions metricflow/naming/linkable_spec_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.time.date_part import DatePart

DUNDER = "__"

logger = logging.getLogger(__name__)
Expand All @@ -24,6 +26,7 @@ class StructuredLinkableSpecName:
entity_link_names: Tuple[str, ...]
element_name: str
time_granularity: Optional[TimeGranularity] = None
date_part: Optional[DatePart] = None

@staticmethod
def from_name(qualified_name: str) -> StructuredLinkableSpecName:
Expand All @@ -32,7 +35,7 @@ def from_name(qualified_name: str) -> StructuredLinkableSpecName:

# No dunder, e.g. "ds"
if len(name_parts) == 1:
return StructuredLinkableSpecName((), name_parts[0])
return StructuredLinkableSpecName(entity_link_names=(), element_name=name_parts[0])

associated_granularity = None
granularity: TimeGranularity
Expand All @@ -44,18 +47,30 @@ def from_name(qualified_name: str) -> StructuredLinkableSpecName:
if associated_granularity:
# e.g. "ds__month"
if len(name_parts) == 2:
return StructuredLinkableSpecName((), name_parts[0], associated_granularity)
return StructuredLinkableSpecName(
entity_link_names=(), element_name=name_parts[0], time_granularity=associated_granularity
)
# e.g. "messages__ds__month"
return StructuredLinkableSpecName(tuple(name_parts[:-2]), name_parts[-2], associated_granularity)
return StructuredLinkableSpecName(
entity_link_names=tuple(name_parts[:-2]),
element_name=name_parts[-2],
time_granularity=associated_granularity,
)

# e.g. "messages__ds"
else:
return StructuredLinkableSpecName(tuple(name_parts[:-1]), name_parts[-1])
return StructuredLinkableSpecName(entity_link_names=tuple(name_parts[:-1]), element_name=name_parts[-1])

@property
def qualified_name(self) -> str:
"""Return the full name form. e.g. ds or listing__ds__month."""
"""Return the full name form. e.g. ds or listing__ds__month.

If date_part is specified, don't include granularity in qualified_name since it will not impact the result.
"""
items = list(self.entity_link_names) + [self.element_name]
if self.time_granularity:
if self.date_part:
items.append(self.date_part_suffix(date_part=self.date_part))
elif self.time_granularity:
tlento marked this conversation as resolved.
Show resolved Hide resolved
items.append(self.time_granularity.value)
return DUNDER.join(items)

Expand All @@ -66,3 +81,8 @@ def entity_prefix(self) -> Optional[str]:
return DUNDER.join(self.entity_link_names)

return None

@staticmethod
def date_part_suffix(date_part: DatePart) -> str:
"""Suffix used for names with a date_part."""
return f"extract_{date_part.value}"
1 change: 1 addition & 0 deletions metricflow/plan_conversion/column_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def visit_time_dimension_spec(self, time_dimension_spec: TimeDimensionSpec) -> C
entity_link_names=tuple(x.element_name for x in time_dimension_spec.entity_links),
element_name=time_dimension_spec.element_name,
time_granularity=time_dimension_spec.time_granularity,
date_part=time_dimension_spec.date_part,
).qualified_name

return ColumnAssociation(
Expand Down
50 changes: 29 additions & 21 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
SqlComparisonExpression,
SqlDateTruncExpression,
SqlExpressionNode,
SqlExtractExpression,
SqlFunctionExpression,
SqlLogicalExpression,
SqlLogicalOperator,
Expand Down Expand Up @@ -287,7 +288,8 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
metric_time_dimension_spec
).column_name

# assemble dataset with metric_time_dimension to join
# Assemble time_spine dataset with metric_time_dimension to join.
# Granularity of time_spine column should match granularity of metric_time column from parent dataset.
assert metric_time_dimension_instance
time_spine_data_set = self._make_time_spine_data_set(
metric_time_dimension_instance=metric_time_dimension_instance,
Expand Down Expand Up @@ -1119,7 +1121,8 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr
# For those matching time dimension instances, create the analog metric time dimension instances for the output.
for matching_time_dimension_instance in matching_time_dimension_instances:
metric_time_dimension_spec = DataSet.metric_time_dimension_spec(
matching_time_dimension_instance.spec.time_granularity
time_granularity=matching_time_dimension_instance.spec.time_granularity,
date_part=matching_time_dimension_instance.spec.date_part,
)
metric_time_dimension_column_association = self._column_association_resolver.resolve_spec(
metric_time_dimension_spec
Expand All @@ -1134,6 +1137,7 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr
output_column_to_input_column[
metric_time_dimension_column_association.column_name
] = matching_time_dimension_instance.associated_column.column_name

output_instance_set = InstanceSet(
measure_instances=tuple(output_measure_instances),
dimension_instances=input_data_set.instance_set.dimension_instances,
Expand Down Expand Up @@ -1345,38 +1349,25 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
len(time_spine_dataset.instance_set.time_dimension_instances) == 1
and len(time_spine_dataset.sql_select_node.select_columns) == 1
), "Time spine dataset not configured properly. Expected exactly one column."
original_time_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0]
time_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0]
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
] = SqlColumnReferenceExpression(
SqlColumnReference(table_alias=time_spine_alias, column_name=original_time_dim_instance.spec.qualified_name)
SqlColumnReference(table_alias=time_spine_alias, column_name=time_dim_instance.spec.qualified_name)
)

# Add requested granularities (skip for default granularity).
# Add requested granularities (skip for default granularity) and date_parts.
metric_time_select_columns = []
metric_time_dimension_instances = []
where: Optional[SqlExpressionNode] = None
for metric_time_dimension_spec in node.metric_time_dimension_specs:
# Apply granularity to SQL.
if metric_time_dimension_spec.time_granularity == self._time_spine_source.time_column_granularity:
select_expr = time_spine_column_select_expr
time_dim_instance = original_time_dim_instance
column_alias = original_time_dim_instance.associated_column.column_name
select_expr: SqlExpressionNode = time_spine_column_select_expr
else:
select_expr = SqlDateTruncExpression(
time_granularity=metric_time_dimension_spec.time_granularity, arg=time_spine_column_select_expr
)
new_time_dim_spec = TimeDimensionSpec(
element_name=original_time_dim_instance.spec.element_name,
entity_links=original_time_dim_instance.spec.entity_links,
time_granularity=metric_time_dimension_spec.time_granularity,
aggregation_state=original_time_dim_instance.spec.aggregation_state,
)
time_dim_instance = TimeDimensionInstance(
defined_from=original_time_dim_instance.defined_from,
associated_columns=(self._column_association_resolver.resolve_spec(new_time_dim_spec),),
spec=new_time_dim_spec,
)
column_alias = time_dim_instance.associated_column.column_name
if node.offset_to_grain:
# Filter down to one row per granularity period
new_filter = SqlComparisonExpression(
Expand All @@ -1386,8 +1377,25 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
where = new_filter
else:
where = SqlLogicalExpression(operator=SqlLogicalOperator.OR, args=(where, new_filter))
# Apply date_part to SQL.
if metric_time_dimension_spec.date_part:
select_expr = SqlExtractExpression(date_part=metric_time_dimension_spec.date_part, arg=select_expr)
time_dim_spec = TimeDimensionSpec(
element_name=time_dim_instance.spec.element_name,
entity_links=time_dim_instance.spec.entity_links,
time_granularity=metric_time_dimension_spec.time_granularity,
date_part=metric_time_dimension_spec.date_part,
aggregation_state=time_dim_instance.spec.aggregation_state,
)
time_dim_instance = TimeDimensionInstance(
defined_from=time_dim_instance.defined_from,
associated_columns=(self._column_association_resolver.resolve_spec(time_dim_spec),),
spec=time_dim_spec,
)
metric_time_dimension_instances.append(time_dim_instance)
metric_time_select_columns.append(SqlSelectColumn(expr=select_expr, column_alias=column_alias))
metric_time_select_columns.append(
SqlSelectColumn(expr=select_expr, column_alias=time_dim_instance.associated_column.column_name)
)
metric_time_instance_set = InstanceSet(time_dimension_instances=tuple(metric_time_dimension_instances))

return SqlDataSet(
Expand Down
11 changes: 9 additions & 2 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
SqlFunctionExpression,
)
from metricflow.sql.sql_plan import SqlSelectColumn
from metricflow.time.date_part import DatePart

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -262,6 +263,7 @@ class _DimensionValidityParams:

dimension_name: str
time_granularity: TimeGranularity
date_part: Optional[DatePart] = None


class CreateValidityWindowJoinDescription(InstanceSetTransform[Optional[ValidityWindowJoinDescription]]):
Expand Down Expand Up @@ -324,12 +326,16 @@ def transform(self, instance_set: InstanceSet) -> Optional[ValidityWindowJoinDes
start_specs = [
spec
for spec in specs
if spec.element_name == start_dim.dimension_name and spec.time_granularity == start_dim.time_granularity
if spec.element_name == start_dim.dimension_name
and spec.time_granularity == start_dim.time_granularity
and spec.date_part == start_dim.date_part
]
end_specs = [
spec
for spec in specs
if spec.element_name == end_dim.dimension_name and spec.time_granularity == end_dim.time_granularity
if spec.element_name == end_dim.dimension_name
and spec.time_granularity == end_dim.time_granularity
and spec.date_part == end_dim.date_part
]
linkless_start_specs = {spec.without_entity_links for spec in start_specs}
linkless_end_specs = {spec.without_entity_links for spec in end_specs}
Expand Down Expand Up @@ -401,6 +407,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
+ time_dimension_instance.spec.entity_links
),
time_granularity=time_dimension_instance.spec.time_granularity,
date_part=time_dimension_instance.spec.date_part,
)
time_dimension_instances_with_additional_link.append(
TimeDimensionInstance(
Expand Down
Loading