Skip to content

Commit

Permalink
refactor out a common block shared between the pairformer and msa module
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 15, 2024
1 parent 9694800 commit 7a44edb
Showing 1 changed file with 78 additions and 85 deletions.
163 changes: 78 additions & 85 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,57 @@ def forward(

return self.dropout(out)

# PairwiseBlock
# used in both MSAModule and Pairformer
# consists of all the "Triangle" modules + Transition

class PairwiseBlock(Module):
def __init__(
self,
*,
dim_pairwise = 128,
tri_mult_dim_hidden = None,
tri_attn_dim_head = 32,
tri_attn_heads = 4,
dropout_row_prob = 0.25,
dropout_col_prob = 0.25,
):
super().__init__()

pre_ln = partial(PreLayerNorm, dim = dim_pairwise)

tri_mult_kwargs = dict(
dim = dim_pairwise,
dim_hidden = tri_mult_dim_hidden
)

tri_attn_kwargs = dict(
dim = dim_pairwise,
heads = tri_attn_heads,
dim_head = tri_attn_dim_head
)

self.tri_mult_outgoing = pre_ln(TriangleMultiplication(mix = 'outgoing', dropout = dropout_row_prob, **tri_mult_kwargs))
self.tri_mult_incoming = pre_ln(TriangleMultiplication(mix = 'incoming', dropout = dropout_row_prob, **tri_mult_kwargs))
self.tri_attn_starting = pre_ln(TriangleAttention(node_type = 'starting', dropout = dropout_row_prob, **tri_attn_kwargs))
self.tri_attn_ending = pre_ln(TriangleAttention(node_type = 'ending', dropout = dropout_col_prob, **tri_attn_kwargs))
self.pairwise_transition = pre_ln(Transition(dim = dim_pairwise))

@typecheck
def forward(
self,
*,
pairwise_repr: Float['b n n d'],
mask: Bool['b n'] | None = None
):
pairwise_repr = self.tri_mult_outgoing(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = self.tri_mult_incoming(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = self.tri_attn_starting(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = self.tri_attn_ending(pairwise_repr, mask = mask) + pairwise_repr

pairwise_repr = self.pairwise_transition(pairwise_repr) + pairwise_repr
return pairwise_repr

# msa module

class OuterProductMean(Module):
Expand All @@ -380,15 +431,15 @@ def __init__(
self,
*,
dim_msa = 64,
dim_pairwise_repr = 128,
dim_pairwise = 128,
dim_hidden = 32,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.norm = nn.LayerNorm(dim_msa)
self.to_hidden = LinearNoBias(dim_msa, dim_hidden * 2)
self.to_pairwise_repr = nn.Linear(dim_hidden ** 2, dim_pairwise_repr)
self.to_pairwise_repr = nn.Linear(dim_hidden ** 2, dim_pairwise)

@typecheck
def forward(
Expand Down Expand Up @@ -441,7 +492,7 @@ def __init__(
self,
*,
dim_msa = 64,
dim_pairwise_repr = 128,
dim_pairwise = 128,
dim_head = 32,
heads = 8,
dropout = 0.
Expand All @@ -456,8 +507,8 @@ def __init__(
)

self.pairwise_repr_to_attn = nn.Sequential(
nn.LayerNorm(dim_pairwise_repr),
LinearNoBias(dim_pairwise_repr, heads),
nn.LayerNorm(dim_pairwise),
LinearNoBias(dim_pairwise, heads),
Rearrange('b i j h -> b h i j')
)

Expand Down Expand Up @@ -508,19 +559,15 @@ def __init__(
self,
*,
dim_single,
dim_pairwise = 128,
depth = 4,
dim_msa = 64,
dim_msa_input = None,
dim_pairwise = 128,
outer_product_mean_dim_hidden = 32,
dropout_row_prob = 0.25,
dropout_col_prob = 0.25,
tri_mult_dim_hidden = None,
tri_attn_dim_head = 32,
tri_attn_heads = 4,
msa_pwa_dropout_row_prob = 0.15,
msa_pwa_heads = 8,
msa_pwa_dim_head = 32,
pairwise_block_kwargs: dict = dict()
):
super().__init__()

Expand All @@ -530,52 +577,35 @@ def __init__(

layers = ModuleList([])

tri_mult_kwargs = dict(
dim = dim_pairwise,
dim_hidden = tri_mult_dim_hidden
)

tri_attn_kwargs = dict(
dim = dim_pairwise,
heads = tri_attn_heads,
dim_head = tri_attn_dim_head
)

for _ in range(depth):

pairwise_pre_ln = partial(PreLayerNorm, dim = dim_pairwise)
msa_pre_ln = partial(PreLayerNorm, dim = dim_msa)

outer_product_mean = OuterProductMean(
dim_msa = dim_msa,
dim_pairwise_repr = dim_pairwise,
dim_pairwise = dim_pairwise,
dim_hidden = outer_product_mean_dim_hidden
)

msa_pair_weighted_avg = MSAPairWeightedAveraging(
dim_msa = dim_msa,
dim_pairwise_repr = dim_pairwise,
dim_pairwise = dim_pairwise,
heads = msa_pwa_heads,
dim_head = msa_pwa_dim_head
)

msa_transition = Transition(dim = dim_msa)

tri_mult_outgoing = TriangleMultiplication(mix = 'outgoing', dropout = dropout_row_prob, **tri_mult_kwargs)
tri_mult_incoming = TriangleMultiplication(mix = 'incoming', dropout = dropout_row_prob, **tri_mult_kwargs)
tri_attn_starting = TriangleAttention(node_type = 'starting', dropout = dropout_row_prob, **tri_attn_kwargs)
tri_attn_ending = TriangleAttention(node_type = 'ending', dropout = dropout_col_prob, **tri_attn_kwargs)
pairwise_transition = Transition(dim = dim_pairwise)
pairwise_block = PairwiseBlock(
dim_pairwise = dim_pairwise,
**pairwise_block_kwargs
)

layers.append(ModuleList([
outer_product_mean,
msa_pair_weighted_avg,
msa_pre_ln(msa_transition),
pairwise_pre_ln(tri_mult_outgoing),
pairwise_pre_ln(tri_mult_incoming),
pairwise_pre_ln(tri_attn_starting),
pairwise_pre_ln(tri_attn_ending),
pairwise_pre_ln(pairwise_transition),
pairwise_block
]))

self.layers = layers
Expand All @@ -602,11 +632,7 @@ def forward(
outer_product_mean,
msa_pair_weighted_avg,
msa_transition,
tri_mult_outgoing,
tri_mult_incoming,
tri_attn_starting,
tri_attn_ending,
pairwise_transition
pairwise_block
)in self.layers:

# communication between msa and pairwise rep
Expand All @@ -616,14 +642,9 @@ def forward(
msa = msa_pair_weighted_avg(msa = msa, pairwise_repr = pairwise_repr, mask = mask) + msa
msa = msa_transition(msa) + msa

# pairwise tri mult + attn + transition

pairwise_repr = tri_mult_outgoing(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_mult_incoming(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_attn_starting(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_attn_ending(pairwise_repr, mask = mask) + pairwise_repr
# pairwise block

pairwise_repr = pairwise_transition(pairwise_repr) + pairwise_repr
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)

return pairwise_repr

Expand All @@ -638,55 +659,36 @@ def __init__(
dim_single,
dim_pairwise = 128,
depth = 48,
tri_mult_dim_hidden = None,
tri_attn_dim_head = 32,
tri_attn_heads = 4,
pair_bias_attn_dim_head = 64,
pair_bias_attn_heads = 16,
dropout_row_prob = 0.25,
dropout_col_prob = 0.25
pairwise_block_kwargs: dict = dict()
):
super().__init__()
layers = ModuleList([])

tri_mult_kwargs = dict(
dim = dim_pairwise,
dim_hidden = tri_mult_dim_hidden
)

tri_attn_kwargs = dict(
dim = dim_pairwise,
heads = tri_attn_heads,
dim_head = tri_attn_dim_head
)

pair_bias_attn_kwargs = dict(
dim = dim_single,
dim_pairwise_repr = dim_pairwise,
heads = pair_bias_attn_heads,
dim_head = pair_bias_attn_dim_head
dim_head = pair_bias_attn_dim_head,
dropout = dropout_row_prob
)

for _ in range(depth):

pairwise_pre_ln = partial(PreLayerNorm, dim = dim_pairwise)
single_pre_ln = partial(PreLayerNorm, dim = dim_single)

tri_mult_outgoing = TriangleMultiplication(mix = 'outgoing', dropout = dropout_row_prob, **tri_mult_kwargs)
tri_mult_incoming = TriangleMultiplication(mix = 'incoming', dropout = dropout_row_prob, **tri_mult_kwargs)
tri_attn_starting = TriangleAttention(node_type = 'starting', dropout = dropout_row_prob, **tri_attn_kwargs)
tri_attn_ending = TriangleAttention(node_type = 'ending', dropout = dropout_col_prob, **tri_attn_kwargs)
pairwise_transition = Transition(dim = dim_pairwise)
pairwise_block = PairwiseBlock(
dim_pairwise = dim_pairwise,
**pairwise_block_kwargs
)

pair_bias_attn = AttentionPairBias(**pair_bias_attn_kwargs)
single_transition = Transition(dim = dim_single)

layers.append(ModuleList([
pairwise_pre_ln(tri_mult_outgoing),
pairwise_pre_ln(tri_mult_incoming),
pairwise_pre_ln(tri_attn_starting),
pairwise_pre_ln(tri_attn_ending),
pairwise_pre_ln(pairwise_transition),
pairwise_block,
single_pre_ln(pair_bias_attn),
single_pre_ln(single_transition),
]))
Expand All @@ -704,21 +706,12 @@ def forward(
) -> Tuple[Float['b n ds'], Float['b n n dp']]:

for (
tri_mult_outgoing,
tri_mult_incoming,
tri_attn_starting,
tri_attn_ending,
pairwise_transition,
pairwise_block,
pair_bias_attn,
single_transition
) in self.layers:

pairwise_repr = tri_mult_outgoing(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_mult_incoming(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_attn_starting(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_attn_ending(pairwise_repr, mask = mask) + pairwise_repr

pairwise_repr = pairwise_transition(pairwise_repr) + pairwise_repr
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)

single_repr = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
single_repr = single_transition(single_repr) + single_repr
Expand Down

0 comments on commit 7a44edb

Please sign in to comment.