Skip to content

Commit

Permalink
make one training step work with fabric + alphafold3 on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 25, 2024
1 parent 4532883 commit 422efba
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<img src="./alphafold3.png" width="450px"></img>
<img src="./alphafold3.png" width="500px"></img>

## Alphafold 3 - Pytorch (wip)

Expand Down
4 changes: 3 additions & 1 deletion alphafold3_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
)

from alphafold3_pytorch.trainer import (
Trainer
Trainer,
Alphafold3Input
)

__all__ = [
Expand Down Expand Up @@ -63,5 +64,6 @@
ConfidenceHead,
DistogramHead,
Alphafold3,
Alphafold3Input,
Trainer
]
5 changes: 3 additions & 2 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2853,7 +2853,8 @@ def forward(
pde_labels: Int['b n n'] | None = None,
plddt_labels: Int['b n'] | None = None,
resolved_labels: Int['b n'] | None = None,
return_loss_breakdown = False
return_loss_breakdown = False,
return_loss_if_possible: bool = True
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:

atom_seq_len = atom_inputs.shape[-2]
Expand Down Expand Up @@ -3016,7 +3017,7 @@ def forward(

# if neither atom positions or any labels are passed in, sample a structure and return

if not return_loss:
if not return_loss_if_possible or not return_loss:
return self.edm.sample(
num_sample_steps = num_sample_steps,
atom_feats = atom_feats,
Expand Down
88 changes: 80 additions & 8 deletions alphafold3_pytorch/trainer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,41 @@
from __future__ import annotations

from alphafold3_pytorch.alphafold3 import Alphafold3
from alphafold3_pytorch.typing import typecheck

from typing import TypedDict
from alphafold3_pytorch.typing import (
typecheck,
Int, Bool, Float
)

import torch
from torch.optim import Adam, Optimizer
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR, LRScheduler

from ema_pytorch import EMA

from lightning import Fabric

# constants

@typecheck
class Alphafold3Input(TypedDict):
atom_inputs: Float['m dai']
residue_atom_lens: Int['n 2']
atompair_feats: Float['m m dap']
additional_residue_feats: Float['n 10']
templates: Float['t n n dt']
template_mask: Bool['t'] | None
msa: Float['s n dm']
msa_mask: Bool['s'] | None
atom_pos: Float['m 3'] | None
residue_atom_indices: Int['n'] | None
distance_labels: Int['n n'] | None
pae_labels: Int['n n'] | None
pde_labels: Int['n'] | None
resolved_labels: Int['n'] | None

# helpers

def exists(val):
Expand All @@ -19,6 +44,11 @@ def exists(val):
def default(v, d):
return v if exists(v) else d

def cycle(dataloader: DataLoader):
while True:
for batch in dataloader:
yield batch

def default_lambda_lr_fn(steps):
# 1000 step warmup

Expand All @@ -40,6 +70,10 @@ def __init__(
self,
model: Alphafold3,
*,
dataset: Dataset,
num_train_steps: int,
batch_size: int,
grad_accum_every: int = 1,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
ema_decay = 0.999,
Expand Down Expand Up @@ -69,12 +103,13 @@ def __init__(

# exponential moving average

self.ema_model = EMA(
model,
beta = ema_decay,
include_online_model = False,
**ema_kwargs
)
if self.is_main:
self.ema_model = EMA(
model,
beta = ema_decay,
include_online_model = False,
**ema_kwargs
)

# optimizer

Expand All @@ -87,10 +122,19 @@ def __init__(

self.optimizer = optimizer

# data

self.dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)

self.num_train_steps = num_train_steps
self.grad_accum_every = grad_accum_every

# setup fabric

self.model, self.optimizer = fabric.setup(self.model, self.optimizer)

fabric.setup_dataloaders(self.dataloader)

# scheduler

if not exists(scheduler):
Expand All @@ -102,7 +146,35 @@ def __init__(

self.clip_grad_norm = clip_grad_norm

@property
def is_main(self):
return self.fabric.global_rank == 0

def __call__(
self
):
pass
dl = iter(self.dataloader)

steps = 0

while steps < self.num_train_steps:
for _ in range(self.grad_accum_every):
inputs = next(dl)

loss = self.model(**inputs)

self.fabric.backward(loss / self.grad_accum_every)

print(f'loss: {loss.item():.3f}')

self.optimizer.step()

if self.is_main:
self.ema_model.update()

self.scheduler.step()
self.optimizer.zero_grad()

steps += 1

print(f'training complete')
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.38"
version = "0.0.39"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
73 changes: 71 additions & 2 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,74 @@
import os
os.environ['TYPECHECK'] = 'True'

import torch
import pytest
import torch
from torch.utils.data import Dataset

from alphafold3_pytorch import (
Alphafold3,
Alphafold3Input,
Trainer
)

# mock dataset

class AtomDataset(Dataset):
def __init__(
self,
seq_len = 16,
atoms_per_window = 27
):
self.seq_len = seq_len
self.atom_seq_len = seq_len * atoms_per_window

def __len__(self):
return 100

def __getitem__(self, idx):
seq_len = self.seq_len
atom_seq_len = self.atom_seq_len

atom_inputs = torch.randn(atom_seq_len, 77)
residue_atom_lens = torch.randint(0, 27, (seq_len,))
atompair_feats = torch.randn(atom_seq_len, atom_seq_len, 16)
additional_residue_feats = torch.randn(seq_len, 10)

templates = torch.randn(2, seq_len, seq_len, 44)
template_mask = torch.ones((2,)).bool()

msa = torch.randn(7, seq_len, 64)
msa_mask = torch.ones((7,)).bool()

# required for training, but omitted on inference

atom_pos = torch.randn(atom_seq_len, 3)
residue_atom_indices = torch.randint(0, 27, (seq_len,))

distance_labels = torch.randint(0, 37, (seq_len, seq_len))
pae_labels = torch.randint(0, 64, (seq_len, seq_len))
pde_labels = torch.randint(0, 64, (seq_len, seq_len))
plddt_labels = torch.randint(0, 50, (seq_len,))
resolved_labels = torch.randint(0, 2, (seq_len,))

return Alphafold3Input(
atom_inputs = atom_inputs,
residue_atom_lens = residue_atom_lens,
atompair_feats = atompair_feats,
additional_residue_feats = additional_residue_feats,
templates = templates,
template_mask = template_mask,
msa = msa,
msa_mask = msa_mask,
atom_pos = atom_pos,
residue_atom_indices = residue_atom_indices,
distance_labels = distance_labels,
pae_labels = pae_labels,
pde_labels = pde_labels,
plddt_labels = plddt_labels,
resolved_labels = resolved_labels
)

def test_trainer():
alphafold3 = Alphafold3(
dim_atom_inputs = 77,
Expand All @@ -33,6 +93,15 @@ def test_trainer():
),
)

trainer = Trainer(alphafold3)
dataset = AtomDataset()

trainer = Trainer(
alphafold3,
dataset = dataset,
accelerator = 'cpu',
num_train_steps = 2,
batch_size = 1,
grad_accum_every = 2
)

trainer()

0 comments on commit 422efba

Please sign in to comment.