Skip to content

Commit

Permalink
new excluded volume calculation routine; fixes #75
Browse files Browse the repository at this point in the history
  • Loading branch information
aozalevsky committed Mar 13, 2024
1 parent bfa02c4 commit 4eef0f0
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions ihm_validation/excludedvolume.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import multiprocessing as mp
import pandas as pd
import numpy as np
from scipy.spatial import KDTree
import math
import os
import csv
Expand Down Expand Up @@ -82,20 +83,37 @@ def get_xyzr_complete(self, model_ID, spheres: list) -> pd.DataFrame:
def get_violation_dict(self, model_spheres_df: pd.DataFrame) -> dict:
""" get violation from model_sphere df"""
viols = {}
for indx, col in model_spheres_df.items():
if indx < model_spheres_df.shape[1]:
sphere_R = model_spheres_df.iloc[-1, indx:]
remaining = model_spheres_df.iloc[:-1, indx:]
subt_alone = remaining.sub(col[:-1], axis=0)
final_df = np.square(subt_alone)
final_df.loc['sqrt'] = np.sqrt(final_df.sum(axis=0))
final_df.loc['R_tot'] = sphere_R.add(
col[[-1]].tolist()[0]).to_list()
final_df.loc['distances'] = final_df.loc['sqrt'] - \
final_df.loc['R_tot']
final_df.loc['violations'] = final_df.loc['distances'].apply(
lambda x: 1 if x < 0 else 0)
viols[indx] = final_df.loc['violations'].sum(axis=0)

# Get coordinates
xyz = model_spheres_df.T[['X', 'Y', 'Z']].to_numpy()
# Get radii
radii = model_spheres_df.T[['R']].to_numpy()
# Get maximum radius
maxr = np.max(radii)
# Build tree
t = KDTree(xyz)

# The enumeration is done to preveserve
# compatibility as it's a drop-in replacement
for indx, i in enumerate(range(len(xyz) - 1), 1):
viols_ = 0.
# Get neighours in R1+R2 radius, where
# R1 is particle's i radius
# and R2 is the maxium radius
# Thus it's a greedy search
nb = t.query_ball_point(xyz[i], radii[i] + maxr)

# Check each neighbour
for j in nb:
# Only check pairs in a triangle
if j > i:
# np.linalg.norm is slow, but convenient
d = np.linalg.norm(xyz[i] - xyz[j])
if d < (radii[i] + radii[j]):
viols_ += 1.

viols[indx] = viols_

return viols

def get_exc_vol_for_models(self, model_dict: dict) -> dict:
Expand Down

0 comments on commit 4eef0f0

Please sign in to comment.