Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/fea-mdbuffer' into fea-mdbuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
wphicks committed Nov 22, 2023
2 parents e6ce9c3 + 8e9ade2 commit 40d75cc
Show file tree
Hide file tree
Showing 39 changed files with 2,239 additions and 386 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ on:
default: nightly

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }}
cancel-in-progress: true

jobs:
Expand Down
201 changes: 124 additions & 77 deletions ci/checks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import os
import sys

import git

SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))

# Add the scripts dir for gitutils
Expand All @@ -37,7 +39,12 @@
re.compile(r"setup[.]cfg$"),
re.compile(r"meta[.]yaml$")
]
ExemptFiles = ["cpp/include/raft/neighbors/detail/faiss_select/"]
ExemptFiles = [
re.compile("cpp/include/raft/neighbors/detail/faiss_select/"),
re.compile("cpp/include/raft/thirdparty/"),
re.compile("docs/source/sphinxext/github_link.py"),
re.compile("cpp/cmake/modules/FindAVX.cmake")
]

# this will break starting at year 10000, which is probably OK :)
CheckSimple = re.compile(
Expand All @@ -48,10 +55,12 @@


def checkThisFile(f):
# This check covers things like symlinks which point to files that DNE
if not (os.path.exists(f)):
return False
if gitutils and gitutils.isFileEmpty(f):
if isinstance(f, git.Diff):
if f.deleted_file or f.b_blob.size == 0:
return False
f = f.b_path
elif not os.path.exists(f) or os.stat(f).st_size == 0:
# This check covers things like symlinks which point to files that DNE
return False
for exempt in ExemptFiles:
if exempt.search(f):
Expand All @@ -62,36 +71,90 @@ def checkThisFile(f):
return False


def modifiedFiles():
"""Get a set of all modified files, as Diff objects.
The files returned have been modified in git since the merge base of HEAD
and the upstream of the target branch. We return the Diff objects so that
we can read only the staged changes.
"""
repo = git.Repo()
# Use the environment variable TARGET_BRANCH or RAPIDS_BASE_BRANCH (defined in CI) if possible
target_branch = os.environ.get("TARGET_BRANCH", os.environ.get("RAPIDS_BASE_BRANCH"))
if target_branch is None:
# Fall back to the closest branch if not on CI
target_branch = repo.git.describe(
all=True, tags=True, match="branch-*", abbrev=0
).lstrip("heads/")

upstream_target_branch = None
if target_branch in repo.heads:
# Use the tracking branch of the local reference if it exists. This
# returns None if no tracking branch is set.
upstream_target_branch = repo.heads[target_branch].tracking_branch()
if upstream_target_branch is None:
# Fall back to the remote with the newest target_branch. This code
# path is used on CI because the only local branch reference is
# current-pr-branch, and thus target_branch is not in repo.heads.
# This also happens if no tracking branch is defined for the local
# target_branch. We use the remote with the latest commit if
# multiple remotes are defined.
candidate_branches = [
remote.refs[target_branch] for remote in repo.remotes
if target_branch in remote.refs
]
if len(candidate_branches) > 0:
upstream_target_branch = sorted(
candidate_branches,
key=lambda branch: branch.commit.committed_datetime,
)[-1]
else:
# If no remotes are defined, try to use the local version of the
# target_branch. If this fails, the repo configuration must be very
# strange and we can fix this script on a case-by-case basis.
upstream_target_branch = repo.heads[target_branch]
merge_base = repo.merge_base("HEAD", upstream_target_branch.commit)[0]
diff = merge_base.diff()
changed_files = {f for f in diff if f.b_path is not None}
return changed_files


def getCopyrightYears(line):
res = CheckSimple.search(line)
if res:
return (int(res.group(1)), int(res.group(1)))
return int(res.group(1)), int(res.group(1))
res = CheckDouble.search(line)
if res:
return (int(res.group(1)), int(res.group(2)))
return (None, None)
return int(res.group(1)), int(res.group(2))
return None, None


def replaceCurrentYear(line, start, end):
# first turn a simple regex into double (if applicable). then update years
res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line)
res = CheckDouble.sub(
r"Copyright (c) {:04d}-{:04d}, NVIDIA CORPORATION".format(start, end),
res)
rf"Copyright (c) {start:04d}-{end:04d}, NVIDIA CORPORATION",
res,
)
return res


def checkCopyright(f, update_current_year):
"""
Checks for copyright headers and their years
"""
"""Checks for copyright headers and their years."""
errs = []
thisYear = datetime.datetime.now().year
lineNum = 0
crFound = False
yearMatched = False
with io.open(f, "r", encoding="utf-8") as fp:
lines = fp.readlines()

if isinstance(f, git.Diff):
path = f.b_path
lines = f.b_blob.data_stream.read().decode().splitlines(keepends=True)
else:
path = f
with open(f, encoding="utf-8") as fp:
lines = fp.readlines()

for line in lines:
lineNum += 1
start, end = getCopyrightYears(line)
Expand All @@ -100,20 +163,19 @@ def checkCopyright(f, update_current_year):
crFound = True
if start > end:
e = [
f,
path,
lineNum,
"First year after second year in the copyright "
"header (manual fix required)",
None
None,
]
errs.append(e)
if thisYear < start or thisYear > end:
elif thisYear < start or thisYear > end:
e = [
f,
path,
lineNum,
"Current year not included in the "
"copyright header",
None
"Current year not included in the copyright header",
None,
]
if thisYear < start:
e[-1] = replaceCurrentYear(line, thisYear, end)
Expand All @@ -122,15 +184,14 @@ def checkCopyright(f, update_current_year):
errs.append(e)
else:
yearMatched = True
fp.close()
# copyright header itself not found
if not crFound:
e = [
f,
path,
0,
"Copyright header missing or formatted incorrectly "
"(manual fix required)",
None
None,
]
errs.append(e)
# even if the year matches a copyright header, make the check pass
Expand All @@ -140,21 +201,19 @@ def checkCopyright(f, update_current_year):
if update_current_year:
errs_update = [x for x in errs if x[-1] is not None]
if len(errs_update) > 0:
print("File: {}. Changing line(s) {}".format(
f, ', '.join(str(x[1]) for x in errs if x[-1] is not None)))
lines_changed = ", ".join(str(x[1]) for x in errs_update)
print(f"File: {path}. Changing line(s) {lines_changed}")
for _, lineNum, __, replacement in errs_update:
lines[lineNum - 1] = replacement
with io.open(f, "w", encoding="utf-8") as out_file:
for new_line in lines:
out_file.write(new_line)
errs = [x for x in errs if x[-1] is None]
with open(path, "w", encoding="utf-8") as out_file:
out_file.writelines(lines)

return errs


def getAllFilesUnderDir(root, pathFilter=None):
retList = []
for (dirpath, dirnames, filenames) in os.walk(root):
for dirpath, dirnames, filenames in os.walk(root):
for fn in filenames:
filePath = os.path.join(dirpath, fn)
if pathFilter(filePath):
Expand All @@ -169,49 +228,37 @@ def checkCopyright_main():
it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch"
"""
retVal = 0
global ExemptFiles

argparser = argparse.ArgumentParser(
"Checks for a consistent copyright header in git's modified files")
argparser.add_argument("--update-current-year",
dest='update_current_year',
action="store_true",
required=False,
help="If set, "
"update the current year if a header "
"is already present and well formatted.")
argparser.add_argument("--git-modified-only",
dest='git_modified_only',
action="store_true",
required=False,
help="If set, "
"only files seen as modified by git will be "
"processed.")
argparser.add_argument("--exclude",
dest='exclude',
action="append",
required=False,
default=["python/cuml/_thirdparty/",
"cpp/include/raft/thirdparty/",
"cpp/cmake/modules/FindAVX.cmake"],
help=("Exclude the paths specified (regexp). "
"Can be specified multiple times."))

(args, dirs) = argparser.parse_known_args()
try:
ExemptFiles = ExemptFiles + [pathName for pathName in args.exclude]
ExemptFiles = [re.compile(file) for file in ExemptFiles]
except re.error as reException:
print("Regular expression error:")
print(reException)
return 1
"Checks for a consistent copyright header in git's modified files"
)
argparser.add_argument(
"--update-current-year",
dest="update_current_year",
action="store_true",
required=False,
help="If set, "
"update the current year if a header is already "
"present and well formatted.",
)
argparser.add_argument(
"--git-modified-only",
dest="git_modified_only",
action="store_true",
required=False,
help="If set, "
"only files seen as modified by git will be "
"processed.",
)

args, dirs = argparser.parse_known_args()

if args.git_modified_only:
files = gitutils.modifiedFiles(pathFilter=checkThisFile)
files = [f for f in modifiedFiles() if checkThisFile(f)]
else:
files = []
for d in [os.path.abspath(d) for d in dirs]:
if not (os.path.isdir(d)):
if not os.path.isdir(d):
raise ValueError(f"{d} is not a directory.")
files += getAllFilesUnderDir(d, pathFilter=checkThisFile)

Expand All @@ -220,24 +267,24 @@ def checkCopyright_main():
errors += checkCopyright(f, args.update_current_year)

if len(errors) > 0:
print("Copyright headers incomplete in some of the files!")
if any(e[-1] is None for e in errors):
print("Copyright headers incomplete in some of the files!")
for e in errors:
print(" %s:%d Issue: %s" % (e[0], e[1], e[2]))
print("")
n_fixable = sum(1 for e in errors if e[-1] is not None)
path_parts = os.path.abspath(__file__).split(os.sep)
file_from_repo = os.sep.join(path_parts[path_parts.index("ci"):])
if n_fixable > 0:
print(("You can run `python {} --git-modified-only "
"--update-current-year` to fix {} of these "
"errors.\n").format(file_from_repo, n_fixable))
file_from_repo = os.sep.join(path_parts[path_parts.index("ci") :])
if n_fixable > 0 and not args.update_current_year:
print(
f"You can run `python {file_from_repo} --git-modified-only "
"--update-current-year` and stage the results in git to "
f"fix {n_fixable} of these errors.\n"
)
retVal = 1
else:
print("Copyright check passed")

return retVal


if __name__ == "__main__":
import sys
sys.exit(checkCopyright_main())
1 change: 0 additions & 1 deletion ci/release/update-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ for FILE in python/*/pyproject.toml; do
for DEP in "${DEPENDENCIES[@]}"; do
sed_runner "/\"${DEP}==/ s/==.*\"/==${NEXT_SHORT_TAG_PEP440}.*\"/g" ${FILE}
done
sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" "${FILE}"
sed_runner "/\"ucx-py==/ s/==.*\"/==${NEXT_UCX_PY_SHORT_TAG_PEP440}.*\"/g" ${FILE}
done

Expand Down
15 changes: 15 additions & 0 deletions ci/wheel_smoke_test_pylibraft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
from scipy.spatial.distance import cdist

Expand Down
2 changes: 1 addition & 1 deletion conda/recipes/pylibraft/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ requirements:
- cython >=3.0.0
- libraft {{ version }}
- libraft-headers {{ version }}
- numpy >=1.21
- python x.x
- rmm ={{ minor_version }}
- scikit-build >=0.13.1
Expand All @@ -60,6 +59,7 @@ requirements:
{% endif %}
- libraft {{ version }}
- libraft-headers {{ version }}
- numpy >=1.21
- python x.x
- rmm ={{ minor_version }}

Expand Down
5 changes: 3 additions & 2 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
)
set(RAFT_ANN_BENCH_USE_RAFT ON)
endif()
Expand Down Expand Up @@ -263,7 +263,8 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib
LINKS
raft::compiled
CXXFLAGS "${HNSW_CXX_FLAGS}"
CXXFLAGS
"${HNSW_CXX_FLAGS}"
)
endif()

Expand Down
Loading

0 comments on commit 40d75cc

Please sign in to comment.