From c4c3a7c4945af248f712f6429cacdec97c58bc82 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Fri, 22 Nov 2024 15:52:43 +0800 Subject: [PATCH] fix pseudo-knots process in secondary structure Signed-off-by: Zhiyuan Chen --- multimolecule/data/functional.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/multimolecule/data/functional.py b/multimolecule/data/functional.py index bcb505e0..441035f4 100644 --- a/multimolecule/data/functional.py +++ b/multimolecule/data/functional.py @@ -14,22 +14,34 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from __future__ import annotations + +import string +from collections import defaultdict + import numpy as np -dot_bracket_to_contact_map_table = str.maketrans( - {",": ".", "_": ".", "[": "(", "]": ")", "{": "(", "}": ")", "<": "(", ">": ")"} -) +dot_bracket_pair_table = {"(": ")", "[": "]", "{": "}", "<": ">"} def dot_bracket_to_contact_map(dot_bracket: str): - dot_bracket = dot_bracket.translate(dot_bracket_to_contact_map_table) n = len(dot_bracket) contact_map = np.zeros((n, n), dtype=int) - stack = [] + + dot_bracket_pair_table.update(zip(string.ascii_uppercase, string.ascii_lowercase)) + reverse_dot_bracket_pair_table = {v: k for k, v in dot_bracket_pair_table.items()} + pairs = {*dot_bracket_pair_table.keys(), *reverse_dot_bracket_pair_table.keys()} + + stacks: defaultdict[str, list[int]] = defaultdict(list) for i, symbol in enumerate(dot_bracket): - if symbol == "(": - stack.append(i) - elif symbol == ")": - j = stack.pop() - contact_map[i, j] = contact_map[j, i] = 1 + if symbol in pairs: + if symbol in dot_bracket_pair_table: + stacks[symbol].append(i) + elif symbol in reverse_dot_bracket_pair_table: + j = stacks[reverse_dot_bracket_pair_table[symbol]].pop() + contact_map[i, j] = contact_map[j, i] = 1 + else: + raise ValueError(f"Invalid symbol {symbol} in dot-bracket notation") + elif symbol not in {".", ",", "_"}: + raise ValueError(f"Invalid symbol {symbol} in dot-bracket notation") return contact_map