Skip to content

Commit

Permalink
complete the msa module
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 14, 2024
1 parent 218477e commit f35ac78
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 11 deletions.
2 changes: 2 additions & 0 deletions alphafold3_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AttentionPairBias,
TriangleAttention,
Transition,
MSAModule,
PairformerStack,
Alphafold3
)
Expand All @@ -29,6 +30,7 @@
AttentionPairBias,
TriangleAttention,
Transition,
MSAModule,
PairformerStack,
Alphafold3
]
140 changes: 129 additions & 11 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ def __init__(
dim_msa = 64,
dim_pairwise_repr = 128,
dim_head = 32,
heads = 8
heads = 8,
dropout = 0.
):
super().__init__()
dim_inner = dim_head * heads
Expand All @@ -462,7 +463,8 @@ def __init__(

self.to_out = nn.Sequential(
Rearrange('b h s n d -> b s n (h d)'),
LinearNoBias(dim_inner, dim_msa)
LinearNoBias(dim_inner, dim_msa),
Dropout(dropout)
)

@typecheck
Expand Down Expand Up @@ -504,21 +506,129 @@ def __init__(
self,
*,
dim_single,
dim_pairwise = 128,
depth = 4,
dim_msa = 64,
depth = 4
dim_msa_input = None,
dim_pairwise = 128,
outer_product_mean_dim_hidden = 32,
dropout_row_prob = 0.25,
dropout_col_prob = 0.25,
tri_mult_dim_hidden = None,
tri_attn_dim_head = 32,
tri_attn_heads = 4,
msa_pwa_dropout_row_prob = 0.15,
msa_pwa_heads = 8,
msa_pwa_dim_head = 32,
):
super().__init__()
raise NotImplementedError

self.msa_init_proj = LinearNoBias(dim_msa_input, dim_msa) if exists(dim_msa_input) else nn.Identity()

self.single_to_msa_feats = LinearNoBias(dim_single, dim_msa)

layers = ModuleList([])

tri_mult_kwargs = dict(
dim = dim_pairwise,
dim_hidden = tri_mult_dim_hidden
)

tri_attn_kwargs = dict(
dim = dim_pairwise,
heads = tri_attn_heads,
dim_head = tri_attn_dim_head
)

for _ in range(depth):

pairwise_pre_ln = partial(PreLayerNorm, dim = dim_pairwise)
msa_pre_ln = partial(PreLayerNorm, dim = dim_msa)

outer_product_mean = OuterProductMean(
dim_msa = dim_msa,
dim_pairwise_repr = dim_pairwise,
dim_hidden = outer_product_mean_dim_hidden
)

msa_pair_weighted_avg = MSAPairWeightedAveraging(
dim_msa = dim_msa,
dim_pairwise_repr = dim_pairwise,
heads = msa_pwa_heads,
dim_head = msa_pwa_dim_head
)

msa_transition = Transition(dim = dim_msa)

tri_mult_outgoing = TriangleMultiplication(mix = 'outgoing', dropout = dropout_row_prob, **tri_mult_kwargs)
tri_mult_incoming = TriangleMultiplication(mix = 'incoming', dropout = dropout_row_prob, **tri_mult_kwargs)
tri_attn_starting = TriangleAttention(node_type = 'starting', dropout = dropout_row_prob, **tri_attn_kwargs)
tri_attn_ending = TriangleAttention(node_type = 'ending', dropout = dropout_col_prob, **tri_attn_kwargs)
pairwise_transition = Transition(dim = dim_pairwise)

layers.append(ModuleList([
outer_product_mean,
msa_pair_weighted_avg,
msa_pre_ln(msa_transition),
pairwise_pre_ln(tri_mult_outgoing),
pairwise_pre_ln(tri_mult_incoming),
pairwise_pre_ln(tri_attn_starting),
pairwise_pre_ln(tri_attn_ending),
pairwise_pre_ln(pairwise_transition),
]))

self.layers = layers

@typecheck
def forward(
self,
*,
single_repr,
pairwise_repr,
msa
):
raise NotImplementedError
single_repr: Float['b n ds'],
pairwise_repr: Float['b n n dp'],
msa: Float['b s n dm'],
mask: Bool['b n'] | None = None,
msa_mask: Bool['b s'] | None = None,
) -> Float['b n n dp']:


msa = self.msa_init_proj(msa)

single_msa_feats = self.single_to_msa_feats(single_repr)

msa = rearrange(single_msa_feats, 'b n d -> b 1 n d') + msa

for (
outer_product_mean,
msa_pair_weighted_avg,
msa_transition,
tri_mult_outgoing,
tri_mult_incoming,
tri_attn_starting,
tri_attn_ending,
pairwise_transition
)in self.layers:

# communication between msa and pairwise rep

pairwise_repr = outer_product_mean(msa, mask = mask, msa_mask = msa_mask) + pairwise_repr

msa = msa_pair_weighted_avg(msa = msa, pairwise_repr = pairwise_repr, mask = mask) + msa

msa, msa_packed_shape = pack_one(msa, 'b * d')
msa = msa_transition(msa) + msa
msa = unpack_one(msa, msa_packed_shape, 'b * d')

# pairwise tri mult + attn + transition

pairwise_repr = tri_mult_outgoing(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_mult_incoming(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_attn_starting(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_attn_ending(pairwise_repr, mask = mask) + pairwise_repr

pairwise_repr, packed_shape = pack_one(pairwise_repr, 'b * d')
pairwise_repr = pairwise_transition(pairwise_repr) + pairwise_repr
pairwise_repr = unpack_one(pairwise_repr, packed_shape, 'b * d')

return pairwise_repr

# pairformer stack

Expand Down Expand Up @@ -596,7 +706,15 @@ def forward(

) -> Tuple[Float['b n ds'], Float['b n n dp']]:

for tri_mult_outgoing, tri_mult_incoming, tri_attn_starting, tri_attn_ending, pairwise_transition, pair_bias_attn, single_transition in self.layers:
for (
tri_mult_outgoing,
tri_mult_incoming,
tri_attn_starting,
tri_attn_ending,
pairwise_transition,
pair_bias_attn,
single_transition
) in self.layers:

pairwise_repr = tri_mult_outgoing(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_mult_incoming(pairwise_repr, mask = mask) + pairwise_repr
Expand Down

0 comments on commit f35ac78

Please sign in to comment.