From 35733a4c73c3d22b4f4249fd64f0e0398006e82e Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Sun, 27 Oct 2024 21:07:13 +0800 Subject: [PATCH] add Rivas dataset --- docs/docs/datasets/rivas.md | 9 ++ docs/mkdocs.yml | 1 + multimolecule/datasets/bprna_new/bprna_new.py | 48 ++++++--- multimolecule/datasets/rivas/README.md | 101 ++++++++++++++++++ multimolecule/datasets/rivas/rivas.py | 60 +++++++++++ 5 files changed, 206 insertions(+), 13 deletions(-) create mode 100644 docs/docs/datasets/rivas.md create mode 100644 multimolecule/datasets/rivas/README.md create mode 100644 multimolecule/datasets/rivas/rivas.py diff --git a/docs/docs/datasets/rivas.md b/docs/docs/datasets/rivas.md new file mode 100644 index 00000000..b20a7e25 --- /dev/null +++ b/docs/docs/datasets/rivas.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# RIVAS + +--8<-- "multimolecule/datasets/rivas/README.md:21:" diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 9c8d52de..93d53a45 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -19,6 +19,7 @@ nav: - RNA: - RNAcentral: datasets/rnacentral.md - Rfam: datasets/rfam.md + - RIVAS: datasets/rivas.md - bpRNA-1m: datasets/bprna.md - bpRNA-spot: datasets/bprna-spot.md - bpRNA-new: datasets/bprna-new.md diff --git a/multimolecule/datasets/bprna_new/bprna_new.py b/multimolecule/datasets/bprna_new/bprna_new.py index d622f9a8..8a20e6bb 100644 --- a/multimolecule/datasets/bprna_new/bprna_new.py +++ b/multimolecule/datasets/bprna_new/bprna_new.py @@ -18,6 +18,7 @@ import os from collections import namedtuple +from collections.abc import Mapping from pathlib import Path import torch @@ -30,20 +31,41 @@ RNA_SS_data = namedtuple("RNA_SS_data", "seq ss_label length name pairs") -def convert_bpseq(bpseq): - if isinstance(bpseq, str): - bpseq = Path(bpseq) - with open(bpseq) as f: +def convert_bpseq(file) -> Mapping: + if not isinstance(file, Path): + file = Path(file) + with open(file) as f: lines = f.read().splitlines() - lines = [[int(i) if i.isdigit() else i for i in j.split()] for j in lines] - sequence, structure = [], ["."] * len(lines) - for row in lines: - index, nucleotide, paired_index = row - sequence.append(nucleotide) - if paired_index > 0 and index < paired_index: - structure[index - 1] = "(" - structure[paired_index - 1] = ")" - return {"id": bpseq.stem.split("-")[0], "sequence": "".join(sequence), "secondary_structure": "".join(structure)} + + num_bases = len(lines) + sequence = [] + dot_bracket = ["."] * num_bases + pairs = [-1] * num_bases + + for line in lines: + parts = line.strip().split() + index = int(parts[0]) - 1 + base = parts[1] + paired_index = int(parts[2]) - 1 + + sequence.append(base) + + if paired_index >= 0: + if paired_index > index: + dot_bracket[index] = "(" + dot_bracket[paired_index] = ")" + elif pairs[paired_index] != index: + raise ValueError( + f"Inconsistent pairing: Base {index} is paired with {paired_index}, " + f"but {paired_index} is not paired with {index}." + ) + pairs[index] = paired_index + + return { + "id": file.stem.split("-")[0], + "sequence": "".join(sequence), + "secondary_structure": "".join(dot_bracket), + } def convert_dataset(convert_config): diff --git a/multimolecule/datasets/rivas/README.md b/multimolecule/datasets/rivas/README.md new file mode 100644 index 00000000..cfc93c8f --- /dev/null +++ b/multimolecule/datasets/rivas/README.md @@ -0,0 +1,101 @@ +--- +language: rna +tags: + - Biology + - RNA +license: + - agpl-3.0 +size_categories: + - 1K. + +from __future__ import annotations + +import os + +import torch +from tqdm import tqdm + +from multimolecule.datasets.bprna_new.bprna_new import convert_bpseq +from multimolecule.datasets.conversion_utils import ConvertConfig as ConvertConfig_ +from multimolecule.datasets.conversion_utils import get_files, save_dataset + +torch.manual_seed(1016) + + +def _convert_dataset(root): + files = get_files(root) + return [convert_bpseq(file) for file in tqdm(files, total=len(files))] + + +def convert_dataset(convert_config): + root = convert_config.dataset_path + train_a = _convert_dataset(os.path.join(root, "TrainSetA")) + train_b = _convert_dataset(os.path.join(root, "TrainSetB")) + test_a = _convert_dataset(os.path.join(root, "TestSetA")) + test_b = _convert_dataset(os.path.join(root, "TestSetB")) + output_path, repo_id = convert_config.output_path, convert_config.repo_id + save_dataset(convert_config, {"train": train_a, "validation": test_a, "test": test_b}) + convert_config.output_path = output_path + "-a" + convert_config.repo_id = repo_id + "-a" + save_dataset(convert_config, {"train": train_a, "test": test_a}) + convert_config.output_path = output_path + "-b" + convert_config.repo_id = repo_id + "-b" + save_dataset(convert_config, {"train": train_b, "test": test_b}) + + +class ConvertConfig(ConvertConfig_): + root: str = os.path.dirname(__file__) + output_path: str = os.path.basename(os.path.dirname(__file__)) + + +if __name__ == "__main__": + config = ConvertConfig() + config.parse() # type: ignore[attr-defined] + convert_dataset(config)