diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3045be5..c71b841 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -56,4 +56,7 @@ repos: [.]pre-commit-config[.]yaml$| [.]pre-commit-hooks[.]yaml$| pyproject[.]toml$ + exclude: | + (?x) + test_copyright[.]py$ args: [--fix, --main-branch=main] diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 232afe5..e93470b 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -231,18 +231,20 @@ def get_changed_files(args): changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()}) return changed_files - diffs = target_branch_upstream_commit.diff( - other=None, - merge_base=True, - find_copies=True, - find_copies_harder=True, - find_renames=True, - ) - for diff in diffs: - if diff.change_type == "A": - changed_files[diff.b_path] = None - elif diff.change_type != "D": - changed_files[diff.b_path] = diff.a_blob + for merge_base in repo.merge_base( + repo.head.commit, target_branch_upstream_commit, all=True + ): + diffs = merge_base.diff( + other=None, + find_copies=True, + find_copies_harder=True, + find_renames=True, + ) + for diff in diffs: + if diff.change_type == "A": + changed_files[diff.b_path] = None + elif diff.change_type != "D": + changed_files[diff.b_path] = diff.a_blob return changed_files diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 166d169..884b21c 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -649,6 +649,85 @@ def file_contents(verbed): assert changed_files[new].data_stream.read() == old_contents +def test_get_changed_files_multiple_merge_bases(git_repo): + def fn(filename): + return os.path.join(git_repo.working_tree_dir, filename) + + def write_file(filename, contents): + with open(fn(filename), "w") as f: + f.write(contents) + + write_file("file1.txt", "File 1\n") + write_file("file2.txt", "File 2\n") + write_file("file3.txt", "File 3\n") + git_repo.index.add(["file1.txt", "file2.txt", "file3.txt"]) + git_repo.index.commit("Initial commit") + + branch_1 = git_repo.create_head("branch-1", "master") + git_repo.head.reference = branch_1 + git_repo.index.reset(index=True, working_tree=True) + write_file("file1.txt", "File 1 modified\n") + git_repo.index.add("file1.txt") + git_repo.index.commit( + "Modify file1.txt", + commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), + ) + + branch_2 = git_repo.create_head("branch-2", "master") + git_repo.head.reference = branch_2 + git_repo.index.reset(index=True, working_tree=True) + write_file("file2.txt", "File 2 modified\n") + git_repo.index.add("file2.txt") + git_repo.index.commit( + "Modify file2.txt", + commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc), + ) + + branch_1_2 = git_repo.create_head("branch-1-2", "master") + git_repo.head.reference = branch_1_2 + git_repo.index.reset(index=True, working_tree=True) + write_file("file1.txt", "File 1 modified\n") + write_file("file2.txt", "File 2 modified\n") + git_repo.index.add(["file1.txt", "file2.txt"]) + git_repo.index.commit( + "Merge branches branch-1 and branch-2", + parent_commits=[branch_1.commit, branch_2.commit], + commit_date=datetime.datetime(2024, 3, 1, tzinfo=datetime.timezone.utc), + ) + + branch_3 = git_repo.create_head("branch-3", "master") + git_repo.head.reference = branch_3 + git_repo.index.reset(index=True, working_tree=True) + write_file("file1.txt", "File 1 modified\n") + write_file("file2.txt", "File 2 modified\n") + git_repo.index.add(["file1.txt", "file2.txt"]) + git_repo.index.commit( + "Merge branches branch-1 and branch-2", + parent_commits=[branch_1.commit, branch_2.commit], + commit_date=datetime.datetime(2024, 4, 1, tzinfo=datetime.timezone.utc), + ) + write_file("file3.txt", "File 3 modified\n") + git_repo.index.add("file3.txt") + git_repo.index.commit( + "Modify file3.txt", + commit_date=datetime.datetime(2024, 5, 1, tzinfo=datetime.timezone.utc), + ) + + with patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)), patch( + "rapids_pre_commit_hooks.copyright.get_target_branch", + Mock(return_value="branch-1-2"), + ): + changed_files = copyright.get_changed_files(None) + assert { + path: old_blob.path if old_blob else None + for path, old_blob in changed_files.items() + } == { + "file1.txt": "file1.txt", + "file2.txt": "file2.txt", + "file3.txt": "file3.txt", + } + + def test_normalize_git_filename(): assert copyright.normalize_git_filename("file.txt") == "file.txt" assert copyright.normalize_git_filename("sub/file.txt") == "sub/file.txt"