Skip to content

Commit

Permalink
another einx set_at removal (#282)
Browse files Browse the repository at this point in the history
* another einx set_at removal

* fix an issue with hard validate atom indices
  • Loading branch information
lucidrains authored Sep 22, 2024
1 parent 6931bdc commit 492b4b8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
37 changes: 24 additions & 13 deletions alphafold3_pytorch/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def hard_validate_atom_indices_ascending(

# NOTE: this is a relaxed assumption, i.e., that if empty, all -1, or only one molecule, then it passes the test

if present_indices.numel() == 0 or present_indices.shape[-1] <= 1:
if present_indices.numel() == 0 or present_indices.shape[0] <= 1:
continue

difference = einx.subtract(
Expand Down Expand Up @@ -1134,7 +1134,6 @@ def molecule_lengthed_molecule_input_to_atom_input(
if mol_is_one_token_per_atom:
coordinates = []

has_bond = torch.zeros(num_atoms, num_atoms).bool()
bonds = mol.GetBonds()
num_bonds = len(bonds)

Expand All @@ -1150,6 +1149,8 @@ def molecule_lengthed_molecule_input_to_atom_input(
)

if num_bonds > 0:
has_bond = torch.zeros(num_atoms, num_atoms).bool()

coordinates = tensor(coordinates).long()

# has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
Expand All @@ -1163,8 +1164,8 @@ def molecule_lengthed_molecule_input_to_atom_input(

# / ein.set_at

row_col_slice = slice(offset, offset + num_atoms)
token_bonds[row_col_slice, row_col_slice] = has_bond
row_col_slice = slice(offset, offset + num_atoms)
token_bonds[row_col_slice, row_col_slice] = has_bond

offset += num_atoms if mol_is_one_token_per_atom else 1

Expand Down Expand Up @@ -3572,13 +3573,14 @@ def pdb_input_to_molecule_input(
# construct ligand and modified polymer chain token bonds

coordinates = []
updates = []

ligand = molecules[ligand_offset]
num_atoms = ligand.GetNumAtoms()
has_bond = torch.zeros(num_atoms, num_atoms).bool()

for bond in ligand.GetBonds():
bonds = ligand.GetBonds()
num_bonds = len(bonds)

for bond in bonds:
atom_start_index = bond.GetBeginAtomIdx()
atom_end_index = bond.GetEndAtomIdx()

Expand All @@ -3589,15 +3591,24 @@ def pdb_input_to_molecule_input(
]
)

updates.extend([True, True])
if num_bonds > 0:
has_bond = torch.zeros(num_atoms, num_atoms).bool()

coordinates = tensor(coordinates).long()
updates = tensor(updates).bool()
coordinates = tensor(coordinates).long()

# has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)

has_bond_stride = tensor(has_bond.stride())
flattened_coordinates = (coordinates * has_bond_stride).sum(dim = -1)
packed_has_bond, unpack_has_bond = pack_one(has_bond, '*')

packed_has_bond[flattened_coordinates] = True
has_bond = unpack_has_bond(packed_has_bond, '*')

has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
# / einx.set_at

row_col_slice = slice(polymer_offset, polymer_offset + num_atoms)
token_bonds[row_col_slice, row_col_slice] = has_bond
row_col_slice = slice(polymer_offset, polymer_offset + num_atoms)
token_bonds[row_col_slice, row_col_slice] = has_bond

polymer_offset += num_atoms
ligand_offset += 1
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.5.35"
version = "0.5.36"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down

0 comments on commit 492b4b8

Please sign in to comment.