diff --git a/vermouth/rcsu/contact_map.py b/vermouth/rcsu/contact_map.py index 859060f2..48231b32 100644 --- a/vermouth/rcsu/contact_map.py +++ b/vermouth/rcsu/contact_map.py @@ -382,6 +382,9 @@ def bondtype(i, j): def contact_info(system): + """ + get the atom attributes that we need to calculate the contacts + """ system = MergeAllMolecules().run_system(system) G = make_residue_graph(system.molecules[0]) @@ -432,59 +435,105 @@ def contact_info(system): return vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G - -def calculate_contact_map(vdw_list, atypes, coords, res_serial, - resids, chains, resnames, nodes, ca_pos, nresidues, - G): - - # some initial definitions of variables that we need - fib = 14 - fiba, fibb = 0, 1 - for _ in range(fib): - fiba, fibb = fibb, fiba + fibb - natoms = len(coords) - +def calculate_overlap(coords_tree, vdw_list, natoms, vdw_max): + """ + Find enlarged (OV) overlap contacts + + coords_tree: KDTree + KDTree of the input coordinates + vdw_list: list + list of vdw radii of the input coordinates + natoms: int + number of atoms in the molecule + vdw_max: float + maximum possible vdw radius of atoms + """ alpha = 1.24 # Enlargement factor for attraction effects - water_radius = 2.80 # Radius of a water molecule in A - - coords_tree = KDTree(coords) - # all_vdw = np.array([x for xs in - # [[PROTEIN_MAP[i][j]['vrad'] for j in PROTEIN_MAP[i].keys()] - # for i in PROTEIN_MAP.keys()] - # for x in xs]) - over = np.zeros((len(coords), len(coords))) - vdw_max = 1.88 # all_vdw.max() + over = np.zeros((natoms, natoms)) over_sdm = coords_tree.sparse_distance_matrix(coords_tree, 2 * vdw_max * alpha) for (idx, jdx), distance_between in over_sdm.items(): if idx != jdx: if distance_between < (vdw_list[idx] + vdw_list[jdx]) * alpha: over[idx, jdx] = 1 + return over + +def calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max): + """ + Calculate contacts of structural units (CSU) + + coords: Nx3 numpy array + coordinates of atoms in the molecule + vdw_list: list + vdw radii of the atoms in the molecule + fiba, fibb: int + n-1th and nth fibonacci numbers from which to generate points on a sphere around the input coordinate + natoms: int + number of atoms in the molecule + coords_tree: KDTree + KDTree of the input coordinates + vdw_max: float + maximum possible vdw radius of atoms + + Returns: + hit_results: natoms x fibb np.array + each i,j entry is the index of the atom in coords which is the closest atom to atom i at index j of the + fibonacci sphere + + """ + water_radius = 2.80 # Radius of a water molecule in A - # set up the surface overlap criterion # generate fibonacci spheres for all atoms. - # can't decide whether quicker/better for memory to generate all in one go here - # or incorporate into the loop. code to do it more on the fly is left in the loop for now + # can't decide whether quicker/better for memory to generate all in one go here or incorporate into the loop spheres = np.stack([KDTree(make_surface(i, fiba, fibb, water_radius + j)) for i, j in zip(coords, vdw_list)]) + #setup arrays to keep track hit_results = np.full((natoms, fibb), -1) dists_counter = np.full((natoms, fibb), np.inf) + + # sparse matrix with a cutoff at the maximum possible distance for a contact surface_sdm = coords_tree.sparse_distance_matrix(coords_tree, (2 * vdw_max) + water_radius) + # n.b. this loop works because sparse_distance_matrix is sorted by (idx, jdx) pairs for (idx, jdx), distance_between in surface_sdm.items(): + # don't take atoms which are identical if idx != jdx: - + # check that the distance between them is shorter than the vdw sum and the water radius if distance_between < (vdw_list[idx] + vdw_list[jdx] + water_radius): - + # get the KDTree corresponding to this base point base_tree = spheres[idx] + + # get the vdw distance of the target point vdw = vdw_list[jdx] + water_radius + # find points on the base point sphere which are within the vdw cutoff of the target point's coordinate res = np.array(base_tree.query_ball_point(coords[jdx], vdw)) + + # if we have any results if len(res) > 0: + # find where the distance between the two points is smaller than the current recorded distance + # at the points which are within the cutoff to_fill = np.where(distance_between < dists_counter[idx][res])[0] + # record the new distances and indices of the points dists_counter[idx][res[to_fill]] = distance_between hit_results[idx][res[to_fill]] = jdx + return hit_results + + +def contact_types(hit_results, natoms, atypes): + """ + From CSU contacts, establish contact types from atomtypes + + hit_results: NxM ndarray + array for N atoms in molecule for M fibonnaci points on each atom. + Each i,j entry is the index of the atom which is the closest contact to i + natoms: int + number of atoms in the molecule + atypes: array + list of the atomtypes of each atom in the molecule + """ + contactcounter_1 = np.zeros((natoms, natoms)) stabilisercounter_1 = np.zeros((natoms, natoms)) destabilisercounter_1 = np.zeros((natoms, natoms)) @@ -502,43 +551,68 @@ def calculate_contact_map(vdw_list, atypes, coords, res_serial, if btype == 5: destabilisercounter_1[i, k] += 1 + return contactcounter_1, stabilisercounter_1, destabilisercounter_1 + + +def calculate_contacts(vdw_list, atypes, coords, res_serial, nresidues): + """ + run the contact calculation functions + + vdw_list: np.array + list of the vdw radii of the atoms in the system + atypes: np.array + list of the atom types in the system to determine the nature of contacts + coords: nx3 array + coordinates of all the atoms in the system + res_serial: np.array + list of the serial residue number of each atom in the system + nresidues: int + number of residues in the system + """ + + # some initial definitions of variables that we need + fib = 14 + fiba, fibb = 0, 1 + for _ in range(fib): + fiba, fibb = fibb, fiba + fibb + + natoms = len(coords) + + # all_vdw = np.array([x for xs in + # [[PROTEIN_MAP[i][j]['vrad'] for j in PROTEIN_MAP[i].keys()] + # for i in PROTEIN_MAP.keys()] + # for x in xs]) + vdw_max = 1.88 # all_vdw.max() + + # make the KDTree of the input coordinates + coords_tree = KDTree(coords) + + # calculate the OV contacts of the molecule + over = calculate_overlap(coords_tree, vdw_list, natoms, vdw_max) + + # Calculate the CSU contacts of the molecule + hit_results = calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max) + + # find the types of contacts we have + contactcounter_1, stabilisercounter_1, destabilisercounter_1 = contact_types(hit_results, natoms, atypes) + + # transform the resolution between atoms and residues overlapcounter_2 = atom2res(over, res_serial, nresidues, norm=True) contactcounter_2 = atom2res(contactcounter_1, res_serial, nresidues) stabilisercounter_2 = atom2res(stabilisercounter_1, res_serial, nresidues) destabilisercounter_2 = atom2res(destabilisercounter_1, res_serial, nresidues) - # # this to write out the file if needed - # with open('contact_map_vermouth.out', 'w') as f: - # count = 0 - # for i1 in range(nresidues): - # for i2 in range(nresidues): - # over = overlapcounter_2[i1, i2] - # cont = contactcounter_2[i1, i2] - # stab = stabilisercounter_2[i1, i2] - # dest = destabilisercounter_2[i1, i2] - # ocsu = stab - # rcsu = stab - dest - # - # if (over > 0 or cont > 0) and (i1 != i2): - # a = np.where(nodes == i1)[0][0] - # b = np.where(nodes == i2)[0][0] - # count += 1 - # msg = (f"R {int(count):6d} " - # f"{int(i1 + 1):5d} {G.nodes[a]['resname']:3s} {G.nodes[a]['chain']:1s} {int(G.nodes[a]['resid']):4d} " - # f"{int(i2 + 1):5d} {G.nodes[b]['resname']:3s} {G.nodes[b]['chain']:1s} {int(G.nodes[b]['resid']):4d} " - # f"{euclidean(ca_pos[a], ca_pos[b])*10:9.4f} " - # f"{int(over):1d} {1 if cont != 0 else 0} {1 if ocsu != 0 else 0} {1 if rcsu > 0 else 0}" - # f"{int(rcsu):6d} {int(cont):6d}\n") - # f.writelines(msg) - - contacts = [] + return overlapcounter_2, contactcounter_2, stabilisercounter_2, destabilisercounter_2 + + +def get_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, nodes, G): + contacts_list = [] for i1 in range(nresidues): for i2 in range(nresidues): - over = overlapcounter_2[i1, i2] - cont = contactcounter_2[i1, i2] - stab = stabilisercounter_2[i1, i2] - dest = destabilisercounter_2[i1, i2] - # ocsu = stab + over = overlaps[i1, i2] + cont = contacts[i1, i2] + stab = stabilisers[i1, i2] + dest = destabilisers[i1, i2] rcsu = 1 if (stab - dest) > 0 else 0 if (over > 0 or cont > 0) and (i1 != i2): @@ -546,10 +620,38 @@ def calculate_contact_map(vdw_list, atypes, coords, res_serial, b = np.where(nodes == i2)[0][0] if over == 1 or (over == 0 and rcsu == 1): # this is a OV or rCSU contact we take it - contacts.append((int(G.nodes[a]['resid']), G.nodes[a]['chain'], - int(G.nodes[b]['resid']), G.nodes[b]['chain'])) - - return contacts + contacts_list.append((int(G.nodes[a]['resid']), G.nodes[a]['chain'], + int(G.nodes[b]['resid']), G.nodes[b]['chain'])) + + return contacts_list + + +def write_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, nodes, ca_pos, G): + # this to write out the file if needed + with open('contact_map_vermouth.out', 'w') as f: + count = 0 + for i1 in range(nresidues): + for i2 in range(nresidues): + over = overlaps[i1, i2] + cont = contacts[i1, i2] + stab = stabilisers[i1, i2] + dest = destabilisers[i1, i2] + ocsu = stab + rcsu = stab - dest + + if (over > 0 or cont > 0) and (i1 != i2): + a = np.where(nodes == i1)[0][0] + b = np.where(nodes == i2)[0][0] + count += 1 + msg = (f"R {int(count):6d} " + f"{int(i1 + 1):5d} {G.nodes[a]['resname']:3s}" + f"{G.nodes[a]['chain']:1s} {int(G.nodes[a]['resid']):4d} " + f"{int(i2 + 1):5d} {G.nodes[b]['resname']:3s}" + f"{G.nodes[b]['chain']:1s} {int(G.nodes[b]['resid']):4d} " + f"{euclidean(ca_pos[a], ca_pos[b])*10:9.4f} " + f"{int(over):1d} {1 if cont != 0 else 0} {1 if ocsu != 0 else 0} {1 if rcsu > 0 else 0}" + f"{int(rcsu):6d} {int(cont):6d}\n") + f.writelines(msg) """ @@ -614,16 +716,18 @@ def run_system(self, system): if self.path is None: vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G = contact_info( system) - self.system.go_params["go_map"].append(calculate_contact_map(vdw_list, - atypes, - coords, - res_serial, - resids, - chains, - resnames, - nodes, - ca_pos, - nresidues, - G)) + + overlaps, contacts, stabilisers, destabilisers = calculate_contacts(vdw_list, + atypes, + coords, + res_serial, + nresidues) + + self.system.go_params["go_map"].append(get_contacts(nresidues, + overlaps, contacts, + stabilisers, + destabilisers, + nodes, + G)) else: self.system.go_params["go_map"].append(read_go_map(file_path=self.path))