From adc54a7fe8c3b1142531f89834334efb8e0d6d2d Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 17 May 2024 18:48:30 -0400 Subject: [PATCH] Check -cu* suffixed packages --- src/rapids_pre_commit_hooks/alpha_spec.py | 33 ++++++++++++++++++- .../test_alpha_spec.py | 21 ++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/rapids_pre_commit_hooks/alpha_spec.py b/src/rapids_pre_commit_hooks/alpha_spec.py index 58f1cf7..d7f5b90 100644 --- a/src/rapids_pre_commit_hooks/alpha_spec.py +++ b/src/rapids_pre_commit_hooks/alpha_spec.py @@ -47,6 +47,29 @@ "distributed-ucxx", } +RAPIDS_CUDA_VERSIONED_PACKAGES = { + "rmm", + "pylibcugraphops", + "pylibcugraph", + "nx-cugraph", + "dask-cudf", + "cuspatial", + "cuproj", + "cuml", + "cugraph", + "cudf", + "ptxcompiler", + "cubinlinker", + "cugraph-dgl", + "cugraph-pyg", + "cugraph-equivariant", + "raft-dask", + "pylibwholegraph", + "pylibraft", + "cuxfilter", + "cucim", +} + ALPHA_SPECIFIER = ">=0.0.0a0" ALPHA_SPEC_OUTPUT_TYPES = { @@ -55,10 +78,18 @@ } +def is_rapids_cuda_versioned_package(name): + return any( + name.startswith(f"{package}-cu") for package in RAPIDS_CUDA_VERSIONED_PACKAGES + ) + + def check_package_spec(linter, args, node): if node.tag == "tag:yaml.org,2002:str": req = Requirement(node.value) - if req.name in RAPIDS_VERSIONED_PACKAGES: + if req.name in RAPIDS_VERSIONED_PACKAGES or is_rapids_cuda_versioned_package( + req.name + ): has_alpha_spec = any( filter(lambda s: str(s) == ALPHA_SPECIFIER, req.specifier) ) diff --git a/test/rapids_pre_commit_hooks/test_alpha_spec.py b/test/rapids_pre_commit_hooks/test_alpha_spec.py index 322be56..687bee5 100644 --- a/test/rapids_pre_commit_hooks/test_alpha_spec.py +++ b/test/rapids_pre_commit_hooks/test_alpha_spec.py @@ -36,6 +36,27 @@ for p in alpha_spec.RAPIDS_VERSIONED_PACKAGES ) ), + *chain( + *( + [ + (f"{p}-cu12", f"{p}-cu12", "development", f"{p}-cu12>=0.0.0a0"), + (f"{p}-cu12", f"{p}-cu12", "release", None), + (f"{p}-cu12", f"{p}-cu12>=0.0.0a0", "development", None), + (f"{p}-cu12", f"{p}-cu12>=0.0.0a0", "release", f"{p}-cu12"), + ] + for p in alpha_spec.RAPIDS_CUDA_VERSIONED_PACKAGES + ) + ), + *chain( + *( + [ + (f"{p}-cu12", f"{p}-cu12", "development", None), + (f"{p}-cu12", f"{p}-cu12>=0.0.0a0", "release", None), + ] + for p in alpha_spec.RAPIDS_VERSIONED_PACKAGES + - alpha_spec.RAPIDS_CUDA_VERSIONED_PACKAGES + ) + ), ("cuml", "cuml>=24.04,<=24.06", "development", "cuml<=24.06,>=0.0.0a0,>=24.04"), ("cuml", "cuml>=24.04,<=24.06,>=0.0.0a0", "release", "cuml<=24.06,>=24.04"), ("packaging", "packaging", "development", None),