Skip to content

Commit

Permalink
complete attention annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 11, 2024
1 parent 682b1bc commit 3de453f
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ def __init__(

def forward(
self,
seq,
mask = None,
context = None
):
seq: Float['b i d'],
mask: Bool['b n']| None = None,
context: Float['b j d'] | None = None
) -> Float['b i d']:

q = self.to_q(seq)

context_seq = default(context, seq)
Expand All @@ -112,7 +113,7 @@ def __init__(
self,
dropout = 0.,
flash = False,
scale = None,
scale: float | None = None,
attn_config: Config = Config(True, True, True)
):
super().__init__()
Expand All @@ -138,7 +139,8 @@ def flash_attn(
k: Float['b h j d'],
v: Float['b h j d'],
mask: Bool['b j'] | None = None
):
) -> Float['b h i d']:

_, heads, seq_len, _ = q.shape

attn_mask = None
Expand All @@ -162,7 +164,7 @@ def forward(
k: Float['b h j d'],
v: Float['b h j d'],
mask: Bool['b j'] | None = None
):
) -> Float['b h i d']:

if self.flash:
return self.flash_attn(q, k, v, mask = mask)
Expand Down

0 comments on commit 3de453f

Please sign in to comment.