Skip to content

Commit

Permalink
change implementation to using KDTree.
Browse files Browse the repository at this point in the history
  • Loading branch information
csbrasnett committed Dec 16, 2024
1 parent cf7cc8b commit a1a217f
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 154 deletions.
20 changes: 10 additions & 10 deletions bin/martinize2
Original file line number Diff line number Diff line change
Expand Up @@ -549,16 +549,10 @@ def entry():
dest="go",
nargs='?',
const=None,
type=str,
type=Path,
help="Use Martini Go model. Accepts either an input file from the server, "
"or just provide the flag to calculate as part of Martinize. Can be slow for large proteins "
"(> 1500 residues)"
# required=False,
# type=Path,
# action='store_true',
# default=False,
# help="Contact map to be used for the Martini Go model."
# "Currently, only one format is supported. See docs.",
)
go_group.add_argument(
"-go-eps",
Expand Down Expand Up @@ -834,12 +828,19 @@ def entry():
"be used together."
)

"""
Sort out the use of the go model:
go = True : do the go model
go_file = None : calculate the contact map in martinize2
go_file = some file : read in the contact map from elsewhere
"""
go = False
go_file = None
if args.go is None:
go = True
else:
if os.path.isfile(args.go):
if Path(args.go).is_file():
go_file = args.go
go = True

Expand Down Expand Up @@ -1015,14 +1016,13 @@ def entry():
# Generate the Go model if required

if go:
if not go_map:
if not system.go_params["go_map"]:
LOGGER.info("Reading Go model contact map.", type="step")
GenerateContactMap(path=go_file).run_system(system)
go_name_prefix = args.molname
LOGGER.info("Generating the Go model.", type="step")
GoPipeline.run_system(system,
moltype=go_name_prefix,
# contact_map=go_map,
cutoff_short=args.go_low,
cutoff_long=args.go_up,
go_eps=args.go_eps,
Expand Down
224 changes: 82 additions & 142 deletions vermouth/rcsu/contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..processors.processor import Processor
import numpy as np
from scipy.spatial.distance import euclidean, cdist
from scipy.spatial import cKDTree as KDTree
from .. import MergeAllMolecules
from ..graph_utils import make_residue_graph

Expand Down Expand Up @@ -280,10 +281,10 @@ def get_vdw_radius(resname, atomname):
"""
res_vdw = PROTEIN_MAP[resname]
try:
atom_vdw = res_vdw[atomname]['vrad']
atom_vdw = res_vdw[atomname]
except KeyError:
atom_vdw = res_vdw['default']['vrad']
return atom_vdw
atom_vdw = res_vdw['default']
return atom_vdw['vrad']


def get_atype(resname, atomname):
Expand All @@ -292,19 +293,25 @@ def get_atype(resname, atomname):
"""
res_vdw = PROTEIN_MAP[resname]
try:
atom_vdw = res_vdw[atomname]['atype']
atom_vdw = res_vdw[atomname]
except KeyError:
atom_vdw = res_vdw['default']['atype']
return atom_vdw
atom_vdw = res_vdw['default']
return atom_vdw['atype']


def make_surface(position, fiba, fibb, vrad):
"""
Generate points on a sphere using Fibonacci points
position: np.array
shape (3,) array of an atomic position to build a sphere around
fiba: int. n-1 fibonacci number to build number of points on sphere
fibb: int. n fibonacci number to build number of points on sphere
vrad: float. VdW radius of the input atom to build a sphere around.
position: centre of sphere
"""
x, y, z = position[0], position[1], position[2]
x, y, z = position
phi_aux = 0

surface = np.zeros((0, 3))
Expand All @@ -323,49 +330,18 @@ def make_surface(position, fiba, fibb, vrad):
return surface


def res2atom(arrin, residues, nresidues):
"""
take an array with residue level data and repeat the entries over
each atom within the residue
would be nice to do this with list comprehension but I can't work out how
to do something like:
[np.tile(res_dists[i,j],
np.where(residues == i)[0].size)
for i in range(nresidues)
for j in range(nresidues)
]
to get it correct in the 2 dimensional way we actually need.
At the moment we only need this function once
(to get the residue COGs foreach atom)
so it's not too much of a limiting factor, but something to optimise
better in future
"""
# find out how many residues we have, and how many atoms are in each of them
unique_values, counts = np.unique(residues, return_counts=True)

assert len(unique_values) == nresidues

out = np.zeros((len(residues), len(residues)))
start0 = 0
for i, j in zip(unique_values, counts):
start1 = 0
for k, l in zip(unique_values, counts):
target_value = arrin[i, k]
out[start0:start0 + j, start1:start1 + l] = target_value.sum()
start1 += l
start0 += j

return out


def atom2res(arrin, residues, nresidues, norm=False):
"""
take an array with atom level data and sum the entries over within the residue
arrin: np.ndarray
NxN array of entries for each atom
residues: np.array
array of length N indicating which residue an atom belongs to
nresidues: int
number of residues in the molecule
norm: bool
if True, then any entry > 0 in the summed array = 1
"""
out = np.array([int(arrin[np.where(residues == i)[0], np.where(residues == j)[0][:, np.newaxis]].sum())
for i in range(nresidues)
Expand All @@ -378,10 +354,8 @@ def atom2res(arrin, residues, nresidues, norm=False):

def bondtype(i, j):
maxatomtype = 10
assert i >= 1
assert i <= maxatomtype
assert j >= 1
assert j <= maxatomtype
assert 1 <= i <= maxatomtype
assert 1 <= j <= maxatomtype

i -= 1
j -= 1
Expand Down Expand Up @@ -416,38 +390,33 @@ def contact_info(system):
chains = []
resnames = []
positions_all = []
cogs = []
ca_pos = []
vdw_list = []
atypes = []
res_serial = []
nodes = []
for node in G.nodes:
for residue in G.nodes:
# we only need these for writing at the end
resnames.append(G.nodes[node]['resname'])
resids.append(G.nodes[node]['resid'])
chains.append(G.nodes[node]['chain'])
nodes.append(G.nodes[node]['_res_serial'])
resnames.append(G.nodes[residue]['resname'])
resids.append(G.nodes[residue]['resid'])
chains.append(G.nodes[residue]['chain'])
nodes.append(G.nodes[residue]['_res_serial'])

res_pos = []
for subnode in sorted(G.nodes[node]['graph'].nodes):
if 'position' in G.nodes[node]['graph'].nodes[subnode]:
res_serial.append(G.nodes[node]['graph'].nodes[subnode]['_res_serial'])
for atom in sorted(G.nodes[residue]['graph'].nodes):
if 'position' in G.nodes[residue]['graph'].nodes[atom]:
res_serial.append(G.nodes[residue]['graph'].nodes[atom]['_res_serial'])

res_pos.append(G.nodes[node]['graph'].nodes[subnode]['position'] * 10)
positions_all.append(G.nodes[node]['graph'].nodes[subnode]['position'] * 10)
positions_all.append(G.nodes[residue]['graph'].nodes[atom]['position'] * 10)

vdw_list.append(get_vdw_radius(G.nodes[node]['graph'].nodes[subnode]['resname'],
G.nodes[node]['graph'].nodes[subnode]['atomname']))
atypes.append(get_atype(G.nodes[node]['graph'].nodes[subnode]['resname'],
G.nodes[node]['graph'].nodes[subnode]['atomname']))
vdw_list.append(get_vdw_radius(G.nodes[residue]['graph'].nodes[atom]['resname'],
G.nodes[residue]['graph'].nodes[atom]['atomname']))
atypes.append(get_atype(G.nodes[residue]['graph'].nodes[atom]['resname'],
G.nodes[residue]['graph'].nodes[atom]['atomname']))

if G.nodes[node]['graph'].nodes[subnode]['atomname'] == 'CA':
ca_pos.append(G.nodes[node]['graph'].nodes[subnode]['position'])
if G.nodes[residue]['graph'].nodes[atom]['atomname'] == 'CA':
ca_pos.append(G.nodes[residue]['graph'].nodes[atom]['position'])

cogs.append(np.stack(np.tile(np.stack(res_pos).mean(axis=0), (len(res_pos), 1))))

cogs = np.vstack(cogs)
vdw_list = np.array(vdw_list)
atypes = np.array(atypes)
coords = np.stack(positions_all)
Expand All @@ -461,10 +430,10 @@ def contact_info(system):
# 2) find the number of residues that we have
nresidues = len(G)

return cogs, vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G
return vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G


def calculate_contact_map(cogs, vdw_list, atypes, coords, res_serial,
def calculate_contact_map(vdw_list, atypes, coords, res_serial,
resids, chains, resnames, nodes, ca_pos, nresidues,
G):

Expand All @@ -473,80 +442,52 @@ def calculate_contact_map(cogs, vdw_list, atypes, coords, res_serial,
fiba, fibb = 0, 1
for _ in range(fib):
fiba, fibb = fibb, fiba + fibb
natoms = len(coords)

alpha = 1.24 # Enlargement factor for attraction effects
water_radius = 2.80 # Radius of a water molecule in A

# 4) set up final bits of information
# sums of vdw pairs for each atom we have
vdw_sum_array = vdw_list[:, np.newaxis] + vdw_list[np.newaxis, :]
coords_tree = KDTree(coords)
# all_vdw = np.array([x for xs in
# [[PROTEIN_MAP[i][j]['vrad'] for j in PROTEIN_MAP[i].keys()]
# for i in PROTEIN_MAP.keys()]
# for x in xs])
over = np.zeros((len(coords), len(coords)))
vdw_max = 1.88 # all_vdw.max()
over_sdm = coords_tree.sparse_distance_matrix(coords_tree, 2 * vdw_max * alpha)
for (idx, jdx), distance_between in over_sdm.items():
if idx != jdx:
if distance_between < (vdw_list[idx] + vdw_list[jdx]) * alpha:
over[idx, jdx] = 1

# set up the surface overlap criterion
# generate fibonacci spheres for all atoms.
# can't decide whether quicker/better for memory to generate all in one go here
# or incorporate into the loop. code to do it more on the fly is left in the loop for now
spheres = np.stack([KDTree(make_surface(i, fiba, fibb, water_radius + j))
for i, j in zip(coords, vdw_list)])

# distances between all the atoms that we have
atomic_distances = cdist(coords, coords)
hit_results = np.full((natoms, fibb), -1)
dists_counter = np.full((natoms, fibb), np.inf)
surface_sdm = coords_tree.sparse_distance_matrix(coords_tree, (2 * vdw_max) + water_radius)
for (idx, jdx), distance_between in surface_sdm.items():
if idx != jdx:

# array with 1 on the diagonal, so we can exclude the self atoms
diagonal_ones = np.diagflat(np.ones(atomic_distances.shape[0], dtype=int))
if distance_between < (vdw_list[idx] + vdw_list[jdx] + water_radius):

# get the coordinates of the centres of geometry for each residue
res_dists = cdist(cogs, cogs)
base_tree = spheres[idx]
vdw = vdw_list[jdx] + water_radius

# 5) find atoms which meet the overlap criterion
over = np.zeros_like(atomic_distances)
overlaps = np.where((atomic_distances <= (vdw_sum_array * alpha)) &
(diagonal_ones != 1) &
(res_dists < 14))
over[overlaps[0], overlaps[1]] = 1
res = np.array(base_tree.query_ball_point(coords[jdx], vdw))
if len(res) > 0:
to_fill = np.where(distance_between < dists_counter[idx][res])[0]

# 6) set up the surface overlap criterion
# generate fibonacci spheres for all atoms.
# can't decide whether quicker/better for memory to generate all in one go here
# or incorporate into the loop. code to do it more on the fly is left in the loop for now
spheres = np.stack([make_surface(i, fiba, fibb, water_radius + j) for i, j in zip(coords, vdw_list)])
surface_overlaps = np.where((atomic_distances <= (vdw_sum_array + water_radius)) &
(diagonal_ones != 1) &
(res_dists < 14))
# find which atoms are uniquely involved as base points
base_points = np.unique(surface_overlaps[0])

hit_results = np.ones((spheres.shape[0], spheres.shape[1]), dtype=int) * -1

# loop over all base points
for base_point in base_points:

# generate the base point sphere now if we didn't earlier.
# sphere = make_surface(coords[base_point], fiba, fibb, vdw_list[base_point] + water_radius)
# get the target points
target_points = surface_overlaps[1][np.where(surface_overlaps[0] == base_point)[0]]
# array of all the target point coordinates
target_point_coords = coords[target_points]
# distances between the points on the base sphere surface and the target point coordinates
surface_to_point = cdist(spheres[base_point], target_point_coords)
# surface_to_point = cdist(sphere, target_point_coords)
# cutoff distances for each of the target points
target_distances = vdw_list[target_points] + water_radius

for i, j in enumerate(surface_to_point):
'''
first find where the radius condition is met, i.e. where the distance between
the target point and this point on the surface is smaller than the vdw radius
of the target point
'''
radius_condition = j < target_distances
if any(radius_condition):
'''
For all the points that meet this condition, look at the distance between the
target point and the base point
'''
distances_to_compare = atomic_distances[base_point][target_points[radius_condition]]
'''
the point that we need is the point with the smallest distance
'''
point_needed = target_points[radius_condition][distances_to_compare.argmin()]
hit_results[base_point, i] = point_needed

contactcounter_1 = np.zeros_like(atomic_distances)
stabilisercounter_1 = np.zeros_like(atomic_distances)
destabilisercounter_1 = np.zeros_like(atomic_distances)
dists_counter[idx][res[to_fill]] = distance_between
hit_results[idx][res[to_fill]] = jdx

contactcounter_1 = np.zeros((natoms, natoms))
stabilisercounter_1 = np.zeros((natoms, natoms))
destabilisercounter_1 = np.zeros((natoms, natoms))

for i, j in enumerate(hit_results):
for k in j:
Expand Down Expand Up @@ -670,11 +611,10 @@ def run_system(self, system):
"""
self.system = system

cogs, vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G = contact_info(system)

if self.path is None:
self.system.go_params["go_map"].append(calculate_contact_map(cogs,
vdw_list,
vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G = contact_info(
system)
self.system.go_params["go_map"].append(calculate_contact_map(vdw_list,
atypes,
coords,
res_serial,
Expand Down
2 changes: 0 additions & 2 deletions vermouth/rcsu/go_structure_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class ComputeStructuralGoBias(Processor):
replacement in the GoPipeline.
"""
def __init__(self,
# contact_map,
cutoff_short,
cutoff_long,
go_eps,
Expand Down Expand Up @@ -85,7 +84,6 @@ def __init__(self,
magic number for Go contacts from the old
GoVirt script.
"""
# self.contact_map = contact_map
self.cutoff_short = cutoff_short
self.cutoff_long = cutoff_long
self.go_eps = go_eps
Expand Down

0 comments on commit a1a217f

Please sign in to comment.