From 097b022183b51ae7878356c00c4402afa7bbd9dd Mon Sep 17 00:00:00 2001 From: BRAUN REMI Date: Tue, 13 Aug 2024 13:33:22 +0200 Subject: [PATCH] ENH: Allow to pass directly a version object to `misc.compare_version` --- CHANGES.md | 1 + CI/SCRIPTS/test_snap.py | 7 +++---- sertit/misc.py | 17 ++++++++++++++--- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index d323cb5..6e60237 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,7 @@ ## 1.42.0 (2024-mm-dd) - **ENH: Add a function `snap.get_snap_version` to retrieve current SNAP version** ([#172](https://github.com/sertit/eoreader/issues/172)) +- **ENH: Allow to pass directly a version object to `misc.compare_version`** ## 1.41.0 (2024-08-06) diff --git a/CI/SCRIPTS/test_snap.py b/CI/SCRIPTS/test_snap.py index bce6b08..cd2addc 100644 --- a/CI/SCRIPTS/test_snap.py +++ b/CI/SCRIPTS/test_snap.py @@ -19,7 +19,6 @@ import shutil import pytest -from packaging.version import Version from sertit import ci, misc, snap from sertit.snap import TILE_SIZE @@ -53,9 +52,9 @@ def test_snap(): assert substr in cli -def snap_version(): +def test_snap_version(): """Test SNAP version""" snap_version = snap.get_snap_version() - assert misc.compare( - snap_version, Version("10.0.0"), "==" + assert misc.compare_version( + snap_version, "10.0.0", "==" ), f"Unexpected SNAP version: {snap_version}." diff --git a/sertit/misc.py b/sertit/misc.py index 48052af..feef5cd 100644 --- a/sertit/misc.py +++ b/sertit/misc.py @@ -25,6 +25,8 @@ from enum import Enum, unique from typing import Any, Union +from packaging.version import Version + from sertit import AnyPath from sertit.logs import SU_NAME from sertit.types import AnyPathStrType @@ -470,7 +472,9 @@ def compare(a, b, operation: str) -> bool: return ops[operation](a, b) -def compare_version(lib: str, version_to_check: str, operator: str) -> bool: +def compare_version( + lib: Union[str, Version], version_to_check: str, operator: str +) -> bool: """ Compare the version of a librarie to a reference, giving the operator. @@ -489,6 +493,13 @@ def compare_version(lib: str, version_to_check: str, operator: str) -> bool: """ from importlib.metadata import version - from packaging.version import Version + if isinstance(lib, Version): + lib_version = lib + elif isinstance(lib, str): + lib_version = Version(version(lib)) + else: + raise TypeError( + "'lib' should either be the name of your library as a string or directly a 'Version' object." + ) - return compare(Version(version(lib)), Version(version_to_check), operator) + return compare(lib_version, Version(version_to_check), operator)