diff --git a/.pre-commit-hooks.yaml b/.pre-commit-hooks.yaml index 9cfc074..89b5ec1 100644 --- a/.pre-commit-hooks.yaml +++ b/.pre-commit-hooks.yaml @@ -18,3 +18,16 @@ language: python types: [shell] args: [--fix] +- id: verify-copyright + name: copyright headers + description: make sure copyright headers are up to date + entry: verify-copyright + language: python + files: | + (?x) + [.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$| + CMakeLists[.]txt$| + CMakeLists_standalone[.]txt$| + setup[.]cfg$| + meta[.]yaml$ + args: [--fix] diff --git a/ci/build-test.sh b/ci/build-test.sh index ce157d1..a784659 100755 --- a/ci/build-test.sh +++ b/ci/build-test.sh @@ -3,13 +3,13 @@ set -ue -pip install build pytest +pip install build python -m build . for PKG in dist/*; do echo "$PKG" pip uninstall -y rapids-pre-commit-hooks - pip install "$PKG" + pip install "$PKG[test]" pytest done diff --git a/pyproject.toml b/pyproject.toml index 487bf24..7ac84b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,15 +33,18 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ "bashlex", + "gitpython", ] [project.optional-dependencies] -dev = [ +test = [ + "freezegun", "pytest", ] [project.scripts] verify-conda-yes = "rapids_pre_commit_hooks.shell.verify_conda_yes:main" +verify-copyright = "rapids_pre_commit_hooks.copyright:main" [tool.setuptools] packages = { "find" = { where = ["src"] } } diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py new file mode 100644 index 0000000..8849c92 --- /dev/null +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -0,0 +1,213 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import functools +import os +import re + +import git + +from .lint import LintMain + +COPYRIGHT_RE = re.compile( + r"Copyright *(?:\(c\))? *(?P(?P\d{4})(-(?P\d{4}))?),?" + r" *NVIDIA C(?:ORPORATION|orporation)" +) +BRANCH_RE = re.compile(r"^branch-(?P[0-9]+)\.(?P[0-9]+)$") + + +class ConflictingFilesError(RuntimeError): + pass + + +def match_copyright(content): + return list(COPYRIGHT_RE.finditer(content)) + + +def strip_copyright(content, copyright_matches): + lines = [] + + def append_stripped(start, item): + lines.append(content[start : item.start()]) + return item.end() + + start = functools.reduce(append_stripped, copyright_matches, 0) + lines.append(content[start:]) + return lines + + +def apply_copyright_check(linter, old_content): + if linter.content != old_content: + current_year = datetime.datetime.now().year + new_copyright_matches = match_copyright(linter.content) + + if old_content is not None: + old_copyright_matches = match_copyright(old_content) + + if old_content is not None and strip_copyright( + old_content, old_copyright_matches + ) == strip_copyright(linter.content, new_copyright_matches): + for old_match, new_match in zip( + old_copyright_matches, new_copyright_matches + ): + if old_match.group() != new_match.group(): + if old_match.group("years") == new_match.group("years"): + warning_pos = new_match.span() + else: + warning_pos = new_match.span("years") + linter.add_warning( + warning_pos, + "copyright is not out of date and should not be updated", + ).add_replacement(new_match.span(), old_match.group()) + else: + if new_copyright_matches: + for match in new_copyright_matches: + if ( + int(match.group("last_year") or match.group("first_year")) + < current_year + ): + linter.add_warning( + match.span("years"), "copyright is out of date" + ).add_replacement( + match.span(), + f"Copyright (c) {match.group('first_year')}-{current_year}" + ", NVIDIA CORPORATION", + ) + else: + linter.add_warning((0, 0), "no copyright notice found") + + +def get_target_branch(repo): + # Try environment + target_branch_name = os.getenv("GITHUB_BASE_REF") + if target_branch_name: + try: + return repo.heads[target_branch_name] + except IndexError: + pass + target_branch_name = os.getenv("TARGET_BRANCH") + if target_branch_name: + try: + return repo.heads[target_branch_name] + except IndexError: + pass + target_branch_name = os.getenv("RAPIDS_BASE_BRANCH") + if target_branch_name: + try: + return repo.heads[target_branch_name] + except IndexError: + pass + + # Try config + with repo.config_reader() as r: + target_branch_name = r.get("rapidsai", "baseBranch", fallback=None) + if target_branch_name: + try: + return repo.heads[target_branch_name] + except IndexError: + pass + + # Try newest branch-xx.yy + branches = sorted( + ( + (branch, (match.group("major"), match.group("minor"))) + for branch in repo.heads + if (match := BRANCH_RE.search(branch.name)) + ), + key=lambda i: i[1], + reverse=True, + ) + try: + return branches[0][0] + except IndexError: + pass + + # Appropriate branch not found + return None + + +def get_target_branch_upstream_commit(repo): + target_branch = get_target_branch(repo) + if target_branch is None: + return repo.head.commit + + target_branch_upstream = target_branch.tracking_branch() + if target_branch_upstream: + return target_branch_upstream.commit + + def try_get_ref(remote): + try: + return remote.refs[target_branch.name] + except IndexError: + return None + + candidate_upstreams = sorted( + (upstream for remote in repo.remotes if (upstream := try_get_ref(remote))), + key=lambda upstream: upstream.commit.committed_datetime, + reverse=True, + ) + try: + return candidate_upstreams[0].commit + except IndexError: + pass + + return target_branch.commit + + +def get_changed_files(repo, target_branch_upstream_commit): + changed_files = {} + + diffs = target_branch_upstream_commit.diff( + other=None, + merge_base=True, + find_copies=True, + find_copies_harder=True, + find_renames=True, + ) + for diff in diffs: + if diff.change_type == "A": + changed_files[diff.b_path] = None + elif diff.change_type != "D": + changed_files[diff.b_path] = diff.a_blob + + changed_files.update({f: None for f in repo.untracked_files}) + return changed_files + + +def check_copyright(): + repo = git.Repo() + target_branch_upstream_commit = get_target_branch_upstream_commit(repo) + changed_files = get_changed_files(repo, target_branch_upstream_commit) + + def the_check(linter, args): + try: + changed_file = changed_files[linter.filename] + except KeyError: + return + + old_content = changed_file.data_stream.read().decode("utf-8") + apply_copyright_check(linter, old_content) + + return the_check + + +def main(): + m = LintMain() + with m.execute() as ctx: + ctx.add_check(check_copyright()) + + +if __name__ == "__main__": + main() diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py new file mode 100644 index 0000000..f83d747 --- /dev/null +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -0,0 +1,700 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os.path +import tempfile +from unittest.mock import Mock, patch + +import git +import pytest +from freezegun import freeze_time + +from rapids_pre_commit_hooks import copyright +from rapids_pre_commit_hooks.lint import Linter + + +def test_match_copyright(): + CONTENT = r""" +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2021-2024 NVIDIA CORPORATION +# Copyright 2021, NVIDIA Corporation and affiliates +""" + + re_matches = copyright.match_copyright(CONTENT) + matches = [ + { + "span": match.span(), + "years": match.span("years"), + "first_year": match.span("first_year"), + "last_year": match.span("last_year"), + } + for match in re_matches + ] + assert matches == [ + { + "span": (1, 38), + "years": (15, 19), + "first_year": (15, 19), + "last_year": (-1, -1), + }, + { + "span": (39, 81), + "years": (53, 62), + "first_year": (53, 57), + "last_year": (58, 62), + }, + { + "span": (84, 119), + "years": (94, 98), + "first_year": (94, 98), + "last_year": (-1, -1), + }, + ] + + +def test_strip_copyright(): + CONTENT = r""" +This is a line before the first copyright statement +Copyright (c) 2024 NVIDIA CORPORATION +This is a line between the first two copyright statements +Copyright (c) 2021-2024 NVIDIA CORPORATION +This is a line between the next two copyright statements +# Copyright 2021, NVIDIA Corporation and affiliates +This is a line after the last copyright statement +""" + matches = copyright.match_copyright(CONTENT) + stripped = copyright.strip_copyright(CONTENT, matches) + assert stripped == [ + "\nThis is a line before the first copyright statement\n", + "\nThis is a line between the first two copyright statements\n", + "\nThis is a line between the next two copyright statements\n# ", + " and affiliates\nThis is a line after the last copyright statement\n", + ] + + stripped = copyright.strip_copyright("No copyright here", []) + assert stripped == ["No copyright here"] + + +@freeze_time("2024-01-18") +def test_apply_copyright_check(): + def run_apply_copyright_check(old_content, new_content): + linter = Linter("file.txt", new_content) + copyright.apply_copyright_check(linter, old_content) + return linter + + expected_linter = Linter("file.txt", "No copyright notice") + expected_linter.add_warning((0, 0), "no copyright notice found") + + linter = run_apply_copyright_check(None, "No copyright notice") + assert linter.warnings == expected_linter.warnings + + linter = run_apply_copyright_check("No copyright notice", "No copyright notice") + assert linter.warnings == [] + + OLD_CONTENT = r""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +Copyright (c) 2023 NVIDIA CORPORATION +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2025 NVIDIA CORPORATION +This file has not been changed +""" + linter = run_apply_copyright_check(OLD_CONTENT, OLD_CONTENT) + assert linter.warnings == [] + + NEW_CONTENT = r""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +Copyright (c) 2023 NVIDIA CORPORATION +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2025 NVIDIA CORPORATION +This file has been changed +""" + expected_linter = Linter("file.txt", NEW_CONTENT) + expected_linter.add_warning((15, 24), "copyright is out of date").add_replacement( + (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" + ) + expected_linter.add_warning((58, 62), "copyright is out of date").add_replacement( + (44, 81), "Copyright (c) 2023-2024, NVIDIA CORPORATION" + ) + + linter = run_apply_copyright_check(OLD_CONTENT, NEW_CONTENT) + assert linter.warnings == expected_linter.warnings + + expected_linter = Linter("file.txt", NEW_CONTENT) + expected_linter.add_warning((15, 24), "copyright is out of date").add_replacement( + (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" + ) + expected_linter.add_warning((58, 62), "copyright is out of date").add_replacement( + (44, 81), "Copyright (c) 2023-2024, NVIDIA CORPORATION" + ) + + linter = run_apply_copyright_check(None, NEW_CONTENT) + assert linter.warnings == expected_linter.warnings + + NEW_CONTENT = r""" +Copyright (c) 2021-2024 NVIDIA CORPORATION +Copyright (c) 2023 NVIDIA CORPORATION +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2025 NVIDIA Corporation +This file has not been changed +""" + expected_linter = Linter("file.txt", NEW_CONTENT) + expected_linter.add_warning( + (15, 24), "copyright is not out of date and should not be updated" + ).add_replacement((1, 43), "Copyright (c) 2021-2023 NVIDIA CORPORATION") + expected_linter.add_warning( + (120, 157), "copyright is not out of date and should not be updated" + ).add_replacement((120, 157), "Copyright (c) 2025 NVIDIA CORPORATION") + + linter = run_apply_copyright_check(OLD_CONTENT, NEW_CONTENT) + assert linter.warnings == expected_linter.warnings + + +@pytest.fixture +def git_repo(): + with tempfile.TemporaryDirectory() as d: + repo = git.Repo.init(d) + with repo.config_writer() as w: + w.set_value("user", "name", "RAPIDS Test Fixtures") + w.set_value("user", "email", "testfixtures@rapids.ai") + yield repo + + +def test_get_target_branch(git_repo): + master = git_repo.head.reference + + with open(os.path.join(git_repo.working_tree_dir, "file.txt"), "w") as f: + f.write("File\n") + git_repo.index.add(["file.txt"]) + git_repo.index.commit("Initial commit") + assert copyright.get_target_branch(git_repo) is None + + branch_24_02 = git_repo.create_head("branch-24.02") + assert copyright.get_target_branch(git_repo) == branch_24_02 + + branch_24_04 = git_repo.create_head("branch-24.04") + branch_24_03 = git_repo.create_head("branch-24.03") + assert copyright.get_target_branch(git_repo) == branch_24_04 + + branch_25_01 = git_repo.create_head("branch-25.01") + assert copyright.get_target_branch(git_repo) == branch_25_01 + + with git_repo.config_writer() as w: + w.set_value("rapidsai", "baseBranch", "nonexistent") + assert copyright.get_target_branch(git_repo) == branch_25_01 + + with git_repo.config_writer() as w: + w.set_value("rapidsai", "baseBranch", "branch-24.03") + assert copyright.get_target_branch(git_repo) == branch_24_03 + + with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "nonexistent"}): + assert copyright.get_target_branch(git_repo) == branch_24_03 + + with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "master"}): + assert copyright.get_target_branch(git_repo) == master + + with patch.dict( + "os.environ", {"RAPIDS_BASE_BRANCH": "master", "TARGET_BRANCH": "nonexistent"} + ): + assert copyright.get_target_branch(git_repo) == master + + with patch.dict( + "os.environ", {"RAPIDS_BASE_BRANCH": "master", "TARGET_BRANCH": "branch-24.02"} + ): + assert copyright.get_target_branch(git_repo) == branch_24_02 + + with patch.dict( + "os.environ", + { + "RAPIDS_BASE_BRANCH": "master", + "TARGET_BRANCH": "branch-24.02", + "GITHUB_BASE_REF": "nonexistent", + }, + ): + assert copyright.get_target_branch(git_repo) == branch_24_02 + + with patch.dict( + "os.environ", + { + "RAPIDS_BASE_BRANCH": "master", + "TARGET_BRANCH": "branch-24.02", + "GITHUB_BASE_REF": "branch-24.04", + }, + ): + assert copyright.get_target_branch(git_repo) == branch_24_04 + + +def test_get_target_branch_upstream_commit(git_repo): + def fn(repo, filename): + return os.path.join(repo.working_tree_dir, filename) + + def write_file(repo, filename, contents): + with open(fn(repo, filename), "w") as f: + f.write(contents) + + # fmt: off + with tempfile.TemporaryDirectory() as remote_dir_1, \ + tempfile.TemporaryDirectory() as remote_dir_2: + # fmt: on + remote_repo_1 = git.Repo.init(remote_dir_1) + remote_repo_2 = git.Repo.init(remote_dir_2) + + remote_1_master = remote_repo_1.head.reference + + write_file(remote_repo_1, "file1.txt", "File 1") + write_file(remote_repo_1, "file2.txt", "File 2") + write_file(remote_repo_1, "file3.txt", "File 3") + write_file(remote_repo_1, "file4.txt", "File 4") + write_file(remote_repo_1, "file5.txt", "File 5") + write_file(remote_repo_1, "file6.txt", "File 6") + remote_repo_1.index.add( + [ + "file1.txt", + "file2.txt", + "file3.txt", + "file4.txt", + "file5.txt", + "file6.txt", + ] + ) + remote_repo_1.index.commit("Initial commit") + + remote_1_branch_1 = remote_repo_1.create_head( + "branch-1", remote_1_master.commit + ) + remote_repo_1.head.reference = remote_1_branch_1 + remote_repo_1.head.reset(index=True, working_tree=True) + write_file(remote_repo_1, "file1.txt", "File 1 modified") + remote_repo_1.index.add(["file1.txt"]) + remote_repo_1.index.commit("Update file1.txt") + + remote_1_branch_2 = remote_repo_1.create_head( + "branch-2", remote_1_master.commit + ) + remote_repo_1.head.reference = remote_1_branch_2 + remote_repo_1.head.reset(index=True, working_tree=True) + write_file(remote_repo_1, "file2.txt", "File 2 modified") + remote_repo_1.index.add(["file2.txt"]) + remote_repo_1.index.commit("Update file2.txt") + + remote_1_branch_3 = remote_repo_1.create_head( + "branch-3", remote_1_master.commit + ) + remote_repo_1.head.reference = remote_1_branch_3 + remote_repo_1.head.reset(index=True, working_tree=True) + write_file(remote_repo_1, "file3.txt", "File 3 modified") + remote_repo_1.index.add(["file3.txt"]) + remote_repo_1.index.commit( + "Update file3.txt", + commit_date=datetime.datetime(2025, 1, 1, tzinfo=datetime.timezone.utc), + ) + + remote_1_branch_4 = remote_repo_1.create_head( + "branch-4", remote_1_master.commit + ) + remote_repo_1.head.reference = remote_1_branch_4 + remote_repo_1.head.reset(index=True, working_tree=True) + write_file(remote_repo_1, "file4.txt", "File 4 modified") + remote_repo_1.index.add(["file4.txt"]) + remote_repo_1.index.commit( + "Update file4.txt", + commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), + ) + + remote_2_1 = remote_repo_2.create_remote("remote-1", remote_dir_1) + remote_2_1.fetch(["master"]) + remote_2_master = remote_repo_2.create_head("master", remote_2_1.refs["master"]) + + remote_2_branch_3 = remote_repo_2.create_head( + "branch-3", remote_2_master.commit + ) + remote_repo_2.head.reference = remote_2_branch_3 + remote_repo_2.head.reset(index=True, working_tree=True) + write_file(remote_repo_2, "file3.txt", "File 3 modified") + remote_repo_2.index.add(["file3.txt"]) + remote_repo_2.index.commit( + "Update file3.txt", + commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), + ) + + remote_2_branch_4 = remote_repo_2.create_head( + "branch-4", remote_2_master.commit + ) + remote_repo_2.head.reference = remote_2_branch_4 + remote_repo_2.head.reset(index=True, working_tree=True) + write_file(remote_repo_2, "file4.txt", "File 4 modified") + remote_repo_2.index.add(["file4.txt"]) + remote_repo_2.index.commit( + "Update file4.txt", + commit_date=datetime.datetime(2025, 1, 1, tzinfo=datetime.timezone.utc), + ) + + remote_2_branch_5 = remote_repo_2.create_head( + "branch-5", remote_2_master.commit + ) + remote_repo_2.head.reference = remote_2_branch_5 + remote_repo_2.head.reset(index=True, working_tree=True) + write_file(remote_repo_2, "file5.txt", "File 5 modified") + remote_repo_2.index.add(["file5.txt"]) + remote_repo_2.index.commit("Update file5.txt") + + remote_1 = git_repo.create_remote("unconventional/remote/name/1", remote_dir_1) + remote_1.fetch(["master", "branch-1", "branch-2", "branch-3", "branch-4"]) + remote_2 = git_repo.create_remote("unconventional/remote/name/2", remote_dir_2) + remote_2.fetch(["branch-3", "branch-4", "branch-5"]) + + main = git_repo.create_head("main", remote_1.refs["master"]) + + branch_1 = git_repo.create_head("branch-1-renamed", remote_1.refs["master"]) + with branch_1.config_writer() as w: + w.set_value("remote", "unconventional/remote/name/1") + w.set_value("merge", "branch-1") + git_repo.head.reference = branch_1 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove("file1.txt", working_tree=True) + git_repo.index.commit("Remove file1.txt") + + branch_2 = git_repo.create_head("branch-2", remote_1.refs["master"]) + git_repo.head.reference = branch_2 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove("file2.txt", working_tree=True) + git_repo.index.commit("Remove file2.txt") + + branch_3 = git_repo.create_head("branch-3", remote_1.refs["master"]) + git_repo.head.reference = branch_3 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove("file3.txt", working_tree=True) + git_repo.index.commit("Remove file3.txt") + + branch_4 = git_repo.create_head("branch-4", remote_1.refs["master"]) + git_repo.head.reference = branch_4 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove(["file4.txt"], working_tree=True) + git_repo.index.commit("Remove file4.txt") + + branch_5 = git_repo.create_head("branch-5", remote_1.refs["master"]) + git_repo.head.reference = branch_5 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove(["file5.txt"], working_tree=True) + git_repo.index.commit("Remove file5.txt") + + branch_6 = git_repo.create_head("branch-6", remote_1.refs["master"]) + git_repo.head.reference = branch_6 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove(["file6.txt"], working_tree=True) + git_repo.index.commit("Remove file6.txt") + + git_repo.head.reference = main + git_repo.head.reset(index=True, working_tree=True) + + def mock_target_branch(branch): + return patch( + "rapids_pre_commit_hooks.copyright.get_target_branch", + Mock(return_value=branch), + ) + + with mock_target_branch(branch_1): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) + == remote_1.refs["branch-1"].commit + ) + + with mock_target_branch(branch_2): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) + == remote_1.refs["branch-2"].commit + ) + + with mock_target_branch(branch_3): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) + == remote_1.refs["branch-3"].commit + ) + + with mock_target_branch(branch_4): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) + == remote_2.refs["branch-4"].commit + ) + + with mock_target_branch(branch_5): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) + == remote_2.refs["branch-5"].commit + ) + + with mock_target_branch(branch_6): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) == branch_6.commit + ) + + with mock_target_branch(None): + assert copyright.get_target_branch_upstream_commit(git_repo) == main.commit + + +def test_get_changed_files(git_repo): + def fn(filename): + return os.path.join(git_repo.working_tree_dir, filename) + + def write_file(filename, contents): + with open(fn(filename), "w") as f: + f.write(contents) + + def file_contents(verbed): + return f"This file will be {verbed}\n" * 100 + + write_file("untouched.txt", file_contents("untouched")) + write_file("copied.txt", file_contents("copied")) + write_file("modified_and_copied.txt", file_contents("modified and copied")) + write_file("copied_and_modified.txt", file_contents("copied and modified")) + write_file("deleted.txt", file_contents("deleted")) + write_file("renamed.txt", file_contents("renamed")) + write_file("modified_and_renamed.txt", file_contents("modified and renamed")) + write_file("modified.txt", file_contents("modified")) + write_file("chmodded.txt", file_contents("chmodded")) + git_repo.index.add( + [ + "untouched.txt", + "copied.txt", + "modified_and_copied.txt", + "copied_and_modified.txt", + "deleted.txt", + "renamed.txt", + "modified_and_renamed.txt", + "modified.txt", + "chmodded.txt", + ] + ) + git_repo.index.commit("Initial commit") + + # Ensure that diff is done against merge base, not branch tip + git_repo.index.remove(["modified.txt"], working_tree=True) + git_repo.index.commit("Remove modified.txt") + + pr_branch = git_repo.create_head("pr", "HEAD~") + git_repo.head.reference = pr_branch + git_repo.head.reset(index=True, working_tree=True) + + write_file("copied_2.txt", file_contents("copied")) + git_repo.index.remove( + ["deleted.txt", "modified_and_renamed.txt"], working_tree=True + ) + git_repo.index.move(["renamed.txt", "renamed_2.txt"]) + write_file( + "modified.txt", file_contents("modified") + "This file has been modified\n" + ) + os.chmod(fn("chmodded.txt"), 0o755) + write_file("untouched.txt", file_contents("untouched") + "Oops\n") + write_file("added.txt", file_contents("added")) + write_file("added_and_deleted.txt", file_contents("added and deleted")) + write_file( + "modified_and_copied.txt", + file_contents("modified and copied") + "This file has been modified\n", + ) + write_file("modified_and_copied_2.txt", file_contents("modified and copied")) + write_file( + "copied_and_modified_2.txt", + file_contents("copied and modified") + "This file has been modified\n", + ) + write_file( + "modified_and_renamed_2.txt", + file_contents("modified and renamed") + "This file has been modified\n", + ) + git_repo.index.add( + [ + "untouched.txt", + "added.txt", + "added_and_deleted.txt", + "modified_and_copied.txt", + "modified_and_copied_2.txt", + "copied_and_modified_2.txt", + "copied_2.txt", + "modified_and_renamed_2.txt", + "modified.txt", + "chmodded.txt", + ] + ) + write_file("untracked.txt", file_contents("untracked")) + write_file("untouched.txt", file_contents("untouched")) + os.unlink(fn("added_and_deleted.txt")) + + target_branch = git_repo.heads["master"] + merge_base = git_repo.merge_base(target_branch, "HEAD")[0] + old_files = { + blob.path: blob + for blob in merge_base.tree.traverse(lambda b, _: isinstance(b, git.Blob)) + } + + # Truly need to be checked + changed = { + "added.txt": None, + "untracked.txt": None, + "modified_and_renamed_2.txt": "modified_and_renamed.txt", + "modified.txt": "modified.txt", + "copied_and_modified_2.txt": "copied_and_modified.txt", + "modified_and_copied.txt": "modified_and_copied.txt", + } + + # Superfluous, but harmless because the content is identical + superfluous = { + "chmodded.txt": "chmodded.txt", + "modified_and_copied_2.txt": "modified_and_copied.txt", + "copied_2.txt": "copied.txt", + "renamed_2.txt": "renamed.txt", + } + + changed_files = copyright.get_changed_files(git_repo, target_branch.commit) + assert { + path: old_blob.path if old_blob else None + for path, old_blob in changed_files.items() + } == changed | superfluous + + for new, old in changed.items(): + if old: + with open(fn(new), "rb") as f: + new_contents = f.read() + old_contents = old_files[old].data_stream.read() + assert new_contents != old_contents + assert changed_files[new].data_stream.read() == old_contents + + for new, old in superfluous.items(): + if old: + with open(fn(new), "rb") as f: + new_contents = f.read() + old_contents = old_files[old].data_stream.read() + assert new_contents == old_contents + assert changed_files[new].data_stream.read() == old_contents + + +@freeze_time("2024-01-18") +def test_check_copyright(git_repo): + def fn(filename): + return os.path.join(git_repo.working_tree_dir, filename) + + def write_file(filename, contents): + with open(fn(filename), "w") as f: + f.write(contents) + + def file_contents(num): + return rf""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +File {num} +""" + + def file_contents_modified(num): + return rf""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +File {num} modified +""" + + write_file("file1.txt", file_contents(1)) + write_file("file2.txt", file_contents(2)) + write_file("file3.txt", file_contents(3)) + write_file("file4.txt", file_contents(4)) + git_repo.index.add(["file1.txt", "file2.txt", "file3.txt", "file4.txt"]) + git_repo.index.commit("Initial commit") + + branch_1 = git_repo.create_head("branch-1", "master") + git_repo.head.reference = branch_1 + git_repo.head.reset(index=True, working_tree=True) + write_file("file1.txt", file_contents_modified(1)) + git_repo.index.add(["file1.txt"]) + git_repo.index.commit("Update file1.txt") + + branch_2 = git_repo.create_head("branch-2", "master") + git_repo.head.reference = branch_2 + git_repo.head.reset(index=True, working_tree=True) + write_file("file2.txt", file_contents_modified(2)) + git_repo.index.add(["file2.txt"]) + git_repo.index.commit("Update file2.txt") + + pr = git_repo.create_head("pr", "branch-1") + git_repo.head.reference = pr + git_repo.head.reset(index=True, working_tree=True) + write_file("file3.txt", file_contents_modified(3)) + git_repo.index.add(["file3.txt"]) + git_repo.index.commit("Update file3.txt") + write_file("file4.txt", file_contents_modified(4)) + git_repo.index.add(["file4.txt"]) + git_repo.index.commit("Update file4.txt") + git_repo.index.move(["file2.txt", "file5.txt"]) + git_repo.index.commit("Rename file2.txt to file5.txt") + + def mock_repo_cwd(): + return patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)) + + def mock_target_branch_upstream_commit(branch_name): + def func(repo): + return repo.heads[branch_name].commit + + return patch( + "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", func + ) + + def mock_apply_copyright_check(): + return patch("rapids_pre_commit_hooks.copyright.apply_copyright_check", Mock()) + + ############################# + # branch-1 is target branch + ############################# + + with mock_repo_cwd(), mock_target_branch_upstream_commit("branch-1"): + copyright_checker = copyright.check_copyright() + + linter = Linter("file1.txt", file_contents_modified(1)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_not_called() + + linter = Linter("file5.txt", file_contents(2)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(2)) + + linter = Linter("file3.txt", file_contents_modified(3)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(3)) + + linter = Linter("file4.txt", file_contents_modified(4)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(4)) + + ############################# + # branch-2 is target branch + ############################# + + with mock_repo_cwd(), mock_target_branch_upstream_commit("branch-2"): + copyright_checker = copyright.check_copyright() + + linter = Linter("file1.txt", file_contents_modified(1)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(1)) + + linter = Linter("file5.txt", file_contents(2)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(2)) + + linter = Linter("file3.txt", file_contents_modified(3)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(3)) + + linter = Linter("file4.txt", file_contents_modified(4)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(4))