From 4aacb5f3f5b24ab326af7a67d08df5c0f5edf57d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 8 Jan 2025 13:07:54 +0000 Subject: [PATCH] fix: parse_version was not parsing duckdb pre-preleases correctly --- narwhals/utils.py | 9 +++++---- tests/utils_test.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/narwhals/utils.py b/narwhals/utils.py index c03642c90..509a0e36a 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -372,7 +372,7 @@ def _is_iterable(arg: Any | Iterable[Any]) -> bool: return isinstance(arg, Iterable) and not isinstance(arg, (str, bytes, Series)) -def parse_version(version: Sequence[str | int]) -> tuple[int, ...]: +def parse_version(version: str) -> tuple[int, ...]: """Simple version parser; split into a tuple of ints for comparison. Arguments: @@ -382,9 +382,10 @@ def parse_version(version: Sequence[str | int]) -> tuple[int, ...]: Parsed version number. """ # lifted from Polars - if isinstance(version, str): # pragma: no cover - version = version.split(".") - return tuple(int(re.sub(r"\D", "", str(v))) for v in version) + # [marco]: Take care of DuckDB pre-releases which end with e.g. `-dev4108` + # and pandas pre-releases which end with e.g. .dev0+618.gb552dc95c9 + version = re.sub(r"(\D?dev.*$)", "", version) + return tuple(int(re.sub(r"\D", "", str(v))) for v in version.split(".")) def isinstance_or_issubclass(obj: Any, cls: Any) -> bool: diff --git a/tests/utils_test.py b/tests/utils_test.py index 26bd2ecf9..e999696d3 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -13,6 +13,7 @@ from pandas.testing import assert_series_equal import narwhals.stable.v1 as nw +from narwhals.utils import parse_version from tests.utils import PANDAS_VERSION from tests.utils import get_module_version_as_tuple @@ -271,3 +272,15 @@ def test_generate_temporary_column_name_raise() -> None: match="Internal Error: Narwhals was not able to generate a column name with ", ): nw.generate_temporary_column_name(n_bytes=1, columns=columns) + + +@pytest.mark.parametrize( + ("version", "expected"), + [ + ("2020.1.2", (2020, 1, 2)), + ("2020.1.2-dev123", (2020, 1, 2)), + ("3.0.0.dev0+618.gb552dc95c9", (3, 0, 0)), + ], +) +def test_parse_version(version: str, expected: tuple[int, ...]) -> None: + assert parse_version(version) == expected