From a1a217f02a771c9521da348e1e668e5312de9fca Mon Sep 17 00:00:00 2001 From: Chris Brasnett <35073246+csbrasnett@users.noreply.github.com> Date: Mon, 16 Dec 2024 13:48:07 +0000 Subject: [PATCH] change implementation to using KDTree. --- bin/martinize2 | 20 +-- vermouth/rcsu/contact_map.py | 224 +++++++++++------------------ vermouth/rcsu/go_structure_bias.py | 2 - 3 files changed, 92 insertions(+), 154 deletions(-) diff --git a/bin/martinize2 b/bin/martinize2 index f6824eea..5b3d0a48 100755 --- a/bin/martinize2 +++ b/bin/martinize2 @@ -549,16 +549,10 @@ def entry(): dest="go", nargs='?', const=None, - type=str, + type=Path, help="Use Martini Go model. Accepts either an input file from the server, " "or just provide the flag to calculate as part of Martinize. Can be slow for large proteins " "(> 1500 residues)" - # required=False, - # type=Path, - # action='store_true', - # default=False, - # help="Contact map to be used for the Martini Go model." - # "Currently, only one format is supported. See docs.", ) go_group.add_argument( "-go-eps", @@ -834,12 +828,19 @@ def entry(): "be used together." ) + """ + Sort out the use of the go model: + + go = True : do the go model + go_file = None : calculate the contact map in martinize2 + go_file = some file : read in the contact map from elsewhere + """ go = False go_file = None if args.go is None: go = True else: - if os.path.isfile(args.go): + if Path(args.go).is_file(): go_file = args.go go = True @@ -1015,14 +1016,13 @@ def entry(): # Generate the Go model if required if go: - if not go_map: + if not system.go_params["go_map"]: LOGGER.info("Reading Go model contact map.", type="step") GenerateContactMap(path=go_file).run_system(system) go_name_prefix = args.molname LOGGER.info("Generating the Go model.", type="step") GoPipeline.run_system(system, moltype=go_name_prefix, - # contact_map=go_map, cutoff_short=args.go_low, cutoff_long=args.go_up, go_eps=args.go_eps, diff --git a/vermouth/rcsu/contact_map.py b/vermouth/rcsu/contact_map.py index c16606da..859060f2 100644 --- a/vermouth/rcsu/contact_map.py +++ b/vermouth/rcsu/contact_map.py @@ -15,6 +15,7 @@ from ..processors.processor import Processor import numpy as np from scipy.spatial.distance import euclidean, cdist +from scipy.spatial import cKDTree as KDTree from .. import MergeAllMolecules from ..graph_utils import make_residue_graph @@ -280,10 +281,10 @@ def get_vdw_radius(resname, atomname): """ res_vdw = PROTEIN_MAP[resname] try: - atom_vdw = res_vdw[atomname]['vrad'] + atom_vdw = res_vdw[atomname] except KeyError: - atom_vdw = res_vdw['default']['vrad'] - return atom_vdw + atom_vdw = res_vdw['default'] + return atom_vdw['vrad'] def get_atype(resname, atomname): @@ -292,19 +293,25 @@ def get_atype(resname, atomname): """ res_vdw = PROTEIN_MAP[resname] try: - atom_vdw = res_vdw[atomname]['atype'] + atom_vdw = res_vdw[atomname] except KeyError: - atom_vdw = res_vdw['default']['atype'] - return atom_vdw + atom_vdw = res_vdw['default'] + return atom_vdw['atype'] def make_surface(position, fiba, fibb, vrad): """ Generate points on a sphere using Fibonacci points + position: np.array + shape (3,) array of an atomic position to build a sphere around + fiba: int. n-1 fibonacci number to build number of points on sphere + fibb: int. n fibonacci number to build number of points on sphere + vrad: float. VdW radius of the input atom to build a sphere around. + position: centre of sphere """ - x, y, z = position[0], position[1], position[2] + x, y, z = position phi_aux = 0 surface = np.zeros((0, 3)) @@ -323,49 +330,18 @@ def make_surface(position, fiba, fibb, vrad): return surface -def res2atom(arrin, residues, nresidues): - """ - take an array with residue level data and repeat the entries over - each atom within the residue - - would be nice to do this with list comprehension but I can't work out how - to do something like: - - [np.tile(res_dists[i,j], - np.where(residues == i)[0].size) - for i in range(nresidues) - for j in range(nresidues) - ] - - to get it correct in the 2 dimensional way we actually need. - - At the moment we only need this function once - (to get the residue COGs foreach atom) - so it's not too much of a limiting factor, but something to optimise - better in future - - """ - # find out how many residues we have, and how many atoms are in each of them - unique_values, counts = np.unique(residues, return_counts=True) - - assert len(unique_values) == nresidues - - out = np.zeros((len(residues), len(residues))) - start0 = 0 - for i, j in zip(unique_values, counts): - start1 = 0 - for k, l in zip(unique_values, counts): - target_value = arrin[i, k] - out[start0:start0 + j, start1:start1 + l] = target_value.sum() - start1 += l - start0 += j - - return out - - def atom2res(arrin, residues, nresidues, norm=False): """ take an array with atom level data and sum the entries over within the residue + + arrin: np.ndarray + NxN array of entries for each atom + residues: np.array + array of length N indicating which residue an atom belongs to + nresidues: int + number of residues in the molecule + norm: bool + if True, then any entry > 0 in the summed array = 1 """ out = np.array([int(arrin[np.where(residues == i)[0], np.where(residues == j)[0][:, np.newaxis]].sum()) for i in range(nresidues) @@ -378,10 +354,8 @@ def atom2res(arrin, residues, nresidues, norm=False): def bondtype(i, j): maxatomtype = 10 - assert i >= 1 - assert i <= maxatomtype - assert j >= 1 - assert j <= maxatomtype + assert 1 <= i <= maxatomtype + assert 1 <= j <= maxatomtype i -= 1 j -= 1 @@ -416,38 +390,33 @@ def contact_info(system): chains = [] resnames = [] positions_all = [] - cogs = [] ca_pos = [] vdw_list = [] atypes = [] res_serial = [] nodes = [] - for node in G.nodes: + for residue in G.nodes: # we only need these for writing at the end - resnames.append(G.nodes[node]['resname']) - resids.append(G.nodes[node]['resid']) - chains.append(G.nodes[node]['chain']) - nodes.append(G.nodes[node]['_res_serial']) + resnames.append(G.nodes[residue]['resname']) + resids.append(G.nodes[residue]['resid']) + chains.append(G.nodes[residue]['chain']) + nodes.append(G.nodes[residue]['_res_serial']) - res_pos = [] - for subnode in sorted(G.nodes[node]['graph'].nodes): - if 'position' in G.nodes[node]['graph'].nodes[subnode]: - res_serial.append(G.nodes[node]['graph'].nodes[subnode]['_res_serial']) + for atom in sorted(G.nodes[residue]['graph'].nodes): + if 'position' in G.nodes[residue]['graph'].nodes[atom]: + res_serial.append(G.nodes[residue]['graph'].nodes[atom]['_res_serial']) - res_pos.append(G.nodes[node]['graph'].nodes[subnode]['position'] * 10) - positions_all.append(G.nodes[node]['graph'].nodes[subnode]['position'] * 10) + positions_all.append(G.nodes[residue]['graph'].nodes[atom]['position'] * 10) - vdw_list.append(get_vdw_radius(G.nodes[node]['graph'].nodes[subnode]['resname'], - G.nodes[node]['graph'].nodes[subnode]['atomname'])) - atypes.append(get_atype(G.nodes[node]['graph'].nodes[subnode]['resname'], - G.nodes[node]['graph'].nodes[subnode]['atomname'])) + vdw_list.append(get_vdw_radius(G.nodes[residue]['graph'].nodes[atom]['resname'], + G.nodes[residue]['graph'].nodes[atom]['atomname'])) + atypes.append(get_atype(G.nodes[residue]['graph'].nodes[atom]['resname'], + G.nodes[residue]['graph'].nodes[atom]['atomname'])) - if G.nodes[node]['graph'].nodes[subnode]['atomname'] == 'CA': - ca_pos.append(G.nodes[node]['graph'].nodes[subnode]['position']) + if G.nodes[residue]['graph'].nodes[atom]['atomname'] == 'CA': + ca_pos.append(G.nodes[residue]['graph'].nodes[atom]['position']) - cogs.append(np.stack(np.tile(np.stack(res_pos).mean(axis=0), (len(res_pos), 1)))) - cogs = np.vstack(cogs) vdw_list = np.array(vdw_list) atypes = np.array(atypes) coords = np.stack(positions_all) @@ -461,10 +430,10 @@ def contact_info(system): # 2) find the number of residues that we have nresidues = len(G) - return cogs, vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G + return vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G -def calculate_contact_map(cogs, vdw_list, atypes, coords, res_serial, +def calculate_contact_map(vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G): @@ -473,80 +442,52 @@ def calculate_contact_map(cogs, vdw_list, atypes, coords, res_serial, fiba, fibb = 0, 1 for _ in range(fib): fiba, fibb = fibb, fiba + fibb + natoms = len(coords) alpha = 1.24 # Enlargement factor for attraction effects water_radius = 2.80 # Radius of a water molecule in A - # 4) set up final bits of information - # sums of vdw pairs for each atom we have - vdw_sum_array = vdw_list[:, np.newaxis] + vdw_list[np.newaxis, :] + coords_tree = KDTree(coords) + # all_vdw = np.array([x for xs in + # [[PROTEIN_MAP[i][j]['vrad'] for j in PROTEIN_MAP[i].keys()] + # for i in PROTEIN_MAP.keys()] + # for x in xs]) + over = np.zeros((len(coords), len(coords))) + vdw_max = 1.88 # all_vdw.max() + over_sdm = coords_tree.sparse_distance_matrix(coords_tree, 2 * vdw_max * alpha) + for (idx, jdx), distance_between in over_sdm.items(): + if idx != jdx: + if distance_between < (vdw_list[idx] + vdw_list[jdx]) * alpha: + over[idx, jdx] = 1 + + # set up the surface overlap criterion + # generate fibonacci spheres for all atoms. + # can't decide whether quicker/better for memory to generate all in one go here + # or incorporate into the loop. code to do it more on the fly is left in the loop for now + spheres = np.stack([KDTree(make_surface(i, fiba, fibb, water_radius + j)) + for i, j in zip(coords, vdw_list)]) - # distances between all the atoms that we have - atomic_distances = cdist(coords, coords) + hit_results = np.full((natoms, fibb), -1) + dists_counter = np.full((natoms, fibb), np.inf) + surface_sdm = coords_tree.sparse_distance_matrix(coords_tree, (2 * vdw_max) + water_radius) + for (idx, jdx), distance_between in surface_sdm.items(): + if idx != jdx: - # array with 1 on the diagonal, so we can exclude the self atoms - diagonal_ones = np.diagflat(np.ones(atomic_distances.shape[0], dtype=int)) + if distance_between < (vdw_list[idx] + vdw_list[jdx] + water_radius): - # get the coordinates of the centres of geometry for each residue - res_dists = cdist(cogs, cogs) + base_tree = spheres[idx] + vdw = vdw_list[jdx] + water_radius - # 5) find atoms which meet the overlap criterion - over = np.zeros_like(atomic_distances) - overlaps = np.where((atomic_distances <= (vdw_sum_array * alpha)) & - (diagonal_ones != 1) & - (res_dists < 14)) - over[overlaps[0], overlaps[1]] = 1 + res = np.array(base_tree.query_ball_point(coords[jdx], vdw)) + if len(res) > 0: + to_fill = np.where(distance_between < dists_counter[idx][res])[0] - # 6) set up the surface overlap criterion - # generate fibonacci spheres for all atoms. - # can't decide whether quicker/better for memory to generate all in one go here - # or incorporate into the loop. code to do it more on the fly is left in the loop for now - spheres = np.stack([make_surface(i, fiba, fibb, water_radius + j) for i, j in zip(coords, vdw_list)]) - surface_overlaps = np.where((atomic_distances <= (vdw_sum_array + water_radius)) & - (diagonal_ones != 1) & - (res_dists < 14)) - # find which atoms are uniquely involved as base points - base_points = np.unique(surface_overlaps[0]) - - hit_results = np.ones((spheres.shape[0], spheres.shape[1]), dtype=int) * -1 - - # loop over all base points - for base_point in base_points: - - # generate the base point sphere now if we didn't earlier. - # sphere = make_surface(coords[base_point], fiba, fibb, vdw_list[base_point] + water_radius) - # get the target points - target_points = surface_overlaps[1][np.where(surface_overlaps[0] == base_point)[0]] - # array of all the target point coordinates - target_point_coords = coords[target_points] - # distances between the points on the base sphere surface and the target point coordinates - surface_to_point = cdist(spheres[base_point], target_point_coords) - # surface_to_point = cdist(sphere, target_point_coords) - # cutoff distances for each of the target points - target_distances = vdw_list[target_points] + water_radius - - for i, j in enumerate(surface_to_point): - ''' - first find where the radius condition is met, i.e. where the distance between - the target point and this point on the surface is smaller than the vdw radius - of the target point - ''' - radius_condition = j < target_distances - if any(radius_condition): - ''' - For all the points that meet this condition, look at the distance between the - target point and the base point - ''' - distances_to_compare = atomic_distances[base_point][target_points[radius_condition]] - ''' - the point that we need is the point with the smallest distance - ''' - point_needed = target_points[radius_condition][distances_to_compare.argmin()] - hit_results[base_point, i] = point_needed - - contactcounter_1 = np.zeros_like(atomic_distances) - stabilisercounter_1 = np.zeros_like(atomic_distances) - destabilisercounter_1 = np.zeros_like(atomic_distances) + dists_counter[idx][res[to_fill]] = distance_between + hit_results[idx][res[to_fill]] = jdx + + contactcounter_1 = np.zeros((natoms, natoms)) + stabilisercounter_1 = np.zeros((natoms, natoms)) + destabilisercounter_1 = np.zeros((natoms, natoms)) for i, j in enumerate(hit_results): for k in j: @@ -670,11 +611,10 @@ def run_system(self, system): """ self.system = system - cogs, vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G = contact_info(system) - if self.path is None: - self.system.go_params["go_map"].append(calculate_contact_map(cogs, - vdw_list, + vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G = contact_info( + system) + self.system.go_params["go_map"].append(calculate_contact_map(vdw_list, atypes, coords, res_serial, diff --git a/vermouth/rcsu/go_structure_bias.py b/vermouth/rcsu/go_structure_bias.py index b516657c..e509e709 100644 --- a/vermouth/rcsu/go_structure_bias.py +++ b/vermouth/rcsu/go_structure_bias.py @@ -46,7 +46,6 @@ class ComputeStructuralGoBias(Processor): replacement in the GoPipeline. """ def __init__(self, - # contact_map, cutoff_short, cutoff_long, go_eps, @@ -85,7 +84,6 @@ def __init__(self, magic number for Go contacts from the old GoVirt script. """ - # self.contact_map = contact_map self.cutoff_short = cutoff_short self.cutoff_long = cutoff_long self.go_eps = go_eps