Skip to content

Commit

Permalink
move towards using residue_atom_lens for both packed and unpacked rep…
Browse files Browse the repository at this point in the history
…resentations for clarity. atom mask internally derived for atom transformer
  • Loading branch information
lucidrains committed May 25, 2024
1 parent 9ab2eb0 commit 1e1bc21
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 30 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ seq_len = 16
atom_seq_len = seq_len * 27

atom_inputs = torch.randn(2, atom_seq_len, 77)
atom_mask = torch.ones((2, atom_seq_len)).bool()
atom_lens = torch.randint(0, 27, (2, seq_len))
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
additional_residue_feats = torch.randn(2, seq_len, 33)

Expand All @@ -61,7 +61,7 @@ resolved_labels = torch.randint(0, 2, (2, seq_len))
loss = alphafold3(
num_recycling_steps = 2,
atom_inputs = atom_inputs,
atom_mask = atom_mask,
residue_atom_lens = atom_lens,
atompair_feats = atompair_feats,
additional_residue_feats = additional_residue_feats,
msa = msa,
Expand Down
41 changes: 23 additions & 18 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@

# constants

DIM_ADDITIONAL_RESIDUE_FEATS = 10

LinearNoBias = partial(Linear, bias = False)

# helper functions
Expand Down Expand Up @@ -1435,7 +1437,8 @@ def forward(
w = self.atoms_per_window
is_unpacked_repr = exists(w)

assert is_unpacked_repr ^ exists(residue_atom_lens), '`residue_atom_lens` must be passed in if using packed_atom_repr (atoms_per_window is None)'
if not is_unpacked_repr:
assert exists(residue_atom_lens), '`residue_atom_lens` must be passed in if using packed_atom_repr (atoms_per_window is None)'

atom_feats = self.proj(atom_feats)

Expand Down Expand Up @@ -1613,7 +1616,8 @@ def forward(
w = self.atoms_per_window
is_unpacked_repr = exists(w)

assert is_unpacked_repr ^ exists(residue_atom_lens)
if not is_unpacked_repr:
assert exists(residue_atom_lens)

# in the paper, it seems they pack the atom feats
# but in this impl, will just use windows for simplicity when communicating between atom and residue resolutions. bit less efficient
Expand Down Expand Up @@ -2350,7 +2354,6 @@ def __init__(
self,
*,
dim_atom_inputs,
dim_additional_residue_feats = 10,
atoms_per_window = 27,
dim_atom = 128,
dim_atompair = 16,
Expand Down Expand Up @@ -2396,9 +2399,7 @@ def __init__(
atoms_per_window = atoms_per_window
)

dim_single_input = dim_token + dim_additional_residue_feats

self.dim_additional_residue_feats = dim_additional_residue_feats
dim_single_input = dim_token + DIM_ADDITIONAL_RESIDUE_FEATS

self.single_input_to_single_init = LinearNoBias(dim_single_input, dim_single)
self.single_input_to_pairwise_init = LinearNoBiasThenOuterSum(dim_single_input, dim_pairwise)
Expand All @@ -2415,7 +2416,7 @@ def forward(

) -> EmbeddedInputs:

assert additional_residue_feats.shape[-1] == self.dim_additional_residue_feats
assert additional_residue_feats.shape[-1] == DIM_ADDITIONAL_RESIDUE_FEATS

w = self.atoms_per_window

Expand Down Expand Up @@ -2608,7 +2609,6 @@ def __init__(
self,
*,
dim_atom_inputs,
dim_additional_residue_feats,
dim_template_feats,
dim_template_model = 64,
atoms_per_window = 27,
Expand Down Expand Up @@ -2713,7 +2713,6 @@ def __init__(

self.input_embedder = InputFeatureEmbedder(
dim_atom_inputs = dim_atom_inputs,
dim_additional_residue_feats = dim_additional_residue_feats,
atoms_per_window = atoms_per_window,
dim_atom = dim_atom,
dim_atompair = dim_atompair,
Expand All @@ -2723,7 +2722,7 @@ def __init__(
**input_embedder_kwargs
)

dim_single_inputs = dim_input_embedder_token + dim_additional_residue_feats
dim_single_inputs = dim_input_embedder_token + DIM_ADDITIONAL_RESIDUE_FEATS

# relative positional encoding
# used by pairwise in main alphafold2 trunk
Expand Down Expand Up @@ -2866,22 +2865,28 @@ def forward(

atom_seq_len = atom_inputs.shape[-2]

assert exists(residue_atom_lens) or exists(atom_mask)

# determine whether using packed or unpacked atom rep

assert exists(residue_atom_lens) ^ exists(atom_mask), 'either atom_lens or atom_mask must be given depending on whether packed_atom_repr kwarg is True or False'
if self.packed_atom_repr:
assert exists(residue_atom_lens), 'residue_atom_lens must be given if using packed atom repr'

if exists(residue_atom_lens):
assert self.packed_atom_repr, '`packed_atom_repr` kwarg on Alphafold3 must be True when passing in `atom_lens`'

# handle atom mask
if self.packed_atom_repr:
# handle atom mask

total_atoms = residue_atom_lens.sum(dim = -1)
atom_mask = lens_to_mask(total_atoms, max_len = atom_seq_len)
total_atoms = residue_atom_lens.sum(dim = -1)
atom_mask = lens_to_mask(total_atoms, max_len = atom_seq_len)

# handle offsets for residue atom indices
# handle offsets for residue atom indices

if exists(residue_atom_indices):
residue_atom_indices += F.pad(residue_atom_lens, (-1, 1), value = 0)
if exists(residue_atom_indices):
residue_atom_indices += F.pad(residue_atom_lens, (-1, 1), value = 0)
else:
atom_mask = lens_to_mask(residue_atom_lens, max_len = self.atoms_per_window)
atom_mask = rearrange(atom_mask, 'b ... -> b (...)')

# get atom sequence length and residue sequence length depending on whether using packed atomic seq

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.34"
version = "0.0.36"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
14 changes: 5 additions & 9 deletions tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ def test_input_embedder():

embedder = InputFeatureEmbedder(
dim_atom_inputs = 77,
dim_additional_residue_feats = 10
)

embedder(
Expand All @@ -369,7 +368,7 @@ def test_alphafold3():
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()

atom_inputs = torch.randn(2, atom_seq_len, 77)
atom_mask = torch.ones((2, atom_seq_len)).bool()
atom_lens = torch.randint(0, 27, (2, seq_len))
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
additional_residue_feats = torch.randn(2, seq_len, 10)

Expand All @@ -389,7 +388,6 @@ def test_alphafold3():

alphafold3 = Alphafold3(
dim_atom_inputs = 77,
dim_additional_residue_feats = 10,
dim_template_feats = 44,
num_dist_bins = 38,
confidence_head_kwargs = dict(
Expand All @@ -414,7 +412,7 @@ def test_alphafold3():
loss, breakdown = alphafold3(
num_recycling_steps = 2,
atom_inputs = atom_inputs,
atom_mask = atom_mask,
residue_atom_lens = atom_lens,
atompair_feats = atompair_feats,
additional_residue_feats = additional_residue_feats,
token_bond = token_bond,
Expand All @@ -437,7 +435,7 @@ def test_alphafold3():
sampled_atom_pos = alphafold3(
num_sample_steps = 16,
atom_inputs = atom_inputs,
atom_mask = atom_mask,
residue_atom_lens = atom_lens,
atompair_feats = atompair_feats,
additional_residue_feats = additional_residue_feats,
msa = msa,
Expand All @@ -452,7 +450,7 @@ def test_alphafold3_without_msa_and_templates():
atom_seq_len = seq_len * 27

atom_inputs = torch.randn(2, atom_seq_len, 77)
atom_mask = torch.ones((2, atom_seq_len)).bool()
atom_lens = torch.randint(0, 27, (2, seq_len))
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
additional_residue_feats = torch.randn(2, seq_len, 10)

Expand All @@ -467,7 +465,6 @@ def test_alphafold3_without_msa_and_templates():

alphafold3 = Alphafold3(
dim_atom_inputs = 77,
dim_additional_residue_feats = 10,
dim_template_feats = 44,
num_dist_bins = 38,
confidence_head_kwargs = dict(
Expand All @@ -492,7 +489,7 @@ def test_alphafold3_without_msa_and_templates():
loss, breakdown = alphafold3(
num_recycling_steps = 2,
atom_inputs = atom_inputs,
atom_mask = atom_mask,
residue_atom_lens = atom_lens,
atompair_feats = atompair_feats,
additional_residue_feats = additional_residue_feats,
atom_pos = atom_pos,
Expand Down Expand Up @@ -536,7 +533,6 @@ def test_alphafold3_with_packed_atom_repr():

alphafold3 = Alphafold3(
dim_atom_inputs = 77,
dim_additional_residue_feats = 10,
dim_template_feats = 44,
num_dist_bins = 38,
packed_atom_repr = True,
Expand Down

0 comments on commit 1e1bc21

Please sign in to comment.