Skip to content

Commit

Permalink
clip gradients as in paper and also make sure not to do gradient sync…
Browse files Browse the repository at this point in the history
… until last step
  • Loading branch information
lucidrains committed May 25, 2024
1 parent 422efba commit 936210c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 8 additions & 2 deletions alphafold3_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,21 @@ def __call__(
steps = 0

while steps < self.num_train_steps:
for _ in range(self.grad_accum_every):

for grad_accum_step in range(self.grad_accum_every):
is_accumulating = grad_accum_step < (self.grad_accum_every - 1)

inputs = next(dl)

loss = self.model(**inputs)
with self.fabric.no_backward_sync(self.model, enabled = is_accumulating):
loss = self.model(**inputs)

self.fabric.backward(loss / self.grad_accum_every)

print(f'loss: {loss.item():.3f}')

self.fabric.clip_gradients(self.model, self.optimizer, max_norm = self.clip_grad_norm)

self.optimizer.step()

if self.is_main:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.39"
version = "0.0.40"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 936210c

Please sign in to comment.