Skip to content

Commit

Permalink
PR feedback and SQL rendering test for new expr
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 16, 2024
1 parent 94579a1 commit b939ff8
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 12 deletions.
6 changes: 3 additions & 3 deletions metricflow/sql/render/duckdb_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3

Expand All @@ -60,9 +60,9 @@ def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRender
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"{count_rendered} * 3"
count_rendered = f"({count_rendered} * 3)"

return SqlExpressionRenderResult(
sql=f"{arg_rendered.sql} + INTERVAL {count_rendered} {granularity.value}",
Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/render/expr_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def visit_subtract_time_interval_expr( # noqa: D102

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3
return SqlExpressionRenderResult(
Expand All @@ -324,9 +324,9 @@ def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRender
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"{count_rendered} * 3"
count_rendered = f"({count_rendered} * 3)"

return SqlExpressionRenderResult(
sql=f"DATEADD({granularity.value}, {count_rendered}, {arg_rendered.sql})",
Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/render/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3
return SqlExpressionRenderResult(
Expand All @@ -62,9 +62,9 @@ def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRender
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"{count_rendered} * 3"
count_rendered = f"({count_rendered} * 3)"

return SqlExpressionRenderResult(
sql=f"{arg_rendered.sql} + MAKE_INTERVAL({granularity.value}s => {count_rendered})",
Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/render/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3
return SqlExpressionRenderResult(
Expand All @@ -67,9 +67,9 @@ def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRender
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
if granularity is TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"{count_rendered} * 3"
count_rendered = f"({count_rendered} * 3)"

return SqlExpressionRenderResult(
sql=f"DATE_ADD('{granularity.value}', {count_rendered}, {arg_rendered.sql})",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
test_name: test_add_time_expr
test_filename: test_engine_specific_rendering.py
docstring:
Tests rendering of the SqlAddTimeExpr in a query.
sql_engine: DuckDB
---
-- Test Add Time Expression
SELECT
'2020-01-01' + INTERVAL (1 * 3) month AS add_time
FROM foo.bar a
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
test_name: test_add_time_expr
test_filename: test_engine_specific_rendering.py
docstring:
Tests rendering of the SqlAddTimeExpr in a query.
sql_engine: Redshift
---
-- Test Approximate Discrete Percentile Expression
SELECT
DATEADD(month, (1 * 3), '2020-01-01') AS add_time
FROM foo.bar a
42 changes: 42 additions & 0 deletions tests_metricflow/sql/test_engine_specific_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlCastToTimestampExpression,
SqlColumnReference,
SqlColumnReferenceExpression,
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileExpressionArgument,
SqlPercentileFunctionType,
SqlStringExpression,
SqlStringLiteralExpression,
)
from metricflow.sql.sql_plan import (
Expand Down Expand Up @@ -295,3 +298,42 @@ def test_approximate_discrete_percentile_expr(
plan_id="plan0",
sql_client=sql_client,
)


@pytest.mark.sql_engine_snapshot
def test_add_time_expr(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
sql_client: SqlClient,
) -> None:
"""Tests rendering of the SqlAddTimeExpr in a query."""
select_columns = [
SqlSelectColumn(
expr=SqlAddTimeExpression.create(
arg=SqlStringLiteralExpression.create(
"2020-01-01",
),
count_expr=SqlStringExpression.create(
"1",
),
granularity=TimeGranularity.QUARTER,
),
column_alias="add_time",
),
]

from_source = SqlTableNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar"))
from_source_alias = "a"

assert_rendered_sql_equal(
request=request,
mf_test_configuration=mf_test_configuration,
sql_plan_node=SqlSelectStatementNode.create(
description="Test Add Time Expression",
select_columns=tuple(select_columns),
from_source=from_source,
from_source_alias=from_source_alias,
),
plan_id="plan0",
sql_client=sql_client,
)

0 comments on commit b939ff8

Please sign in to comment.