From 6841fcfbbb409a1ac3a7e3b8cf0de594dfdaf720 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 28 May 2024 15:15:17 -0400 Subject: [PATCH] Pretty-sort version specifier --- src/rapids_pre_commit_hooks/alpha_spec.py | 48 +++++++++++++++---- .../test_alpha_spec.py | 6 +-- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/src/rapids_pre_commit_hooks/alpha_spec.py b/src/rapids_pre_commit_hooks/alpha_spec.py index aed845b..95515d2 100644 --- a/src/rapids_pre_commit_hooks/alpha_spec.py +++ b/src/rapids_pre_commit_hooks/alpha_spec.py @@ -13,11 +13,10 @@ # limitations under the License. import re -from functools import reduce +from functools import total_ordering import yaml from packaging.requirements import Requirement -from packaging.specifiers import SpecifierSet from .lint import LintMain @@ -74,6 +73,29 @@ def is_rapids_cuda_suffixed_package(name): def check_package_spec(linter, args, node): + @total_ordering + class SpecPriority: + def __init__(self, spec): + self.spec = spec + + def __eq__(self, other): + return self.spec == other.spec + + def __lt__(self, other): + if self.spec == other.spec: + return False + if self.spec == ALPHA_SPECIFIER: + return False + if other.spec == ALPHA_SPECIFIER: + return True + return self.sort_str() < other.sort_str() + + def sort_str(self): + return "".join(c for c in self.spec if c not in "<>=") + + def create_specifier_string(specifiers): + return ",".join(sorted(specifiers, key=SpecPriority)) + if node_has_type(node, "str"): req = Requirement(node.value) if req.name in RAPIDS_ALPHA_SPEC_PACKAGES or is_rapids_cuda_suffixed_package( @@ -81,24 +103,30 @@ def check_package_spec(linter, args, node): ): has_alpha_spec = any(str(s) == ALPHA_SPECIFIER for s in req.specifier) if args.mode == "development" and not has_alpha_spec: - req.specifier &= ALPHA_SPECIFIER linter.add_warning( (node.start_mark.index, node.end_mark.index), f"add alpha spec for RAPIDS package {req.name}", ).add_replacement( - (node.start_mark.index, node.end_mark.index), str(req) + (node.start_mark.index, node.end_mark.index), + str( + req.name + + create_specifier_string( + {str(s) for s in req.specifier} | {ALPHA_SPECIFIER} + ) + ), ) elif args.mode == "release" and has_alpha_spec: - req.specifier = reduce( - lambda ss, s: ss & str(s), - filter(lambda s: str(s) != ALPHA_SPECIFIER, req.specifier), - SpecifierSet(), - ) linter.add_warning( (node.start_mark.index, node.end_mark.index), f"remove alpha spec for RAPIDS package {req.name}", ).add_replacement( - (node.start_mark.index, node.end_mark.index), str(req) + (node.start_mark.index, node.end_mark.index), + str( + req.name + + create_specifier_string( + {str(s) for s in req.specifier} - {ALPHA_SPECIFIER} + ) + ), ) diff --git a/test/rapids_pre_commit_hooks/test_alpha_spec.py b/test/rapids_pre_commit_hooks/test_alpha_spec.py index f75a547..9661cf5 100644 --- a/test/rapids_pre_commit_hooks/test_alpha_spec.py +++ b/test/rapids_pre_commit_hooks/test_alpha_spec.py @@ -87,13 +87,13 @@ def test_is_rapids_cuda_suffixed_package(name, is_suffixed): - alpha_spec.RAPIDS_CUDA_SUFFIXED_PACKAGES ) ), - ("cuml", "cuml>=24.04,<24.06", "development", "cuml<24.06,>=0.0.0a0,>=24.04"), - ("cuml", "cuml>=24.04,<24.06,>=0.0.0a0", "release", "cuml<24.06,>=24.04"), + ("cuml", "cuml>=24.04,<24.06", "development", "cuml>=24.04,<24.06,>=0.0.0a0"), + ("cuml", "cuml>=24.04,<24.06,>=0.0.0a0", "release", "cuml>=24.04,<24.06"), ( "cuml", "&cuml cuml>=24.04,<24.06,>=0.0.0a0", "release", - "cuml<24.06,>=24.04", + "cuml>=24.04,<24.06", ), ("packaging", "packaging", "development", None), ],