diff --git a/ihm_validation/excludedvolume.py b/ihm_validation/excludedvolume.py index 30b7b0fe..a0564d4c 100644 --- a/ihm_validation/excludedvolume.py +++ b/ihm_validation/excludedvolume.py @@ -12,6 +12,7 @@ import multiprocessing as mp import pandas as pd import numpy as np +from scipy.spatial import KDTree import math import os import csv @@ -82,20 +83,37 @@ def get_xyzr_complete(self, model_ID, spheres: list) -> pd.DataFrame: def get_violation_dict(self, model_spheres_df: pd.DataFrame) -> dict: """ get violation from model_sphere df""" viols = {} - for indx, col in model_spheres_df.items(): - if indx < model_spheres_df.shape[1]: - sphere_R = model_spheres_df.iloc[-1, indx:] - remaining = model_spheres_df.iloc[:-1, indx:] - subt_alone = remaining.sub(col[:-1], axis=0) - final_df = np.square(subt_alone) - final_df.loc['sqrt'] = np.sqrt(final_df.sum(axis=0)) - final_df.loc['R_tot'] = sphere_R.add( - col[[-1]].tolist()[0]).to_list() - final_df.loc['distances'] = final_df.loc['sqrt'] - \ - final_df.loc['R_tot'] - final_df.loc['violations'] = final_df.loc['distances'].apply( - lambda x: 1 if x < 0 else 0) - viols[indx] = final_df.loc['violations'].sum(axis=0) + + # Get coordinates + xyz = model_spheres_df.T[['X', 'Y', 'Z']].to_numpy() + # Get radii + radii = model_spheres_df.T[['R']].to_numpy() + # Get maximum radius + maxr = np.max(radii) + # Build tree + t = KDTree(xyz) + + # The enumeration is done to preveserve + # compatibility as it's a drop-in replacement + for indx, i in enumerate(range(len(xyz) - 1), 1): + viols_ = 0. + # Get neighours in R1+R2 radius, where + # R1 is particle's i radius + # and R2 is the maxium radius + # Thus it's a greedy search + nb = t.query_ball_point(xyz[i], radii[i] + maxr) + + # Check each neighbour + for j in nb: + # Only check pairs in a triangle + if j > i: + # np.linalg.norm is slow, but convenient + d = np.linalg.norm(xyz[i] - xyz[j]) + if d < (radii[i] + radii[j]): + viols_ += 1. + + viols[indx] = viols_ + return viols def get_exc_vol_for_models(self, model_dict: dict) -> dict: