Skip to content

Commit

Permalink
exploit ellipsis in jaxtyping
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 14, 2024
1 parent d7c68a2 commit 9694800
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class SwiGLU(Module):
@typecheck
def forward(
self,
x: Float['b n d']
) -> Float['b n (d//2)']:
x: Float['... d']
) -> Float['... (d//2)']:

x, gates = x.chunk(2, dim = -1)
return F.silu(gates) * x
Expand All @@ -82,8 +82,8 @@ def __init__(
@typecheck
def forward(
self,
x: Float['b n d']
) -> Float['b n d']:
x: Float['... d']
) -> Float['... d']:

return self.ff(x)

Expand Down Expand Up @@ -614,10 +614,7 @@ def forward(
pairwise_repr = outer_product_mean(msa, mask = mask, msa_mask = msa_mask) + pairwise_repr

msa = msa_pair_weighted_avg(msa = msa, pairwise_repr = pairwise_repr, mask = mask) + msa

msa, msa_packed_shape = pack_one(msa, 'b * d')
msa = msa_transition(msa) + msa
msa = unpack_one(msa, msa_packed_shape, 'b * d')
msa = msa_transition(msa) + msa

# pairwise tri mult + attn + transition

Expand All @@ -626,9 +623,7 @@ def forward(
pairwise_repr = tri_attn_starting(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_attn_ending(pairwise_repr, mask = mask) + pairwise_repr

pairwise_repr, packed_shape = pack_one(pairwise_repr, 'b * d')
pairwise_repr = pairwise_transition(pairwise_repr) + pairwise_repr
pairwise_repr = unpack_one(pairwise_repr, packed_shape, 'b * d')

return pairwise_repr

Expand Down Expand Up @@ -723,9 +718,7 @@ def forward(
pairwise_repr = tri_attn_starting(pairwise_repr, mask = mask) + pairwise_repr
pairwise_repr = tri_attn_ending(pairwise_repr, mask = mask) + pairwise_repr

pairwise_repr, packed_shape = pack_one(pairwise_repr, 'b * d')
pairwise_repr = pairwise_transition(pairwise_repr) + pairwise_repr
pairwise_repr = unpack_one(pairwise_repr, packed_shape, 'b * d')

single_repr = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
single_repr = single_transition(single_repr) + single_repr
Expand Down

0 comments on commit 9694800

Please sign in to comment.