Skip to content

Commit

Permalink
some cleanup to trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 28, 2024
1 parent c59b6ca commit 340dd87
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
22 changes: 16 additions & 6 deletions alphafold3_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,26 @@ def __init__(

self.clip_grad_norm = clip_grad_norm

# steps

self.steps = 0

@property
def is_main(self):
return self.fabric.global_rank == 0

def print(self, *args, **kwargs):
self.fabric.print(*args, **kwargs)

def log(self, **log_data):
self.fabric.log_dict(log_data, step = self.steps)

def __call__(
self
):
dl = iter(self.dataloader)

steps = 0
dl = cycle(self.dataloader)

while steps < self.num_train_steps:
while self.steps < self.num_train_steps:

for grad_accum_step in range(self.grad_accum_every):
is_accumulating = grad_accum_step < (self.grad_accum_every - 1)
Expand All @@ -169,7 +177,9 @@ def __call__(

self.fabric.backward(loss / self.grad_accum_every)

print(f'loss: {loss.item():.3f}')
self.log(loss = loss)

self.print(f'loss: {loss.item():.3f}')

self.fabric.clip_gradients(self.model, self.optimizer, max_norm = self.clip_grad_norm)

Expand All @@ -181,6 +191,6 @@ def __call__(
self.scheduler.step()
self.optimizer.zero_grad()

steps += 1
self.steps += 1

print(f'training complete')
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.44"
version = "0.0.45"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 340dd87

Please sign in to comment.