From c42884e017e05599adea226b22fa12ccd0663ca5 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 29 Dec 2024 09:52:18 -0800 Subject: [PATCH] complete hyper connected alphafold3 --- alphafold3_pytorch/alphafold3.py | 16 ++++++++++------ pyproject.toml | 4 ++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index d61a113d..ffda1784 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -119,7 +119,7 @@ from colt5_attention import ConditionalRoutedAttention -from hyper_connections import HyperConnections +from hyper_connections.hyper_connections_with_multi_input_streams import HyperConnections # other external libs @@ -995,8 +995,8 @@ def __init__( @typecheck def forward( self, - *, pairwise_repr: Float['b n n d'], + *, mask: Bool['b n'] | None = None, value_residuals: tuple[Tensor, Tensor] | None = None, return_values = False, @@ -1470,8 +1470,8 @@ def __init__( single_transition = Transition(dim = dim_single) layers.append(ModuleList([ - pairwise_block, - init_hyper_conn(dim = dim_single, branch = single_pre_ln(pair_bias_attn)), + init_hyper_conn(dim = dim_pairwise, branch = pairwise_block), + init_hyper_conn(dim = dim_single, additional_input_paths = [('pairwise_repr', dim_pairwise)], branch = single_pre_ln(pair_bias_attn)), init_hyper_conn(dim = dim_single, branch = single_pre_ln(single_transition)), ])) @@ -1508,6 +1508,7 @@ def to_layers( ) -> Tuple[Float['b n ds'], Float['b n n dp']]: single_repr = self.expand_streams(single_repr) + pairwise_repr = self.expand_streams(pairwise_repr) for _ in range(self.recurrent_depth): @@ -1520,7 +1521,7 @@ def to_layers( single_transition ) in self.layers: - pairwise_repr, pairwise_attn_values = pairwise_block(pairwise_repr = pairwise_repr, mask = mask, value_residuals = pairwise_value_residuals, return_values = True) + pairwise_repr, pairwise_attn_values = pairwise_block(pairwise_repr, mask = mask, value_residuals = pairwise_value_residuals, return_values = True) single_repr, attn_values = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = value_residual) @@ -1531,6 +1532,7 @@ def to_layers( single_repr = single_transition(single_repr) single_repr = self.reduce_streams(single_repr) + pairwise_repr = self.reduce_streams(pairwise_repr) return single_repr, pairwise_repr @@ -1548,7 +1550,7 @@ def pairwise_block_wrapper(layer): @wraps(layer) def inner(inputs, *args, **kwargs): single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs - pairwise_repr, pairwise_attn_values = layer(pairwise_repr = pairwise_repr, mask = mask, value_residuals = maybe_pairwise_value_residuals, return_values = True) + pairwise_repr, pairwise_attn_values = layer(pairwise_repr, mask = mask, value_residuals = maybe_pairwise_value_residuals, return_values = True) if self.add_value_residual: maybe_pairwise_value_residuals = default(maybe_pairwise_value_residuals, pairwise_attn_values) @@ -1589,6 +1591,7 @@ def inner(inputs, *args, **kwargs): wrapped_layers.append(single_transition_wrapper(single_transition)) single_repr = self.expand_streams(single_repr) + pairwise_repr = self.expand_streams(pairwise_repr) for _ in range(self.recurrent_depth): inputs = (single_repr, pairwise_repr, mask, None, None) @@ -1599,6 +1602,7 @@ def inner(inputs, *args, **kwargs): single_repr, pairwise_repr, *_ = inputs single_repr = self.reduce_streams(single_repr) + pairwise_repr = self.reduce_streams(pairwise_repr) return single_repr, pairwise_repr diff --git a/pyproject.toml b/pyproject.toml index 0ee6c70a..a29f3c4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.7.7" +version = "0.7.8" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }, @@ -41,7 +41,7 @@ dependencies = [ "fair-esm", "fastapi", "frame-averaging-pytorch>=0.0.18", - "hyper-connections>=0.0.21", + "hyper-connections>=0.0.23", "gradio", "gradio_molecule3d", "huggingface_hub>=0.21.4",