diff --git a/alphafold3_pytorch/attention.py b/alphafold3_pytorch/attention.py index 55cd1ded..6cc5cca9 100644 --- a/alphafold3_pytorch/attention.py +++ b/alphafold3_pytorch/attention.py @@ -84,10 +84,11 @@ def __init__( def forward( self, - seq, - mask = None, - context = None - ): + seq: Float['b i d'], + mask: Bool['b n']| None = None, + context: Float['b j d'] | None = None + ) -> Float['b i d']: + q = self.to_q(seq) context_seq = default(context, seq) @@ -112,7 +113,7 @@ def __init__( self, dropout = 0., flash = False, - scale = None, + scale: float | None = None, attn_config: Config = Config(True, True, True) ): super().__init__() @@ -138,7 +139,8 @@ def flash_attn( k: Float['b h j d'], v: Float['b h j d'], mask: Bool['b j'] | None = None - ): + ) -> Float['b h i d']: + _, heads, seq_len, _ = q.shape attn_mask = None @@ -162,7 +164,7 @@ def forward( k: Float['b h j d'], v: Float['b h j d'], mask: Bool['b j'] | None = None - ): + ) -> Float['b h i d']: if self.flash: return self.flash_attn(q, k, v, mask = mask)