Skip to content

Commit

Permalink
output type for distogram head
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 19, 2024
1 parent 216eea7 commit b0198d4
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,11 +1720,13 @@ def forward(
# distogram head

class DistogramHead(Module):

@typecheck
def __init__(
self,
*,
dim_pairwise = 128,
num_dist_bins = 38 # think it is 38?
num_dist_bins = 38, # think it is 38?
):
super().__init__()

Expand All @@ -1737,9 +1739,9 @@ def __init__(
def forward(
self,
pairwise_repr: Float['b n n d']
):
logits = self.to_distogram_logits(pairwise_repr)
) -> Float['b l n n']:

logits = self.to_distogram_logits(pairwise_repr)
return logits

# confidence head
Expand Down

0 comments on commit b0198d4

Please sign in to comment.