Skip to content

Commit

Permalink
Add support for lDDT computation
Browse files Browse the repository at this point in the history
  • Loading branch information
padix-key committed Nov 12, 2024
1 parent d9f0c13 commit 7532607
Show file tree
Hide file tree
Showing 5 changed files with 425 additions and 5 deletions.
32 changes: 32 additions & 0 deletions benchmarks/structure/benchmark_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import itertools
from pathlib import Path
import pytest
import biotite.structure as struc
import biotite.structure.io.pdbx as pdbx
from tests.util import data_dir


@pytest.fixture
def atoms():
pdbx_file = pdbx.BinaryCIFFile.read(Path(data_dir("structure")) / "1gya.bcif")
atoms = pdbx.get_structure(pdbx_file)
# Reduce the number of atoms to speed up the benchmark
return atoms[atoms.element != "H"]


@pytest.mark.benchmark
@pytest.mark.parametrize(
"multi_model, aggregation",
itertools.product([False, True], ["all", "chain", "residue", "atom"]),
)
def benchmark_lddt(atoms, multi_model, aggregation):
"""
Compute lDDT on different aggregation levels.
"""
reference = atoms[0]
if multi_model:
subject = atoms
else:
subject = atoms[0]

struc.lddt(reference, subject, aggregation=aggregation)
3 changes: 2 additions & 1 deletion doc/apidoc.json
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@
"average",
"rmsd",
"rmspd",
"rmsf"
"rmsf",
"lddt"
],
"General analysis" : [
"sasa",
Expand Down
14 changes: 14 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,20 @@ @article{Ma2002
doi = {10.1093/bioinformatics/18.3.440}
}

@article{Mariani2013,
title = {{{lDDT}}: A Local Superposition-Free Score for Comparing Protein Structures and Models Using Distance Difference Tests},
shorttitle = {{{lDDT}}},
author = {Mariani, Valerio and Biasini, Marco and Barbato, Alessandro and Schwede, Torsten},
year = {2013},
month = nov,
journal = {Bioinformatics},
volume = {29},
number = {21},
pages = {2722--2728},
issn = {1367-4803},
doi = {10.1093/bioinformatics/btt473}
}

@article{Martin2005,
title = {Using Information Theory to Search for Co-Evolving Residues in Proteins},
author = {Martin, L. C. and Gloor, G. B. and Dunn, S. D. and Wahl, L. M.},
Expand Down
245 changes: 243 additions & 2 deletions src/biotite/structure/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@

__name__ = "biotite.structure"
__author__ = "Patrick Kunzmann"
__all__ = ["rmsd", "rmspd", "rmsf", "average"]
__all__ = ["rmsd", "rmspd", "rmsf", "average", "lddt"]

import collections.abc
import warnings
import numpy as np
from biotite.structure.atoms import AtomArrayStack, coord
from biotite.structure.atoms import AtomArray, AtomArrayStack, coord
from biotite.structure.celllist import CellList
from biotite.structure.chains import get_chain_positions
from biotite.structure.geometry import index_distance
from biotite.structure.residues import get_residue_positions
from biotite.structure.util import vector_dot


Expand Down Expand Up @@ -242,6 +247,182 @@ def average(atoms):
return mean_coords


def lddt(
reference,
subject,
aggregation="all",
inclusion_radius=15,
distance_bins=(0.5, 1.0, 2.0, 4.0),
exclude_same_residue=True,
):
"""
Calculate the *local Distance Difference Test* (lDDT) score of a structure with
respect to its reference.
:footcite:`Mariani2013`
Parameters
----------
reference : AtomArray or ndarray, dtype=float, shape=(n,3)
The reference structure.
Alternatively, coordinates can be provided directly as
:class:`ndarray`, if `exclude_same_residue` is set to ``False``.
subject : AtomArray or AtomArrayStack or ndarray, dtype=float, shape=(n,3) or shape=(m,n,3)
The structure(s) to evaluate with respect to `reference`.
The number of atoms must be the same as in `reference`.
Alternatively, coordinates can be provided directly as
:class:`ndarray`.
aggregation : {'all', 'chain', 'residue', 'atom'} or ndarray, shape=(n,), dtype=int, optional
Defines on which scale the lDDT score is calculated.
- `'all'`: The score is computed over all contacts.
- `'chain'`: The score is calculated for each chain separately.
- `'residue'`: The score is calculated for each residue separately.
- `'atom'`: The score is calculated for each atom separately.
Alternatively, an array of aggregation bins can be provided, i.e. each contact
is assigned to the corresponding bin.
inclusion_radius : float, optional
Pairwise atom distances are considered within this radius in `reference`.
distance_bins : list of float, optional
The distance bins for the score calculation, i.e if a distance deviation is
within the first bin, the score is 1, if it is outside all bins, the score is 0.
exclude_same_residue : bool, optional
If set to False, distances between atoms of the same residue are also
considered.
By default, only atom distances between different residues are considered.
Returns
-------
lddt : float or ndarray, dtype=float
The lDDT score for each model and aggregation bin.
The shape depends on `subject` and `aggregation`:
If `subject` is an :class:`AtomArrayStack` (or equivalent coordinate
:class:`ndarray`), a dimension depicting each model is added.
if `aggregation` is not ``'all'``, a second dimension with the length equal to
the number of aggregation bins is added (i.e. number of chains, residues, etc.).
If both, an :class:`AtomArray` as `subject` and ``aggregation='all'`` is passed,
a float is returned.
Notes
-----
The lDDT score measures how well the pairwise atom distances in a model match the
corresponding distances in a reference.
Hence, like :func:`rmspd()` it works superimposition-free, but instead of capturing
the global deviation, only the local environment within the `inclusion_radius` is
considered.
Note that by default, also hydrogen atoms are considered in the distance
calculation.
If this is undesired, the hydrogen atoms can be removed prior to the calculation.
References
----------
.. footbibliography::
Examples
--------
Calculate the global lDDT of all models to the first model:
>>> reference = atom_array_stack[0]
>>> subject = atom_array_stack[1:]
>>> print(lddt(reference, subject))
[0.799 0.769 0.792 0.836 0.799 0.752 0.860 0.769 0.825 0.777 0.760 0.787
0.790 0.783 0.804 0.842 0.769 0.797 0.757 0.852 0.811 0.786 0.805 0.755
0.734 0.794 0.771 0.778 0.842 0.772 0.815 0.789 0.828 0.750 0.826 0.739
0.760]
Calculate the residue-wise lDDT for a single model:
>>> subject = atom_array_stack[1]
>>> print(lddt(reference, subject, aggregation="residue"))
[0.599 0.692 0.870 0.780 0.830 0.881 0.872 0.658 0.782 0.901 0.888 0.885
0.856 0.795 0.847 0.603 0.895 0.878 0.871 0.789]
As example for custom aggregation, calculate the lDDT for each chemical element:
>>> unique_elements = np.unique(reference.element)
>>> element_bins = np.array(
... [np.where(unique_elements == element)[0][0] for element in reference.element]
... )
>>> element_lddt = lddt(reference, subject, aggregation=element_bins)
>>> for element, lddt in zip(unique_elements, element_lddt):
... print(f"{element}: {lddt:.3f}")
C: 0.837
H: 0.770
N: 0.811
O: 0.808
"""
distance_bins = np.asarray(distance_bins)
reference_coord = coord(reference)
subject_coord = coord(subject)

# Use a cell list to find atoms within inclusion radius in O(n) time complexity
cell_list = CellList(reference_coord, inclusion_radius)
# Pairs of indices for atoms within the inclusion radius
contacts = _to_sparse_indices(
cell_list.get_atoms(reference_coord, inclusion_radius)
)

if isinstance(aggregation, str) and aggregation == "all":
# Remove duplicate pairs as each pair appears twice
# (if i is in threshold distance to j, j is also in threshold distance to i)
# keep only the pair where i < j
# This improves performance due to less distances that need to be computed
# and also removes self-contacts
contacts = contacts[contacts[:, 0] < contacts[:, 1]]
else:
# On all other aggregation levels, the duplicate contacts cannot be removed,
# as i and j are possibly in different aggregation bins
# Still, self-contacts are removed
contacts = contacts[contacts[:, 0] != contacts[:, 1]]

if exclude_same_residue:
if not isinstance(reference, AtomArray):
raise ValueError(
"If 'exclude_same_residue' is set to True, "
"'reference' must be an AtomArray"
)
# Find the index of the residue for each atom
residue_indices = get_residue_positions(reference, contacts.flatten()).reshape(
contacts.shape
)
# Remove contacts between atoms of the same residue
contacts = contacts[residue_indices[:, 0] != residue_indices[:, 1]]
if len(contacts) == 0:
warnings.warn("No contacts found within the inclusion radius")

# Measure the deviation in the distances for the filtered contacts
reference_distances = index_distance(reference_coord, contacts)
subject_distances = index_distance(subject_coord, contacts)
deviations = np.abs(subject_distances - reference_distances)
fraction_preserved_bins = np.count_nonzero(
deviations[..., np.newaxis] <= distance_bins[np.newaxis, :], axis=-1
) / len(distance_bins)

# Aggregate the fractions over the desired level
if isinstance(aggregation, str) and aggregation == "all":
# Average over all contacts
return np.mean(fraction_preserved_bins, axis=-1)
else:
# A string is also a 'Sequence'
# -> distinguish between string and array, list, etc.
if isinstance(
aggregation, (np.ndarray, collections.abc.Sequence)
) and not isinstance(aggregation, str):
aggregation_bins = np.asarray(aggregation)[contacts[:, 0]]
elif aggregation == "chain":
aggregation_bins = get_chain_positions(reference, contacts[:, 0])
elif aggregation == "residue":
aggregation_bins = get_residue_positions(reference, contacts[:, 0])
elif aggregation == "atom":
aggregation_bins = contacts[:, 0]
else:
raise ValueError(f"Invalid aggregation level '{aggregation}'")
return _average_over_indices(fraction_preserved_bins, aggregation_bins)


def _sq_euclidian(reference, subject):
"""
Calculate squared euclidian distance between atoms in two
Expand Down Expand Up @@ -272,3 +453,63 @@ def _sq_euclidian(reference, subject):
)
dif = subject_coord - reference_coord
return vector_dot(dif, dif)


def _to_sparse_indices(all_contacts):
"""
Create tuples of contact indices from the :meth:`CellList.get_atoms()` return value.
In other words, they would mark the non-zero elements in a dense contact matrix.
Parameters
----------
all_contacts : ndarray, dtype=int, shape=(m,n)
The contact indices as returned by :meth:`CellList.get_atoms()`.
Padded with -1, in the second dimension.
Dimension *m* marks the query atoms, dimension *n* marks the contact atoms.
Returns
-------
combined_indices : ndarray, dtype=int, shape=(l,2)
The contact indices.
Each column contains the query and contact atom index.
"""
# Find rows where a query atom has at least one contact
non_empty_indices = np.where(np.any(all_contacts != -1, axis=1))[0]
# Take those rows and flatten them
contact_indices = all_contacts[non_empty_indices].flatten()
# For each row the corresponding query atom is the same
# Hence in the flattened form the query atom index is simply repeated
query_indices = np.repeat(non_empty_indices, all_contacts.shape[1])
combined_indices = np.stack([query_indices, contact_indices], axis=1)
# Remove the padding values
return combined_indices[contact_indices != -1]


def _average_over_indices(values, bins):
"""
For each unique index in `bins`, average the corresponding values in `values`.
Based on
https://stackoverflow.com/questions/79140661/how-to-sum-values-based-on-a-second-index-array-in-a-vectorized-manner
Parameters
----------
values : ndarray, shape=(..., n)
The values to average.
bins : ndarray, shape=(n,) dtype=int
Returns
-------
averaged : ndarray, shape=(..., k)
The averaged values.
*k* is the maximum value in `bins` + 1.
"""
n_elements_per_bin = np.bincount(bins)
n_bins = len(n_elements_per_bin)
# The last dimension is replaced by the number of bins
# Broadcasting in 'np.add.at()' requires the replaced dimension to be the first
aggregated = np.zeros((n_bins, *values.shape[:-1]), dtype=values.dtype)
np.add.at(aggregated, bins, np.swapaxes(values, 0, -1))
# Bring the bin dimension into the last dimension again
return np.swapaxes(aggregated, 0, -1) / n_elements_per_bin
Loading

0 comments on commit 7532607

Please sign in to comment.