Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up SQL time interval math expressions #796

Merged
merged 3 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions metricflow/dag/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SQL_EXPR_IS_NULL_PREFIX = "isn"
SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt"
SQL_EXPR_DATE_TRUNC = "dt"
SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX = "sti"
SQL_EXPR_EXTRACT = "ex"
SQL_EXPR_RATIO_COMPUTATION = "rc"
SQL_EXPR_BETWEEN_PREFIX = "betw"
Expand Down
6 changes: 3 additions & 3 deletions metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SqlIsNullExpression,
SqlLogicalExpression,
SqlLogicalOperator,
SqlTimeDeltaExpression,
SqlSubtractTimeIntervalExpression,
)
from metricflow.sql.sql_plan import SqlExpressionNode, SqlJoinDescription, SqlJoinType, SqlSelectStatementNode

Expand Down Expand Up @@ -441,7 +441,7 @@ def make_cumulative_metric_time_range_join_description(
start_of_range_comparison_expr = SqlComparisonExpression(
left_expr=metric_time_column_expr,
comparison=SqlComparison.GREATER_THAN,
right_expr=SqlTimeDeltaExpression(
right_expr=SqlSubtractTimeIntervalExpression(
arg=time_spine_column_expr,
count=node.window.count,
granularity=node.window.granularity,
Expand Down Expand Up @@ -481,7 +481,7 @@ def make_join_to_time_spine_join_description(
col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=metric_time_dimension_column_name)
)
if node.offset_window:
left_expr = SqlTimeDeltaExpression(
left_expr = SqlSubtractTimeIntervalExpression(
arg=left_expr, count=node.offset_window.count, granularity=node.offset_window.granularity
)
elif node.offset_to_grain:
Expand Down
14 changes: 2 additions & 12 deletions metricflow/sql/render/big_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
SqlTimeDeltaExpression,
SqlSubtractTimeIntervalExpression,
)
from metricflow.sql.sql_plan import SqlSelectColumn
from metricflow.time.date_part import DatePart
Expand Down Expand Up @@ -142,19 +142,9 @@ def render_date_part(self, date_part: DatePart) -> str:
return super().render_date_part(date_part)

@override
def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult:
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
"""Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value."""
column = node.arg.accept(self)
if node.grain_to_date:
granularity = node.granularity
if granularity == TimeGranularity.WEEK:
granularity_value = "ISO" + granularity.value.upper()
else:
granularity_value = granularity.value
return SqlExpressionRenderResult(
sql=f"DATE_TRUNC({column.sql}, {granularity_value})",
bind_parameters=column.bind_parameters,
)

return SqlExpressionRenderResult(
sql=f"DATE_SUB(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {node.count} {node.granularity.value})",
Expand Down
9 changes: 2 additions & 7 deletions metricflow/sql/render/duckdb_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
SqlTimeDeltaExpression,
SqlSubtractTimeIntervalExpression,
)


Expand All @@ -34,14 +34,9 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio
}

@override
def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult:
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
"""Render time delta expression for DuckDB, which requires slightly different syntax from other engines."""
arg_rendered = node.arg.accept(self)
if node.grain_to_date:
return SqlExpressionRenderResult(
sql=f"DATE_TRUNC('{node.granularity.value}', {arg_rendered.sql}::timestamp)",
bind_parameters=arg_rendered.bind_parameters,
)

count = node.count
granularity = node.granularity
Expand Down
9 changes: 2 additions & 7 deletions metricflow/sql/render/expr_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
SqlRatioComputationExpression,
SqlStringExpression,
SqlStringLiteralExpression,
SqlTimeDeltaExpression,
SqlSubtractTimeIntervalExpression,
SqlWindowFunctionExpression,
)
from metricflow.sql.sql_plan import SqlSelectColumn
Expand Down Expand Up @@ -281,13 +281,8 @@ def render_date_part(self, date_part: DatePart) -> str:
"""Render DATE PART for an EXTRACT expression."""
return date_part.value

def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: # noqa: D
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: # noqa: D
arg_rendered = node.arg.accept(self)
if node.grain_to_date:
return SqlExpressionRenderResult(
sql=f"DATE_TRUNC('{node.granularity.value}', {arg_rendered.sql}::timestamp)",
bind_parameters=arg_rendered.bind_parameters,
)

count = node.count
granularity = node.granularity
Expand Down
9 changes: 2 additions & 7 deletions metricflow/sql/render/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
SqlTimeDeltaExpression,
SqlSubtractTimeIntervalExpression,
)


Expand All @@ -37,14 +37,9 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio
return {SqlPercentileFunctionType.CONTINUOUS, SqlPercentileFunctionType.DISCRETE}

@override
def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult:
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
"""Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity."""
arg_rendered = node.arg.accept(self)
if node.grain_to_date:
return SqlExpressionRenderResult(
sql=f"DATE_TRUNC('{node.granularity.value}', {arg_rendered.sql}::timestamp)",
bind_parameters=arg_rendered.bind_parameters,
)

count = node.count
granularity = node.granularity
Expand Down
36 changes: 16 additions & 20 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from metricflow.dag.id_generation import (
SQL_EXPR_BETWEEN_PREFIX,
SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX,
SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX,
SQL_EXPR_COMPARISON_ID_PREFIX,
SQL_EXPR_DATE_TRUNC,
Expand All @@ -29,6 +30,7 @@
SQL_EXPR_RATIO_COMPUTATION,
SQL_EXPR_STRING_ID_PREFIX,
SQL_EXPR_STRING_LITERAL_PREFIX,
SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX,
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX,
)
from metricflow.dag.mf_dag import DagNode, DisplayedProperty, NodeId
Expand Down Expand Up @@ -225,7 +227,7 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> VisitorOutputT: # n
pass

@abstractmethod
def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> VisitorOutputT: # noqa: D
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> VisitorOutputT: # noqa: D
pass

@abstractmethod
Expand Down Expand Up @@ -1243,25 +1245,29 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D
return self._parents_match(other)


class SqlTimeDeltaExpression(SqlExpressionNode):
"""create time delta between eg `DATE_SUB(ds, 2, month)`."""
class SqlSubtractTimeIntervalExpression(SqlExpressionNode):
"""Represents an interval subtraction from a given timestamp.

This node contains the information required to produce a SQL statement which subtracts an interval with the given
count and granularity (which together define the interval duration) from the input timestamp expression. The return
value from the SQL rendering for this expression should be a timestamp expression offset from the initial input
value.
"""

def __init__( # noqa: D
self,
arg: SqlExpressionNode,
count: int,
granularity: TimeGranularity,
grain_to_date: Optional[TimeGranularity] = None,
) -> None:
super().__init__(node_id=self.create_unique_id(), parent_nodes=[arg])
self._count = count
self._time_granularity = granularity
self._arg = arg
self._grain_to_date = grain_to_date

@classmethod
def id_prefix(cls) -> str: # noqa: D
return SQL_EXPR_IS_NULL_PREFIX
return SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX

@property
def requires_parenthesis(self) -> bool: # noqa: D
Expand All @@ -1278,10 +1284,6 @@ def description(self) -> str: # noqa: D
def arg(self) -> SqlExpressionNode: # noqa: D
return self._arg

@property
def grain_to_date(self) -> Optional[TimeGranularity]: # noqa: D
return self._grain_to_date

@property
def count(self) -> int: # noqa: D
return self._count
Expand All @@ -1295,11 +1297,10 @@ def rewrite( # noqa: D
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return SqlTimeDeltaExpression(
return SqlSubtractTimeIntervalExpression(
arg=self.arg.rewrite(column_replacements, should_render_table_alias),
count=self.count,
granularity=self.granularity,
grain_to_date=self.grain_to_date,
)

@property
Expand All @@ -1309,14 +1310,9 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D
if not isinstance(other, SqlTimeDeltaExpression):
if not isinstance(other, SqlSubtractTimeIntervalExpression):
return False
return (
self.count == other.count
and self.granularity == other.granularity
and self.grain_to_date == other.grain_to_date
and self._parents_match(other)
)
return self.count == other.count and self.granularity == other.granularity and self._parents_match(other)


class SqlCastToTimestampExpression(SqlExpressionNode):
Expand All @@ -1327,7 +1323,7 @@ def __init__(self, arg: SqlExpressionNode) -> None: # noqa: D

@classmethod
def id_prefix(cls) -> str: # noqa: D
return SQL_EXPR_IS_NULL_PREFIX
return SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX

@property
def requires_parenthesis(self) -> bool: # noqa: D
Expand Down
4 changes: 2 additions & 2 deletions metricflow/test/integration/test_configured_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SqlPercentileExpressionArgument,
SqlPercentileFunctionType,
SqlStringExpression,
SqlTimeDeltaExpression,
SqlSubtractTimeIntervalExpression,
)
from metricflow.test.compare_df import assert_dataframes_equal
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
Expand Down Expand Up @@ -85,7 +85,7 @@ def render_date_sub(
granularity: TimeGranularity,
) -> str:
"""Renders a date subtract expression."""
expr = SqlTimeDeltaExpression(
expr = SqlSubtractTimeIntervalExpression(
arg=SqlColumnReferenceExpression(SqlColumnReference(table_alias, column_alias)),
count=count,
granularity=granularity,
Expand Down
Loading