From a9fb71278c31077f8ce6410b7dbcc36c7e390697 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Thu, 18 Jan 2024 14:19:40 -0500 Subject: [PATCH 01/24] Add copyright check Fixes: https://github.com/rapidsai/pre-commit-hooks/issues/2 --- .pre-commit-hooks.yaml | 13 + ci/build-test.sh | 4 +- pyproject.toml | 5 +- src/rapids_pre_commit_hooks/copyright.py | 213 ++++++ .../rapids_pre_commit_hooks/test_copyright.py | 700 ++++++++++++++++++ 5 files changed, 932 insertions(+), 3 deletions(-) create mode 100644 src/rapids_pre_commit_hooks/copyright.py create mode 100644 test/rapids_pre_commit_hooks/test_copyright.py diff --git a/.pre-commit-hooks.yaml b/.pre-commit-hooks.yaml index 9cfc074..89b5ec1 100644 --- a/.pre-commit-hooks.yaml +++ b/.pre-commit-hooks.yaml @@ -18,3 +18,16 @@ language: python types: [shell] args: [--fix] +- id: verify-copyright + name: copyright headers + description: make sure copyright headers are up to date + entry: verify-copyright + language: python + files: | + (?x) + [.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$| + CMakeLists[.]txt$| + CMakeLists_standalone[.]txt$| + setup[.]cfg$| + meta[.]yaml$ + args: [--fix] diff --git a/ci/build-test.sh b/ci/build-test.sh index ce157d1..a784659 100755 --- a/ci/build-test.sh +++ b/ci/build-test.sh @@ -3,13 +3,13 @@ set -ue -pip install build pytest +pip install build python -m build . for PKG in dist/*; do echo "$PKG" pip uninstall -y rapids-pre-commit-hooks - pip install "$PKG" + pip install "$PKG[test]" pytest done diff --git a/pyproject.toml b/pyproject.toml index 487bf24..7ac84b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,15 +33,18 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ "bashlex", + "gitpython", ] [project.optional-dependencies] -dev = [ +test = [ + "freezegun", "pytest", ] [project.scripts] verify-conda-yes = "rapids_pre_commit_hooks.shell.verify_conda_yes:main" +verify-copyright = "rapids_pre_commit_hooks.copyright:main" [tool.setuptools] packages = { "find" = { where = ["src"] } } diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py new file mode 100644 index 0000000..8849c92 --- /dev/null +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -0,0 +1,213 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import functools +import os +import re + +import git + +from .lint import LintMain + +COPYRIGHT_RE = re.compile( + r"Copyright *(?:\(c\))? *(?P(?P\d{4})(-(?P\d{4}))?),?" + r" *NVIDIA C(?:ORPORATION|orporation)" +) +BRANCH_RE = re.compile(r"^branch-(?P[0-9]+)\.(?P[0-9]+)$") + + +class ConflictingFilesError(RuntimeError): + pass + + +def match_copyright(content): + return list(COPYRIGHT_RE.finditer(content)) + + +def strip_copyright(content, copyright_matches): + lines = [] + + def append_stripped(start, item): + lines.append(content[start : item.start()]) + return item.end() + + start = functools.reduce(append_stripped, copyright_matches, 0) + lines.append(content[start:]) + return lines + + +def apply_copyright_check(linter, old_content): + if linter.content != old_content: + current_year = datetime.datetime.now().year + new_copyright_matches = match_copyright(linter.content) + + if old_content is not None: + old_copyright_matches = match_copyright(old_content) + + if old_content is not None and strip_copyright( + old_content, old_copyright_matches + ) == strip_copyright(linter.content, new_copyright_matches): + for old_match, new_match in zip( + old_copyright_matches, new_copyright_matches + ): + if old_match.group() != new_match.group(): + if old_match.group("years") == new_match.group("years"): + warning_pos = new_match.span() + else: + warning_pos = new_match.span("years") + linter.add_warning( + warning_pos, + "copyright is not out of date and should not be updated", + ).add_replacement(new_match.span(), old_match.group()) + else: + if new_copyright_matches: + for match in new_copyright_matches: + if ( + int(match.group("last_year") or match.group("first_year")) + < current_year + ): + linter.add_warning( + match.span("years"), "copyright is out of date" + ).add_replacement( + match.span(), + f"Copyright (c) {match.group('first_year')}-{current_year}" + ", NVIDIA CORPORATION", + ) + else: + linter.add_warning((0, 0), "no copyright notice found") + + +def get_target_branch(repo): + # Try environment + target_branch_name = os.getenv("GITHUB_BASE_REF") + if target_branch_name: + try: + return repo.heads[target_branch_name] + except IndexError: + pass + target_branch_name = os.getenv("TARGET_BRANCH") + if target_branch_name: + try: + return repo.heads[target_branch_name] + except IndexError: + pass + target_branch_name = os.getenv("RAPIDS_BASE_BRANCH") + if target_branch_name: + try: + return repo.heads[target_branch_name] + except IndexError: + pass + + # Try config + with repo.config_reader() as r: + target_branch_name = r.get("rapidsai", "baseBranch", fallback=None) + if target_branch_name: + try: + return repo.heads[target_branch_name] + except IndexError: + pass + + # Try newest branch-xx.yy + branches = sorted( + ( + (branch, (match.group("major"), match.group("minor"))) + for branch in repo.heads + if (match := BRANCH_RE.search(branch.name)) + ), + key=lambda i: i[1], + reverse=True, + ) + try: + return branches[0][0] + except IndexError: + pass + + # Appropriate branch not found + return None + + +def get_target_branch_upstream_commit(repo): + target_branch = get_target_branch(repo) + if target_branch is None: + return repo.head.commit + + target_branch_upstream = target_branch.tracking_branch() + if target_branch_upstream: + return target_branch_upstream.commit + + def try_get_ref(remote): + try: + return remote.refs[target_branch.name] + except IndexError: + return None + + candidate_upstreams = sorted( + (upstream for remote in repo.remotes if (upstream := try_get_ref(remote))), + key=lambda upstream: upstream.commit.committed_datetime, + reverse=True, + ) + try: + return candidate_upstreams[0].commit + except IndexError: + pass + + return target_branch.commit + + +def get_changed_files(repo, target_branch_upstream_commit): + changed_files = {} + + diffs = target_branch_upstream_commit.diff( + other=None, + merge_base=True, + find_copies=True, + find_copies_harder=True, + find_renames=True, + ) + for diff in diffs: + if diff.change_type == "A": + changed_files[diff.b_path] = None + elif diff.change_type != "D": + changed_files[diff.b_path] = diff.a_blob + + changed_files.update({f: None for f in repo.untracked_files}) + return changed_files + + +def check_copyright(): + repo = git.Repo() + target_branch_upstream_commit = get_target_branch_upstream_commit(repo) + changed_files = get_changed_files(repo, target_branch_upstream_commit) + + def the_check(linter, args): + try: + changed_file = changed_files[linter.filename] + except KeyError: + return + + old_content = changed_file.data_stream.read().decode("utf-8") + apply_copyright_check(linter, old_content) + + return the_check + + +def main(): + m = LintMain() + with m.execute() as ctx: + ctx.add_check(check_copyright()) + + +if __name__ == "__main__": + main() diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py new file mode 100644 index 0000000..f83d747 --- /dev/null +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -0,0 +1,700 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os.path +import tempfile +from unittest.mock import Mock, patch + +import git +import pytest +from freezegun import freeze_time + +from rapids_pre_commit_hooks import copyright +from rapids_pre_commit_hooks.lint import Linter + + +def test_match_copyright(): + CONTENT = r""" +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2021-2024 NVIDIA CORPORATION +# Copyright 2021, NVIDIA Corporation and affiliates +""" + + re_matches = copyright.match_copyright(CONTENT) + matches = [ + { + "span": match.span(), + "years": match.span("years"), + "first_year": match.span("first_year"), + "last_year": match.span("last_year"), + } + for match in re_matches + ] + assert matches == [ + { + "span": (1, 38), + "years": (15, 19), + "first_year": (15, 19), + "last_year": (-1, -1), + }, + { + "span": (39, 81), + "years": (53, 62), + "first_year": (53, 57), + "last_year": (58, 62), + }, + { + "span": (84, 119), + "years": (94, 98), + "first_year": (94, 98), + "last_year": (-1, -1), + }, + ] + + +def test_strip_copyright(): + CONTENT = r""" +This is a line before the first copyright statement +Copyright (c) 2024 NVIDIA CORPORATION +This is a line between the first two copyright statements +Copyright (c) 2021-2024 NVIDIA CORPORATION +This is a line between the next two copyright statements +# Copyright 2021, NVIDIA Corporation and affiliates +This is a line after the last copyright statement +""" + matches = copyright.match_copyright(CONTENT) + stripped = copyright.strip_copyright(CONTENT, matches) + assert stripped == [ + "\nThis is a line before the first copyright statement\n", + "\nThis is a line between the first two copyright statements\n", + "\nThis is a line between the next two copyright statements\n# ", + " and affiliates\nThis is a line after the last copyright statement\n", + ] + + stripped = copyright.strip_copyright("No copyright here", []) + assert stripped == ["No copyright here"] + + +@freeze_time("2024-01-18") +def test_apply_copyright_check(): + def run_apply_copyright_check(old_content, new_content): + linter = Linter("file.txt", new_content) + copyright.apply_copyright_check(linter, old_content) + return linter + + expected_linter = Linter("file.txt", "No copyright notice") + expected_linter.add_warning((0, 0), "no copyright notice found") + + linter = run_apply_copyright_check(None, "No copyright notice") + assert linter.warnings == expected_linter.warnings + + linter = run_apply_copyright_check("No copyright notice", "No copyright notice") + assert linter.warnings == [] + + OLD_CONTENT = r""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +Copyright (c) 2023 NVIDIA CORPORATION +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2025 NVIDIA CORPORATION +This file has not been changed +""" + linter = run_apply_copyright_check(OLD_CONTENT, OLD_CONTENT) + assert linter.warnings == [] + + NEW_CONTENT = r""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +Copyright (c) 2023 NVIDIA CORPORATION +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2025 NVIDIA CORPORATION +This file has been changed +""" + expected_linter = Linter("file.txt", NEW_CONTENT) + expected_linter.add_warning((15, 24), "copyright is out of date").add_replacement( + (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" + ) + expected_linter.add_warning((58, 62), "copyright is out of date").add_replacement( + (44, 81), "Copyright (c) 2023-2024, NVIDIA CORPORATION" + ) + + linter = run_apply_copyright_check(OLD_CONTENT, NEW_CONTENT) + assert linter.warnings == expected_linter.warnings + + expected_linter = Linter("file.txt", NEW_CONTENT) + expected_linter.add_warning((15, 24), "copyright is out of date").add_replacement( + (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" + ) + expected_linter.add_warning((58, 62), "copyright is out of date").add_replacement( + (44, 81), "Copyright (c) 2023-2024, NVIDIA CORPORATION" + ) + + linter = run_apply_copyright_check(None, NEW_CONTENT) + assert linter.warnings == expected_linter.warnings + + NEW_CONTENT = r""" +Copyright (c) 2021-2024 NVIDIA CORPORATION +Copyright (c) 2023 NVIDIA CORPORATION +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2025 NVIDIA Corporation +This file has not been changed +""" + expected_linter = Linter("file.txt", NEW_CONTENT) + expected_linter.add_warning( + (15, 24), "copyright is not out of date and should not be updated" + ).add_replacement((1, 43), "Copyright (c) 2021-2023 NVIDIA CORPORATION") + expected_linter.add_warning( + (120, 157), "copyright is not out of date and should not be updated" + ).add_replacement((120, 157), "Copyright (c) 2025 NVIDIA CORPORATION") + + linter = run_apply_copyright_check(OLD_CONTENT, NEW_CONTENT) + assert linter.warnings == expected_linter.warnings + + +@pytest.fixture +def git_repo(): + with tempfile.TemporaryDirectory() as d: + repo = git.Repo.init(d) + with repo.config_writer() as w: + w.set_value("user", "name", "RAPIDS Test Fixtures") + w.set_value("user", "email", "testfixtures@rapids.ai") + yield repo + + +def test_get_target_branch(git_repo): + master = git_repo.head.reference + + with open(os.path.join(git_repo.working_tree_dir, "file.txt"), "w") as f: + f.write("File\n") + git_repo.index.add(["file.txt"]) + git_repo.index.commit("Initial commit") + assert copyright.get_target_branch(git_repo) is None + + branch_24_02 = git_repo.create_head("branch-24.02") + assert copyright.get_target_branch(git_repo) == branch_24_02 + + branch_24_04 = git_repo.create_head("branch-24.04") + branch_24_03 = git_repo.create_head("branch-24.03") + assert copyright.get_target_branch(git_repo) == branch_24_04 + + branch_25_01 = git_repo.create_head("branch-25.01") + assert copyright.get_target_branch(git_repo) == branch_25_01 + + with git_repo.config_writer() as w: + w.set_value("rapidsai", "baseBranch", "nonexistent") + assert copyright.get_target_branch(git_repo) == branch_25_01 + + with git_repo.config_writer() as w: + w.set_value("rapidsai", "baseBranch", "branch-24.03") + assert copyright.get_target_branch(git_repo) == branch_24_03 + + with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "nonexistent"}): + assert copyright.get_target_branch(git_repo) == branch_24_03 + + with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "master"}): + assert copyright.get_target_branch(git_repo) == master + + with patch.dict( + "os.environ", {"RAPIDS_BASE_BRANCH": "master", "TARGET_BRANCH": "nonexistent"} + ): + assert copyright.get_target_branch(git_repo) == master + + with patch.dict( + "os.environ", {"RAPIDS_BASE_BRANCH": "master", "TARGET_BRANCH": "branch-24.02"} + ): + assert copyright.get_target_branch(git_repo) == branch_24_02 + + with patch.dict( + "os.environ", + { + "RAPIDS_BASE_BRANCH": "master", + "TARGET_BRANCH": "branch-24.02", + "GITHUB_BASE_REF": "nonexistent", + }, + ): + assert copyright.get_target_branch(git_repo) == branch_24_02 + + with patch.dict( + "os.environ", + { + "RAPIDS_BASE_BRANCH": "master", + "TARGET_BRANCH": "branch-24.02", + "GITHUB_BASE_REF": "branch-24.04", + }, + ): + assert copyright.get_target_branch(git_repo) == branch_24_04 + + +def test_get_target_branch_upstream_commit(git_repo): + def fn(repo, filename): + return os.path.join(repo.working_tree_dir, filename) + + def write_file(repo, filename, contents): + with open(fn(repo, filename), "w") as f: + f.write(contents) + + # fmt: off + with tempfile.TemporaryDirectory() as remote_dir_1, \ + tempfile.TemporaryDirectory() as remote_dir_2: + # fmt: on + remote_repo_1 = git.Repo.init(remote_dir_1) + remote_repo_2 = git.Repo.init(remote_dir_2) + + remote_1_master = remote_repo_1.head.reference + + write_file(remote_repo_1, "file1.txt", "File 1") + write_file(remote_repo_1, "file2.txt", "File 2") + write_file(remote_repo_1, "file3.txt", "File 3") + write_file(remote_repo_1, "file4.txt", "File 4") + write_file(remote_repo_1, "file5.txt", "File 5") + write_file(remote_repo_1, "file6.txt", "File 6") + remote_repo_1.index.add( + [ + "file1.txt", + "file2.txt", + "file3.txt", + "file4.txt", + "file5.txt", + "file6.txt", + ] + ) + remote_repo_1.index.commit("Initial commit") + + remote_1_branch_1 = remote_repo_1.create_head( + "branch-1", remote_1_master.commit + ) + remote_repo_1.head.reference = remote_1_branch_1 + remote_repo_1.head.reset(index=True, working_tree=True) + write_file(remote_repo_1, "file1.txt", "File 1 modified") + remote_repo_1.index.add(["file1.txt"]) + remote_repo_1.index.commit("Update file1.txt") + + remote_1_branch_2 = remote_repo_1.create_head( + "branch-2", remote_1_master.commit + ) + remote_repo_1.head.reference = remote_1_branch_2 + remote_repo_1.head.reset(index=True, working_tree=True) + write_file(remote_repo_1, "file2.txt", "File 2 modified") + remote_repo_1.index.add(["file2.txt"]) + remote_repo_1.index.commit("Update file2.txt") + + remote_1_branch_3 = remote_repo_1.create_head( + "branch-3", remote_1_master.commit + ) + remote_repo_1.head.reference = remote_1_branch_3 + remote_repo_1.head.reset(index=True, working_tree=True) + write_file(remote_repo_1, "file3.txt", "File 3 modified") + remote_repo_1.index.add(["file3.txt"]) + remote_repo_1.index.commit( + "Update file3.txt", + commit_date=datetime.datetime(2025, 1, 1, tzinfo=datetime.timezone.utc), + ) + + remote_1_branch_4 = remote_repo_1.create_head( + "branch-4", remote_1_master.commit + ) + remote_repo_1.head.reference = remote_1_branch_4 + remote_repo_1.head.reset(index=True, working_tree=True) + write_file(remote_repo_1, "file4.txt", "File 4 modified") + remote_repo_1.index.add(["file4.txt"]) + remote_repo_1.index.commit( + "Update file4.txt", + commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), + ) + + remote_2_1 = remote_repo_2.create_remote("remote-1", remote_dir_1) + remote_2_1.fetch(["master"]) + remote_2_master = remote_repo_2.create_head("master", remote_2_1.refs["master"]) + + remote_2_branch_3 = remote_repo_2.create_head( + "branch-3", remote_2_master.commit + ) + remote_repo_2.head.reference = remote_2_branch_3 + remote_repo_2.head.reset(index=True, working_tree=True) + write_file(remote_repo_2, "file3.txt", "File 3 modified") + remote_repo_2.index.add(["file3.txt"]) + remote_repo_2.index.commit( + "Update file3.txt", + commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), + ) + + remote_2_branch_4 = remote_repo_2.create_head( + "branch-4", remote_2_master.commit + ) + remote_repo_2.head.reference = remote_2_branch_4 + remote_repo_2.head.reset(index=True, working_tree=True) + write_file(remote_repo_2, "file4.txt", "File 4 modified") + remote_repo_2.index.add(["file4.txt"]) + remote_repo_2.index.commit( + "Update file4.txt", + commit_date=datetime.datetime(2025, 1, 1, tzinfo=datetime.timezone.utc), + ) + + remote_2_branch_5 = remote_repo_2.create_head( + "branch-5", remote_2_master.commit + ) + remote_repo_2.head.reference = remote_2_branch_5 + remote_repo_2.head.reset(index=True, working_tree=True) + write_file(remote_repo_2, "file5.txt", "File 5 modified") + remote_repo_2.index.add(["file5.txt"]) + remote_repo_2.index.commit("Update file5.txt") + + remote_1 = git_repo.create_remote("unconventional/remote/name/1", remote_dir_1) + remote_1.fetch(["master", "branch-1", "branch-2", "branch-3", "branch-4"]) + remote_2 = git_repo.create_remote("unconventional/remote/name/2", remote_dir_2) + remote_2.fetch(["branch-3", "branch-4", "branch-5"]) + + main = git_repo.create_head("main", remote_1.refs["master"]) + + branch_1 = git_repo.create_head("branch-1-renamed", remote_1.refs["master"]) + with branch_1.config_writer() as w: + w.set_value("remote", "unconventional/remote/name/1") + w.set_value("merge", "branch-1") + git_repo.head.reference = branch_1 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove("file1.txt", working_tree=True) + git_repo.index.commit("Remove file1.txt") + + branch_2 = git_repo.create_head("branch-2", remote_1.refs["master"]) + git_repo.head.reference = branch_2 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove("file2.txt", working_tree=True) + git_repo.index.commit("Remove file2.txt") + + branch_3 = git_repo.create_head("branch-3", remote_1.refs["master"]) + git_repo.head.reference = branch_3 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove("file3.txt", working_tree=True) + git_repo.index.commit("Remove file3.txt") + + branch_4 = git_repo.create_head("branch-4", remote_1.refs["master"]) + git_repo.head.reference = branch_4 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove(["file4.txt"], working_tree=True) + git_repo.index.commit("Remove file4.txt") + + branch_5 = git_repo.create_head("branch-5", remote_1.refs["master"]) + git_repo.head.reference = branch_5 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove(["file5.txt"], working_tree=True) + git_repo.index.commit("Remove file5.txt") + + branch_6 = git_repo.create_head("branch-6", remote_1.refs["master"]) + git_repo.head.reference = branch_6 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove(["file6.txt"], working_tree=True) + git_repo.index.commit("Remove file6.txt") + + git_repo.head.reference = main + git_repo.head.reset(index=True, working_tree=True) + + def mock_target_branch(branch): + return patch( + "rapids_pre_commit_hooks.copyright.get_target_branch", + Mock(return_value=branch), + ) + + with mock_target_branch(branch_1): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) + == remote_1.refs["branch-1"].commit + ) + + with mock_target_branch(branch_2): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) + == remote_1.refs["branch-2"].commit + ) + + with mock_target_branch(branch_3): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) + == remote_1.refs["branch-3"].commit + ) + + with mock_target_branch(branch_4): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) + == remote_2.refs["branch-4"].commit + ) + + with mock_target_branch(branch_5): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) + == remote_2.refs["branch-5"].commit + ) + + with mock_target_branch(branch_6): + assert ( + copyright.get_target_branch_upstream_commit(git_repo) == branch_6.commit + ) + + with mock_target_branch(None): + assert copyright.get_target_branch_upstream_commit(git_repo) == main.commit + + +def test_get_changed_files(git_repo): + def fn(filename): + return os.path.join(git_repo.working_tree_dir, filename) + + def write_file(filename, contents): + with open(fn(filename), "w") as f: + f.write(contents) + + def file_contents(verbed): + return f"This file will be {verbed}\n" * 100 + + write_file("untouched.txt", file_contents("untouched")) + write_file("copied.txt", file_contents("copied")) + write_file("modified_and_copied.txt", file_contents("modified and copied")) + write_file("copied_and_modified.txt", file_contents("copied and modified")) + write_file("deleted.txt", file_contents("deleted")) + write_file("renamed.txt", file_contents("renamed")) + write_file("modified_and_renamed.txt", file_contents("modified and renamed")) + write_file("modified.txt", file_contents("modified")) + write_file("chmodded.txt", file_contents("chmodded")) + git_repo.index.add( + [ + "untouched.txt", + "copied.txt", + "modified_and_copied.txt", + "copied_and_modified.txt", + "deleted.txt", + "renamed.txt", + "modified_and_renamed.txt", + "modified.txt", + "chmodded.txt", + ] + ) + git_repo.index.commit("Initial commit") + + # Ensure that diff is done against merge base, not branch tip + git_repo.index.remove(["modified.txt"], working_tree=True) + git_repo.index.commit("Remove modified.txt") + + pr_branch = git_repo.create_head("pr", "HEAD~") + git_repo.head.reference = pr_branch + git_repo.head.reset(index=True, working_tree=True) + + write_file("copied_2.txt", file_contents("copied")) + git_repo.index.remove( + ["deleted.txt", "modified_and_renamed.txt"], working_tree=True + ) + git_repo.index.move(["renamed.txt", "renamed_2.txt"]) + write_file( + "modified.txt", file_contents("modified") + "This file has been modified\n" + ) + os.chmod(fn("chmodded.txt"), 0o755) + write_file("untouched.txt", file_contents("untouched") + "Oops\n") + write_file("added.txt", file_contents("added")) + write_file("added_and_deleted.txt", file_contents("added and deleted")) + write_file( + "modified_and_copied.txt", + file_contents("modified and copied") + "This file has been modified\n", + ) + write_file("modified_and_copied_2.txt", file_contents("modified and copied")) + write_file( + "copied_and_modified_2.txt", + file_contents("copied and modified") + "This file has been modified\n", + ) + write_file( + "modified_and_renamed_2.txt", + file_contents("modified and renamed") + "This file has been modified\n", + ) + git_repo.index.add( + [ + "untouched.txt", + "added.txt", + "added_and_deleted.txt", + "modified_and_copied.txt", + "modified_and_copied_2.txt", + "copied_and_modified_2.txt", + "copied_2.txt", + "modified_and_renamed_2.txt", + "modified.txt", + "chmodded.txt", + ] + ) + write_file("untracked.txt", file_contents("untracked")) + write_file("untouched.txt", file_contents("untouched")) + os.unlink(fn("added_and_deleted.txt")) + + target_branch = git_repo.heads["master"] + merge_base = git_repo.merge_base(target_branch, "HEAD")[0] + old_files = { + blob.path: blob + for blob in merge_base.tree.traverse(lambda b, _: isinstance(b, git.Blob)) + } + + # Truly need to be checked + changed = { + "added.txt": None, + "untracked.txt": None, + "modified_and_renamed_2.txt": "modified_and_renamed.txt", + "modified.txt": "modified.txt", + "copied_and_modified_2.txt": "copied_and_modified.txt", + "modified_and_copied.txt": "modified_and_copied.txt", + } + + # Superfluous, but harmless because the content is identical + superfluous = { + "chmodded.txt": "chmodded.txt", + "modified_and_copied_2.txt": "modified_and_copied.txt", + "copied_2.txt": "copied.txt", + "renamed_2.txt": "renamed.txt", + } + + changed_files = copyright.get_changed_files(git_repo, target_branch.commit) + assert { + path: old_blob.path if old_blob else None + for path, old_blob in changed_files.items() + } == changed | superfluous + + for new, old in changed.items(): + if old: + with open(fn(new), "rb") as f: + new_contents = f.read() + old_contents = old_files[old].data_stream.read() + assert new_contents != old_contents + assert changed_files[new].data_stream.read() == old_contents + + for new, old in superfluous.items(): + if old: + with open(fn(new), "rb") as f: + new_contents = f.read() + old_contents = old_files[old].data_stream.read() + assert new_contents == old_contents + assert changed_files[new].data_stream.read() == old_contents + + +@freeze_time("2024-01-18") +def test_check_copyright(git_repo): + def fn(filename): + return os.path.join(git_repo.working_tree_dir, filename) + + def write_file(filename, contents): + with open(fn(filename), "w") as f: + f.write(contents) + + def file_contents(num): + return rf""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +File {num} +""" + + def file_contents_modified(num): + return rf""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +File {num} modified +""" + + write_file("file1.txt", file_contents(1)) + write_file("file2.txt", file_contents(2)) + write_file("file3.txt", file_contents(3)) + write_file("file4.txt", file_contents(4)) + git_repo.index.add(["file1.txt", "file2.txt", "file3.txt", "file4.txt"]) + git_repo.index.commit("Initial commit") + + branch_1 = git_repo.create_head("branch-1", "master") + git_repo.head.reference = branch_1 + git_repo.head.reset(index=True, working_tree=True) + write_file("file1.txt", file_contents_modified(1)) + git_repo.index.add(["file1.txt"]) + git_repo.index.commit("Update file1.txt") + + branch_2 = git_repo.create_head("branch-2", "master") + git_repo.head.reference = branch_2 + git_repo.head.reset(index=True, working_tree=True) + write_file("file2.txt", file_contents_modified(2)) + git_repo.index.add(["file2.txt"]) + git_repo.index.commit("Update file2.txt") + + pr = git_repo.create_head("pr", "branch-1") + git_repo.head.reference = pr + git_repo.head.reset(index=True, working_tree=True) + write_file("file3.txt", file_contents_modified(3)) + git_repo.index.add(["file3.txt"]) + git_repo.index.commit("Update file3.txt") + write_file("file4.txt", file_contents_modified(4)) + git_repo.index.add(["file4.txt"]) + git_repo.index.commit("Update file4.txt") + git_repo.index.move(["file2.txt", "file5.txt"]) + git_repo.index.commit("Rename file2.txt to file5.txt") + + def mock_repo_cwd(): + return patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)) + + def mock_target_branch_upstream_commit(branch_name): + def func(repo): + return repo.heads[branch_name].commit + + return patch( + "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", func + ) + + def mock_apply_copyright_check(): + return patch("rapids_pre_commit_hooks.copyright.apply_copyright_check", Mock()) + + ############################# + # branch-1 is target branch + ############################# + + with mock_repo_cwd(), mock_target_branch_upstream_commit("branch-1"): + copyright_checker = copyright.check_copyright() + + linter = Linter("file1.txt", file_contents_modified(1)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_not_called() + + linter = Linter("file5.txt", file_contents(2)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(2)) + + linter = Linter("file3.txt", file_contents_modified(3)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(3)) + + linter = Linter("file4.txt", file_contents_modified(4)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(4)) + + ############################# + # branch-2 is target branch + ############################# + + with mock_repo_cwd(), mock_target_branch_upstream_commit("branch-2"): + copyright_checker = copyright.check_copyright() + + linter = Linter("file1.txt", file_contents_modified(1)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(1)) + + linter = Linter("file5.txt", file_contents(2)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(2)) + + linter = Linter("file3.txt", file_contents_modified(3)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(3)) + + linter = Linter("file4.txt", file_contents_modified(4)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(4)) From f2016f4336ed644e1bad15c40503d18a0a005831 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 22 Jan 2024 15:04:18 -0500 Subject: [PATCH 02/24] Add support for non-repos and repos with no commits --- src/rapids_pre_commit_hooks/copyright.py | 27 ++++++--- .../rapids_pre_commit_hooks/test_copyright.py | 56 ++++++++++++++++++- 2 files changed, 74 insertions(+), 9 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 8849c92..cb9022c 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -141,7 +141,10 @@ def get_target_branch(repo): def get_target_branch_upstream_commit(repo): target_branch = get_target_branch(repo) if target_branch is None: - return repo.head.commit + try: + return repo.head.commit + except ValueError: + return None target_branch_upstream = target_branch.tracking_branch() if target_branch_upstream: @@ -166,8 +169,21 @@ def try_get_ref(remote): return target_branch.commit -def get_changed_files(repo, target_branch_upstream_commit): - changed_files = {} +def get_changed_files(): + try: + repo = git.Repo() + except git.InvalidGitRepositoryError: + return { + os.path.relpath(os.path.join(dirpath, filename), "."): None + for dirpath, dirnames, filenames in os.walk(".") + for filename in filenames + } + + changed_files = {f: None for f in repo.untracked_files} + target_branch_upstream_commit = get_target_branch_upstream_commit(repo) + if target_branch_upstream_commit is None: + changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()}) + return changed_files diffs = target_branch_upstream_commit.diff( other=None, @@ -182,14 +198,11 @@ def get_changed_files(repo, target_branch_upstream_commit): elif diff.change_type != "D": changed_files[diff.b_path] = diff.a_blob - changed_files.update({f: None for f in repo.untracked_files}) return changed_files def check_copyright(): - repo = git.Repo() - target_branch_upstream_commit = get_target_branch_upstream_commit(repo) - changed_files = get_changed_files(repo, target_branch_upstream_commit) + changed_files = get_changed_files() def the_check(linter, args): try: diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index f83d747..a6f9e3d 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -444,6 +444,37 @@ def mock_target_branch(branch): def test_get_changed_files(git_repo): + def mock_os_walk(top): + return patch( + "os.walk", + Mock( + return_value=( + ( + "." + if (rel := os.path.relpath(dirpath, top)) == "." + else os.path.join(".", rel), + dirnames, + filenames, + ) + for dirpath, dirnames, filenames in os.walk(top) + ) + ), + ) + + with tempfile.TemporaryDirectory() as non_git_dir, patch( + "os.getcwd", Mock(return_value=non_git_dir) + ), mock_os_walk(non_git_dir): + with open(os.path.join(non_git_dir, "top.txt"), "w") as f: + f.write("Top file\n") + os.mkdir(os.path.join(non_git_dir, "subdir1")) + os.mkdir(os.path.join(non_git_dir, "subdir1/subdir2")) + with open(os.path.join(non_git_dir, "subdir1", "subdir2", "sub.txt"), "w") as f: + f.write("Subdir file\n") + assert copyright.get_changed_files() == { + "top.txt": None, + "subdir1/subdir2/sub.txt": None, + } + def fn(filename): return os.path.join(git_repo.working_tree_dir, filename) @@ -463,6 +494,7 @@ def file_contents(verbed): write_file("modified_and_renamed.txt", file_contents("modified and renamed")) write_file("modified.txt", file_contents("modified")) write_file("chmodded.txt", file_contents("chmodded")) + write_file("untracked.txt", file_contents("untracked")) git_repo.index.add( [ "untouched.txt", @@ -476,6 +508,23 @@ def file_contents(verbed): "chmodded.txt", ] ) + + with patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)), mock_os_walk( + git_repo.working_tree_dir + ): + assert copyright.get_changed_files() == { + "untouched.txt": None, + "copied.txt": None, + "modified_and_copied.txt": None, + "copied_and_modified.txt": None, + "deleted.txt": None, + "renamed.txt": None, + "modified_and_renamed.txt": None, + "modified.txt": None, + "chmodded.txt": None, + "untracked.txt": None, + } + git_repo.index.commit("Initial commit") # Ensure that diff is done against merge base, not branch tip @@ -525,7 +574,6 @@ def file_contents(verbed): "chmodded.txt", ] ) - write_file("untracked.txt", file_contents("untracked")) write_file("untouched.txt", file_contents("untouched")) os.unlink(fn("added_and_deleted.txt")) @@ -554,7 +602,11 @@ def file_contents(verbed): "renamed_2.txt": "renamed.txt", } - changed_files = copyright.get_changed_files(git_repo, target_branch.commit) + with patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)), patch( + "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", + Mock(return_value=target_branch.commit), + ): + changed_files = copyright.get_changed_files() assert { path: old_blob.path if old_blob else None for path, old_blob in changed_files.items() From 1b3b78e2fd70ce347756afad67064411a7ab2ab1 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 22 Jan 2024 15:09:13 -0500 Subject: [PATCH 03/24] Fix issue with new files --- src/rapids_pre_commit_hooks/copyright.py | 6 +++++- test/rapids_pre_commit_hooks/test_copyright.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index cb9022c..737f0a2 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -210,7 +210,11 @@ def the_check(linter, args): except KeyError: return - old_content = changed_file.data_stream.read().decode("utf-8") + old_content = ( + changed_file.data_stream.read().decode("utf-8") + if changed_file is not None + else None + ) apply_copyright_check(linter, old_content) return the_check diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index a6f9e3d..6f28c7d 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -683,6 +683,8 @@ def file_contents_modified(num): git_repo.index.move(["file2.txt", "file5.txt"]) git_repo.index.commit("Rename file2.txt to file5.txt") + write_file("file6.txt", file_contents(6)) + def mock_repo_cwd(): return patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)) @@ -724,6 +726,11 @@ def mock_apply_copyright_check(): copyright_checker(linter, None) apply_copyright_check.assert_called_once_with(linter, file_contents(4)) + linter = Linter("file6.txt", file_contents(6)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, None) + ############################# # branch-2 is target branch ############################# @@ -750,3 +757,8 @@ def mock_apply_copyright_check(): with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, None) apply_copyright_check.assert_called_once_with(linter, file_contents(4)) + + linter = Linter("file6.txt", file_contents(6)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, None) From 3509fe45e3a13378160763e848d65f95fb79dd96 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 22 Jan 2024 16:08:13 -0500 Subject: [PATCH 04/24] Update with review feedback --- .pre-commit-hooks.yaml | 1 + src/rapids_pre_commit_hooks/copyright.py | 41 +++++++++++++++---- .../rapids_pre_commit_hooks/test_copyright.py | 7 +++- 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/.pre-commit-hooks.yaml b/.pre-commit-hooks.yaml index 89b5ec1..572fe25 100644 --- a/.pre-commit-hooks.yaml +++ b/.pre-commit-hooks.yaml @@ -29,5 +29,6 @@ CMakeLists[.]txt$| CMakeLists_standalone[.]txt$| setup[.]cfg$| + pyproject[.]toml$| meta[.]yaml$ args: [--fix] diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 737f0a2..a7ff754 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -16,6 +16,7 @@ import functools import os import re +import warnings import git @@ -26,6 +27,11 @@ r" *NVIDIA C(?:ORPORATION|orporation)" ) BRANCH_RE = re.compile(r"^branch-(?P[0-9]+)\.(?P[0-9]+)$") +COPYRIGHT_REPLACEMENT = "Copyright (c) {first_year}-{last_year}, NVIDIA CORPORATION" + + +class NoTargetBranchWarning(Warning): + pass class ConflictingFilesError(RuntimeError): @@ -82,29 +88,41 @@ def apply_copyright_check(linter, old_content): match.span("years"), "copyright is out of date" ).add_replacement( match.span(), - f"Copyright (c) {match.group('first_year')}-{current_year}" - ", NVIDIA CORPORATION", + COPYRIGHT_REPLACEMENT.format( + first_year=match.group("first_year"), + last_year=current_year, + ), ) else: linter.add_warning((0, 0), "no copyright notice found") def get_target_branch(repo): + """Determine which branch is the "target" branch. + + The target branch is determined in the following order: + + * If any of the ``$GITHUB_BASE_REF``, ``$TARGET_BRANCH``, or ``$RAPIDS_BASE_BRANCH`` + environment variables, in that order, are defined and point to a valid branch, + that branch is used. + * If the configuration option ``rapidsai.baseBranch`` points to a valid branch, that + branch is used. + * If a ``branch-.`` branch exists, that branch is used. If more than + one such branch exists, the one with the latest version is used. + * Otherwise, None is returned and a warning is issued. + """ # Try environment - target_branch_name = os.getenv("GITHUB_BASE_REF") - if target_branch_name: + if target_branch_name := os.getenv("GITHUB_BASE_REF"): try: return repo.heads[target_branch_name] except IndexError: pass - target_branch_name = os.getenv("TARGET_BRANCH") - if target_branch_name: + if target_branch_name := os.getenv("TARGET_BRANCH"): try: return repo.heads[target_branch_name] except IndexError: pass - target_branch_name = os.getenv("RAPIDS_BASE_BRANCH") - if target_branch_name: + if target_branch_name := os.getenv("RAPIDS_BASE_BRANCH"): try: return repo.heads[target_branch_name] except IndexError: @@ -135,6 +153,11 @@ def get_target_branch(repo): pass # Appropriate branch not found + warnings.warn( + "Could not determine target branch. Try setting the TARGET_BRANCH or " + "RAPIDS_BASE_BRANCH environment variable.", + NoTargetBranchWarning, + ) return None @@ -211,7 +234,7 @@ def the_check(linter, args): return old_content = ( - changed_file.data_stream.read().decode("utf-8") + changed_file.data_stream.read().decode() if changed_file is not None else None ) diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 6f28c7d..b9f2e5e 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -178,7 +178,12 @@ def test_get_target_branch(git_repo): f.write("File\n") git_repo.index.add(["file.txt"]) git_repo.index.commit("Initial commit") - assert copyright.get_target_branch(git_repo) is None + with pytest.warns( + copyright.NoTargetBranchWarning, + match=r"^Could not determine target branch[.] Try setting the TARGET_BRANCH or " + r"RAPIDS_BASE_BRANCH environment variable[.]$", + ): + assert copyright.get_target_branch(git_repo) is None branch_24_02 = git_repo.create_head("branch-24.02") assert copyright.get_target_branch(git_repo) == branch_24_02 From dc8f6f3b6534812904fa3391726e15741c4bf047 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 22 Jan 2024 16:14:11 -0500 Subject: [PATCH 05/24] Explain get_target_branch() more thoroughly --- src/rapids_pre_commit_hooks/copyright.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index a7ff754..221a97a 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -102,13 +102,17 @@ def get_target_branch(repo): The target branch is determined in the following order: - * If any of the ``$GITHUB_BASE_REF``, ``$TARGET_BRANCH``, or ``$RAPIDS_BASE_BRANCH`` - environment variables, in that order, are defined and point to a valid branch, - that branch is used. + * If the ``$GITHUB_BASE_REF`` environment variable is defined and points to a valid + branch, that branch is used. This allows GitHub Actions to easily use this tool. + * If either of the ``$TARGET_BRANCH`` or ``$RAPIDS_BASE_BRANCH`` environment + variables, in that order, are defined and point to a valid branch, that branch is + used. This allows users to locally set a base branch on a one-time basis. * If the configuration option ``rapidsai.baseBranch`` points to a valid branch, that - branch is used. + branch is used. This allows users to locally set a base branch on a long-term + basis. * If a ``branch-.`` branch exists, that branch is used. If more than - one such branch exists, the one with the latest version is used. + one such branch exists, the one with the latest version is used. This supports the + expected default. * Otherwise, None is returned and a warning is issued. """ # Try environment From a7e190ff041ccab8014267ae76e5340b0e63fa3b Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 22 Jan 2024 16:18:26 -0500 Subject: [PATCH 06/24] Update order of environment variables --- src/rapids_pre_commit_hooks/copyright.py | 10 ++++----- .../rapids_pre_commit_hooks/test_copyright.py | 21 ++++++++++--------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 221a97a..5a7c8b0 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -102,11 +102,11 @@ def get_target_branch(repo): The target branch is determined in the following order: - * If the ``$GITHUB_BASE_REF`` environment variable is defined and points to a valid - branch, that branch is used. This allows GitHub Actions to easily use this tool. * If either of the ``$TARGET_BRANCH`` or ``$RAPIDS_BASE_BRANCH`` environment variables, in that order, are defined and point to a valid branch, that branch is used. This allows users to locally set a base branch on a one-time basis. + * If the ``$GITHUB_BASE_REF`` environment variable is defined and points to a valid + branch, that branch is used. This allows GitHub Actions to easily use this tool. * If the configuration option ``rapidsai.baseBranch`` points to a valid branch, that branch is used. This allows users to locally set a base branch on a long-term basis. @@ -116,17 +116,17 @@ def get_target_branch(repo): * Otherwise, None is returned and a warning is issued. """ # Try environment - if target_branch_name := os.getenv("GITHUB_BASE_REF"): + if target_branch_name := os.getenv("TARGET_BRANCH"): try: return repo.heads[target_branch_name] except IndexError: pass - if target_branch_name := os.getenv("TARGET_BRANCH"): + if target_branch_name := os.getenv("RAPIDS_BASE_BRANCH"): try: return repo.heads[target_branch_name] except IndexError: pass - if target_branch_name := os.getenv("RAPIDS_BASE_BRANCH"): + if target_branch_name := os.getenv("GITHUB_BASE_REF"): try: return repo.heads[target_branch_name] except IndexError: diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index b9f2e5e..3afa5f7 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -203,28 +203,29 @@ def test_get_target_branch(git_repo): w.set_value("rapidsai", "baseBranch", "branch-24.03") assert copyright.get_target_branch(git_repo) == branch_24_03 - with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "nonexistent"}): + with patch.dict("os.environ", {"GITHUB_BASE_REF": "nonexistent"}): assert copyright.get_target_branch(git_repo) == branch_24_03 - with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "master"}): + with patch.dict("os.environ", {"GITHUB_BASE_REF": "master"}): assert copyright.get_target_branch(git_repo) == master with patch.dict( - "os.environ", {"RAPIDS_BASE_BRANCH": "master", "TARGET_BRANCH": "nonexistent"} + "os.environ", {"GITHUB_BASE_REF": "master", "RAPIDS_BASE_BRANCH": "nonexistent"} ): assert copyright.get_target_branch(git_repo) == master with patch.dict( - "os.environ", {"RAPIDS_BASE_BRANCH": "master", "TARGET_BRANCH": "branch-24.02"} + "os.environ", + {"GITHUB_BASE_REF": "master", "RAPIDS_BASE_BRANCH": "branch-24.02"}, ): assert copyright.get_target_branch(git_repo) == branch_24_02 with patch.dict( "os.environ", { - "RAPIDS_BASE_BRANCH": "master", - "TARGET_BRANCH": "branch-24.02", - "GITHUB_BASE_REF": "nonexistent", + "GITHUB_BASE_REF": "master", + "RAPIDS_BASE_BRANCH": "branch-24.02", + "TARGET_BRANCH": "nonexistent", }, ): assert copyright.get_target_branch(git_repo) == branch_24_02 @@ -232,9 +233,9 @@ def test_get_target_branch(git_repo): with patch.dict( "os.environ", { - "RAPIDS_BASE_BRANCH": "master", - "TARGET_BRANCH": "branch-24.02", - "GITHUB_BASE_REF": "branch-24.04", + "GITHUB_BASE_REF": "master", + "RAPIDS_BASE_BRANCH": "branch-24.02", + "TARGET_BRANCH": "branch-24.04", }, ): assert copyright.get_target_branch(git_repo) == branch_24_04 From 958d0891d2c0608e4a209ecc04adc5e9833a3dc1 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 22 Jan 2024 16:20:25 -0500 Subject: [PATCH 07/24] Add configuration option to warning message --- src/rapids_pre_commit_hooks/copyright.py | 3 ++- test/rapids_pre_commit_hooks/test_copyright.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 5a7c8b0..1490f0e 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -159,7 +159,8 @@ def get_target_branch(repo): # Appropriate branch not found warnings.warn( "Could not determine target branch. Try setting the TARGET_BRANCH or " - "RAPIDS_BASE_BRANCH environment variable.", + "RAPIDS_BASE_BRANCH environment variable, or setting the rapidsai.baseBranch " + "configuration option.", NoTargetBranchWarning, ) return None diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 3afa5f7..2307220 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -181,7 +181,8 @@ def test_get_target_branch(git_repo): with pytest.warns( copyright.NoTargetBranchWarning, match=r"^Could not determine target branch[.] Try setting the TARGET_BRANCH or " - r"RAPIDS_BASE_BRANCH environment variable[.]$", + r"RAPIDS_BASE_BRANCH environment variable, or setting the rapidsai.baseBranch " + r"configuration option[.]$", ): assert copyright.get_target_branch(git_repo) is None From 78214e86218d8fd9c547642e1a045c464a12de7c Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 22 Jan 2024 16:34:36 -0500 Subject: [PATCH 08/24] More review feedback --- .pre-commit-hooks.yaml | 4 +-- src/rapids_pre_commit_hooks/copyright.py | 33 ++++++++++++------------ 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/.pre-commit-hooks.yaml b/.pre-commit-hooks.yaml index 572fe25..6bfd017 100644 --- a/.pre-commit-hooks.yaml +++ b/.pre-commit-hooks.yaml @@ -28,7 +28,7 @@ [.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$| CMakeLists[.]txt$| CMakeLists_standalone[.]txt$| - setup[.]cfg$| + meta[.]yaml$| pyproject[.]toml$| - meta[.]yaml$ + setup[.]cfg$ args: [--fix] diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 1490f0e..08dff11 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -77,24 +77,23 @@ def apply_copyright_check(linter, old_content): warning_pos, "copyright is not out of date and should not be updated", ).add_replacement(new_match.span(), old_match.group()) + elif new_copyright_matches: + for match in new_copyright_matches: + if ( + int(match.group("last_year") or match.group("first_year")) + < current_year + ): + linter.add_warning( + match.span("years"), "copyright is out of date" + ).add_replacement( + match.span(), + COPYRIGHT_REPLACEMENT.format( + first_year=match.group("first_year"), + last_year=current_year, + ), + ) else: - if new_copyright_matches: - for match in new_copyright_matches: - if ( - int(match.group("last_year") or match.group("first_year")) - < current_year - ): - linter.add_warning( - match.span("years"), "copyright is out of date" - ).add_replacement( - match.span(), - COPYRIGHT_REPLACEMENT.format( - first_year=match.group("first_year"), - last_year=current_year, - ), - ) - else: - linter.add_warning((0, 0), "no copyright notice found") + linter.add_warning((0, 0), "no copyright notice found") def get_target_branch(repo): From dfd944cdd3e8d8e7c78d175ea8040f4e6254ec31 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 22 Jan 2024 16:45:46 -0500 Subject: [PATCH 09/24] Use max() instead of sorted() --- src/rapids_pre_commit_hooks/copyright.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 08dff11..ab52a59 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -183,14 +183,12 @@ def try_get_ref(remote): except IndexError: return None - candidate_upstreams = sorted( - (upstream for remote in repo.remotes if (upstream := try_get_ref(remote))), - key=lambda upstream: upstream.commit.committed_datetime, - reverse=True, - ) try: - return candidate_upstreams[0].commit - except IndexError: + return max( + (upstream for remote in repo.remotes if (upstream := try_get_ref(remote))), + key=lambda upstream: upstream.commit.committed_datetime, + ).commit + except ValueError: pass return target_branch.commit From cf900c4b9657f8f00fbd29832191f71880673525 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 22 Jan 2024 17:02:11 -0500 Subject: [PATCH 10/24] Fix warning in test --- test/rapids_pre_commit_hooks/test_copyright.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 2307220..1a8129d 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -518,6 +518,9 @@ def file_contents(verbed): with patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)), mock_os_walk( git_repo.working_tree_dir + ), patch( + "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", + Mock(return_value=None), ): assert copyright.get_changed_files() == { "untouched.txt": None, From 8eb681402d90afdcce21ac7d3716eb2723ccf2c0 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 23 Jan 2024 09:54:22 -0500 Subject: [PATCH 11/24] Add --target-branch argument, update help --- src/rapids_pre_commit_hooks/copyright.py | 45 ++++++++++++++---- src/rapids_pre_commit_hooks/lint.py | 4 +- .../rapids_pre_commit_hooks/test_copyright.py | 47 ++++++++++++------- 3 files changed, 69 insertions(+), 27 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index ab52a59..eb171be 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -30,6 +30,10 @@ COPYRIGHT_REPLACEMENT = "Copyright (c) {first_year}-{last_year}, NVIDIA CORPORATION" +class NoSuchBranchWarning(Warning): + pass + + class NoTargetBranchWarning(Warning): pass @@ -96,11 +100,13 @@ def apply_copyright_check(linter, old_content): linter.add_warning((0, 0), "no copyright notice found") -def get_target_branch(repo): +def get_target_branch(repo, target_branch_arg=None): """Determine which branch is the "target" branch. The target branch is determined in the following order: + * If the ``--target-branch`` argument is passed, and points to a valid branch, that + branch is used. This allows users to set a base branch on the command line. * If either of the ``$TARGET_BRANCH`` or ``$RAPIDS_BASE_BRANCH`` environment variables, in that order, are defined and point to a valid branch, that branch is used. This allows users to locally set a base branch on a one-time basis. @@ -114,6 +120,16 @@ def get_target_branch(repo): expected default. * Otherwise, None is returned and a warning is issued. """ + # Try command line + if target_branch_arg: + try: + return repo.heads[target_branch_arg] + except IndexError: + warnings.warn( + f'--target-branch: branch name "{target_branch_arg}" does not exist.', + NoSuchBranchWarning, + ) + # Try environment if target_branch_name := os.getenv("TARGET_BRANCH"): try: @@ -165,8 +181,8 @@ def get_target_branch(repo): return None -def get_target_branch_upstream_commit(repo): - target_branch = get_target_branch(repo) +def get_target_branch_upstream_commit(repo, target_branch_arg=None): + target_branch = get_target_branch(repo, target_branch_arg) if target_branch is None: try: return repo.head.commit @@ -194,7 +210,7 @@ def try_get_ref(remote): return target_branch.commit -def get_changed_files(): +def get_changed_files(target_branch_arg): try: repo = git.Repo() except git.InvalidGitRepositoryError: @@ -205,7 +221,9 @@ def get_changed_files(): } changed_files = {f: None for f in repo.untracked_files} - target_branch_upstream_commit = get_target_branch_upstream_commit(repo) + target_branch_upstream_commit = get_target_branch_upstream_commit( + repo, target_branch_arg + ) if target_branch_upstream_commit is None: changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()}) return changed_files @@ -226,8 +244,8 @@ def get_changed_files(): return changed_files -def check_copyright(): - changed_files = get_changed_files() +def check_copyright(args): + changed_files = get_changed_files(args.target_branch) def the_check(linter, args): try: @@ -247,8 +265,19 @@ def the_check(linter, args): def main(): m = LintMain() + m.argparser.description = ( + "Verify that all files have had their copyright notices updated. Each file " + "will be compared against the target branch (determined automatically or with " + "the --target-branch argument) to decide whether or not they need a copyright " + "update." + ) + m.argparser.add_argument( + "--target-branch", + metavar="", + help="target branch to check modified files against", + ) with m.execute() as ctx: - ctx.add_check(check_copyright()) + ctx.add_check(check_copyright(ctx.args)) if __name__ == "__main__": diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index c688a79..5e46c0f 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -247,7 +247,9 @@ class LintMain: def __init__(self): self.argparser = argparse.ArgumentParser() - self.argparser.add_argument("--fix", action="store_true") + self.argparser.add_argument( + "--fix", action="store_true", help="automatically fix warnings" + ) self.argparser.add_argument("files", nargs="+", metavar="file") def execute(self): diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 1a8129d..fba0830 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -240,6 +240,12 @@ def test_get_target_branch(git_repo): }, ): assert copyright.get_target_branch(git_repo) == branch_24_04 + with pytest.warns( + copyright.NoSuchBranchWarning, + match=r'^--target-branch: branch name "nonexistent" does not exist\.$', + ): + assert copyright.get_target_branch(git_repo, "nonexistent") == branch_24_04 + assert copyright.get_target_branch(git_repo, "master") == master def test_get_target_branch_upstream_commit(git_repo): @@ -477,7 +483,7 @@ def mock_os_walk(top): os.mkdir(os.path.join(non_git_dir, "subdir1/subdir2")) with open(os.path.join(non_git_dir, "subdir1", "subdir2", "sub.txt"), "w") as f: f.write("Subdir file\n") - assert copyright.get_changed_files() == { + assert copyright.get_changed_files(Mock(target_branch=None)) == { "top.txt": None, "subdir1/subdir2/sub.txt": None, } @@ -522,7 +528,7 @@ def file_contents(verbed): "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", Mock(return_value=None), ): - assert copyright.get_changed_files() == { + assert copyright.get_changed_files(Mock(target_branch=None)) == { "untouched.txt": None, "copied.txt": None, "modified_and_copied.txt": None, @@ -616,7 +622,7 @@ def file_contents(verbed): "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", Mock(return_value=target_branch.commit), ): - changed_files = copyright.get_changed_files() + changed_files = copyright.get_changed_files(Mock(target_branch=None)) assert { path: old_blob.path if old_blob else None for path, old_blob in changed_files.items() @@ -698,9 +704,10 @@ def file_contents_modified(num): def mock_repo_cwd(): return patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)) - def mock_target_branch_upstream_commit(branch_name): - def func(repo): - return repo.heads[branch_name].commit + def mock_target_branch_upstream_commit(target_branch): + def func(repo, target_branch_arg): + assert target_branch == target_branch_arg + return repo.heads[target_branch].commit return patch( "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", func @@ -713,62 +720,66 @@ def mock_apply_copyright_check(): # branch-1 is target branch ############################# + mock_args = Mock(target_branch="branch-1") + with mock_repo_cwd(), mock_target_branch_upstream_commit("branch-1"): - copyright_checker = copyright.check_copyright() + copyright_checker = copyright.check_copyright(mock_args) linter = Linter("file1.txt", file_contents_modified(1)) with mock_apply_copyright_check() as apply_copyright_check: - copyright_checker(linter, None) + copyright_checker(linter, mock_args) apply_copyright_check.assert_not_called() linter = Linter("file5.txt", file_contents(2)) with mock_apply_copyright_check() as apply_copyright_check: - copyright_checker(linter, None) + copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(2)) linter = Linter("file3.txt", file_contents_modified(3)) with mock_apply_copyright_check() as apply_copyright_check: - copyright_checker(linter, None) + copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(3)) linter = Linter("file4.txt", file_contents_modified(4)) with mock_apply_copyright_check() as apply_copyright_check: - copyright_checker(linter, None) + copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(4)) linter = Linter("file6.txt", file_contents(6)) with mock_apply_copyright_check() as apply_copyright_check: - copyright_checker(linter, None) + copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, None) ############################# # branch-2 is target branch ############################# + mock_args = Mock(target_branch="branch-2") + with mock_repo_cwd(), mock_target_branch_upstream_commit("branch-2"): - copyright_checker = copyright.check_copyright() + copyright_checker = copyright.check_copyright(mock_args) linter = Linter("file1.txt", file_contents_modified(1)) with mock_apply_copyright_check() as apply_copyright_check: - copyright_checker(linter, None) + copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(1)) linter = Linter("file5.txt", file_contents(2)) with mock_apply_copyright_check() as apply_copyright_check: - copyright_checker(linter, None) + copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(2)) linter = Linter("file3.txt", file_contents_modified(3)) with mock_apply_copyright_check() as apply_copyright_check: - copyright_checker(linter, None) + copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(3)) linter = Linter("file4.txt", file_contents_modified(4)) with mock_apply_copyright_check() as apply_copyright_check: - copyright_checker(linter, None) + copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(4)) linter = Linter("file6.txt", file_contents(6)) with mock_apply_copyright_check() as apply_copyright_check: - copyright_checker(linter, None) + copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, None) From f901713339b23589b47bcb7e5c1bdb1006bc20f3 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 23 Jan 2024 10:18:31 -0500 Subject: [PATCH 12/24] Use max() instead of sorted() --- src/rapids_pre_commit_hooks/copyright.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index eb171be..ba84b72 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -157,18 +157,16 @@ def get_target_branch(repo, target_branch_arg=None): pass # Try newest branch-xx.yy - branches = sorted( - ( - (branch, (match.group("major"), match.group("minor"))) - for branch in repo.heads - if (match := BRANCH_RE.search(branch.name)) - ), - key=lambda i: i[1], - reverse=True, - ) try: - return branches[0][0] - except IndexError: + return max( + ( + (branch, (match.group("major"), match.group("minor"))) + for branch in repo.heads + if (match := BRANCH_RE.search(branch.name)) + ), + key=lambda i: i[1], + )[0] + except ValueError: pass # Appropriate branch not found From e4be581ed9de1b100194722f99b385f5f2dceaa3 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 23 Jan 2024 10:33:29 -0500 Subject: [PATCH 13/24] Factor out helper methods from apply_copyright_check() --- src/rapids_pre_commit_hooks/copyright.py | 55 +++++++++++++----------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index ba84b72..135e6f5 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -58,6 +58,33 @@ def append_stripped(start, item): return lines +def apply_copyright_revert(linter, old_copyright_matches, new_copyright_matches): + for old_match, new_match in zip(old_copyright_matches, new_copyright_matches): + if old_match.group() != new_match.group(): + if old_match.group("years") == new_match.group("years"): + warning_pos = new_match.span() + else: + warning_pos = new_match.span("years") + linter.add_warning( + warning_pos, + "copyright is not out of date and should not be updated", + ).add_replacement(new_match.span(), old_match.group()) + + +def apply_copyright_update(linter, copyright_matches, year): + for match in copyright_matches: + if int(match.group("last_year") or match.group("first_year")) < year: + linter.add_warning( + match.span("years"), "copyright is out of date" + ).add_replacement( + match.span(), + COPYRIGHT_REPLACEMENT.format( + first_year=match.group("first_year"), + last_year=year, + ), + ) + + def apply_copyright_check(linter, old_content): if linter.content != old_content: current_year = datetime.datetime.now().year @@ -69,33 +96,9 @@ def apply_copyright_check(linter, old_content): if old_content is not None and strip_copyright( old_content, old_copyright_matches ) == strip_copyright(linter.content, new_copyright_matches): - for old_match, new_match in zip( - old_copyright_matches, new_copyright_matches - ): - if old_match.group() != new_match.group(): - if old_match.group("years") == new_match.group("years"): - warning_pos = new_match.span() - else: - warning_pos = new_match.span("years") - linter.add_warning( - warning_pos, - "copyright is not out of date and should not be updated", - ).add_replacement(new_match.span(), old_match.group()) + apply_copyright_revert(linter, old_copyright_matches, new_copyright_matches) elif new_copyright_matches: - for match in new_copyright_matches: - if ( - int(match.group("last_year") or match.group("first_year")) - < current_year - ): - linter.add_warning( - match.span("years"), "copyright is out of date" - ).add_replacement( - match.span(), - COPYRIGHT_REPLACEMENT.format( - first_year=match.group("first_year"), - last_year=current_year, - ), - ) + apply_copyright_update(linter, new_copyright_matches, current_year) else: linter.add_warning((0, 0), "no copyright notice found") From 07c45b4b55fe9aa672f5b8161af770b320a083ec Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 23 Jan 2024 10:36:11 -0500 Subject: [PATCH 14/24] zip(strict=True) --- src/rapids_pre_commit_hooks/copyright.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 135e6f5..1735eeb 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -59,7 +59,9 @@ def append_stripped(start, item): def apply_copyright_revert(linter, old_copyright_matches, new_copyright_matches): - for old_match, new_match in zip(old_copyright_matches, new_copyright_matches): + for old_match, new_match in zip( + old_copyright_matches, new_copyright_matches, strict=True + ): if old_match.group() != new_match.group(): if old_match.group("years") == new_match.group("years"): warning_pos = new_match.span() From 2c6061a79a93112ccdff0d6e2409998395c32e74 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 23 Jan 2024 12:46:01 -0500 Subject: [PATCH 15/24] More refactoring --- src/rapids_pre_commit_hooks/copyright.py | 59 ++++++++++++------------ 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 1735eeb..7655620 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -58,33 +58,25 @@ def append_stripped(start, item): return lines -def apply_copyright_revert(linter, old_copyright_matches, new_copyright_matches): - for old_match, new_match in zip( - old_copyright_matches, new_copyright_matches, strict=True - ): - if old_match.group() != new_match.group(): - if old_match.group("years") == new_match.group("years"): - warning_pos = new_match.span() - else: - warning_pos = new_match.span("years") - linter.add_warning( - warning_pos, - "copyright is not out of date and should not be updated", - ).add_replacement(new_match.span(), old_match.group()) - - -def apply_copyright_update(linter, copyright_matches, year): - for match in copyright_matches: - if int(match.group("last_year") or match.group("first_year")) < year: - linter.add_warning( - match.span("years"), "copyright is out of date" - ).add_replacement( - match.span(), - COPYRIGHT_REPLACEMENT.format( - first_year=match.group("first_year"), - last_year=year, - ), - ) +def apply_copyright_revert(linter, old_match, new_match): + if old_match.group("years") == new_match.group("years"): + warning_pos = new_match.span() + else: + warning_pos = new_match.span("years") + linter.add_warning( + warning_pos, + "copyright is not out of date and should not be updated", + ).add_replacement(new_match.span(), old_match.group()) + + +def apply_copyright_update(linter, match, year): + linter.add_warning(match.span("years"), "copyright is out of date").add_replacement( + match.span(), + COPYRIGHT_REPLACEMENT.format( + first_year=match.group("first_year"), + last_year=year, + ), + ) def apply_copyright_check(linter, old_content): @@ -98,9 +90,18 @@ def apply_copyright_check(linter, old_content): if old_content is not None and strip_copyright( old_content, old_copyright_matches ) == strip_copyright(linter.content, new_copyright_matches): - apply_copyright_revert(linter, old_copyright_matches, new_copyright_matches) + for old_match, new_match in zip( + old_copyright_matches, new_copyright_matches, strict=True + ): + if old_match.group() != new_match.group(): + apply_copyright_revert(linter, old_match, new_match) elif new_copyright_matches: - apply_copyright_update(linter, new_copyright_matches, current_year) + for match in new_copyright_matches: + if ( + int(match.group("last_year") or match.group("first_year")) + < current_year + ): + apply_copyright_update(linter, match, current_year) else: linter.add_warning((0, 0), "no copyright notice found") From 2ea13fa6e67a1107c1dd26ac4d4e03d7d84e17c0 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 23 Jan 2024 14:10:35 -0500 Subject: [PATCH 16/24] Make warnings RuntimeWarnings --- src/rapids_pre_commit_hooks/copyright.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 7655620..28fff6d 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -30,11 +30,11 @@ COPYRIGHT_REPLACEMENT = "Copyright (c) {first_year}-{last_year}, NVIDIA CORPORATION" -class NoSuchBranchWarning(Warning): +class NoSuchBranchWarning(RuntimeWarning): pass -class NoTargetBranchWarning(Warning): +class NoTargetBranchWarning(RuntimeWarning): pass From dc4fb6c90ebbbc498a4427cec6d867031fda5215 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 23 Jan 2024 12:43:08 -0500 Subject: [PATCH 17/24] Add ability to update files to last modified date --- src/rapids_pre_commit_hooks/copyright.py | 140 +++++ .../rapids_pre_commit_hooks/test_copyright.py | 531 +++++++++++++++++- 2 files changed, 659 insertions(+), 12 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 28fff6d..91b110d 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -38,6 +38,10 @@ class NoTargetBranchWarning(RuntimeWarning): pass +class ConflictingFilesWarning(RuntimeWarning): + pass + + class ConflictingFilesError(RuntimeError): pass @@ -248,7 +252,138 @@ def get_changed_files(target_branch_arg): return changed_files +def find_blob(tree, filename): + try: + return next( + blob + for blob in tree.traverse() + if blob.type == "blob" and blob.path == filename + ) + except StopIteration: + return None + + +def get_file_last_modified(commit, filename): + blob = find_blob(commit.tree, filename) + if not blob: + return (None, None) + + queue = [(commit, blob)] + last_modified = None + checked = set() + + while queue: + commit, blob = queue.pop(0) + if (commit.hexsha, blob.path) in checked: + continue + checked.add((commit.hexsha, blob.path)) + all_modified = True + + for parent_commit in commit.parents: + + def compare_files(old_blob): + nonlocal all_modified + + if old_blob.hexsha == blob.hexsha: + # Same file contents + all_modified = False + queue.append((parent_commit, old_blob)) + else: + # Different file contents, but non-copyright-header content might be + # the same + old_content, new_content = ( + old_blob.data_stream.read().decode(), + blob.data_stream.read().decode(), + ) + old_copyright_matches, new_copyright_matches = match_copyright( + old_content + ), match_copyright(new_content) + + if strip_copyright( + old_content, old_copyright_matches + ) == strip_copyright(new_content, new_copyright_matches): + all_modified = False + queue.append((parent_commit, old_blob)) + + if parent_blob := find_blob(parent_commit.tree, blob.path): + compare_files(parent_blob) + else: + diffs = parent_commit.diff( + other=commit, + find_copies=True, + find_copies_harder=True, + find_renames=True, + ) + diff = next(diff for diff in diffs if diff.b_path == blob.path) + if diff.change_type != "A": + compare_files(diff.a_blob) + + if all_modified: + if ( + not last_modified + or commit.committed_datetime > last_modified[0].committed_datetime + ): + last_modified = (commit, blob) + + assert last_modified + return last_modified + + +def apply_batch_copyright_check(repo, linter): + current_blob = find_blob(repo.head.commit.tree, linter.filename) + if not current_blob: + warnings.warn( + f'File "{linter.filename}" not in Git history. Not running batch copyright ' + "update.", + ConflictingFilesWarning, + ) + return + if current_blob.data_stream.read().decode() != linter.content: + warnings.warn( + f'File "{linter.filename}" differs from Git history. Not running batch ' + "copyright update.", + ConflictingFilesWarning, + ) + return + + commit, old_blob = get_file_last_modified(repo.head.commit, linter.filename) + year = commit.committed_datetime.year + old_content = old_blob.data_stream.read().decode() + + old_copyright_matches, new_copyright_matches = match_copyright( + old_content + ), match_copyright(linter.content) + assert strip_copyright(old_content, old_copyright_matches) == strip_copyright( + linter.content, new_copyright_matches + ) + if new_copyright_matches: + for old_match, new_match in zip( + old_copyright_matches, new_copyright_matches, strict=True + ): + if ( + int(new_match.group("last_year") or new_match.group("first_year")) + < year + ): + apply_copyright_update(linter, new_match, year) + elif ( + old_match.group() != new_match.group() + and int(old_match.group("last_year") or old_match.group("first_year")) + >= year + ): + apply_copyright_revert(linter, old_match, new_match) + else: + linter.add_warning((0, 0), "no copyright notice found") + + def check_copyright(args): + if args.batch: + repo = git.Repo() + + def the_check(linter, args): + apply_batch_copyright_check(repo, linter) + + return the_check + changed_files = get_changed_files(args.target_branch) def the_check(linter, args): @@ -280,6 +415,11 @@ def main(): metavar="", help="target branch to check modified files against", ) + m.argparser.add_argument( + "--batch", + action="store_true", + help="batch update files based on last modification commit", + ) with m.execute() as ctx: ctx.add_check(check_copyright(ctx.args)) diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index fba0830..0902d27 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import datetime import os.path import tempfile @@ -645,6 +646,459 @@ def file_contents(verbed): assert changed_files[new].data_stream.read() == old_contents +def test_find_blob(git_repo): + with open(os.path.join(git_repo.working_tree_dir, "top.txt"), "w"): + pass + os.mkdir(os.path.join(git_repo.working_tree_dir, "sub1")) + os.mkdir(os.path.join(git_repo.working_tree_dir, "sub1", "sub2")) + with open(os.path.join(git_repo.working_tree_dir, "sub1", "sub2", "sub.txt"), "w"): + pass + git_repo.index.add(["top.txt", "sub1/sub2/sub.txt"]) + git_repo.index.commit("Initial commit") + + assert copyright.find_blob(git_repo.head.commit.tree, "top.txt").path == "top.txt" + assert ( + copyright.find_blob(git_repo.head.commit.tree, "sub1/sub2/sub.txt").path + == "sub1/sub2/sub.txt" + ) + assert copyright.find_blob(git_repo.head.commit.tree, "nonexistent.txt") is None + + +def test_get_file_last_modified(git_repo): + def fn(filename): + return os.path.join(git_repo.working_tree_dir, filename) + + def write_file(filename, contents): + with open(fn(filename), "w") as f: + f.write(contents) + + def expected_return_value(commit, filename): + return (commit, copyright.find_blob(commit.tree, filename)) + + @contextlib.contextmanager + def no_match_copyright(): + with patch( + "rapids_pre_commit_hooks.copyright.match_copyright", Mock() + ) as match_copyright, patch( + "rapids_pre_commit_hooks.copyright.strip_copyright", Mock() + ) as strip_copyright: + yield + match_copyright.assert_not_called() + strip_copyright.assert_not_called() + + write_file("file1.txt", "File 1") + git_repo.index.add("file1.txt") + git_repo.index.commit("Initial commit") + with no_match_copyright(): + assert copyright.get_file_last_modified( + git_repo.head.commit, "file1.txt" + ) == expected_return_value(git_repo.head.commit, "file1.txt") + + write_file("file2.txt", "File 2") + git_repo.index.add("file2.txt") + git_repo.index.commit("Add file2.txt") + with no_match_copyright(): + assert copyright.get_file_last_modified( + git_repo.head.commit, "file1.txt" + ) == expected_return_value(git_repo.head.commit.parents[0], "file1.txt") + assert copyright.get_file_last_modified( + git_repo.head.commit.parents[0], "file1.txt" + ) == expected_return_value(git_repo.head.commit.parents[0], "file1.txt") + assert copyright.get_file_last_modified( + git_repo.head.commit, "file2.txt" + ) == expected_return_value(git_repo.head.commit, "file2.txt") + assert copyright.get_file_last_modified( + git_repo.head.commit.parents[0], "file2.txt" + ) == (None, None) + + git_repo.index.remove("file1.txt", working_tree=True) + write_file("file1_2.txt", "File 1") + write_file("file2_2.txt", "File 2") + git_repo.index.add(["file1_2.txt", "file2_2.txt"]) + git_repo.index.commit("Rename and copy") + with no_match_copyright(): + assert copyright.get_file_last_modified( + git_repo.head.commit, "file1_2.txt" + ) == expected_return_value( + git_repo.head.commit.parents[0].parents[0], "file1.txt" + ) + assert copyright.get_file_last_modified( + git_repo.head.commit, "file2_2.txt" + ) == expected_return_value(git_repo.head.commit.parents[0], "file2.txt") + + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023 NVIDIA CORPORATION +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + git_repo.index.commit("Add copyrighted file") + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + git_repo.index.commit("Update copyright") + assert copyright.get_file_last_modified( + git_repo.head.commit, "copyright.txt" + ) == expected_return_value(git_repo.head.commit.parents[0], "copyright.txt") + + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +New content +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + git_repo.index.commit("New contents") + assert copyright.get_file_last_modified( + git_repo.head.commit, "copyright.txt" + ) == expected_return_value(git_repo.head.commit, "copyright.txt") + + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +Updated content +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + commit_1 = git_repo.index.commit( + "Update contents", + commit_date=datetime.datetime(2024, 1, 23, tzinfo=datetime.timezone.utc), + ) + commit_2 = git_repo.index.commit( + "Update contents", + commit_date=datetime.datetime(2024, 1, 24, tzinfo=datetime.timezone.utc), + parent_commits=commit_1.parents, + ) + git_repo.index.commit("Merge", parent_commits=[commit_1, commit_2]) + assert copyright.get_file_last_modified( + git_repo.head.commit, "copyright.txt" + ) == expected_return_value(commit_2, "copyright.txt") + + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +New updated content +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + commit_1 = git_repo.index.commit( + "Update contents again", + commit_date=datetime.datetime(2024, 1, 24, tzinfo=datetime.timezone.utc), + ) + commit_2 = git_repo.index.commit( + "Update contents again", + commit_date=datetime.datetime(2024, 1, 23, tzinfo=datetime.timezone.utc), + parent_commits=commit_1.parents, + ) + git_repo.index.commit("Merge", parent_commits=[commit_1, commit_2]) + assert copyright.get_file_last_modified( + git_repo.head.commit, "copyright.txt" + ) == expected_return_value(commit_1, "copyright.txt") + + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +Old content +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + git_repo.index.commit( + "Old content", + commit_date=datetime.datetime(2024, 1, 23, tzinfo=datetime.timezone.utc), + ) + old_commit = git_repo.head.commit + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +New content +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + new_commit = git_repo.index.commit( + "New content", + commit_date=datetime.datetime(2024, 1, 24, tzinfo=datetime.timezone.utc), + ) + git_repo.index.commit( + "Merge", + commit_date=datetime.datetime(2024, 1, 25, tzinfo=datetime.timezone.utc), + parent_commits=[git_repo.head.commit, old_commit], + ) + assert copyright.get_file_last_modified( + git_repo.head.commit, "copyright.txt" + ) == expected_return_value(new_commit, "copyright.txt") + + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +Old content +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + git_repo.index.commit( + "Old content", + commit_date=datetime.datetime(2024, 1, 23, tzinfo=datetime.timezone.utc), + ) + old_commit = git_repo.head.commit + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023 NVIDIA CORPORATION +New content +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + new_commit = git_repo.index.commit( + "New content", + commit_date=datetime.datetime(2024, 1, 25, tzinfo=datetime.timezone.utc), + ) + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +New content +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + git_repo.index.commit( + "Merge", + commit_date=datetime.datetime(2024, 1, 24, tzinfo=datetime.timezone.utc), + parent_commits=[git_repo.head.commit, old_commit], + ) + assert copyright.get_file_last_modified( + git_repo.head.commit, "copyright.txt" + ) == expected_return_value(new_commit, "copyright.txt") + + write_file( + "copyright.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023 NVIDIA CORPORATION +Copyrighted content +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + git_repo.index.commit("Add copyrighted content") + write_file( + "copyright2.txt", + """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +Copyrighted content +End of copyrighted file +""", + ) + git_repo.index.add("copyright2.txt") + git_repo.index.commit("Copy copyrighted file") + assert copyright.get_file_last_modified( + git_repo.head.commit, "copyright2.txt" + ) == expected_return_value(git_repo.head.commit.parents[0], "copyright.txt") + + git_repo.index.remove("copyright2.txt", working_tree=True) + write_file( + "copyright.txt", + f""" +Beginning of copyrighted file +Copyright (c) 2023 NVIDIA CORPORATION +{'''Lots of content +'''} * 100 +End of copyrighted file +""", + ) + git_repo.index.add("copyright.txt") + git_repo.index.commit("Add copyrighted content") + write_file( + "copyright2.txt", + f""" +Beginning of copyrighted file +Copyright (c) 2023 NVIDIA CORPORATION +{'''Lots of content +'''} * 100 +More content +End of copyrighted file +""", + ) + git_repo.index.add("copyright2.txt") + git_repo.index.commit("Copy and modify copyrighted file") + assert copyright.get_file_last_modified( + git_repo.head.commit, "copyright2.txt" + ) == expected_return_value(git_repo.head.commit, "copyright2.txt") + + +def test_apply_batch_copyright_check(git_repo): + def fn(filename): + return os.path.join(git_repo.working_tree_dir, filename) + + def write_file(filename, content): + with open(fn(filename), "w") as f: + f.write(content) + + CONTENT = """ +Beginning of copyrighted file +Copyright (c) 2023 NVIDIA CORPORATION +End of copyrighted file +""" + write_file("file.txt", CONTENT) + git_repo.index.add("file.txt") + git_repo.index.commit( + "Initial commit", + commit_date=datetime.datetime(2023, 2, 1, tzinfo=datetime.timezone.utc), + ) + + linter = Linter("file.txt", CONTENT) + copyright.apply_batch_copyright_check(git_repo, linter) + assert linter.warnings == [] + + linter = Linter("file.txt", CONTENT + "Oops") + with pytest.warns( + copyright.ConflictingFilesWarning, + match=r'^File "file[.]txt" differs from Git history. Not running batch ' + r"copyright update[.]$", + ): + copyright.apply_batch_copyright_check(git_repo, linter) + assert linter.warnings == [] + + linter = Linter("file2.txt", CONTENT + "Oops") + with pytest.warns( + copyright.ConflictingFilesWarning, + match=r'^File "file2[.]txt" not in Git history. Not running batch copyright ' + r"update[.]$", + ): + copyright.apply_batch_copyright_check(git_repo, linter) + assert linter.warnings == [] + + CONTENT = """ +Beginning of copyrighted file +Copyright (c) 2023 NVIDIA CORPORATION +New content +End of copyrighted file +""" + write_file("file.txt", CONTENT) + git_repo.index.add("file.txt") + git_repo.index.commit( + "Add content", + commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc), + ) + + expected_linter = Linter("file.txt", CONTENT) + expected_linter.add_warning((45, 49), "copyright is out of date").add_replacement( + (31, 68), "Copyright (c) 2023-2024, NVIDIA CORPORATION" + ) + + linter = Linter("file.txt", CONTENT) + copyright.apply_batch_copyright_check(git_repo, linter) + assert linter.warnings == expected_linter.warnings + + CONTENT = """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +New content +End of copyrighted file +""" + write_file("file.txt", CONTENT) + git_repo.index.add("file.txt") + git_repo.index.commit( + "Add content", + commit_date=datetime.datetime(2024, 2, 2, tzinfo=datetime.timezone.utc), + ) + + linter = Linter("file.txt", CONTENT) + copyright.apply_batch_copyright_check(git_repo, linter) + assert linter.warnings == [] + + CONTENT = """ +Beginning of copyrighted file +Copyright (c) 2023-2024 NVIDIA CORPORATION +Newer content +End of copyrighted file +""" + write_file("file.txt", CONTENT) + git_repo.index.add("file.txt") + git_repo.index.commit( + "Update copyright and content", + commit_date=datetime.datetime(2024, 2, 3, tzinfo=datetime.timezone.utc), + ) + + linter = Linter("file.txt", CONTENT) + copyright.apply_batch_copyright_check(git_repo, linter) + assert linter.warnings == [] + + CONTENT = """ +Beginning of copyrighted file +Copyright (c) 2023-2025 NVIDIA CORPORATION +Newer content +End of copyrighted file +""" + write_file("file.txt", CONTENT) + git_repo.index.add("file.txt") + git_repo.index.commit( + "Update copyright again", + commit_date=datetime.datetime(2025, 2, 1, tzinfo=datetime.timezone.utc), + ) + + expected_linter = Linter("file.txt", CONTENT) + expected_linter.add_warning( + (45, 54), "copyright is not out of date and should not be updated" + ).add_replacement((31, 73), "Copyright (c) 2023-2024 NVIDIA CORPORATION") + + linter = Linter("file.txt", CONTENT) + copyright.apply_batch_copyright_check(git_repo, linter) + assert linter.warnings == expected_linter.warnings + + CONTENT = """ +Beginning of copyrighted file +Copyright (c) 2023-2025 NVIDIA CORPORATION +Even newer content +End of copyrighted file +""" + write_file("file.txt", CONTENT) + git_repo.index.add("file.txt") + git_repo.index.commit( + "Update copyright again", + commit_date=datetime.datetime(2026, 2, 1, tzinfo=datetime.timezone.utc), + ) + + expected_linter = Linter("file.txt", CONTENT) + expected_linter.add_warning((45, 54), "copyright is out of date").add_replacement( + (31, 73), "Copyright (c) 2023-2026, NVIDIA CORPORATION" + ) + + linter = Linter("file.txt", CONTENT) + copyright.apply_batch_copyright_check(git_repo, linter) + assert linter.warnings == expected_linter.warnings + + @freeze_time("2024-01-18") def test_check_copyright(git_repo): def fn(filename): @@ -716,37 +1170,60 @@ def func(repo, target_branch_arg): def mock_apply_copyright_check(): return patch("rapids_pre_commit_hooks.copyright.apply_copyright_check", Mock()) + @contextlib.contextmanager + def no_apply_batch_copyright_check(): + with patch( + "rapids_pre_commit_hooks.copyright.apply_batch_copyright_check", Mock() + ) as apply_batch_copyright_check: + yield + apply_batch_copyright_check.assert_not_called() + ############################# # branch-1 is target branch ############################# - mock_args = Mock(target_branch="branch-1") + mock_args = Mock(target_branch="branch-1", batch=False) with mock_repo_cwd(), mock_target_branch_upstream_commit("branch-1"): copyright_checker = copyright.check_copyright(mock_args) linter = Linter("file1.txt", file_contents_modified(1)) - with mock_apply_copyright_check() as apply_copyright_check: + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on copyright_checker(linter, mock_args) apply_copyright_check.assert_not_called() linter = Linter("file5.txt", file_contents(2)) - with mock_apply_copyright_check() as apply_copyright_check: + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(2)) linter = Linter("file3.txt", file_contents_modified(3)) - with mock_apply_copyright_check() as apply_copyright_check: + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(3)) linter = Linter("file4.txt", file_contents_modified(4)) - with mock_apply_copyright_check() as apply_copyright_check: + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(4)) linter = Linter("file6.txt", file_contents(6)) - with mock_apply_copyright_check() as apply_copyright_check: + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, None) @@ -754,32 +1231,62 @@ def mock_apply_copyright_check(): # branch-2 is target branch ############################# - mock_args = Mock(target_branch="branch-2") + mock_args = Mock(target_branch="branch-2", batch=False) with mock_repo_cwd(), mock_target_branch_upstream_commit("branch-2"): copyright_checker = copyright.check_copyright(mock_args) linter = Linter("file1.txt", file_contents_modified(1)) - with mock_apply_copyright_check() as apply_copyright_check: + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(1)) linter = Linter("file5.txt", file_contents(2)) - with mock_apply_copyright_check() as apply_copyright_check: + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(2)) linter = Linter("file3.txt", file_contents_modified(3)) - with mock_apply_copyright_check() as apply_copyright_check: + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(3)) linter = Linter("file4.txt", file_contents_modified(4)) - with mock_apply_copyright_check() as apply_copyright_check: + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(4)) linter = Linter("file6.txt", file_contents(6)) - with mock_apply_copyright_check() as apply_copyright_check: + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, None) + + +def test_check_copyright_batch(): + git_repo = Mock() + with patch("git.Repo", Mock(return_value=git_repo)), patch( + "rapids_pre_commit_hooks.copyright.apply_copyright_check", Mock() + ) as apply_copyright_check, patch( + "rapids_pre_commit_hooks.copyright.apply_batch_copyright_check", Mock() + ) as apply_batch_copyright_check: + mock_args = Mock(batch=True) + copyright_checker = copyright.check_copyright(mock_args) + linter = Mock() + copyright_checker(linter, mock_args) + apply_batch_copyright_check.assert_called_once_with(git_repo, linter) + apply_copyright_check.assert_not_called() From b42bb4f4dcdbce064d72916db437044b94057438 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 23 Jan 2024 15:11:47 -0500 Subject: [PATCH 18/24] Fix regexes --- test/rapids_pre_commit_hooks/test_copyright.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 0902d27..f897793 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -983,7 +983,7 @@ def write_file(filename, content): linter = Linter("file.txt", CONTENT + "Oops") with pytest.warns( copyright.ConflictingFilesWarning, - match=r'^File "file[.]txt" differs from Git history. Not running batch ' + match=r'^File "file[.]txt" differs from Git history[.] Not running batch ' r"copyright update[.]$", ): copyright.apply_batch_copyright_check(git_repo, linter) @@ -992,7 +992,7 @@ def write_file(filename, content): linter = Linter("file2.txt", CONTENT + "Oops") with pytest.warns( copyright.ConflictingFilesWarning, - match=r'^File "file2[.]txt" not in Git history. Not running batch copyright ' + match=r'^File "file2[.]txt" not in Git history[.] Not running batch copyright ' r"update[.]$", ): copyright.apply_batch_copyright_check(git_repo, linter) From 5770735941e5ca1adda7cecfc3864fa88ac05f22 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 23 Jan 2024 15:55:38 -0500 Subject: [PATCH 19/24] Drastically speed up find_blob() --- src/rapids_pre_commit_hooks/copyright.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 91b110d..b9f0b0e 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -253,12 +253,21 @@ def get_changed_files(target_branch_arg): def find_blob(tree, filename): + d1, d2 = os.path.split(filename) + split = [d2] + while d1: + d1, d2 = os.path.split(d1) + split.insert(0, d2) + + while len(split) > 1: + component = split.pop(0) + try: + tree = next(t for t in tree.trees if t.name == component) + except StopIteration: + return None + try: - return next( - blob - for blob in tree.traverse() - if blob.type == "blob" and blob.path == filename - ) + return next(blob for blob in tree.blobs if blob.name == split[0]) except StopIteration: return None From 2a727696e9b9470b8d2ac1e7373259bafc3705e4 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Wed, 24 Jan 2024 09:39:12 -0500 Subject: [PATCH 20/24] Normalize Git paths --- src/rapids_pre_commit_hooks/copyright.py | 29 +++++++++- .../rapids_pre_commit_hooks/test_copyright.py | 58 +++++++++++++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index b9f0b0e..76ed940 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -252,6 +252,13 @@ def get_changed_files(target_branch_arg): return changed_files +def normalize_git_filename(filename): + relpath = os.path.relpath(filename) + if re.search(r"^\.\.(/|$)", relpath): + return None + return relpath + + def find_blob(tree, filename): d1, d2 = os.path.split(filename) split = [d2] @@ -339,7 +346,15 @@ def compare_files(old_blob): def apply_batch_copyright_check(repo, linter): - current_blob = find_blob(repo.head.commit.tree, linter.filename) + if not (git_filename := normalize_git_filename(linter.filename)): + warnings.warn( + f'File "{linter.filename}" is outside of current directory. Not running ' + "linter on it.", + ConflictingFilesWarning, + ) + return + + current_blob = find_blob(repo.head.commit.tree, git_filename) if not current_blob: warnings.warn( f'File "{linter.filename}" not in Git history. Not running batch copyright ' @@ -355,7 +370,7 @@ def apply_batch_copyright_check(repo, linter): ) return - commit, old_blob = get_file_last_modified(repo.head.commit, linter.filename) + commit, old_blob = get_file_last_modified(repo.head.commit, git_filename) year = commit.committed_datetime.year old_content = old_blob.data_stream.read().decode() @@ -396,8 +411,16 @@ def the_check(linter, args): changed_files = get_changed_files(args.target_branch) def the_check(linter, args): + if not (git_filename := normalize_git_filename(linter.filename)): + warnings.warn( + f'File "{linter.filename}" is outside of current directory. Not ' + "running linter on it.", + ConflictingFilesWarning, + ) + return + try: - changed_file = changed_files[linter.filename] + changed_file = changed_files[git_filename] except KeyError: return diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index f897793..6a3db43 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -646,6 +646,25 @@ def file_contents(verbed): assert changed_files[new].data_stream.read() == old_contents +def test_normalize_git_filename(): + assert copyright.normalize_git_filename("file.txt") == "file.txt" + assert copyright.normalize_git_filename("sub/file.txt") == "sub/file.txt" + assert copyright.normalize_git_filename("sub//file.txt") == "sub/file.txt" + assert copyright.normalize_git_filename("sub/../file.txt") == "file.txt" + assert copyright.normalize_git_filename("./file.txt") == "file.txt" + assert copyright.normalize_git_filename("../file.txt") is None + assert ( + copyright.normalize_git_filename(os.path.join(os.getcwd(), "file.txt")) + == "file.txt" + ) + assert ( + copyright.normalize_git_filename( + os.path.join("..", os.path.basename(os.getcwd()), "file.txt") + ) + == "file.txt" + ) + + def test_find_blob(git_repo): with open(os.path.join(git_repo.working_tree_dir, "top.txt"), "w"): pass @@ -662,6 +681,7 @@ def test_find_blob(git_repo): == "sub1/sub2/sub.txt" ) assert copyright.find_blob(git_repo.head.commit.tree, "nonexistent.txt") is None + assert copyright.find_blob(git_repo.head.commit.tree, "nonexistent/sub.txt") is None def test_get_file_last_modified(git_repo): @@ -1020,6 +1040,24 @@ def write_file(filename, content): copyright.apply_batch_copyright_check(git_repo, linter) assert linter.warnings == expected_linter.warnings + expected_linter = Linter("./file.txt", CONTENT) + expected_linter.add_warning((45, 49), "copyright is out of date").add_replacement( + (31, 68), "Copyright (c) 2023-2024, NVIDIA CORPORATION" + ) + + linter = Linter("./file.txt", CONTENT) + copyright.apply_batch_copyright_check(git_repo, linter) + assert linter.warnings == expected_linter.warnings + + linter = Linter("../file.txt", CONTENT) + with pytest.warns( + copyright.ConflictingFilesWarning, + match=r'File "\.\./file.txt" is outside of current directory\. Not running ' + r"linter on it\.$", + ): + copyright.apply_batch_copyright_check(git_repo, linter) + assert linter.warnings == [] + CONTENT = """ Beginning of copyrighted file Copyright (c) 2023-2024 NVIDIA CORPORATION @@ -1244,6 +1282,26 @@ def no_apply_batch_copyright_check(): copyright_checker(linter, mock_args) apply_copyright_check.assert_called_once_with(linter, file_contents(1)) + linter = Linter("./file1.txt", file_contents_modified(1)) + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on + copyright_checker(linter, mock_args) + apply_copyright_check.assert_called_once_with(linter, file_contents(1)) + + linter = Linter("../file1.txt", file_contents_modified(1)) + # fmt: off + with mock_apply_copyright_check() as apply_copyright_check, \ + no_apply_batch_copyright_check(): + # fmt: on + with pytest.warns( + copyright.ConflictingFilesWarning, + match=r'File "\.\./file1\.txt" is outside of current directory\. Not ' + r'running linter on it\.$'): + copyright_checker(linter, mock_args) + apply_copyright_check.assert_not_called() + linter = Linter("file5.txt", file_contents(2)) # fmt: off with mock_apply_copyright_check() as apply_copyright_check, \ From 2011a2bbe195d97bd8d1eaccc9c8e4667e7dfc66 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Wed, 24 Jan 2024 09:45:23 -0500 Subject: [PATCH 21/24] Add verify-copyright-batch pre-commit hook --- .pre-commit-hooks.yaml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/.pre-commit-hooks.yaml b/.pre-commit-hooks.yaml index 6bfd017..5ed91c6 100644 --- a/.pre-commit-hooks.yaml +++ b/.pre-commit-hooks.yaml @@ -32,3 +32,17 @@ pyproject[.]toml$| setup[.]cfg$ args: [--fix] +- id: verify-copyright-batch + name: copyright headers + description: make sure copyright headers are up to date in Git history + entry: verify-copyright + language: python + files: | + (?x) + [.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$| + CMakeLists[.]txt$| + CMakeLists_standalone[.]txt$| + meta[.]yaml$| + pyproject[.]toml$| + setup[.]cfg$ + args: [--batch, --fix] From 8bd1b4ad0f0a1ea4850dacf426a240e07d1e4686 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 26 Jan 2024 14:21:56 -0500 Subject: [PATCH 22/24] Remove verify-copyright-batch, change name --- .pre-commit-hooks.yaml | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/.pre-commit-hooks.yaml b/.pre-commit-hooks.yaml index 5ed91c6..566452c 100644 --- a/.pre-commit-hooks.yaml +++ b/.pre-commit-hooks.yaml @@ -19,7 +19,7 @@ types: [shell] args: [--fix] - id: verify-copyright - name: copyright headers + name: verify-copyright description: make sure copyright headers are up to date entry: verify-copyright language: python @@ -32,17 +32,3 @@ pyproject[.]toml$| setup[.]cfg$ args: [--fix] -- id: verify-copyright-batch - name: copyright headers - description: make sure copyright headers are up to date in Git history - entry: verify-copyright - language: python - files: | - (?x) - [.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$| - CMakeLists[.]txt$| - CMakeLists_standalone[.]txt$| - meta[.]yaml$| - pyproject[.]toml$| - setup[.]cfg$ - args: [--batch, --fix] From 1b55d6412a767004d740b617f3d331bc95c925cf Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 26 Jan 2024 15:21:02 -0500 Subject: [PATCH 23/24] Simplify last modified logic at expense of smart change detection --- src/rapids_pre_commit_hooks/copyright.py | 95 ++------------ .../rapids_pre_commit_hooks/test_copyright.py | 124 +++++++++--------- 2 files changed, 70 insertions(+), 149 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 76ed940..79f3650 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -280,69 +280,12 @@ def find_blob(tree, filename): def get_file_last_modified(commit, filename): - blob = find_blob(commit.tree, filename) - if not blob: - return (None, None) - - queue = [(commit, blob)] - last_modified = None - checked = set() - - while queue: - commit, blob = queue.pop(0) - if (commit.hexsha, blob.path) in checked: - continue - checked.add((commit.hexsha, blob.path)) - all_modified = True - - for parent_commit in commit.parents: - - def compare_files(old_blob): - nonlocal all_modified - - if old_blob.hexsha == blob.hexsha: - # Same file contents - all_modified = False - queue.append((parent_commit, old_blob)) - else: - # Different file contents, but non-copyright-header content might be - # the same - old_content, new_content = ( - old_blob.data_stream.read().decode(), - blob.data_stream.read().decode(), - ) - old_copyright_matches, new_copyright_matches = match_copyright( - old_content - ), match_copyright(new_content) - - if strip_copyright( - old_content, old_copyright_matches - ) == strip_copyright(new_content, new_copyright_matches): - all_modified = False - queue.append((parent_commit, old_blob)) - - if parent_blob := find_blob(parent_commit.tree, blob.path): - compare_files(parent_blob) - else: - diffs = parent_commit.diff( - other=commit, - find_copies=True, - find_copies_harder=True, - find_renames=True, - ) - diff = next(diff for diff in diffs if diff.b_path == blob.path) - if diff.change_type != "A": - compare_files(diff.a_blob) - - if all_modified: - if ( - not last_modified - or commit.committed_datetime > last_modified[0].committed_datetime - ): - last_modified = (commit, blob) + try: + last_modified_commit = next(commit.repo.iter_commits(commit, filename)) + except StopIteration: + return None - assert last_modified - return last_modified + return last_modified_commit def apply_batch_copyright_check(repo, linter): @@ -370,31 +313,13 @@ def apply_batch_copyright_check(repo, linter): ) return - commit, old_blob = get_file_last_modified(repo.head.commit, git_filename) + commit = get_file_last_modified(repo.head.commit, git_filename) year = commit.committed_datetime.year - old_content = old_blob.data_stream.read().decode() - old_copyright_matches, new_copyright_matches = match_copyright( - old_content - ), match_copyright(linter.content) - assert strip_copyright(old_content, old_copyright_matches) == strip_copyright( - linter.content, new_copyright_matches - ) - if new_copyright_matches: - for old_match, new_match in zip( - old_copyright_matches, new_copyright_matches, strict=True - ): - if ( - int(new_match.group("last_year") or new_match.group("first_year")) - < year - ): - apply_copyright_update(linter, new_match, year) - elif ( - old_match.group() != new_match.group() - and int(old_match.group("last_year") or old_match.group("first_year")) - >= year - ): - apply_copyright_revert(linter, old_match, new_match) + if copyright_matches := match_copyright(linter.content): + for match in copyright_matches: + if int(match.group("last_year") or match.group("first_year")) < year: + apply_copyright_update(linter, match, year) else: linter.add_warning((0, 0), "no copyright notice found") diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 6a3db43..1db8df5 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -692,59 +692,47 @@ def write_file(filename, contents): with open(fn(filename), "w") as f: f.write(contents) - def expected_return_value(commit, filename): - return (commit, copyright.find_blob(commit.tree, filename)) - - @contextlib.contextmanager - def no_match_copyright(): - with patch( - "rapids_pre_commit_hooks.copyright.match_copyright", Mock() - ) as match_copyright, patch( - "rapids_pre_commit_hooks.copyright.strip_copyright", Mock() - ) as strip_copyright: - yield - match_copyright.assert_not_called() - strip_copyright.assert_not_called() - write_file("file1.txt", "File 1") git_repo.index.add("file1.txt") git_repo.index.commit("Initial commit") - with no_match_copyright(): - assert copyright.get_file_last_modified( - git_repo.head.commit, "file1.txt" - ) == expected_return_value(git_repo.head.commit, "file1.txt") + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "file1.txt") + == git_repo.head.commit + ) write_file("file2.txt", "File 2") git_repo.index.add("file2.txt") git_repo.index.commit("Add file2.txt") - with no_match_copyright(): - assert copyright.get_file_last_modified( - git_repo.head.commit, "file1.txt" - ) == expected_return_value(git_repo.head.commit.parents[0], "file1.txt") - assert copyright.get_file_last_modified( - git_repo.head.commit.parents[0], "file1.txt" - ) == expected_return_value(git_repo.head.commit.parents[0], "file1.txt") - assert copyright.get_file_last_modified( - git_repo.head.commit, "file2.txt" - ) == expected_return_value(git_repo.head.commit, "file2.txt") - assert copyright.get_file_last_modified( - git_repo.head.commit.parents[0], "file2.txt" - ) == (None, None) + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "file1.txt") + == git_repo.head.commit.parents[0] + ) + assert ( + copyright.get_file_last_modified(git_repo.head.commit.parents[0], "file1.txt") + == git_repo.head.commit.parents[0] + ) + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "file2.txt") + == git_repo.head.commit + ) + assert ( + copyright.get_file_last_modified(git_repo.head.commit.parents[0], "file2.txt") + is None + ) git_repo.index.remove("file1.txt", working_tree=True) write_file("file1_2.txt", "File 1") write_file("file2_2.txt", "File 2") git_repo.index.add(["file1_2.txt", "file2_2.txt"]) git_repo.index.commit("Rename and copy") - with no_match_copyright(): - assert copyright.get_file_last_modified( - git_repo.head.commit, "file1_2.txt" - ) == expected_return_value( - git_repo.head.commit.parents[0].parents[0], "file1.txt" - ) - assert copyright.get_file_last_modified( - git_repo.head.commit, "file2_2.txt" - ) == expected_return_value(git_repo.head.commit.parents[0], "file2.txt") + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "file1_2.txt") + == git_repo.head.commit + ) + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "file2_2.txt") + == git_repo.head.commit + ) write_file( "copyright.txt", @@ -766,9 +754,10 @@ def no_match_copyright(): ) git_repo.index.add("copyright.txt") git_repo.index.commit("Update copyright") - assert copyright.get_file_last_modified( - git_repo.head.commit, "copyright.txt" - ) == expected_return_value(git_repo.head.commit.parents[0], "copyright.txt") + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "copyright.txt") + == git_repo.head.commit + ) write_file( "copyright.txt", @@ -781,9 +770,10 @@ def no_match_copyright(): ) git_repo.index.add("copyright.txt") git_repo.index.commit("New contents") - assert copyright.get_file_last_modified( - git_repo.head.commit, "copyright.txt" - ) == expected_return_value(git_repo.head.commit, "copyright.txt") + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "copyright.txt") + == git_repo.head.commit + ) write_file( "copyright.txt", @@ -805,9 +795,10 @@ def no_match_copyright(): parent_commits=commit_1.parents, ) git_repo.index.commit("Merge", parent_commits=[commit_1, commit_2]) - assert copyright.get_file_last_modified( - git_repo.head.commit, "copyright.txt" - ) == expected_return_value(commit_2, "copyright.txt") + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "copyright.txt") + == commit_1 + ) write_file( "copyright.txt", @@ -829,9 +820,10 @@ def no_match_copyright(): parent_commits=commit_1.parents, ) git_repo.index.commit("Merge", parent_commits=[commit_1, commit_2]) - assert copyright.get_file_last_modified( - git_repo.head.commit, "copyright.txt" - ) == expected_return_value(commit_1, "copyright.txt") + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "copyright.txt") + == commit_1 + ) write_file( "copyright.txt", @@ -867,9 +859,10 @@ def no_match_copyright(): commit_date=datetime.datetime(2024, 1, 25, tzinfo=datetime.timezone.utc), parent_commits=[git_repo.head.commit, old_commit], ) - assert copyright.get_file_last_modified( - git_repo.head.commit, "copyright.txt" - ) == expected_return_value(new_commit, "copyright.txt") + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "copyright.txt") + == new_commit + ) write_file( "copyright.txt", @@ -915,9 +908,10 @@ def no_match_copyright(): commit_date=datetime.datetime(2024, 1, 24, tzinfo=datetime.timezone.utc), parent_commits=[git_repo.head.commit, old_commit], ) - assert copyright.get_file_last_modified( - git_repo.head.commit, "copyright.txt" - ) == expected_return_value(new_commit, "copyright.txt") + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "copyright.txt") + == git_repo.head.commit + ) write_file( "copyright.txt", @@ -941,9 +935,10 @@ def no_match_copyright(): ) git_repo.index.add("copyright2.txt") git_repo.index.commit("Copy copyrighted file") - assert copyright.get_file_last_modified( - git_repo.head.commit, "copyright2.txt" - ) == expected_return_value(git_repo.head.commit.parents[0], "copyright.txt") + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "copyright2.txt") + == git_repo.head.commit + ) git_repo.index.remove("copyright2.txt", working_tree=True) write_file( @@ -971,9 +966,10 @@ def no_match_copyright(): ) git_repo.index.add("copyright2.txt") git_repo.index.commit("Copy and modify copyrighted file") - assert copyright.get_file_last_modified( - git_repo.head.commit, "copyright2.txt" - ) == expected_return_value(git_repo.head.commit, "copyright2.txt") + assert ( + copyright.get_file_last_modified(git_repo.head.commit, "copyright2.txt") + == git_repo.head.commit + ) def test_apply_batch_copyright_check(git_repo): From 0e823c55f1491c1de40687b03b3b315400301dc9 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 26 Jan 2024 15:22:34 -0500 Subject: [PATCH 24/24] Fix test --- test/rapids_pre_commit_hooks/test_copyright.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 1db8df5..6dad0ad 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -1101,14 +1101,9 @@ def write_file(filename, content): commit_date=datetime.datetime(2025, 2, 1, tzinfo=datetime.timezone.utc), ) - expected_linter = Linter("file.txt", CONTENT) - expected_linter.add_warning( - (45, 54), "copyright is not out of date and should not be updated" - ).add_replacement((31, 73), "Copyright (c) 2023-2024 NVIDIA CORPORATION") - linter = Linter("file.txt", CONTENT) copyright.apply_batch_copyright_check(git_repo, linter) - assert linter.warnings == expected_linter.warnings + assert linter.warnings == [] CONTENT = """ Beginning of copyrighted file