diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 64b56f8..8e182f0 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -177,6 +177,7 @@ def get_target_branch(repo, args): def get_target_branch_upstream_commit(repo, args): + # If no target branch can be determined, use HEAD if it exists target_branch_name = get_target_branch(repo, args) if target_branch_name is None: try: @@ -184,14 +185,22 @@ def get_target_branch_upstream_commit(repo, args): except ValueError: return None + commits_to_try = [] + try: target_branch = repo.heads[target_branch_name] except IndexError: - target_branch = None - if target_branch: - target_branch_upstream = target_branch.tracking_branch() - if target_branch_upstream: - return target_branch_upstream.commit + pass + else: + # Try the branch specified by the branch name + commits_to_try.append(target_branch.commit) + + # If the branch has an upstream, try it and exit + if target_branch_upstream := target_branch.tracking_branch(): + return max( + [target_branch.commit, target_branch_upstream.commit], + key=lambda commit: commit.committed_datetime, + ) def try_get_ref(remote): try: @@ -200,15 +209,21 @@ def try_get_ref(remote): return None try: - return max( + # Try branches in all remotes that have the branch name + upstream_commit = max( (upstream for remote in repo.remotes if (upstream := try_get_ref(remote))), key=lambda upstream: upstream.commit.committed_datetime, ).commit except ValueError: pass + else: + commits_to_try.append(upstream_commit) + + if commits_to_try: + return max(commits_to_try, key=lambda commit: commit.committed_datetime) - if target_branch: - return target_branch.commit + # No branch with the specified name, local or remote, can be found, so return HEAD + # if it exists try: return repo.head.commit except ValueError: diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 884b21c..50e118b 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -283,6 +283,7 @@ def mock_target_branch(branch): write_file(remote_repo_1, "file4.txt", "File 4") write_file(remote_repo_1, "file5.txt", "File 5") write_file(remote_repo_1, "file6.txt", "File 6") + write_file(remote_repo_1, "file7.txt", "File 7") remote_repo_1.index.add( [ "file1.txt", @@ -291,6 +292,7 @@ def mock_target_branch(branch): "file4.txt", "file5.txt", "file6.txt", + "file7.txt", ] ) remote_repo_1.index.commit("Initial commit") @@ -302,7 +304,10 @@ def mock_target_branch(branch): remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file1.txt", "File 1 modified") remote_repo_1.index.add(["file1.txt"]) - remote_repo_1.index.commit("Update file1.txt") + remote_repo_1.index.commit( + "Update file1.txt", + commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc), + ) remote_1_branch_2 = remote_repo_1.create_head( "branch-2", remote_1_master.commit @@ -337,6 +342,18 @@ def mock_target_branch(branch): commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), ) + remote_1_branch_7 = remote_repo_1.create_head( + "branch-7", remote_1_master.commit + ) + remote_repo_1.head.reference = remote_1_branch_7 + remote_repo_1.head.reset(index=True, working_tree=True) + write_file(remote_repo_1, "file7.txt", "File 7 modified") + remote_repo_1.index.add(["file7.txt"]) + remote_repo_1.index.commit( + "Update file7.txt", + commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), + ) + remote_2_1 = remote_repo_2.create_remote("remote-1", remote_dir_1) remote_2_1.fetch(["master"]) remote_2_master = remote_repo_2.create_head("master", remote_2_1.refs["master"]) @@ -387,6 +404,7 @@ def mock_target_branch(branch): "branch-2", "branch-3", "branch-4", + "branch-7", ]) remote_2 = git_repo.create_remote("unconventional/remote/name/2", remote_dir_2) remote_2.fetch(["branch-3", "branch-4", "branch-5"]) @@ -400,7 +418,10 @@ def mock_target_branch(branch): git_repo.head.reference = branch_1 git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove("file1.txt", working_tree=True) - git_repo.index.commit("Remove file1.txt") + git_repo.index.commit( + "Remove file1.txt", + commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), + ) branch_6 = git_repo.create_head("branch-6", remote_1.refs["master"]) git_repo.head.reference = branch_6 @@ -408,6 +429,18 @@ def mock_target_branch(branch): git_repo.index.remove(["file6.txt"], working_tree=True) git_repo.index.commit("Remove file6.txt") + branch_7 = git_repo.create_head("branch-7", remote_1.refs["master"]) + with branch_7.config_writer() as w: + w.set_value("remote", "unconventional/remote/name/1") + w.set_value("merge", "branch-7") + git_repo.head.reference = branch_7 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove(["file7.txt"], working_tree=True) + git_repo.index.commit( + "Remove file7.txt", + commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc), + ) + git_repo.head.reference = main git_repo.head.reset(index=True, working_tree=True) @@ -448,6 +481,12 @@ def mock_target_branch(branch): ) with mock_target_branch("branch-7"): + assert ( + copyright.get_target_branch_upstream_commit(git_repo, None) + == branch_7.commit + ) + + with mock_target_branch("nonexistent-branch"): assert ( copyright.get_target_branch_upstream_commit(git_repo, None) == main.commit