Skip to content

Commit

Permalink
addressed smaller comments. atom2res now significantly faster with pr…
Browse files Browse the repository at this point in the history
…earranged dictionary
  • Loading branch information
csbrasnett committed Dec 18, 2024
1 parent 3468c0e commit 3678176
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 120 deletions.
16 changes: 7 additions & 9 deletions bin/martinize2
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import argparse
import functools
import logging
import itertools
import os.path
from pathlib import Path
import sys
import networkx as nx
Expand Down Expand Up @@ -831,9 +830,9 @@ def entry():
"""
Sort out the use of the go model:
go = True : do the go model
go_file = None : calculate the contact map in martinize2
go_file = some file : read in the contact map from elsewhere
go_file = True: calculate contact map
go_file = str: parse contact map
bool(go_file) = False: no go
"""
go = False
go_file = None
Expand Down Expand Up @@ -985,13 +984,15 @@ def entry():
elif args.cystein_bridge != "auto":
vermouth.AddCysteinBridgesThreshold(args.cystein_bridge).run_system(system)

go_map = False
if go:
# need this here because have to get contact map at atomistic resolution
if go_file is None:
LOGGER.info("Generating Go model contact map.", type="step")
GenerateContactMap().run_system(system)
go_map = True
else:
LOGGER.info("Reading Go model contact map.", type="step")
GenerateContactMap(path=go_file).run_system(system)


# Run martinize on the system.
system = martinize(
Expand All @@ -1016,9 +1017,6 @@ def entry():
# Generate the Go model if required

if go:
if not system.go_params["go_map"]:
LOGGER.info("Reading Go model contact map.", type="step")
GenerateContactMap(path=go_file).run_system(system)
go_name_prefix = args.molname
LOGGER.info("Generating the Go model.", type="step")
GoPipeline.run_system(system,
Expand Down
215 changes: 104 additions & 111 deletions vermouth/rcsu/contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from ..processors.processor import Processor
import numpy as np
from scipy.spatial.distance import euclidean, cdist
from scipy.spatial.distance import euclidean
from scipy.spatial import cKDTree as KDTree
from .. import MergeAllMolecules
from ..graph_utils import make_residue_graph

from itertools import product

PROTEIN_MAP = {
"ALA": {
Expand Down Expand Up @@ -311,41 +311,35 @@ def make_surface(position, fiba, fibb, vrad):
position: centre of sphere
"""
x, y, z = position
phi_aux = 0

surface = np.zeros((0, 3))
for k in range(fibb):
x, y, z = position

phi_aux += fiba
if phi_aux > fibb:
phi_aux -= fibb
k = np.arange(fibb)
phi_aux = (np.arange(fibb) * fiba) % fibb
theta = np.arccos(1.0 - 2.0 * k / fibb)
phi = 2.0 * np.pi * phi_aux / fibb
surface_x = x + vrad * np.sin(theta) * np.cos(phi)
surface_y = y + vrad * np.sin(theta) * np.sin(phi)
surface_z = z + vrad * np.cos(theta)
surface = np.stack((surface_x, surface_y, surface_z), axis=-1)

theta = np.arccos(1.0 - 2.0 * k / fibb)
phi = 2.0 * np.pi * phi_aux / fibb
surface_x = x + vrad * np.sin(theta) * np.cos(phi)
surface_y = y + vrad * np.sin(theta) * np.sin(phi)
surface_z = z + vrad * np.cos(theta)
surface = np.vstack((surface, np.array([surface_x, surface_y, surface_z])))
return surface


def atom2res(arrin, residues, nresidues, norm=False):
"""
def atom2res(arrin, nresidues, atom_map, norm=False):
'''
take an array with atom level data and sum the entries over within the residue
arrin: np.ndarray
NxN array of entries for each atom
residues: np.array
array of length N indicating which residue an atom belongs to
nresidues: int
number of residues in the molecule
norm: bool
if True, then any entry > 0 in the summed array = 1
"""
out = np.array([int(arrin[np.where(residues == i)[0], np.where(residues == j)[0][:, np.newaxis]].sum())
for i in range(nresidues)
for j in range(nresidues)]).reshape((nresidues, nresidues))
'''

out = np.zeros((nresidues, nresidues))
for res_idx, res_jdx in product(np.arange(nresidues), np.arange(nresidues)):
atom_idxs = np.array(atom_map[res_idx])
atom_jdxs = np.array(atom_map[res_jdx])
value = arrin[atom_idxs,
atom_jdxs[:, np.newaxis]].sum()
out[res_idx, res_jdx] = value

if norm:
out[out > 0] = 1

Expand Down Expand Up @@ -404,20 +398,22 @@ def contact_info(system):
resids.append(G.nodes[residue]['resid'])
chains.append(G.nodes[residue]['chain'])
nodes.append(G.nodes[residue]['_res_serial'])
subgraph = G.nodes[residue]['graph']

for atom in sorted(G.nodes[residue]['graph'].nodes):
if 'position' in G.nodes[residue]['graph'].nodes[atom]:
res_serial.append(G.nodes[residue]['graph'].nodes[atom]['_res_serial'])
position = subgraph.nodes[atom].get('position', [np.nan]*3)
if np.isfinite(position).all():
res_serial.append(subgraph.nodes[atom]['_res_serial'])

positions_all.append(G.nodes[residue]['graph'].nodes[atom]['position'] * 10)
positions_all.append(subgraph.nodes[atom]['position'] * 10)

vdw_list.append(get_vdw_radius(G.nodes[residue]['graph'].nodes[atom]['resname'],
G.nodes[residue]['graph'].nodes[atom]['atomname']))
atypes.append(get_atype(G.nodes[residue]['graph'].nodes[atom]['resname'],
G.nodes[residue]['graph'].nodes[atom]['atomname']))
vdw_list.append(get_vdw_radius(subgraph.nodes[atom]['resname'],
subgraph.nodes[atom]['atomname']))
atypes.append(get_atype(subgraph.nodes[atom]['resname'],
subgraph.nodes[atom]['atomname']))

if G.nodes[residue]['graph'].nodes[atom]['atomname'] == 'CA':
ca_pos.append(G.nodes[residue]['graph'].nodes[atom]['position'])
if subgraph.nodes[atom]['atomname'] == 'CA':
ca_pos.append(subgraph.nodes[atom]['position'])


vdw_list = np.array(vdw_list)
Expand All @@ -435,7 +431,7 @@ def contact_info(system):

return vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G

def calculate_overlap(coords_tree, vdw_list, natoms, vdw_max):
def calculate_overlap(coords_tree, vdw_list, natoms, vdw_max, alpha):
"""
Find enlarged (OV) overlap contacts
Expand All @@ -447,8 +443,9 @@ def calculate_overlap(coords_tree, vdw_list, natoms, vdw_max):
number of atoms in the molecule
vdw_max: float
maximum possible vdw radius of atoms
alpha: float
Enlargement factor for attraction effects
"""
alpha = 1.24 # Enlargement factor for attraction effects
over = np.zeros((natoms, natoms))
over_sdm = coords_tree.sparse_distance_matrix(coords_tree, 2 * vdw_max * alpha)
for (idx, jdx), distance_between in over_sdm.items():
Expand All @@ -457,7 +454,7 @@ def calculate_overlap(coords_tree, vdw_list, natoms, vdw_max):
over[idx, jdx] = 1
return over

def calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max):
def calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max, water_radius):
"""
Calculate contacts of structural units (CSU)
Expand All @@ -473,19 +470,16 @@ def calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max):
KDTree of the input coordinates
vdw_max: float
maximum possible vdw radius of atoms
water_radius: float
radius of water molecule in A
Returns:
hit_results: natoms x fibb np.array
each i,j entry is the index of the atom in coords which is the closest atom to atom i at index j of the
fibonacci sphere
"""
water_radius = 2.80 # Radius of a water molecule in A

# generate fibonacci spheres for all atoms.
# can't decide whether quicker/better for memory to generate all in one go here or incorporate into the loop
spheres = np.stack([KDTree(make_surface(i, fiba, fibb, water_radius + j))
for i, j in zip(coords, vdw_list)])

#setup arrays to keep track
hit_results = np.full((natoms, fibb), -1)
Expand All @@ -496,27 +490,28 @@ def calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max):
# n.b. this loop works because sparse_distance_matrix is sorted by (idx, jdx) pairs
for (idx, jdx), distance_between in surface_sdm.items():
# don't take atoms which are identical
if idx != jdx:
# check that the distance between them is shorter than the vdw sum and the water radius
if distance_between < (vdw_list[idx] + vdw_list[jdx] + water_radius):
# get the KDTree corresponding to this base point
base_tree = spheres[idx]
if idx == jdx:
continue

# get the vdw distance of the target point
vdw = vdw_list[jdx] + water_radius
# check that the distance between them is shorter than the vdw sum and the water radius
if distance_between >= (vdw_list[idx] + vdw_list[jdx] + water_radius):
continue

# find points on the base point sphere which are within the vdw cutoff of the target point's coordinate
res = np.array(base_tree.query_ball_point(coords[jdx], vdw))
# Generate the fibonacci sphere for this point and make a KDTree from it
base_tree = KDTree(make_surface(coords[idx], fiba, fibb, vdw_list[idx]+water_radius))

# if we have any results
if len(res) > 0:
# find where the distance between the two points is smaller than the current recorded distance
# at the points which are within the cutoff
to_fill = np.where(distance_between < dists_counter[idx][res])[0]
# find points on the base point sphere which are within the vdw cutoff of the target point's coordinate
res = np.array(base_tree.query_ball_point(coords[jdx], vdw_list[jdx] + water_radius))

# record the new distances and indices of the points
dists_counter[idx][res[to_fill]] = distance_between
hit_results[idx][res[to_fill]] = jdx
# if we have any results
if len(res) > 0:
# find where the distance between the two points is smaller than the current recorded distance
# at the points which are within the cutoff
to_fill = np.where(distance_between < dists_counter[idx][res])[0]

# record the new distances and indices of the points
dists_counter[idx][res[to_fill]] = distance_between
hit_results[idx][res[to_fill]] = jdx

return hit_results

Expand Down Expand Up @@ -578,50 +573,50 @@ def calculate_contacts(vdw_list, atypes, coords, res_serial, nresidues):

natoms = len(coords)

# all_vdw = np.array([x for xs in
# [[PROTEIN_MAP[i][j]['vrad'] for j in PROTEIN_MAP[i].keys()]
# for i in PROTEIN_MAP.keys()]
# for x in xs])
vdw_max = 1.88 # all_vdw.max()
vdw_max = max(item['vmax'] for atoms in PROTEIN_MAP.values() for item in atoms.values())

# make the KDTree of the input coordinates
coords_tree = KDTree(coords)

# calculate the OV contacts of the molecule
over = calculate_overlap(coords_tree, vdw_list, natoms, vdw_max)
over = calculate_overlap(coords_tree, vdw_list, natoms, vdw_max, alpha=1.24)

# Calculate the CSU contacts of the molecule
hit_results = calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max)
hit_results = calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max, water_radius=2.80)

# find the types of contacts we have
contactcounter_1, stabilisercounter_1, destabilisercounter_1 = contact_types(hit_results, natoms, atypes)

atom_map = {}
for i in range(nresidues):
atom_map[i] = np.where(res_serial == i)[0]

# transform the resolution between atoms and residues
overlapcounter_2 = atom2res(over, res_serial, nresidues, norm=True)
contactcounter_2 = atom2res(contactcounter_1, res_serial, nresidues)
stabilisercounter_2 = atom2res(stabilisercounter_1, res_serial, nresidues)
destabilisercounter_2 = atom2res(destabilisercounter_1, res_serial, nresidues)
overlapcounter_2 = atom2res(over, nresidues, atom_map, norm=True)
contactcounter_2 = atom2res(contactcounter_1, nresidues, atom_map)
stabilisercounter_2 = atom2res(stabilisercounter_1, nresidues, atom_map)
destabilisercounter_2 = atom2res(destabilisercounter_1, nresidues, atom_map)

return overlapcounter_2, contactcounter_2, stabilisercounter_2, destabilisercounter_2


def get_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, nodes, G):
contacts_list = []
for i1 in range(nresidues):
for i2 in range(nresidues):
over = overlaps[i1, i2]
cont = contacts[i1, i2]
stab = stabilisers[i1, i2]
dest = destabilisers[i1, i2]
rcsu = 1 if (stab - dest) > 0 else 0

if (over > 0 or cont > 0) and (i1 != i2):
a = np.where(nodes == i1)[0][0]
b = np.where(nodes == i2)[0][0]
if over == 1 or (over == 0 and rcsu == 1):
# this is a OV or rCSU contact we take it
contacts_list.append((int(G.nodes[a]['resid']), G.nodes[a]['chain'],
int(G.nodes[b]['resid']), G.nodes[b]['chain']))
for i1, i2 in product(np.arange(nresidues), np.arange(nresidues)):
if i1 == i2: continue
over = overlaps[i1, i2]
cont = contacts[i1, i2]
stab = stabilisers[i1, i2]
dest = destabilisers[i1, i2]
rcsu = (stab - dest) > 0

if (over > 0 or cont > 0):
a = np.where(nodes == i1)[0][0]
b = np.where(nodes == i2)[0][0]
if over == 1 or (over == 0 and rcsu):
# this is a OV or rCSU contact we take it
contacts_list.append((int(G.nodes[a]['resid']), G.nodes[a]['chain'],
int(G.nodes[b]['resid']), G.nodes[b]['chain']))

return contacts_list

Expand All @@ -630,28 +625,26 @@ def write_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, no
# this to write out the file if needed
with open('contact_map_vermouth.out', 'w') as f:
count = 0
for i1 in range(nresidues):
for i2 in range(nresidues):
over = overlaps[i1, i2]
cont = contacts[i1, i2]
stab = stabilisers[i1, i2]
dest = destabilisers[i1, i2]
ocsu = stab
rcsu = stab - dest

if (over > 0 or cont > 0) and (i1 != i2):
a = np.where(nodes == i1)[0][0]
b = np.where(nodes == i2)[0][0]
count += 1
msg = (f"R {int(count):6d} "
f"{int(i1 + 1):5d} {G.nodes[a]['resname']:3s}"
f"{G.nodes[a]['chain']:1s} {int(G.nodes[a]['resid']):4d} "
f"{int(i2 + 1):5d} {G.nodes[b]['resname']:3s}"
f"{G.nodes[b]['chain']:1s} {int(G.nodes[b]['resid']):4d} "
f"{euclidean(ca_pos[a], ca_pos[b])*10:9.4f} "
f"{int(over):1d} {1 if cont != 0 else 0} {1 if ocsu != 0 else 0} {1 if rcsu > 0 else 0}"
f"{int(rcsu):6d} {int(cont):6d}\n")
f.writelines(msg)
for i1, i2 in product(np.arange(nresidues), np.arange(nresidues)):
over = overlaps[i1, i2]
cont = contacts[i1, i2]
stab = stabilisers[i1, i2]
dest = destabilisers[i1, i2]
rcsu = (stab - dest) > 0

if (over > 0 or cont > 0) and (i1 != i2):
a = np.where(nodes == i1)[0][0]
b = np.where(nodes == i2)[0][0]
count += 1
msg = (f"R {int(count):6d} "
f"{int(i1 + 1):5d} {G.nodes[a]['resname']:3s}"
f"{G.nodes[a]['chain']:1s} {int(G.nodes[a]['resid']):4d} "
f"{int(i2 + 1):5d} {G.nodes[b]['resname']:3s}"
f"{G.nodes[b]['chain']:1s} {int(G.nodes[b]['resid']):4d} "
f"{euclidean(ca_pos[a], ca_pos[b])*10:9.4f} "
f"{int(over):1d} {1 if cont != 0 else 0} {1 if stab != 0 else 0} {1 if rcsu else 0}"
f"{int(rcsu):6d} {int(cont):6d}\n")
f.writelines(msg)


"""
Expand Down

0 comments on commit 3678176

Please sign in to comment.