Skip to content

Commit

Permalink
RFC, feat: infer datetime format for pyarrow backend (#1195)
Browse files Browse the repository at this point in the history
* feat: infer datetime format for pyarrow

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix for date format

* use first 10 non null values only to infer format

* test with null

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
FBruzzesi and pre-commit-ci[bot] authored Oct 29, 2024
1 parent 9f8809c commit f349cb2
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 14 deletions.
4 changes: 2 additions & 2 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from narwhals._arrow.utils import floordiv_compat
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import parse_datetime_format
from narwhals._arrow.utils import validate_column_comparand
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name
Expand Down Expand Up @@ -1115,8 +1116,7 @@ def to_datetime(self: Self, format: str | None) -> ArrowSeries: # noqa: A002
import pyarrow.compute as pc # ignore-banned-import()

if format is None:
msg = "`format` is required for pyarrow backend."
raise ValueError(msg)
format = parse_datetime_format(self._arrow_series._native_series)

return self._arrow_series._from_native_series(
pc.strptime(self._arrow_series._native_series, format=format, unit="us")
Expand Down
85 changes: 85 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,88 @@ def convert_str_slice_to_int_slice(
stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None
step = str_slice.step
return (start, stop, step)


# Regex for date, time, separator and timezone components
DATE_RE = r"(?P<date>\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4})"
SEP_RE = r"(?P<sep>\s|T)"
TIME_RE = r"(?P<time>\d{2}:\d{2}:\d{2})" # \s*(?P<period>[AP]M)?)?
TZ_RE = r"(?P<tz>Z|[+-]\d{2}:?\d{2})" # Matches 'Z', '+02:00', '+0200', '+02', etc.
FULL_RE = rf"{DATE_RE}{SEP_RE}?{TIME_RE}?{TZ_RE}?$"

# Separate regexes for different date formats
YMD_RE = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])$"
DMY_RE = r"^(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
MDY_RE = r"^(?P<month>0[1-9]|1[0-2])(?P<sep1>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"

DATE_FORMATS = (
(YMD_RE, "%Y-%m-%d"),
(DMY_RE, "%d-%m-%Y"),
(MDY_RE, "%m-%d-%Y"),
)


def parse_datetime_format(arr: pa.StringArray) -> str:
"""Try to infer datetime format from StringArray."""
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

matches = pa.concat_arrays( # converts from ChunkedArray to StructArray
pc.extract_regex(pc.drop_null(arr).slice(0, 10), pattern=FULL_RE).chunks
)

if not pc.all(matches.is_valid()).as_py():
msg = (
"Unable to infer datetime format, provided format is not supported. "
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
)
raise NotImplementedError(msg)

dates = matches.field("date")
separators = matches.field("sep")
times = matches.field("time")
tz = matches.field("tz")

# separators and time zones must be unique
if pc.count(pc.unique(separators)).as_py() > 1:
msg = "Found multiple separator values while inferring datetime format."
raise ValueError(msg)

if pc.count(pc.unique(tz)).as_py() > 1:
msg = "Found multiple timezone values while inferring datetime format."
raise ValueError(msg)

date_value = _parse_date_format(dates)
time_value = _parse_time_format(times)

sep_value = separators[0].as_py()
tz_value = "%z" if tz[0].as_py() else ""

return f"{date_value}{sep_value}{time_value}{tz_value}"


def _parse_date_format(arr: pa.Array) -> str:
import pyarrow.compute as pc # ignore-banned-import

for date_rgx, date_fmt in DATE_FORMATS:
matches = pc.extract_regex(arr, pattern=date_rgx)
if (
pc.all(matches.is_valid()).as_py()
and pc.count(pc.unique(sep1 := matches.field("sep1"))).as_py() == 1
and pc.count(pc.unique(sep2 := matches.field("sep2"))).as_py() == 1
and (date_sep_value := sep1[0].as_py()) == sep2[0].as_py()
):
return date_fmt.replace("-", date_sep_value)

msg = (
"Unable to infer datetime format. "
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
)
raise ValueError(msg)


def _parse_time_format(arr: pa.Array) -> str:
import pyarrow.compute as pc # ignore-banned-import

matches = pc.extract_regex(arr, pattern=TIME_RE)
return "%H:%M:%S" if pc.all(matches.is_valid()).as_py() else ""
64 changes: 52 additions & 12 deletions tests/expr_and_series/str/to_datetime_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
from narwhals._arrow.utils import parse_datetime_format
from tests.utils import assert_equal_data

if TYPE_CHECKING:
from tests.utils import Constructor
from tests.utils import ConstructorEager

data = {"a": ["2020-01-01T12:34:56"]}


Expand Down Expand Up @@ -42,12 +47,7 @@ def test_to_datetime_series(constructor_eager: ConstructorEager) -> None:
assert str(result) == expected


def test_to_datetime_infer_fmt(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_to_datetime_infer_fmt(constructor: Constructor) -> None:
if "cudf" in str(constructor): # pragma: no cover
expected = "2020-01-01T12:34:56.000000000"
else:
Expand All @@ -63,12 +63,7 @@ def test_to_datetime_infer_fmt(
assert str(result) == expected


def test_to_datetime_series_infer_fmt(
request: pytest.FixtureRequest, constructor_eager: ConstructorEager
) -> None:
if "pyarrow_table" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)

def test_to_datetime_series_infer_fmt(constructor_eager: ConstructorEager) -> None:
if "cudf" in str(constructor_eager): # pragma: no cover
expected = "2020-01-01T12:34:56.000000000"
else:
Expand All @@ -78,3 +73,48 @@ def test_to_datetime_series_infer_fmt(
nw.from_native(constructor_eager(data), eager_only=True)["a"].str.to_datetime()
).item(0)
assert str(result) == expected


def test_to_datetime_infer_fmt_from_date(constructor: Constructor) -> None:
data = {"z": ["2020-01-01", "2020-01-02", None]}
expected = [datetime(2020, 1, 1), datetime(2020, 1, 2), None]
result = (
nw.from_native(constructor(data))
.lazy()
.select(nw.col("z").str.to_datetime())
.collect()
)
assert_equal_data(result, {"z": expected})


def test_pyarrow_infer_datetime_raise_invalid() -> None:
with pytest.raises(
NotImplementedError,
match="Unable to infer datetime format, provided format is not supported.",
):
parse_datetime_format(pa.chunked_array([["2024-01-01", "abc"]]))


@pytest.mark.parametrize(
("data", "duplicate"),
[
(["2024-01-01T00:00:00", "2024-01-01 01:00:00"], "separator"),
(["2024-01-01 00:00:00+01:00", "2024-01-01 01:00:00+02:00"], "timezone"),
],
)
def test_pyarrow_infer_datetime_raise_not_unique(
data: list[str | None], duplicate: str
) -> None:
with pytest.raises(
ValueError,
match=f"Found multiple {duplicate} values while inferring datetime format.",
):
parse_datetime_format(pa.chunked_array([data]))


@pytest.mark.parametrize("data", [["2024-01-01", "2024-12-01", "02-02-2024"]])
def test_pyarrow_infer_datetime_raise_inconsistent_date_fmt(
data: list[str | None],
) -> None:
with pytest.raises(ValueError, match="Unable to infer datetime format. "):
parse_datetime_format(pa.chunked_array([data]))

0 comments on commit f349cb2

Please sign in to comment.