Skip to content

Commit

Permalink
tidy up for readability and testing, added comments and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
csbrasnett committed Dec 16, 2024
1 parent a1a217f commit 3468c0e
Showing 1 changed file with 175 additions and 71 deletions.
246 changes: 175 additions & 71 deletions vermouth/rcsu/contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ def bondtype(i, j):


def contact_info(system):
"""
get the atom attributes that we need to calculate the contacts
"""

system = MergeAllMolecules().run_system(system)
G = make_residue_graph(system.molecules[0])
Expand Down Expand Up @@ -432,59 +435,105 @@ def contact_info(system):

return vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G


def calculate_contact_map(vdw_list, atypes, coords, res_serial,
resids, chains, resnames, nodes, ca_pos, nresidues,
G):

# some initial definitions of variables that we need
fib = 14
fiba, fibb = 0, 1
for _ in range(fib):
fiba, fibb = fibb, fiba + fibb
natoms = len(coords)

def calculate_overlap(coords_tree, vdw_list, natoms, vdw_max):
"""
Find enlarged (OV) overlap contacts
coords_tree: KDTree
KDTree of the input coordinates
vdw_list: list
list of vdw radii of the input coordinates
natoms: int
number of atoms in the molecule
vdw_max: float
maximum possible vdw radius of atoms
"""
alpha = 1.24 # Enlargement factor for attraction effects
water_radius = 2.80 # Radius of a water molecule in A

coords_tree = KDTree(coords)
# all_vdw = np.array([x for xs in
# [[PROTEIN_MAP[i][j]['vrad'] for j in PROTEIN_MAP[i].keys()]
# for i in PROTEIN_MAP.keys()]
# for x in xs])
over = np.zeros((len(coords), len(coords)))
vdw_max = 1.88 # all_vdw.max()
over = np.zeros((natoms, natoms))
over_sdm = coords_tree.sparse_distance_matrix(coords_tree, 2 * vdw_max * alpha)
for (idx, jdx), distance_between in over_sdm.items():
if idx != jdx:
if distance_between < (vdw_list[idx] + vdw_list[jdx]) * alpha:
over[idx, jdx] = 1
return over

def calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max):
"""
Calculate contacts of structural units (CSU)
coords: Nx3 numpy array
coordinates of atoms in the molecule
vdw_list: list
vdw radii of the atoms in the molecule
fiba, fibb: int
n-1th and nth fibonacci numbers from which to generate points on a sphere around the input coordinate
natoms: int
number of atoms in the molecule
coords_tree: KDTree
KDTree of the input coordinates
vdw_max: float
maximum possible vdw radius of atoms
Returns:
hit_results: natoms x fibb np.array
each i,j entry is the index of the atom in coords which is the closest atom to atom i at index j of the
fibonacci sphere
"""
water_radius = 2.80 # Radius of a water molecule in A

# set up the surface overlap criterion
# generate fibonacci spheres for all atoms.
# can't decide whether quicker/better for memory to generate all in one go here
# or incorporate into the loop. code to do it more on the fly is left in the loop for now
# can't decide whether quicker/better for memory to generate all in one go here or incorporate into the loop
spheres = np.stack([KDTree(make_surface(i, fiba, fibb, water_radius + j))
for i, j in zip(coords, vdw_list)])

#setup arrays to keep track
hit_results = np.full((natoms, fibb), -1)
dists_counter = np.full((natoms, fibb), np.inf)

# sparse matrix with a cutoff at the maximum possible distance for a contact
surface_sdm = coords_tree.sparse_distance_matrix(coords_tree, (2 * vdw_max) + water_radius)
# n.b. this loop works because sparse_distance_matrix is sorted by (idx, jdx) pairs
for (idx, jdx), distance_between in surface_sdm.items():
# don't take atoms which are identical
if idx != jdx:

# check that the distance between them is shorter than the vdw sum and the water radius
if distance_between < (vdw_list[idx] + vdw_list[jdx] + water_radius):

# get the KDTree corresponding to this base point
base_tree = spheres[idx]

# get the vdw distance of the target point
vdw = vdw_list[jdx] + water_radius

# find points on the base point sphere which are within the vdw cutoff of the target point's coordinate
res = np.array(base_tree.query_ball_point(coords[jdx], vdw))

# if we have any results
if len(res) > 0:
# find where the distance between the two points is smaller than the current recorded distance
# at the points which are within the cutoff
to_fill = np.where(distance_between < dists_counter[idx][res])[0]

# record the new distances and indices of the points
dists_counter[idx][res[to_fill]] = distance_between
hit_results[idx][res[to_fill]] = jdx

return hit_results


def contact_types(hit_results, natoms, atypes):
"""
From CSU contacts, establish contact types from atomtypes
hit_results: NxM ndarray
array for N atoms in molecule for M fibonnaci points on each atom.
Each i,j entry is the index of the atom which is the closest contact to i
natoms: int
number of atoms in the molecule
atypes: array
list of the atomtypes of each atom in the molecule
"""

contactcounter_1 = np.zeros((natoms, natoms))
stabilisercounter_1 = np.zeros((natoms, natoms))
destabilisercounter_1 = np.zeros((natoms, natoms))
Expand All @@ -502,54 +551,107 @@ def calculate_contact_map(vdw_list, atypes, coords, res_serial,
if btype == 5:
destabilisercounter_1[i, k] += 1

return contactcounter_1, stabilisercounter_1, destabilisercounter_1


def calculate_contacts(vdw_list, atypes, coords, res_serial, nresidues):
"""
run the contact calculation functions
vdw_list: np.array
list of the vdw radii of the atoms in the system
atypes: np.array
list of the atom types in the system to determine the nature of contacts
coords: nx3 array
coordinates of all the atoms in the system
res_serial: np.array
list of the serial residue number of each atom in the system
nresidues: int
number of residues in the system
"""

# some initial definitions of variables that we need
fib = 14
fiba, fibb = 0, 1
for _ in range(fib):
fiba, fibb = fibb, fiba + fibb

natoms = len(coords)

# all_vdw = np.array([x for xs in
# [[PROTEIN_MAP[i][j]['vrad'] for j in PROTEIN_MAP[i].keys()]
# for i in PROTEIN_MAP.keys()]
# for x in xs])
vdw_max = 1.88 # all_vdw.max()

# make the KDTree of the input coordinates
coords_tree = KDTree(coords)

# calculate the OV contacts of the molecule
over = calculate_overlap(coords_tree, vdw_list, natoms, vdw_max)

# Calculate the CSU contacts of the molecule
hit_results = calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max)

# find the types of contacts we have
contactcounter_1, stabilisercounter_1, destabilisercounter_1 = contact_types(hit_results, natoms, atypes)

# transform the resolution between atoms and residues
overlapcounter_2 = atom2res(over, res_serial, nresidues, norm=True)
contactcounter_2 = atom2res(contactcounter_1, res_serial, nresidues)
stabilisercounter_2 = atom2res(stabilisercounter_1, res_serial, nresidues)
destabilisercounter_2 = atom2res(destabilisercounter_1, res_serial, nresidues)

# # this to write out the file if needed
# with open('contact_map_vermouth.out', 'w') as f:
# count = 0
# for i1 in range(nresidues):
# for i2 in range(nresidues):
# over = overlapcounter_2[i1, i2]
# cont = contactcounter_2[i1, i2]
# stab = stabilisercounter_2[i1, i2]
# dest = destabilisercounter_2[i1, i2]
# ocsu = stab
# rcsu = stab - dest
#
# if (over > 0 or cont > 0) and (i1 != i2):
# a = np.where(nodes == i1)[0][0]
# b = np.where(nodes == i2)[0][0]
# count += 1
# msg = (f"R {int(count):6d} "
# f"{int(i1 + 1):5d} {G.nodes[a]['resname']:3s} {G.nodes[a]['chain']:1s} {int(G.nodes[a]['resid']):4d} "
# f"{int(i2 + 1):5d} {G.nodes[b]['resname']:3s} {G.nodes[b]['chain']:1s} {int(G.nodes[b]['resid']):4d} "
# f"{euclidean(ca_pos[a], ca_pos[b])*10:9.4f} "
# f"{int(over):1d} {1 if cont != 0 else 0} {1 if ocsu != 0 else 0} {1 if rcsu > 0 else 0}"
# f"{int(rcsu):6d} {int(cont):6d}\n")
# f.writelines(msg)

contacts = []
return overlapcounter_2, contactcounter_2, stabilisercounter_2, destabilisercounter_2


def get_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, nodes, G):
contacts_list = []
for i1 in range(nresidues):
for i2 in range(nresidues):
over = overlapcounter_2[i1, i2]
cont = contactcounter_2[i1, i2]
stab = stabilisercounter_2[i1, i2]
dest = destabilisercounter_2[i1, i2]
# ocsu = stab
over = overlaps[i1, i2]
cont = contacts[i1, i2]
stab = stabilisers[i1, i2]
dest = destabilisers[i1, i2]
rcsu = 1 if (stab - dest) > 0 else 0

if (over > 0 or cont > 0) and (i1 != i2):
a = np.where(nodes == i1)[0][0]
b = np.where(nodes == i2)[0][0]
if over == 1 or (over == 0 and rcsu == 1):
# this is a OV or rCSU contact we take it
contacts.append((int(G.nodes[a]['resid']), G.nodes[a]['chain'],
int(G.nodes[b]['resid']), G.nodes[b]['chain']))

return contacts
contacts_list.append((int(G.nodes[a]['resid']), G.nodes[a]['chain'],
int(G.nodes[b]['resid']), G.nodes[b]['chain']))

return contacts_list


def write_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, nodes, ca_pos, G):
# this to write out the file if needed
with open('contact_map_vermouth.out', 'w') as f:
count = 0
for i1 in range(nresidues):
for i2 in range(nresidues):
over = overlaps[i1, i2]
cont = contacts[i1, i2]
stab = stabilisers[i1, i2]
dest = destabilisers[i1, i2]
ocsu = stab
rcsu = stab - dest

if (over > 0 or cont > 0) and (i1 != i2):
a = np.where(nodes == i1)[0][0]
b = np.where(nodes == i2)[0][0]
count += 1
msg = (f"R {int(count):6d} "
f"{int(i1 + 1):5d} {G.nodes[a]['resname']:3s}"
f"{G.nodes[a]['chain']:1s} {int(G.nodes[a]['resid']):4d} "
f"{int(i2 + 1):5d} {G.nodes[b]['resname']:3s}"
f"{G.nodes[b]['chain']:1s} {int(G.nodes[b]['resid']):4d} "
f"{euclidean(ca_pos[a], ca_pos[b])*10:9.4f} "
f"{int(over):1d} {1 if cont != 0 else 0} {1 if ocsu != 0 else 0} {1 if rcsu > 0 else 0}"
f"{int(rcsu):6d} {int(cont):6d}\n")
f.writelines(msg)


"""
Expand Down Expand Up @@ -614,16 +716,18 @@ def run_system(self, system):
if self.path is None:
vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G = contact_info(
system)
self.system.go_params["go_map"].append(calculate_contact_map(vdw_list,
atypes,
coords,
res_serial,
resids,
chains,
resnames,
nodes,
ca_pos,
nresidues,
G))

overlaps, contacts, stabilisers, destabilisers = calculate_contacts(vdw_list,
atypes,
coords,
res_serial,
nresidues)

self.system.go_params["go_map"].append(get_contacts(nresidues,
overlaps, contacts,
stabilisers,
destabilisers,
nodes,
G))
else:
self.system.go_params["go_map"].append(read_go_map(file_path=self.path))

0 comments on commit 3468c0e

Please sign in to comment.