From 24e5395563a6ad3af5ba5c67c5357d821f8ee9f9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 25 Sep 2024 08:09:07 -0400 Subject: [PATCH] fix(clickhouse): fix truncating to date from a timestamp (#10220) --- .../test_timestamp_truncate/d/out.sql | 2 +- .../test_timestamp_truncate/w/out.sql | 2 +- .../test_timestamp_truncate/y/out.sql | 2 +- .../clickhouse/tests/test_functions.py | 16 ++++++++++++ ibis/backends/sql/compilers/clickhouse.py | 26 ++++++++++++++----- 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/d/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/d/out.sql index 570242c29545..59d136762a04 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/d/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/d/out.sql @@ -1,2 +1,2 @@ SELECT - toStartOfDay(parseDateTimeBestEffort('2009-05-17T12:34:56')) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56), DAY)" \ No newline at end of file + CAST(toStartOfDay(parseDateTimeBestEffort('2009-05-17T12:34:56')) AS Nullable(DateTime)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56), DAY)" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/w/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/w/out.sql index b6d895e7c92d..7a20e46b5652 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/w/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/w/out.sql @@ -1,2 +1,2 @@ SELECT - toMonday(parseDateTimeBestEffort('2009-05-17T12:34:56')) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56), WEEK)" \ No newline at end of file + CAST(toMonday(parseDateTimeBestEffort('2009-05-17T12:34:56')) AS Nullable(DateTime)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56), WEEK)" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/y/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/y/out.sql index 634f92da101b..9ccea5d38a34 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/y/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/y/out.sql @@ -1,2 +1,2 @@ SELECT - toStartOfYear(parseDateTimeBestEffort('2009-05-17T12:34:56')) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56), YEAR)" \ No newline at end of file + CAST(toStartOfYear(parseDateTimeBestEffort('2009-05-17T12:34:56')) AS Nullable(DateTime)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56), YEAR)" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/test_functions.py b/ibis/backends/clickhouse/tests/test_functions.py index 6a8d185e4d34..2637860c4a74 100644 --- a/ibis/backends/clickhouse/tests/test_functions.py +++ b/ibis/backends/clickhouse/tests/test_functions.py @@ -492,3 +492,19 @@ def my_eq(a: int, b: int) -> bool: ... expr = alltypes.int_col.collect().filter(lambda x: my_eq(x, 1)) result = expr.execute() assert set(result) == {1} + + +def test_timestamp_to_start_of_week(con): + pytest.importorskip("pyarrow") + + expr = ibis.timestamp("2024-02-03 00:00:00").truncate("W") + result = con.to_pyarrow(expr).as_py() + assert result == datetime(2024, 1, 29, 0, 0, 0) + + +def test_date_to_start_of_day(con): + pytest.importorskip("pyarrow") + + expr = ibis.date("2024-02-03") + expr1 = expr.truncate("D") + assert con.to_pyarrow(expr1) == con.to_pyarrow(expr) diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index 46201bfff8f4..185b44052cdd 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -377,21 +377,35 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): def visit_TimestampTruncate(self, op, *, arg, unit): if (short := unit.short) == "W": - func = "toMonday" + funcname = "toMonday" else: - func = f"toStartOf{unit.singular.capitalize()}" + funcname = f"toStartOf{unit.singular.capitalize()}" - if short in ("s", "ms", "us", "ns"): - arg = self.f.toDateTime64(arg, op.arg.dtype.scale or 0) - return self.f[func](arg) + func = self.f[funcname] + + if short in ("Y", "Q", "M", "W", "D"): + # these units return `Date` so we have to cast back to the + # corresponding Ibis type + return self.cast(func(arg), op.dtype) + elif short in ("s", "ms", "us", "ns"): + return func(self.f.toDateTime64(arg, op.arg.dtype.scale or 0)) + else: + assert short in ("h", "m"), short + return func(arg) visit_TimeTruncate = visit_TimestampTruncate def visit_DateTruncate(self, op, *, arg, unit): - if unit.short == "W": + if unit.short == "D": + # no op because truncating a date to a date has no effect + return arg + elif unit.short == "W": func = "toMonday" else: func = f"toStartOf{unit.singular.capitalize()}" + + # no cast needed here because all of the allowed units return `Date` + # values return self.f[func](arg) def visit_TimestampBucket(self, op, *, arg, interval, offset):