Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: two line dot bracket encoding #147

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/eltetrado/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.5.16
1.5.17
216 changes: 147 additions & 69 deletions src/eltetrado/analysis.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import copy
import itertools
import logging
import math
import os
import string
import subprocess
import sys
import tempfile
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from functools import lru_cache
from typing import IO, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple

import numpy
import numpy.typing
from rnapolis.common import BaseInteractions, GlycosidicBond
from rnapolis.common import BaseInteractions, BpSeq, Entry, GlycosidicBond
from rnapolis.tertiary import Atom, BasePair3D, Mapping2D3D, Residue3D, Structure3D

from eltetrado.model import (
Expand All @@ -29,6 +30,69 @@
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))


@lru_cache(maxsize=None)
def is_nucleotide(nt):
phosphate_atoms = {"P", "OP1", "OP2", "O3'", "O5'"}
sugar_atoms = {"C1'", "C2'", "C3'", "C4'", "C5'", "O4'"}
adenine_atoms = {"N1", "C2", "N3", "C4", "C5", "C6", "N6", "N7", "C8", "N9"}
guanine_atoms = {"N1", "C2", "N2", "N3", "C4", "C5", "C6", "O6", "N7", "C8", "N9"}
cytosine_atoms = {"N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"}
thymine_atoms = {"N1", "C2", "O2", "N3", "C4", "O4", "C5", "C5M", "C6"}
uracil_atoms = {"N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"}
scores = {"phosphate": 0.0, "sugar": 0.0, "base": 0.0, "connections": 0.0}
weights = {"phosphate": 0.25, "sugar": 0.25, "base": 0.25, "connections": 0.25}

residue_atoms = {atom.name for atom in nt.atoms}

phosphate_match = len(residue_atoms.intersection(phosphate_atoms))
scores["phosphate"] = phosphate_match / len(phosphate_atoms)

sugar_match = len(residue_atoms.intersection(sugar_atoms))
scores["sugar"] = sugar_match / len(sugar_atoms)

matches = {
"A": len(residue_atoms.intersection(adenine_atoms)) / len(adenine_atoms),
"G": len(residue_atoms.intersection(guanine_atoms)) / len(guanine_atoms),
"C": len(residue_atoms.intersection(cytosine_atoms)) / len(cytosine_atoms),
"T": len(residue_atoms.intersection(thymine_atoms)) / len(thymine_atoms),
"U": len(residue_atoms.intersection(uracil_atoms)) / len(uracil_atoms),
}
best_match = max(matches.items(), key=lambda x: x[1])
scores["base"] = best_match[1]

connection_score = 0.0
distance_threshold = 2.0

if "P" in residue_atoms and "O5'" in residue_atoms:
p_atom = next(atom for atom in nt.atoms if atom.name == "P")
o5_atom = next(atom for atom in nt.atoms if atom.name == "O5'")
if (
numpy.linalg.norm(p_atom.coordinates - o5_atom.coordinates)
<= distance_threshold
):
connection_score += 0.5
if "C1'" in residue_atoms:
c1_atom = next(atom for atom in nt.atoms if atom.name == "C1'")
for base_connection in ["N9", "N1"]:
if base_connection in residue_atoms:
base_atom = next(
atom for atom in nt.atoms if atom.name == base_connection
)
if (
numpy.linalg.norm(c1_atom.coordinates - base_atom.coordinates)
<= distance_threshold
):
connection_score += 0.5
break

scores["connections"] = connection_score

probability = sum(
scores[component] * weights[component] for component in scores.keys()
)
return probability > 0.5


@dataclass(order=True)
class Tetrad:
@staticmethod
Expand Down Expand Up @@ -526,7 +590,7 @@ def __find_loops(self) -> List[Loop]:
else:
nts = list(
filter(
lambda nt: nt.is_nucleotide
lambda nt: is_nucleotide(nt)
and self.global_index[nprev]
< self.global_index[nt]
< self.global_index[ncur],
Expand Down Expand Up @@ -639,7 +703,7 @@ def __str__(self):
f" {self.loop_class.value} {self.loop_class.loop_progression()}"
)
else:
builder += f" n/a"
builder += " n/a"
builder += f" quadruplex with {len(self.tetrads)} tetrads\n"
builder += str(self.tetrad_pairs[0].tetrad1)
for tetrad_pair in self.tetrad_pairs:
Expand Down Expand Up @@ -746,7 +810,6 @@ class Analysis:
sequence: str = field(init=False)
line1: str = field(init=False)
line2: str = field(init=False)
shifts: Dict[Residue3D, int] = field(init=False)

def __post_init__(self):
self.global_index = self.__prepare_global_index()
Expand All @@ -768,7 +831,6 @@ def __post_init__(self):
self.sequence,
self.line1,
self.line2,
self.shifts,
) = self.__generate_twoline_dotbracket()
self.ions = self.__find_ions()
self.__assign_ions_to_tetrads()
Expand Down Expand Up @@ -1197,64 +1259,67 @@ def __generate_twoline_dotbracket(
self,
) -> Tuple[str, str, str, Dict[Residue3D, int]]:
layer1, layer2 = [], []

for tetrad in self.tetrads:
layer1.extend([tetrad.pair_12, tetrad.pair_34])
layer2.extend([tetrad.pair_23, tetrad.pair_41])
sequence, line1, shifts = self.__elimination_conflicts(layer1)
_, line2, _ = self.__elimination_conflicts(layer2)
return sequence, line1, line2, shifts

def __elimination_conflicts(
self, pairs: List[BasePair3D]
) -> Tuple[str, str, Dict[Residue3D, int]]:
orders: Dict[BasePair3D, int] = {}
order = 0
queue = list(pairs)
removed = []

while queue:
conflicts = defaultdict(list)
for pi, pj in itertools.combinations(queue, 2):
if self.__is_conflicted(pi, pj):
conflicts[pi].append(pj)
conflicts[pj].append(pi)
if conflicts:
pair, _ = max(conflicts.items(), key=lambda x: (len(x[1]), x[0].nt1))
removed.append(pair)
queue.remove(pair)
tetrad_copy = copy.deepcopy(tetrad)
tetrad_copy.reorder_to_match_5p_3p()
score_org = BasePair3D.score_table.get(tetrad_copy.pair_12.lw, 20)
score_rev = BasePair3D.score_table.get(tetrad_copy.pair_12.lw.reverse, 20)

if score_org < score_rev:
layer1.extend([tetrad_copy.pair_12, tetrad_copy.pair_34])
layer2.extend([tetrad_copy.pair_23, tetrad_copy.pair_41])
else:
orders.update({pair: order for pair in queue})
queue, removed = removed, []
order += 1

opening = list("([{<" + string.ascii_uppercase)
closing = list(")]}>" + string.ascii_lowercase)
dotbracket: Dict[Residue3D, str] = {}
for pair, order in orders.items():
nt1, nt2 = sorted(
[pair.nt1_3d, pair.nt2_3d], key=lambda nt: self.global_index[nt]
)
dotbracket[nt1] = opening[order]
dotbracket[nt2] = closing[order]
layer1.extend([tetrad_copy.pair_23, tetrad_copy.pair_41])
layer2.extend([tetrad_copy.pair_12, tetrad_copy.pair_34])

sequence = ""
structure = ""
sequence, line1 = self.__dot_bracket(layer1)
_, line2 = self.__dot_bracket(layer2)
return sequence, line1, line2

def __dot_bracket(self, pairs: List[BasePair3D]) -> Tuple[str, str]:
bpseq_index = {}
bpseq_entries = []
chain = None
shifts = dict()
shift_value = 0
chain = None
i = 1

for nt in sorted(
filter(lambda nt: nt.is_nucleotide, self.structure3d.residues),
filter(lambda nt: is_nucleotide(nt), self.structure3d.residues),
key=lambda nt: self.global_index[nt],
):
if chain and chain != nt.chain:
sequence += "-"
structure += "-"
shift_value += 1
sequence += nt.one_letter_name
structure += dotbracket.get(nt, ".")
shifts[nt] = shift_value
chain = nt.chain
return sequence, structure, shifts
bpseq_index[nt] = i
bpseq_entries.append(Entry(i, nt.one_letter_name, 0))
i += 1

for pair in pairs:
ni = bpseq_index[pair.nt1_3d]
nj = bpseq_index[pair.nt2_3d]
bpseq_entries[ni - 1].pair = nj
bpseq_entries[nj - 1].pair = ni

dot_bracket = BpSeq(bpseq_entries).fcfs

def insert_dashes(string, indices):
result = []
for i, c in enumerate(string):
if i in indices:
result.append("-")
result.append(c)
return "".join(result)

arr = numpy.array(list(shifts.values()))
changes = numpy.where(arr[1:] != arr[:-1])[0]

return (
insert_dashes(dot_bracket.sequence, changes),
insert_dashes(dot_bracket.structure, changes),
)

def __str__(self):
builder = f'Chain order: {" ".join(self.__chain_order())}\n'
Expand All @@ -1265,7 +1330,7 @@ def __str__(self):

def __chain_order(self) -> List[str]:
only_nucleic_acids = filter(
lambda nt: nt.is_nucleotide, self.structure3d.residues
lambda nt: is_nucleotide(nt), self.structure3d.residues
)
return list(
{
Expand Down Expand Up @@ -1322,11 +1387,7 @@ class Visualizer:
onz_dict: Dict[BasePair3D, ONZ] = field(init=False)

def __post_init__(self):
self.onz_dict = {
pair: tetrad.onz
for tetrad in self.tetrads
for pair in [tetrad.pair_12, tetrad.pair_23, tetrad.pair_34, tetrad.pair_41]
}
self.onz_dict = {}

def visualize(self, prefix: str, suffix: str):
fasta = tempfile.NamedTemporaryFile("w+", suffix=".fasta")
Expand All @@ -1336,16 +1397,23 @@ def visualize(self, prefix: str, suffix: str):

layer1, layer2 = [], []
for tetrad in self.tetrads:
plus_ordered = self.global_index[tetrad.nt2] < self.global_index[tetrad.nt4]
plus_assigned = tetrad.onz in (ONZ.O_PLUS, ONZ.N_PLUS, ONZ.Z_PLUS)
if (plus_ordered and not plus_assigned) or (
not plus_ordered and plus_assigned
):
layer1.extend([tetrad.pair_41, tetrad.pair_23])
layer2.extend([tetrad.pair_12, tetrad.pair_34])
tetrad_copy = copy.deepcopy(tetrad)
tetrad_copy.reorder_to_match_5p_3p()
score_org = BasePair3D.score_table.get(tetrad_copy.pair_12.lw, 20)
score_rev = BasePair3D.score_table.get(tetrad_copy.pair_12.lw.reverse, 20)

if score_org < score_rev:
layer1.extend([tetrad_copy.pair_12, tetrad_copy.pair_34])
layer2.extend([tetrad_copy.pair_23, tetrad_copy.pair_41])
else:
layer1.extend([tetrad.pair_12, tetrad.pair_34])
layer2.extend([tetrad.pair_23, tetrad.pair_41])
layer1.extend([tetrad_copy.pair_23, tetrad_copy.pair_41])
layer2.extend([tetrad_copy.pair_12, tetrad_copy.pair_34])

self.onz_dict[tetrad_copy.pair_12] = tetrad.onz
self.onz_dict[tetrad_copy.pair_23] = tetrad.onz
self.onz_dict[tetrad_copy.pair_34] = tetrad.onz
self.onz_dict[tetrad_copy.pair_41] = tetrad.onz

helix1 = self.__to_helix(
layer1, self.analysis.canonical() if self.complete2d else []
)
Expand Down Expand Up @@ -1382,8 +1450,18 @@ def __to_helix(
ONZ.Z_PLUS: 5,
ONZ.Z_MINUS: 6,
}
nucleotides = self.analysis.structure3d.residues
shifts = self.analysis.shifts
nucleotides = [
nt for nt in self.analysis.structure3d.residues if is_nucleotide(nt)
]
chain = None
shifts = dict()
shift_value = 0

for nucleotide in nucleotides:
if chain and chain != nucleotide.chain:
shift_value += 1
shifts[nucleotide] = shift_value
chain = nucleotide.chain

helix = tempfile.NamedTemporaryFile("w+", suffix=".helix")
helix.write(f"#{len(self.analysis.sequence) + 1}\n")
Expand Down
4 changes: 2 additions & 2 deletions src/eltetrado/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import List, Optional

from eltetrado.analysis import Analysis, Quadruplex, TetradPair
from eltetrado.analysis import Analysis, Quadruplex, TetradPair, is_nucleotide
from eltetrado.model import Ion


Expand Down Expand Up @@ -127,7 +127,7 @@ def convert_nucleotides(analysis: Analysis) -> List[NucleotideDTO]:
nt.chi_class.value if nt.chi_class else None,
)
for nt in analysis.structure3d.residues
if nt.is_nucleotide
if is_nucleotide(nt)
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ n4-helix with 4 tetrads
propeller- H.U4
propeller- F.U4

UGGUGU-UGGUGU-UGGUGU-UGGUGU-UGGUGU-UGGUGU-UGGUGU-UGGGU
.([.{<-.)].{<-.([.}>-.)].}>-.([.{<-.([.}>-.)].{<-.)]}>
.{<.([-.{<.)]-.}>.([-.}>.)]-.{<.([-.}>.([-.{<.)]-.}>)]
UGGUG-UUGGUG-UUGGUG-UUGGUG-UUGGUG-UUGGUG-UUGGUG-UUGGGU
.([.{-[.)(.{-].[).}-(.]].}-).((.[-{.{).]-(.}[.{-).)]}}
.((.[-{.{).]-(.}[.{-).)].}-}.([.{-[.)(.{-].[).}-(.]]})
Binary file added tests/files/5v3f-assembly-1.cif.gz
Binary file not shown.
1 change: 1 addition & 0 deletions tests/files/5v3f-assembly-1.json

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ n4-helix with 2 tetrads
lateral+ A.DT24, A.DT25

GGTTGGCGCGAAGCATTCGCGGGTTGG
((..))...............((..))
((..((...............))..))
([..[)...............(]..])
((..)(...............)(..))
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_eltetrado_with_dssr(capfd):
]
)
out, err = capfd.readouterr()
with open("tests/files/2awe-assembly-1.out.json") as f:
with open("tests/files/2awe-assembly-1.out") as f:
assert out == f.read()


Expand All @@ -49,5 +49,5 @@ def test_eltetrado_without_dssr(capfd):
]
)
out, err = capfd.readouterr()
with open("tests/files/6fc9-assembly-1.out.json") as f:
with open("tests/files/6fc9-assembly-1.out") as f:
assert out == f.read()
18 changes: 18 additions & 0 deletions tests/test_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,21 @@ def test_2awe():
for l in q.loops
for nt in l.nucleotides
]


def test_5v3f():
"""
In 5V3F there are O+ and O- tetrads and the two-line dot-bracket has to take that into account
"""
cif = handle_input_file("tests/files/5v3f-assembly-1.cif.gz")
structure3d = rnapolis.parser.read_3d_structure(cif, 1)
structure2d = read_secondary_structure_from_dssr(
structure3d, 1, "tests/files/5v3f-assembly-1.json"
)
analysis = eltetrado(structure2d, structure3d, False, False, 2)
dto = generate_dto(analysis)
assert (
dto.dotBracket.sequence == "GUGCGAAGGGACGGUGCGGAGAGGAGAGCA-CGGGACGGUGCGGAGAGGAG"
)
assert dto.dotBracket.line1 == ".......([{..)].(.[{.).]}.}....-.([{..)].(.[{.).]}.}"
assert dto.dotBracket.line2 == ".......([(..[{.).]}.{.)].}....-.([(..[{.).]}.{.)].}"