Skip to content

Commit

Permalink
add value residual learning
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 2, 2024
1 parent 952bf6d commit b8acadf
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 26 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -503,4 +503,3 @@ docker run -v .:/data --gpus all -it af3
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```

87 changes: 68 additions & 19 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,13 +678,26 @@ def forward(
*,
cond: Float['b n dc'],
**kwargs
) -> Float['b n d']:
) -> (
Float['b n d'] |
tuple[Float['b n d'], Float['b _ _']]
):
x = self.adaptive_norm(x, cond = cond)

out = self.fn(x, **kwargs)

tuple_output = isinstance(out, tuple)

if tuple_output:
out, *rest = out

gamma = self.to_adaln_zero_gamma(cond)
return out * gamma
out = out * gamma

if tuple_output:
out = (out, *rest)

return out

# triangle multiplicative module
# seems to be unchanged from alphafold2
Expand Down Expand Up @@ -762,7 +775,10 @@ def __init__(self, *, heads, dim_pairwise, window_size=None, num_memory_kv=0, **
self.window_size = window_size

self.attn = Attention(
heads=heads, window_size=window_size, num_memory_kv=num_memory_kv, **attn_kwargs
heads = heads,
window_size = window_size,
num_memory_kv = num_memory_kv,
**attn_kwargs
)

# line 8 of Algorithm 24
Expand All @@ -777,8 +793,14 @@ def forward(
*,
pairwise_repr: Float["b n n dp"] | Float["b nw w (w*2) dp"], # type: ignore
attn_bias: Float["b n n"] | Float["b nw w (w*2)"] | None = None, # type: ignore
return_values: bool = False,
value_residual: Float['b _ _'] | None = None,
**kwargs,
) -> Float["b n ds"]: # type: ignore
) -> (
Float['b n ds'] |
tuple[Float['b n ds'], Float['b _ _']]
): # type: ignore

"""Perform the forward pass.
:param single_repr: The single representation tensor.
Expand Down Expand Up @@ -837,9 +859,22 @@ def forward(
else:
attn_bias = self.to_attn_bias(self.to_attn_bias_norm(pairwise_repr)) + attn_bias

out = self.attn(single_repr, attn_bias=attn_bias, **kwargs)
# attention

return out
out, values = self.attn(
single_repr,
attn_bias = attn_bias,
value_residual = value_residual,
return_values = True,
**kwargs
)

# whether to return values for value residual learning

if not return_values:
return out

return out, values

class TriangleAttention(Module):
def __init__(
Expand Down Expand Up @@ -1915,9 +1950,9 @@ def __init__(
attn_num_memory_kv = False,
trans_expansion_factor = 2,
num_register_tokens = 0,
add_residual = True,
use_linear_attn = False,
checkpoint = False,
add_value_residual = False,
linear_attn_kwargs = dict(
heads = 8,
dim_head = 16
Expand Down Expand Up @@ -1997,7 +2032,7 @@ def __init__(

self.layers = layers

self.add_residual = add_residual
self.add_value_residual = add_value_residual

self.has_registers = num_register_tokens > 0
self.num_registers = num_register_tokens
Expand All @@ -2018,32 +2053,37 @@ def to_checkpointed_serial_layers(
windowed_mask: Bool['b nw w (w*2)'] | None = None
):

inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask)
inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask, None)

wrapped_layers = []

def efficient_attn_wrapper(fn):
@wraps(fn)
def inner(inputs):
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
noised_repr = fn(noised_repr, mask = mask) + noised_repr
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
return inner

def attn_wrapper(fn):
@wraps(fn)
def inner(inputs):
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
noised_repr = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask) + noised_repr
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
attn_out, attn_values = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask, value_residual = maybe_value_residual, return_values = True)
noised_repr = attn_out + noised_repr

if self.add_value_residual:
maybe_value_residual = default(maybe_value_residual, attn_values)

return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
return inner

def transition_wrapper(fn):
@wraps(fn)
def inner(inputs):
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
noised_repr = fn(noised_repr, cond = single_repr) + noised_repr
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
return inner

for linear_attn, colt5_attn, attn, transition in self.layers:
Expand Down Expand Up @@ -2074,6 +2114,8 @@ def to_serial_layers(
windowed_mask: Bool['b nw w (w*2)'] | None = None
):

value_residual = None

for linear_attn, colt5_attn, attn, transition in self.layers:

if exists(linear_attn):
Expand All @@ -2082,13 +2124,20 @@ def to_serial_layers(
if exists(colt5_attn):
noised_repr = colt5_attn(noised_repr, mask = mask) + noised_repr

noised_repr = attn(
attn_out, attn_values = attn(
noised_repr,
cond = single_repr,
pairwise_repr = pairwise_repr,
mask = mask,
windowed_mask = windowed_mask
) + noised_repr
windowed_mask = windowed_mask,
return_values = True,
value_residual = value_residual
)

noised_repr = noised_repr + attn_out

if self.add_value_residual:
value_residual = default(value_residual, attn_values)

noised_repr = transition(
noised_repr,
Expand Down
27 changes: 24 additions & 3 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,29 @@ def forward(
mask: Bool['b n']| None = None,
context: Float['b j d'] | None = None,
windowed_mask: Bool['b nw w (w*2)'] | None = None,
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None,
return_values: bool = False,
value_residual: Float['b j dh'] | None = None,

) -> Float['b i d']:
) -> (
Float['b i d'] |
tuple[Float['b i d'], Float['b j _']]
):

q = self.to_q(seq)

context_seq = default(context, seq)
k, v = self.to_kv(context_seq).chunk(2, dim = -1)

# handle value residual

orig_v = v

if exists(value_residual):
v = v + value_residual

# split heads

q, k, v = tuple(self.split_heads(t) for t in (q, k, v))

# attention
Expand All @@ -270,7 +284,14 @@ def forward(

# combine heads

return self.to_out(out)
out = self.to_out(out)

# maybe return values

if not return_values:
return out

return out, orig_v

# the main attention function

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.6.5"
version = "0.6.6"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down
7 changes: 5 additions & 2 deletions tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,12 @@ def test_msa_module(
@pytest.mark.parametrize('checkpoint', (False, True))
@pytest.mark.parametrize('use_linear_attn', (False, True))
@pytest.mark.parametrize('use_colt5_attn', (False, True))
@pytest.mark.parametrize('add_value_residual', (False, True))
def test_diffusion_transformer(
checkpoint,
use_linear_attn,
use_colt5_attn
use_colt5_attn,
add_value_residual
):

single = torch.randn(2, 16, 384).requires_grad_()
Expand All @@ -383,7 +385,8 @@ def test_diffusion_transformer(
heads = 16,
checkpoint = checkpoint,
use_linear_attn = use_linear_attn,
use_colt5_attn = use_colt5_attn
use_colt5_attn = use_colt5_attn,
add_value_residual = add_value_residual
)

single_out = diffusion_transformer(
Expand Down

0 comments on commit b8acadf

Please sign in to comment.