From 9964d9383070d706d39d8d3fd6949c901970b8aa Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Thu, 16 May 2024 15:46:20 -0400 Subject: [PATCH 1/3] Use local branch if newer than remote branch If upstream is a user's fork that doesn't get updated, we want to instead use the local branch, which is more likely to be up to date. --- src/rapids_pre_commit_hooks/copyright.py | 44 +++++++++++-------- .../rapids_pre_commit_hooks/test_copyright.py | 43 +++++++++++++++++- 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 64b56f8..97d1260 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -184,31 +184,39 @@ def get_target_branch_upstream_commit(repo, args): except ValueError: return None + commits_to_try = [] try: target_branch = repo.heads[target_branch_name] except IndexError: - target_branch = None - if target_branch: - target_branch_upstream = target_branch.tracking_branch() - if target_branch_upstream: - return target_branch_upstream.commit - def try_get_ref(remote): + def try_get_ref(remote): + try: + return remote.refs[target_branch_name] + except IndexError: + return None + try: - return remote.refs[target_branch_name] - except IndexError: - return None + upstream_commit = max( + ( + upstream + for remote in repo.remotes + if (upstream := try_get_ref(remote)) + ), + key=lambda upstream: upstream.commit.committed_datetime, + ).commit + except ValueError: + pass + else: + commits_to_try.append(upstream_commit) + else: + commits_to_try.append(target_branch.commit) + target_branch_upstream = target_branch.tracking_branch() + if target_branch_upstream: + commits_to_try.append(target_branch_upstream.commit) - try: - return max( - (upstream for remote in repo.remotes if (upstream := try_get_ref(remote))), - key=lambda upstream: upstream.commit.committed_datetime, - ).commit - except ValueError: - pass + if commits_to_try: + return max(commits_to_try, key=lambda commit: commit.committed_datetime) - if target_branch: - return target_branch.commit try: return repo.head.commit except ValueError: diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 884b21c..50e118b 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -283,6 +283,7 @@ def mock_target_branch(branch): write_file(remote_repo_1, "file4.txt", "File 4") write_file(remote_repo_1, "file5.txt", "File 5") write_file(remote_repo_1, "file6.txt", "File 6") + write_file(remote_repo_1, "file7.txt", "File 7") remote_repo_1.index.add( [ "file1.txt", @@ -291,6 +292,7 @@ def mock_target_branch(branch): "file4.txt", "file5.txt", "file6.txt", + "file7.txt", ] ) remote_repo_1.index.commit("Initial commit") @@ -302,7 +304,10 @@ def mock_target_branch(branch): remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file1.txt", "File 1 modified") remote_repo_1.index.add(["file1.txt"]) - remote_repo_1.index.commit("Update file1.txt") + remote_repo_1.index.commit( + "Update file1.txt", + commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc), + ) remote_1_branch_2 = remote_repo_1.create_head( "branch-2", remote_1_master.commit @@ -337,6 +342,18 @@ def mock_target_branch(branch): commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), ) + remote_1_branch_7 = remote_repo_1.create_head( + "branch-7", remote_1_master.commit + ) + remote_repo_1.head.reference = remote_1_branch_7 + remote_repo_1.head.reset(index=True, working_tree=True) + write_file(remote_repo_1, "file7.txt", "File 7 modified") + remote_repo_1.index.add(["file7.txt"]) + remote_repo_1.index.commit( + "Update file7.txt", + commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), + ) + remote_2_1 = remote_repo_2.create_remote("remote-1", remote_dir_1) remote_2_1.fetch(["master"]) remote_2_master = remote_repo_2.create_head("master", remote_2_1.refs["master"]) @@ -387,6 +404,7 @@ def mock_target_branch(branch): "branch-2", "branch-3", "branch-4", + "branch-7", ]) remote_2 = git_repo.create_remote("unconventional/remote/name/2", remote_dir_2) remote_2.fetch(["branch-3", "branch-4", "branch-5"]) @@ -400,7 +418,10 @@ def mock_target_branch(branch): git_repo.head.reference = branch_1 git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove("file1.txt", working_tree=True) - git_repo.index.commit("Remove file1.txt") + git_repo.index.commit( + "Remove file1.txt", + commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc), + ) branch_6 = git_repo.create_head("branch-6", remote_1.refs["master"]) git_repo.head.reference = branch_6 @@ -408,6 +429,18 @@ def mock_target_branch(branch): git_repo.index.remove(["file6.txt"], working_tree=True) git_repo.index.commit("Remove file6.txt") + branch_7 = git_repo.create_head("branch-7", remote_1.refs["master"]) + with branch_7.config_writer() as w: + w.set_value("remote", "unconventional/remote/name/1") + w.set_value("merge", "branch-7") + git_repo.head.reference = branch_7 + git_repo.head.reset(index=True, working_tree=True) + git_repo.index.remove(["file7.txt"], working_tree=True) + git_repo.index.commit( + "Remove file7.txt", + commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc), + ) + git_repo.head.reference = main git_repo.head.reset(index=True, working_tree=True) @@ -448,6 +481,12 @@ def mock_target_branch(branch): ) with mock_target_branch("branch-7"): + assert ( + copyright.get_target_branch_upstream_commit(git_repo, None) + == branch_7.commit + ) + + with mock_target_branch("nonexistent-branch"): assert ( copyright.get_target_branch_upstream_commit(git_repo, None) == main.commit From 9c8e12517712b97265a49f09753f17577abb381c Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Thu, 16 May 2024 16:31:40 -0400 Subject: [PATCH 2/3] Try same-name remote branches even if local branch exists --- src/rapids_pre_commit_hooks/copyright.py | 41 +++++++++++------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 97d1260..d03b353 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -188,31 +188,28 @@ def get_target_branch_upstream_commit(repo, args): try: target_branch = repo.heads[target_branch_name] except IndexError: - - def try_get_ref(remote): - try: - return remote.refs[target_branch_name] - except IndexError: - return None - - try: - upstream_commit = max( - ( - upstream - for remote in repo.remotes - if (upstream := try_get_ref(remote)) - ), - key=lambda upstream: upstream.commit.committed_datetime, - ).commit - except ValueError: - pass - else: - commits_to_try.append(upstream_commit) + pass else: commits_to_try.append(target_branch.commit) - target_branch_upstream = target_branch.tracking_branch() - if target_branch_upstream: + if target_branch_upstream := target_branch.tracking_branch(): commits_to_try.append(target_branch_upstream.commit) + return max(commits_to_try, key=lambda commit: commit.committed_datetime) + + def try_get_ref(remote): + try: + return remote.refs[target_branch_name] + except IndexError: + return None + + try: + upstream_commit = max( + (upstream for remote in repo.remotes if (upstream := try_get_ref(remote))), + key=lambda upstream: upstream.commit.committed_datetime, + ).commit + except ValueError: + pass + else: + commits_to_try.append(upstream_commit) if commits_to_try: return max(commits_to_try, key=lambda commit: commit.committed_datetime) From bbcd14da7b82f0a108e9170d467ca6204ff94613 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Thu, 16 May 2024 16:44:45 -0400 Subject: [PATCH 3/3] Add explanation comments, simplify logic --- src/rapids_pre_commit_hooks/copyright.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index d03b353..8e182f0 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -177,6 +177,7 @@ def get_target_branch(repo, args): def get_target_branch_upstream_commit(repo, args): + # If no target branch can be determined, use HEAD if it exists target_branch_name = get_target_branch(repo, args) if target_branch_name is None: try: @@ -185,15 +186,21 @@ def get_target_branch_upstream_commit(repo, args): return None commits_to_try = [] + try: target_branch = repo.heads[target_branch_name] except IndexError: pass else: + # Try the branch specified by the branch name commits_to_try.append(target_branch.commit) + + # If the branch has an upstream, try it and exit if target_branch_upstream := target_branch.tracking_branch(): - commits_to_try.append(target_branch_upstream.commit) - return max(commits_to_try, key=lambda commit: commit.committed_datetime) + return max( + [target_branch.commit, target_branch_upstream.commit], + key=lambda commit: commit.committed_datetime, + ) def try_get_ref(remote): try: @@ -202,6 +209,7 @@ def try_get_ref(remote): return None try: + # Try branches in all remotes that have the branch name upstream_commit = max( (upstream for remote in repo.remotes if (upstream := try_get_ref(remote))), key=lambda upstream: upstream.commit.committed_datetime, @@ -214,6 +222,8 @@ def try_get_ref(remote): if commits_to_try: return max(commits_to_try, key=lambda commit: commit.committed_datetime) + # No branch with the specified name, local or remote, can be found, so return HEAD + # if it exists try: return repo.head.commit except ValueError: