Skip to content

Commit

Permalink
Create SqlAddTimeExpression (#1562)
Browse files Browse the repository at this point in the history
Adds a new SQL expression that adds a time interval to a date/time
expression. This differs from the existing time delta expression because
it 1) adds instead of subtracts, and 2) accepts a SQL expression for the
count instead of an integer. This will be used to build the SQL for
custom granularity offset windows.
  • Loading branch information
courtneyholcomb authored Dec 17, 2024
1 parent dd0e8c3 commit 728474d
Show file tree
Hide file tree
Showing 15 changed files with 270 additions and 12 deletions.
1 change: 1 addition & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt"
SQL_EXPR_DATE_TRUNC = "dt"
SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX = "sti"
SQL_EXPR_ADD_TIME_PREFIX = "ati"
SQL_EXPR_EXTRACT = "ex"
SQL_EXPR_RATIO_COMPUTATION = "rc"
SQL_EXPR_BETWEEN_PREFIX = "betw"
Expand Down
14 changes: 13 additions & 1 deletion metricflow/sql/render/big_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlCastToTimestampExpression,
SqlDateTruncExpression,
SqlExtractExpression,
Expand Down Expand Up @@ -167,7 +168,7 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> SqlExpressionRenderR
)

@override
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
"""Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value."""
column = node.arg.accept(self)

Expand All @@ -176,6 +177,17 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE
bind_parameter_set=column.bind_parameter_set,
)

@override
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult:
"""Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value."""
column = node.arg.accept(self)
count = node.count_expr.accept(self)

return SqlExpressionRenderResult(
sql=f"DATE_ADD(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {count} {node.granularity.value})",
bind_parameter_set=column.bind_parameter_set,
)

@override
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult:
return SqlExpressionRenderResult(
Expand Down
21 changes: 19 additions & 2 deletions metricflow/sql/render/duckdb_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
Expand All @@ -37,13 +38,13 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio
}

@override
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
"""Render time delta expression for DuckDB, which requires slightly different syntax from other engines."""
arg_rendered = node.arg.accept(self)

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3

Expand All @@ -52,6 +53,22 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult:
"""Render time delta expression for DuckDB, which requires slightly different syntax from other engines."""
arg_rendered = node.arg.accept(self)
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"({count_rendered} * 3)"

return SqlExpressionRenderResult(
sql=f"{arg_rendered.sql} + INTERVAL {count_rendered} {granularity.value}",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult:
return SqlExpressionRenderResult(
Expand Down
21 changes: 19 additions & 2 deletions metricflow/sql/render/expr_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from metricflow.sql.render.rendering_constants import SqlRenderingConstants
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlAggregateFunctionExpression,
SqlBetweenExpression,
SqlCastToTimestampExpression,
Expand Down Expand Up @@ -303,19 +304,35 @@ def render_date_part(self, date_part: DatePart) -> str:

return date_part.value

def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: # noqa: D102
def visit_subtract_time_interval_expr( # noqa: D102
self, node: SqlSubtractTimeIntervalExpression
) -> SqlExpressionRenderResult:
arg_rendered = node.arg.accept(self)

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3
return SqlExpressionRenderResult(
sql=f"DATEADD({granularity.value}, -{count}, {arg_rendered.sql})",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: # noqa: D102
arg_rendered = node.arg.accept(self)
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"({count_rendered} * 3)"

return SqlExpressionRenderResult(
sql=f"DATEADD({granularity.value}, {count_rendered}, {arg_rendered.sql})",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> SqlExpressionRenderResult:
"""Render the ratio computation for a ratio metric.
Expand Down
21 changes: 19 additions & 2 deletions metricflow/sql/render/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
Expand All @@ -40,20 +41,36 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio
return {SqlPercentileFunctionType.CONTINUOUS, SqlPercentileFunctionType.DISCRETE}

@override
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
"""Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity."""
arg_rendered = node.arg.accept(self)

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3
return SqlExpressionRenderResult(
sql=f"{arg_rendered.sql} - MAKE_INTERVAL({granularity.value}s => {count})",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult:
"""Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity."""
arg_rendered = node.arg.accept(self)
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"({count_rendered} * 3)"

return SqlExpressionRenderResult(
sql=f"{arg_rendered.sql} + MAKE_INTERVAL({granularity.value}s => {count_rendered})",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult:
return SqlExpressionRenderResult(
Expand Down
21 changes: 19 additions & 2 deletions metricflow/sql/render/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlBetweenExpression,
SqlGenerateUuidExpression,
SqlPercentileExpression,
Expand Down Expand Up @@ -45,20 +46,36 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpres
)

@override
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
"""Render time delta for Trino, require granularity in quotes and function name change."""
arg_rendered = node.arg.accept(self)

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3
return SqlExpressionRenderResult(
sql=f"DATE_ADD('{granularity.value}', -{count}, {arg_rendered.sql})",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult:
"""Render time delta for Trino, require granularity in quotes and function name change."""
arg_rendered = node.arg.accept(self)
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"({count_rendered} * 3)"

return SqlExpressionRenderResult(
sql=f"DATE_ADD('{granularity.value}', {count_rendered}, {arg_rendered.sql})",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult:
"""Render a percentile expression for Trino."""
Expand Down
71 changes: 68 additions & 3 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,13 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> VisitorOutputT: # n
pass

@abstractmethod
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> VisitorOutputT: # noqa: D102
def visit_subtract_time_interval_expr( # noqa: D102
self, node: SqlSubtractTimeIntervalExpression
) -> VisitorOutputT:
pass

@abstractmethod
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
Expand Down Expand Up @@ -1289,11 +1295,11 @@ def requires_parenthesis(self) -> bool: # noqa: D102
return False

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_time_delta_expr(self)
return visitor.visit_subtract_time_interval_expr(self)

@property
def description(self) -> str: # noqa: D102
return "Time delta"
return "Subtract time interval"

def rewrite( # noqa: D102
self,
Expand All @@ -1318,6 +1324,65 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.count == other.count and self.granularity == other.granularity and self._parents_match(other)


@dataclass(frozen=True, eq=False)
class SqlAddTimeExpression(SqlExpressionNode):
"""Add a time interval expr to a timestamp."""

arg: SqlExpressionNode
count_expr: SqlExpressionNode
granularity: TimeGranularity

@staticmethod
def create( # noqa: D102
arg: SqlExpressionNode,
count_expr: SqlExpressionNode,
granularity: TimeGranularity,
) -> SqlAddTimeExpression:
return SqlAddTimeExpression(
parent_nodes=(arg, count_expr),
arg=arg,
count_expr=count_expr,
granularity=granularity,
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_ADD_TIME_PREFIX

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return False

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_add_time_expr(self)

@property
def description(self) -> str: # noqa: D102
return "Add time interval"

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return SqlAddTimeExpression.create(
arg=self.arg.rewrite(column_replacements, should_render_table_alias),
count_expr=self.count_expr,
granularity=self.granularity,
)

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlAddTimeExpression):
return False
return self.count_expr == other.count_expr and self.granularity == other.granularity and self.arg == other.arg


@dataclass(frozen=True, eq=False)
class SqlCastToTimestampExpression(SqlExpressionNode):
"""Cast to the timestamp type like CAST('2020-01-01' AS TIMESTAMP)."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
test_name: test_add_time_expr
test_filename: test_engine_specific_rendering.py
docstring:
Tests rendering of the SqlAddTimeExpr in a query.
sql_engine: BigQuery
---
-- Test Add Time Expression
SELECT
DATE_ADD(CAST('2020-01-01' AS DATETIME), INTERVAL SqlExpressionRenderResult(sql='1', bind_parameter_set=SqlBindParameterSet(param_items=())) quarter) AS add_time
FROM foo.bar a
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
test_name: test_add_time_expr
test_filename: test_engine_specific_rendering.py
docstring:
Tests rendering of the SqlAddTimeExpr in a query.
sql_engine: Databricks
---
-- Test Add Time Expression
SELECT
DATEADD(month, (1 * 3), '2020-01-01') AS add_time
FROM foo.bar a
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
test_name: test_add_time_expr
test_filename: test_engine_specific_rendering.py
docstring:
Tests rendering of the SqlAddTimeExpr in a query.
sql_engine: DuckDB
---
-- Test Add Time Expression
SELECT
'2020-01-01' + INTERVAL (1 * 3) month AS add_time
FROM foo.bar a
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
test_name: test_add_time_expr
test_filename: test_engine_specific_rendering.py
docstring:
Tests rendering of the SqlAddTimeExpr in a query.
sql_engine: Postgres
---
-- Test Add Time Expression
SELECT
'2020-01-01' + MAKE_INTERVAL(months => (1 * 3)) AS add_time
FROM foo.bar a
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
test_name: test_add_time_expr
test_filename: test_engine_specific_rendering.py
docstring:
Tests rendering of the SqlAddTimeExpr in a query.
sql_engine: Redshift
---
-- Test Add Time Expression
SELECT
DATEADD(month, (1 * 3), '2020-01-01') AS add_time
FROM foo.bar a
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
test_name: test_add_time_expr
test_filename: test_engine_specific_rendering.py
docstring:
Tests rendering of the SqlAddTimeExpr in a query.
sql_engine: Snowflake
---
-- Test Add Time Expression
SELECT
DATEADD(month, (1 * 3), '2020-01-01') AS add_time
FROM foo.bar a
Loading

0 comments on commit 728474d

Please sign in to comment.