Skip to content

Commit

Permalink
Fix ci/checks/copyright.py to mirror RAPIDS reference (#2008)
Browse files Browse the repository at this point in the history
At some point, `ci/checks/copyright.py` implementation diverged from other RAPIDS repos. This PR uses https://github.com/rapidsai/cudf/blob/branch-24.02/ci/checks/copyright.py as a reference to update the script. This new implementation uses git history to figure out the year in which a file was last modified and then adds that to the copyright year.

The PR also:
1. Excludes thirdparty files/licences
2. Adds missing copyright headers

Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Bradley Dice (https://github.com/bdice)
  - Jake Awe (https://github.com/AyodeAwe)

URL: #2008
  • Loading branch information
divyegala authored Nov 17, 2023
1 parent b5b5202 commit 7e307b9
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 77 deletions.
201 changes: 124 additions & 77 deletions ci/checks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import os
import sys

import git

SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))

# Add the scripts dir for gitutils
Expand All @@ -37,7 +39,12 @@
re.compile(r"setup[.]cfg$"),
re.compile(r"meta[.]yaml$")
]
ExemptFiles = ["cpp/include/raft/neighbors/detail/faiss_select/"]
ExemptFiles = [
re.compile("cpp/include/raft/neighbors/detail/faiss_select/"),
re.compile("cpp/include/raft/thirdparty/"),
re.compile("docs/source/sphinxext/github_link.py"),
re.compile("cpp/cmake/modules/FindAVX.cmake")
]

# this will break starting at year 10000, which is probably OK :)
CheckSimple = re.compile(
Expand All @@ -48,10 +55,12 @@


def checkThisFile(f):
# This check covers things like symlinks which point to files that DNE
if not (os.path.exists(f)):
return False
if gitutils and gitutils.isFileEmpty(f):
if isinstance(f, git.Diff):
if f.deleted_file or f.b_blob.size == 0:
return False
f = f.b_path
elif not os.path.exists(f) or os.stat(f).st_size == 0:
# This check covers things like symlinks which point to files that DNE
return False
for exempt in ExemptFiles:
if exempt.search(f):
Expand All @@ -62,36 +71,90 @@ def checkThisFile(f):
return False


def modifiedFiles():
"""Get a set of all modified files, as Diff objects.
The files returned have been modified in git since the merge base of HEAD
and the upstream of the target branch. We return the Diff objects so that
we can read only the staged changes.
"""
repo = git.Repo()
# Use the environment variable TARGET_BRANCH or RAPIDS_BASE_BRANCH (defined in CI) if possible
target_branch = os.environ.get("TARGET_BRANCH", os.environ.get("RAPIDS_BASE_BRANCH"))
if target_branch is None:
# Fall back to the closest branch if not on CI
target_branch = repo.git.describe(
all=True, tags=True, match="branch-*", abbrev=0
).lstrip("heads/")

upstream_target_branch = None
if target_branch in repo.heads:
# Use the tracking branch of the local reference if it exists. This
# returns None if no tracking branch is set.
upstream_target_branch = repo.heads[target_branch].tracking_branch()
if upstream_target_branch is None:
# Fall back to the remote with the newest target_branch. This code
# path is used on CI because the only local branch reference is
# current-pr-branch, and thus target_branch is not in repo.heads.
# This also happens if no tracking branch is defined for the local
# target_branch. We use the remote with the latest commit if
# multiple remotes are defined.
candidate_branches = [
remote.refs[target_branch] for remote in repo.remotes
if target_branch in remote.refs
]
if len(candidate_branches) > 0:
upstream_target_branch = sorted(
candidate_branches,
key=lambda branch: branch.commit.committed_datetime,
)[-1]
else:
# If no remotes are defined, try to use the local version of the
# target_branch. If this fails, the repo configuration must be very
# strange and we can fix this script on a case-by-case basis.
upstream_target_branch = repo.heads[target_branch]
merge_base = repo.merge_base("HEAD", upstream_target_branch.commit)[0]
diff = merge_base.diff()
changed_files = {f for f in diff if f.b_path is not None}
return changed_files


def getCopyrightYears(line):
res = CheckSimple.search(line)
if res:
return (int(res.group(1)), int(res.group(1)))
return int(res.group(1)), int(res.group(1))
res = CheckDouble.search(line)
if res:
return (int(res.group(1)), int(res.group(2)))
return (None, None)
return int(res.group(1)), int(res.group(2))
return None, None


def replaceCurrentYear(line, start, end):
# first turn a simple regex into double (if applicable). then update years
res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line)
res = CheckDouble.sub(
r"Copyright (c) {:04d}-{:04d}, NVIDIA CORPORATION".format(start, end),
res)
rf"Copyright (c) {start:04d}-{end:04d}, NVIDIA CORPORATION",
res,
)
return res


def checkCopyright(f, update_current_year):
"""
Checks for copyright headers and their years
"""
"""Checks for copyright headers and their years."""
errs = []
thisYear = datetime.datetime.now().year
lineNum = 0
crFound = False
yearMatched = False
with io.open(f, "r", encoding="utf-8") as fp:
lines = fp.readlines()

if isinstance(f, git.Diff):
path = f.b_path
lines = f.b_blob.data_stream.read().decode().splitlines(keepends=True)
else:
path = f
with open(f, encoding="utf-8") as fp:
lines = fp.readlines()

for line in lines:
lineNum += 1
start, end = getCopyrightYears(line)
Expand All @@ -100,20 +163,19 @@ def checkCopyright(f, update_current_year):
crFound = True
if start > end:
e = [
f,
path,
lineNum,
"First year after second year in the copyright "
"header (manual fix required)",
None
None,
]
errs.append(e)
if thisYear < start or thisYear > end:
elif thisYear < start or thisYear > end:
e = [
f,
path,
lineNum,
"Current year not included in the "
"copyright header",
None
"Current year not included in the copyright header",
None,
]
if thisYear < start:
e[-1] = replaceCurrentYear(line, thisYear, end)
Expand All @@ -122,15 +184,14 @@ def checkCopyright(f, update_current_year):
errs.append(e)
else:
yearMatched = True
fp.close()
# copyright header itself not found
if not crFound:
e = [
f,
path,
0,
"Copyright header missing or formatted incorrectly "
"(manual fix required)",
None
None,
]
errs.append(e)
# even if the year matches a copyright header, make the check pass
Expand All @@ -140,21 +201,19 @@ def checkCopyright(f, update_current_year):
if update_current_year:
errs_update = [x for x in errs if x[-1] is not None]
if len(errs_update) > 0:
print("File: {}. Changing line(s) {}".format(
f, ', '.join(str(x[1]) for x in errs if x[-1] is not None)))
lines_changed = ", ".join(str(x[1]) for x in errs_update)
print(f"File: {path}. Changing line(s) {lines_changed}")
for _, lineNum, __, replacement in errs_update:
lines[lineNum - 1] = replacement
with io.open(f, "w", encoding="utf-8") as out_file:
for new_line in lines:
out_file.write(new_line)
errs = [x for x in errs if x[-1] is None]
with open(path, "w", encoding="utf-8") as out_file:
out_file.writelines(lines)

return errs


def getAllFilesUnderDir(root, pathFilter=None):
retList = []
for (dirpath, dirnames, filenames) in os.walk(root):
for dirpath, dirnames, filenames in os.walk(root):
for fn in filenames:
filePath = os.path.join(dirpath, fn)
if pathFilter(filePath):
Expand All @@ -169,49 +228,37 @@ def checkCopyright_main():
it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch"
"""
retVal = 0
global ExemptFiles

argparser = argparse.ArgumentParser(
"Checks for a consistent copyright header in git's modified files")
argparser.add_argument("--update-current-year",
dest='update_current_year',
action="store_true",
required=False,
help="If set, "
"update the current year if a header "
"is already present and well formatted.")
argparser.add_argument("--git-modified-only",
dest='git_modified_only',
action="store_true",
required=False,
help="If set, "
"only files seen as modified by git will be "
"processed.")
argparser.add_argument("--exclude",
dest='exclude',
action="append",
required=False,
default=["python/cuml/_thirdparty/",
"cpp/include/raft/thirdparty/",
"cpp/cmake/modules/FindAVX.cmake"],
help=("Exclude the paths specified (regexp). "
"Can be specified multiple times."))

(args, dirs) = argparser.parse_known_args()
try:
ExemptFiles = ExemptFiles + [pathName for pathName in args.exclude]
ExemptFiles = [re.compile(file) for file in ExemptFiles]
except re.error as reException:
print("Regular expression error:")
print(reException)
return 1
"Checks for a consistent copyright header in git's modified files"
)
argparser.add_argument(
"--update-current-year",
dest="update_current_year",
action="store_true",
required=False,
help="If set, "
"update the current year if a header is already "
"present and well formatted.",
)
argparser.add_argument(
"--git-modified-only",
dest="git_modified_only",
action="store_true",
required=False,
help="If set, "
"only files seen as modified by git will be "
"processed.",
)

args, dirs = argparser.parse_known_args()

if args.git_modified_only:
files = gitutils.modifiedFiles(pathFilter=checkThisFile)
files = [f for f in modifiedFiles() if checkThisFile(f)]
else:
files = []
for d in [os.path.abspath(d) for d in dirs]:
if not (os.path.isdir(d)):
if not os.path.isdir(d):
raise ValueError(f"{d} is not a directory.")
files += getAllFilesUnderDir(d, pathFilter=checkThisFile)

Expand All @@ -220,24 +267,24 @@ def checkCopyright_main():
errors += checkCopyright(f, args.update_current_year)

if len(errors) > 0:
print("Copyright headers incomplete in some of the files!")
if any(e[-1] is None for e in errors):
print("Copyright headers incomplete in some of the files!")
for e in errors:
print(" %s:%d Issue: %s" % (e[0], e[1], e[2]))
print("")
n_fixable = sum(1 for e in errors if e[-1] is not None)
path_parts = os.path.abspath(__file__).split(os.sep)
file_from_repo = os.sep.join(path_parts[path_parts.index("ci"):])
if n_fixable > 0:
print(("You can run `python {} --git-modified-only "
"--update-current-year` to fix {} of these "
"errors.\n").format(file_from_repo, n_fixable))
file_from_repo = os.sep.join(path_parts[path_parts.index("ci") :])
if n_fixable > 0 and not args.update_current_year:
print(
f"You can run `python {file_from_repo} --git-modified-only "
"--update-current-year` and stage the results in git to "
f"fix {n_fixable} of these errors.\n"
)
retVal = 1
else:
print("Copyright check passed")

return retVal


if __name__ == "__main__":
import sys
sys.exit(checkCopyright_main())
15 changes: 15 additions & 0 deletions ci/wheel_smoke_test_pylibraft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
from scipy.spatial.distance import cdist

Expand Down

0 comments on commit 7e307b9

Please sign in to comment.