Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Querying with DATE PART #772

Merged
merged 14 commits into from
Sep 19, 2023
1 change: 1 addition & 0 deletions metricflow/dag/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SQL_EXPR_IS_NULL_PREFIX = "isn"
SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt"
SQL_EXPR_DATE_TRUNC = "dt"
SQL_EXPR_EXTRACT = "ex"
SQL_EXPR_RATIO_COMPUTATION = "rc"
SQL_EXPR_BETWEEN_PREFIX = "betw"
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc"
Expand Down
48 changes: 32 additions & 16 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SqlColumnReferenceExpression,
SqlDateTruncExpression,
SqlExpressionNode,
SqlExtractExpression,
SqlStringExpression,
)
from metricflow.sql.sql_plan import (
Expand All @@ -46,6 +47,7 @@
SqlSelectStatementNode,
SqlTableFromClauseNode,
)
from metricflow.time.date_part import DatePart

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,12 +104,14 @@ def _create_time_dimension_instance(
time_dimension: Dimension,
entity_links: Tuple[EntityReference, ...],
time_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY,
date_part: Optional[DatePart] = None,
) -> TimeDimensionInstance:
"""Create a time dimension instance from the dimension object from a semantic model in the model."""
time_dimension_spec = TimeDimensionSpec(
element_name=time_dimension.reference.element_name,
entity_links=entity_links,
time_granularity=time_granularity,
date_part=date_part,
)

return TimeDimensionInstance(
Expand Down Expand Up @@ -219,6 +223,11 @@ def _convert_dimensions(
select_columns = []

for dimension in dimensions or []:
dimension_select_expr = SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
)
if dimension.type == DimensionType.CATEGORICAL:
dimension_instance = self._create_dimension_instance(
semantic_model_name=semantic_model_name,
Expand All @@ -228,11 +237,7 @@ def _convert_dimensions(
dimension_instances.append(dimension_instance)
select_columns.append(
SqlSelectColumn(
expr=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
expr=dimension_select_expr,
column_alias=dimension_instance.associated_column.column_name,
)
)
Expand All @@ -251,11 +256,7 @@ def _convert_dimensions(
time_dimension_instances.append(time_dimension_instance)
select_columns.append(
SqlSelectColumn(
expr=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
expr=dimension_select_expr,
column_alias=time_dimension_instance.associated_column.column_name,
)
)
Expand All @@ -274,16 +275,31 @@ def _convert_dimensions(
select_columns.append(
SqlSelectColumn(
expr=SqlDateTruncExpression(
time_granularity=time_granularity,
arg=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
time_granularity=time_granularity, arg=dimension_select_expr
),
column_alias=time_dimension_instance.associated_column.column_name,
)
)

# Add all date part options for easy query resolution
for date_part in DatePart:
if date_part.to_int() >= defined_time_granularity.to_int():
time_dimension_instance = self._create_time_dimension_instance(
semantic_model_name=semantic_model_name,
time_dimension=dimension,
entity_links=entity_links,
time_granularity=defined_time_granularity,
date_part=date_part,
)
time_dimension_instances.append(time_dimension_instance)

select_columns.append(
SqlSelectColumn(
expr=SqlExtractExpression(date_part=date_part, arg=dimension_select_expr),
column_alias=time_dimension_instance.associated_column.column_name,
)
)

else:
assert False, f"Unhandled dimension type: {dimension.type}"

Expand Down
8 changes: 6 additions & 2 deletions metricflow/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import logging
from typing import Sequence
from typing import Optional, Sequence

from dbt_semantic_interfaces.references import TimeDimensionReference
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from dbt_semantic_interfaces.validations.unique_valid_name import MetricFlowReservedKeywords

from metricflow.instances import InstanceSet, TimeDimensionInstance
from metricflow.specs.specs import TimeDimensionSpec
from metricflow.time.date_part import DatePart

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,12 +49,15 @@ def metric_time_dimension_name() -> str:
return DataSet.metric_time_dimension_reference().element_name

@staticmethod
def metric_time_dimension_spec(time_granularity: TimeGranularity) -> TimeDimensionSpec:
def metric_time_dimension_spec(
time_granularity: TimeGranularity, date_part: Optional[DatePart] = None
) -> TimeDimensionSpec:
"""Spec that corresponds to DataSet.metric_time_dimension_reference."""
return TimeDimensionSpec(
element_name=DataSet.metric_time_dimension_reference().element_name,
entity_links=(),
time_granularity=time_granularity,
date_part=date_part,
)

def __repr__(self) -> str: # noqa: D
Expand Down
32 changes: 26 additions & 6 deletions metricflow/naming/linkable_spec_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.time.date_part import DatePart

DUNDER = "__"

logger = logging.getLogger(__name__)
Expand All @@ -24,6 +26,7 @@ class StructuredLinkableSpecName:
entity_link_names: Tuple[str, ...]
element_name: str
time_granularity: Optional[TimeGranularity] = None
date_part: Optional[DatePart] = None

@staticmethod
def from_name(qualified_name: str) -> StructuredLinkableSpecName:
Expand All @@ -32,7 +35,7 @@ def from_name(qualified_name: str) -> StructuredLinkableSpecName:

# No dunder, e.g. "ds"
if len(name_parts) == 1:
return StructuredLinkableSpecName((), name_parts[0])
return StructuredLinkableSpecName(entity_link_names=(), element_name=name_parts[0])

associated_granularity = None
granularity: TimeGranularity
Expand All @@ -44,18 +47,30 @@ def from_name(qualified_name: str) -> StructuredLinkableSpecName:
if associated_granularity:
# e.g. "ds__month"
if len(name_parts) == 2:
return StructuredLinkableSpecName((), name_parts[0], associated_granularity)
return StructuredLinkableSpecName(
entity_link_names=(), element_name=name_parts[0], time_granularity=associated_granularity
)
# e.g. "messages__ds__month"
return StructuredLinkableSpecName(tuple(name_parts[:-2]), name_parts[-2], associated_granularity)
return StructuredLinkableSpecName(
entity_link_names=tuple(name_parts[:-2]),
element_name=name_parts[-2],
time_granularity=associated_granularity,
)

# e.g. "messages__ds"
else:
return StructuredLinkableSpecName(tuple(name_parts[:-1]), name_parts[-1])
return StructuredLinkableSpecName(entity_link_names=tuple(name_parts[:-1]), element_name=name_parts[-1])

@property
def qualified_name(self) -> str:
"""Return the full name form. e.g. ds or listing__ds__month."""
"""Return the full name form. e.g. ds or listing__ds__month.

If date_part is specified, don't include granularity in qualified_name since it will not impact the result.
"""
items = list(self.entity_link_names) + [self.element_name]
if self.time_granularity:
if self.date_part:
items.append(self.date_part_suffix(date_part=self.date_part))
elif self.time_granularity:
tlento marked this conversation as resolved.
Show resolved Hide resolved
items.append(self.time_granularity.value)
return DUNDER.join(items)

Expand All @@ -66,3 +81,8 @@ def entity_prefix(self) -> Optional[str]:
return DUNDER.join(self.entity_link_names)

return None

@staticmethod
def date_part_suffix(date_part: DatePart) -> str:
"""Suffix used for names with a date_part."""
return f"extract_{date_part.value}"
1 change: 1 addition & 0 deletions metricflow/plan_conversion/column_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def visit_time_dimension_spec(self, time_dimension_spec: TimeDimensionSpec) -> C
entity_link_names=tuple(x.element_name for x in time_dimension_spec.entity_links),
element_name=time_dimension_spec.element_name,
time_granularity=time_dimension_spec.time_granularity,
date_part=time_dimension_spec.date_part,
).qualified_name

return ColumnAssociation(
Expand Down
5 changes: 5 additions & 0 deletions metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql.sql_column_type import SqlColumnType
from metricflow.time.date_part import DatePart
from metricflow.visitor import VisitorOutputT


Expand Down Expand Up @@ -284,6 +285,7 @@ def accept(self, visitor: InstanceSpecVisitor[VisitorOutputT]) -> VisitorOutputT
@dataclass(frozen=True)
class TimeDimensionSpec(DimensionSpec): # noqa: D
time_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY
date_part: Optional[DatePart] = None

# Used for semi-additive joins. Some more thought is needed, but this may be useful in InstanceSpec.
aggregation_state: Optional[AggregationState] = None
Expand All @@ -295,6 +297,7 @@ def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D
element_name=self.element_name,
entity_links=self.entity_links[1:],
time_granularity=self.time_granularity,
date_part=self.date_part,
)

@property
Expand Down Expand Up @@ -324,6 +327,7 @@ def qualified_name(self) -> str: # noqa: D
entity_link_names=tuple(x.element_name for x in self.entity_links),
element_name=self.element_name,
time_granularity=self.time_granularity,
date_part=self.date_part,
).qualified_name

@property
Expand All @@ -338,6 +342,7 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim
element_name=self.element_name,
entity_links=self.entity_links,
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=aggregation_state,
)

Expand Down
12 changes: 12 additions & 0 deletions metricflow/sql/render/big_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SqlTimeDeltaExpression,
)
from metricflow.sql.sql_plan import SqlSelectColumn
from metricflow.time.date_part import DatePart


class BigQuerySqlExpressionRenderer(DefaultSqlExpressionRenderer):
Expand Down Expand Up @@ -129,6 +130,17 @@ def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> SqlExpressionRe
bind_parameters=arg_rendered.bind_parameters,
)

@override
def render_date_part(self, date_part: DatePart) -> str:
if date_part == DatePart.DOY:
return "dayofyear"
if date_part == DatePart.DOW:
return "dayofweek"
if date_part == DatePart.WEEK:
return "isoweek"
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved

return super().render_date_part(date_part)

@override
def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult:
"""Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value."""
Expand Down
14 changes: 14 additions & 0 deletions metricflow/sql/render/expr_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SqlDateTruncExpression,
SqlExpressionNode,
SqlExpressionNodeVisitor,
SqlExtractExpression,
SqlFunction,
SqlGenerateUuidExpression,
SqlIsNullExpression,
Expand All @@ -36,6 +37,7 @@
SqlWindowFunctionExpression,
)
from metricflow.sql.sql_plan import SqlSelectColumn
from metricflow.time.date_part import DatePart

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -267,6 +269,18 @@ def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> SqlExpressionRe
bind_parameters=arg_rendered.bind_parameters,
)

def visit_extract_expr(self, node: SqlExtractExpression) -> SqlExpressionRenderResult: # noqa: D
arg_rendered = self.render_sql_expr(node.arg)

return SqlExpressionRenderResult(
sql=f"EXTRACT({self.render_date_part(node.date_part)} FROM {arg_rendered.sql})",
bind_parameters=arg_rendered.bind_parameters,
)

def render_date_part(self, date_part: DatePart) -> str:
"""Render DATE PART for an EXTRACT expression."""
return date_part.value

def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: # noqa: D
arg_rendered = node.arg.accept(self)
if node.grain_to_date:
Expand Down
Loading