Skip to content

Commit

Permalink
Add copyright check
Browse files Browse the repository at this point in the history
Fixes: #2
  • Loading branch information
KyleFromNVIDIA committed Jan 20, 2024
1 parent ad67587 commit f307716
Show file tree
Hide file tree
Showing 3 changed files with 575 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ dependencies = [

[project.optional-dependencies]
dev = [
"freezegun",
"gitpython",
"pytest",
]

[project.scripts]
verify-conda-yes = "rapids_pre_commit_hooks.shell.verify_conda_yes:main"
verify-copyright = "rapids_pre_commit_hooks.copyright:main"

[tool.setuptools]
packages = { "find" = { where = ["src"] } }
Expand Down
136 changes: 136 additions & 0 deletions src/rapids_pre_commit_hooks/copyright.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import functools
import re

import git

from .lint import LintMain

COPYRIGHT_RE = re.compile(
r"Copyright *(?:\(c\))? *(?P<years>(?P<first_year>\d{4})(-(?P<last_year>\d{4}))?),?"
r" *NVIDIA C(?:ORPORATION|orporation)"
)


class ConflictingFilesError(RuntimeError):
pass


def match_copyright(content):
return list(COPYRIGHT_RE.finditer(content))


def strip_copyright(content, copyright_matches):
lines = []

def append_stripped(start, item):
lines.append(content[start : item.start()])
return item.end()

start = functools.reduce(append_stripped, copyright_matches, 0)
lines.append(content[start:])
return lines


def apply_copyright_check(linter, old_content):
if linter.content != old_content:
current_year = datetime.datetime.now().year
new_copyright_matches = match_copyright(linter.content)

if old_content is not None:
old_copyright_matches = match_copyright(old_content)

if old_content is not None and strip_copyright(
old_content, old_copyright_matches
) == strip_copyright(linter.content, new_copyright_matches):
for old_match, new_match in zip(
old_copyright_matches, new_copyright_matches
):
if old_match.group() != new_match.group():
if old_match.group("years") == new_match.group("years"):
warning_pos = new_match.span()
else:
warning_pos = new_match.span("years")
linter.add_warning(
warning_pos,
"copyright is not out of date and should not be updated",
).add_replacement(new_match.span(), old_match.group())
else:
for match in new_copyright_matches:
if (
int(match.group("last_year") or match.group("first_year"))
< current_year
):
linter.add_warning(
match.span("years"), "copyright is out of date"
).add_replacement(
match.span(),
f"Copyright (c) {match.group('first_year')}-{current_year}"
", NVIDIA CORPORATION",
)


def get_target_branch(repo):
# TODO
raise NotImplementedError


def get_changed_files(repo, target_branch):
changed_files = {}

diffs = target_branch.commit.diff(
other=None,
merge_base=True,
find_copies=True,
find_copies_harder=True,
find_renames=True,
)
for diff in diffs:
if diff.change_type == "A":
changed_files[diff.b_path] = None
elif diff.change_type != "D":
changed_files[diff.b_path] = diff.a_blob

changed_files.update({f: None for f in repo.untracked_files})
return changed_files


def check_copyright():
repo = git.Repo()
target_branch = get_target_branch(repo)
changed_files = get_changed_files(repo, target_branch)

def the_check(linter, args):
try:
changed_file = changed_files[linter.filename]
except KeyError:
return

old_content = changed_file.data_stream.read().decode("utf-8")
apply_copyright_check(linter, old_content)

return the_check


def main():
m = LintMain()
with m.execute() as ctx:
ctx.add_check(check_copyright())


if __name__ == "__main__":
main()
Loading

0 comments on commit f307716

Please sign in to comment.