diff --git a/tests/client_test.py b/tests/client_test.py index 5b2fe27..3f4e481 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -1,11 +1,12 @@ # mypy: ignore-errors import pytest +from packaging.version import Version from pytest_httpx import HTTPXMock from tests.conftest import API_BASE_URL, mock_healthcheck, undo_mock_healthcheck from zep_python import APIError -from zep_python.zep_client import ZepClient, concat_url +from zep_python.zep_client import ZepClient, concat_url, parse_version_string _ = mock_healthcheck, undo_mock_healthcheck @@ -54,3 +55,28 @@ def test_concat_url(): concat_url("https://server.com/zep/", "v1/api") == "https://server.com/zep/v1/api" ) + + +def test_parse_version_string_with_dash(): + assert parse_version_string("1.2.3-456") == Version("1.2.3") + + +def test_parse_version_string_with_dash_and_empty_prefix(): + assert parse_version_string("-456") == Version("0.0.0") + + +def test_parse_version_string_with_dash_and_empty_prefix(): + assert parse_version_string("abc") == Version("0.0.0") + + +def test_parse_version_string_without_dash(): + assert parse_version_string("1.2.3") == Version("0.0.0") + + +def test_parse_version_string_empty(): + assert parse_version_string("") == Version("0.0.0") + + +def test_parse_version_string_none(): + with pytest.raises(TypeError): + parse_version_string(None) diff --git a/zep_python/zep_client.py b/zep_python/zep_client.py index b4517b0..4d54b4c 100644 --- a/zep_python/zep_client.py +++ b/zep_python/zep_client.py @@ -6,7 +6,7 @@ from urllib.parse import urljoin import httpx -from packaging.version import Version +from packaging.version import InvalidVersion, Version from zep_python.document.client import DocumentClient from zep_python.exceptions import APIError @@ -117,7 +117,8 @@ def _healthcheck(self, base_url: str) -> None: if zep_server_version_str: if "dev" in zep_server_version_str: return - zep_server_version = Version(zep_server_version_str.split("-")[0]) + + zep_server_version = parse_version_string(zep_server_version_str) else: zep_server_version = Version("0.0.0") @@ -268,3 +269,28 @@ def deprecated_warning(func: Callable[..., Any]) -> Callable[..., Any]: stacklevel=3, ) return func + + +def parse_version_string(version_string: str) -> Version: + """ + Parse a string into a Version object. + + Parameters + ---------- + version_string : str + The version string to parse. + + Returns + ------- + Version + The parsed version. + """ + + try: + if "-" in version_string: + version_str = version_string.split("-")[0] + return Version(version_str if version_str else "0.0.0") + except InvalidVersion: + return Version("0.0.0") + + return Version("0.0.0")