Skip to content

Commit

Permalink
Use local branch if newer than remote branch (#33)
Browse files Browse the repository at this point in the history
* Use local branch if newer than remote branch

If upstream is a user's fork that doesn't get updated, we want to
instead use the local branch, which is more likely to be up to date.

* Try same-name remote branches even if local branch exists

* Add explanation comments, simplify logic
  • Loading branch information
KyleFromNVIDIA authored May 16, 2024
1 parent b467a5e commit 9bec08c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 10 deletions.
31 changes: 23 additions & 8 deletions src/rapids_pre_commit_hooks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,30 @@ def get_target_branch(repo, args):


def get_target_branch_upstream_commit(repo, args):
# If no target branch can be determined, use HEAD if it exists
target_branch_name = get_target_branch(repo, args)
if target_branch_name is None:
try:
return repo.head.commit
except ValueError:
return None

commits_to_try = []

try:
target_branch = repo.heads[target_branch_name]
except IndexError:
target_branch = None
if target_branch:
target_branch_upstream = target_branch.tracking_branch()
if target_branch_upstream:
return target_branch_upstream.commit
pass
else:
# Try the branch specified by the branch name
commits_to_try.append(target_branch.commit)

# If the branch has an upstream, try it and exit
if target_branch_upstream := target_branch.tracking_branch():
return max(
[target_branch.commit, target_branch_upstream.commit],
key=lambda commit: commit.committed_datetime,
)

def try_get_ref(remote):
try:
Expand All @@ -200,15 +209,21 @@ def try_get_ref(remote):
return None

try:
return max(
# Try branches in all remotes that have the branch name
upstream_commit = max(
(upstream for remote in repo.remotes if (upstream := try_get_ref(remote))),
key=lambda upstream: upstream.commit.committed_datetime,
).commit
except ValueError:
pass
else:
commits_to_try.append(upstream_commit)

if commits_to_try:
return max(commits_to_try, key=lambda commit: commit.committed_datetime)

if target_branch:
return target_branch.commit
# No branch with the specified name, local or remote, can be found, so return HEAD
# if it exists
try:
return repo.head.commit
except ValueError:
Expand Down
43 changes: 41 additions & 2 deletions test/rapids_pre_commit_hooks/test_copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def mock_target_branch(branch):
write_file(remote_repo_1, "file4.txt", "File 4")
write_file(remote_repo_1, "file5.txt", "File 5")
write_file(remote_repo_1, "file6.txt", "File 6")
write_file(remote_repo_1, "file7.txt", "File 7")
remote_repo_1.index.add(
[
"file1.txt",
Expand All @@ -291,6 +292,7 @@ def mock_target_branch(branch):
"file4.txt",
"file5.txt",
"file6.txt",
"file7.txt",
]
)
remote_repo_1.index.commit("Initial commit")
Expand All @@ -302,7 +304,10 @@ def mock_target_branch(branch):
remote_repo_1.head.reset(index=True, working_tree=True)
write_file(remote_repo_1, "file1.txt", "File 1 modified")
remote_repo_1.index.add(["file1.txt"])
remote_repo_1.index.commit("Update file1.txt")
remote_repo_1.index.commit(
"Update file1.txt",
commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc),
)

remote_1_branch_2 = remote_repo_1.create_head(
"branch-2", remote_1_master.commit
Expand Down Expand Up @@ -337,6 +342,18 @@ def mock_target_branch(branch):
commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc),
)

remote_1_branch_7 = remote_repo_1.create_head(
"branch-7", remote_1_master.commit
)
remote_repo_1.head.reference = remote_1_branch_7
remote_repo_1.head.reset(index=True, working_tree=True)
write_file(remote_repo_1, "file7.txt", "File 7 modified")
remote_repo_1.index.add(["file7.txt"])
remote_repo_1.index.commit(
"Update file7.txt",
commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc),
)

remote_2_1 = remote_repo_2.create_remote("remote-1", remote_dir_1)
remote_2_1.fetch(["master"])
remote_2_master = remote_repo_2.create_head("master", remote_2_1.refs["master"])
Expand Down Expand Up @@ -387,6 +404,7 @@ def mock_target_branch(branch):
"branch-2",
"branch-3",
"branch-4",
"branch-7",
])
remote_2 = git_repo.create_remote("unconventional/remote/name/2", remote_dir_2)
remote_2.fetch(["branch-3", "branch-4", "branch-5"])
Expand All @@ -400,14 +418,29 @@ def mock_target_branch(branch):
git_repo.head.reference = branch_1
git_repo.head.reset(index=True, working_tree=True)
git_repo.index.remove("file1.txt", working_tree=True)
git_repo.index.commit("Remove file1.txt")
git_repo.index.commit(
"Remove file1.txt",
commit_date=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc),
)

branch_6 = git_repo.create_head("branch-6", remote_1.refs["master"])
git_repo.head.reference = branch_6
git_repo.head.reset(index=True, working_tree=True)
git_repo.index.remove(["file6.txt"], working_tree=True)
git_repo.index.commit("Remove file6.txt")

branch_7 = git_repo.create_head("branch-7", remote_1.refs["master"])
with branch_7.config_writer() as w:
w.set_value("remote", "unconventional/remote/name/1")
w.set_value("merge", "branch-7")
git_repo.head.reference = branch_7
git_repo.head.reset(index=True, working_tree=True)
git_repo.index.remove(["file7.txt"], working_tree=True)
git_repo.index.commit(
"Remove file7.txt",
commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc),
)

git_repo.head.reference = main
git_repo.head.reset(index=True, working_tree=True)

Expand Down Expand Up @@ -448,6 +481,12 @@ def mock_target_branch(branch):
)

with mock_target_branch("branch-7"):
assert (
copyright.get_target_branch_upstream_commit(git_repo, None)
== branch_7.commit
)

with mock_target_branch("nonexistent-branch"):
assert (
copyright.get_target_branch_upstream_commit(git_repo, None)
== main.commit
Expand Down

0 comments on commit 9bec08c

Please sign in to comment.