From 872831735ccb88292cc0f202d0d4323774922b41 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Wed, 15 May 2024 14:35:21 -0400 Subject: [PATCH] api cleanup --- pyproject.toml | 1 + src/beignet/datasets/__uni_ref_dataset.py | 39 ++++++++--------- src/beignet/datasets/_fasta_dataset.py | 42 ++++++++++--------- .../datasets/_sized_sequence_dataset.py | 4 +- src/beignet/datasets/_uniref100_dataset.py | 6 --- src/beignet/datasets/_uniref50_dataset.py | 6 --- src/beignet/datasets/_uniref90_dataset.py | 1 - 7 files changed, 44 insertions(+), 55 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 169c88d1d2..16bb1154bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ authors = [{ email = "allen.goodman@icloud.com", name = "Allen Goodman" }] dependencies = [ "pooch", "torch", + "tqdm", ] dynamic = ["version"] license = { file = "LICENSE" } diff --git a/src/beignet/datasets/__uni_ref_dataset.py b/src/beignet/datasets/__uni_ref_dataset.py index f59c1f8756..f7e1e45e71 100644 --- a/src/beignet/datasets/__uni_ref_dataset.py +++ b/src/beignet/datasets/__uni_ref_dataset.py @@ -18,7 +18,6 @@ def __init__( root: str | PathLike | None = None, known_hash: str | None = None, *, - index: bool = True, transform: Callable | Transform | None = None, target_transform: Callable | Transform | None = None, ) -> None: @@ -34,10 +33,6 @@ def __init__( `download` is `True`, the directory where the dataset subdirectory will be created and the dataset downloaded. - index : bool, optional - If `True`, caches the sequence indexes to disk for faster - re-initialization (default: `True`). - transform : Callable | Transform, optional A `Callable` or `Transform` that that maps a sequence to a transformed sequence (default: `None`). @@ -56,32 +51,34 @@ def __init__( name = self.__class__.__name__.replace("Dataset", "") - path = pooch.retrieve( - url, - known_hash, - f"{name}.fasta.gz", - root / name, - processor=Decompress(), - progressbar=True, - ) - self._pattern = re.compile(r"^UniRef.+_([A-Z0-9]+)\s.+$") - super().__init__(path, index=index) + super().__init__( + pooch.retrieve( + url, + known_hash, + f"{name}.fasta.gz", + root / name, + processor=Decompress( + name=f"{name}.fasta", + ), + progressbar=True, + ), + ) - self._transform = transform + self.transform = transform - self._target_transform = target_transform + self.target_transform = target_transform def __getitem__(self, index: int) -> (str, str): target, sequence = self.get(index) (target,) = re.search(self._pattern, target).groups() - if self._transform: - sequence = self._transform(sequence) + if self.transform: + sequence = self.transform(sequence) - if self._target_transform: - target = self._target_transform(target) + if self.target_transform: + target = self.target_transform(target) return sequence, target diff --git a/src/beignet/datasets/_fasta_dataset.py b/src/beignet/datasets/_fasta_dataset.py index 934337da60..e3d93ebe6f 100644 --- a/src/beignet/datasets/_fasta_dataset.py +++ b/src/beignet/datasets/_fasta_dataset.py @@ -1,4 +1,5 @@ import subprocess +from os import PathLike from pathlib import Path from typing import Callable, Tuple, TypeVar @@ -6,6 +7,7 @@ from beignet.io import ThreadSafeFile +from ..transforms import Transform from ._sized_sequence_dataset import SizedSequenceDataset T = TypeVar("T") @@ -14,39 +16,41 @@ class FASTADataset(SizedSequenceDataset): def __init__( self, - root: str | Path, + root: str | PathLike, *, - index: bool = True, - transform: Callable[[T], T] | None = None, + transform: Callable | Transform | None = None, ) -> None: - self.root = Path(root) + if isinstance(root, str): + self.root = Path(root) + + self.root = self.root.resolve() if not self.root.exists(): raise FileNotFoundError - self._thread_safe_file = ThreadSafeFile(root, open) - - self._index = Path(f"{self.root}.index.npy") + self.data = ThreadSafeFile(self.root, open) - if index: - if self._index.exists(): - self.offsets, sizes = numpy.load(str(self._index)) - else: - self.offsets, sizes = self._build_index() + offsets = Path(f"{self.root}.index.npy") - numpy.save(str(self._index), numpy.stack([self.offsets, sizes])) + if offsets.exists(): + self.offsets, sizes = numpy.load(f"{offsets}") else: self.offsets, sizes = self._build_index() - self._transform = transform + numpy.save( + f"{offsets}", + numpy.stack([self.offsets, sizes]), + ) + + self.transform = transform super().__init__(self.root, sizes) def __getitem__(self, index: int) -> Tuple[str, str]: x = self.get(index) - if self._transform: - x = self._transform(x) + if self.transform: + x = self.transform(x) return x @@ -54,12 +58,12 @@ def __len__(self) -> int: return self.offsets.size def get(self, index: int) -> str: - self._thread_safe_file.seek(self.offsets[index]) + self.data.seek(self.offsets[index]) if index == len(self) - 1: - data = self._thread_safe_file.read() + data = self.data.read() else: - data = self._thread_safe_file.read( + data = self.data.read( self.offsets[index + 1] - self.offsets[index], ) diff --git a/src/beignet/datasets/_sized_sequence_dataset.py b/src/beignet/datasets/_sized_sequence_dataset.py index cb11dc0bff..9145136872 100644 --- a/src/beignet/datasets/_sized_sequence_dataset.py +++ b/src/beignet/datasets/_sized_sequence_dataset.py @@ -1,4 +1,4 @@ -from pathlib import Path +from os import PathLike import numpy @@ -8,7 +8,7 @@ class SizedSequenceDataset(SequenceDataset): def __init__( self, - root: str | Path, + root: str | PathLike, sizes: numpy.ndarray, *args, **kwargs, diff --git a/src/beignet/datasets/_uniref100_dataset.py b/src/beignet/datasets/_uniref100_dataset.py index 0e7e3b7c0e..3652ffa64f 100644 --- a/src/beignet/datasets/_uniref100_dataset.py +++ b/src/beignet/datasets/_uniref100_dataset.py @@ -10,7 +10,6 @@ def __init__( self, root: str | Path, *, - index: bool = True, transform: Callable | Transform | None = None, target_transform: Callable | Transform | None = None, ) -> None: @@ -22,10 +21,6 @@ def __init__( `download` is `True`, the directory where the dataset subdirectory will be created and the dataset downloaded. - index : bool, optional - If `True`, caches the sequence indicies to disk for faster - re-initialization (default: `True`). - transform : Callable, optional A `Callable` or `Transform` that that maps a sequence to a transformed sequence (default: `None`). @@ -38,7 +33,6 @@ def __init__( "http://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref100/uniref100.fasta.gz", root, "md5:0354240a56f4ca91ff426f8241cfeb7d", - index=index, transform=transform, target_transform=target_transform, ) diff --git a/src/beignet/datasets/_uniref50_dataset.py b/src/beignet/datasets/_uniref50_dataset.py index 1ff14948f4..22a4627ec7 100644 --- a/src/beignet/datasets/_uniref50_dataset.py +++ b/src/beignet/datasets/_uniref50_dataset.py @@ -11,7 +11,6 @@ def __init__( self, root: str | PathLike | None = None, *, - index: bool = True, transform: Callable | Transform | None = None, target_transform: Callable | Transform | None = None, ) -> None: @@ -23,10 +22,6 @@ def __init__( `download` is `True`, the directory where the dataset subdirectory will be created and the dataset downloaded. - index : bool, optional - If `True`, caches the sequence indexes to disk for faster - re-initialization (default: `True`). - transform : Callable, optional A `Callable` or `Transform` that that maps a sequence to a transformed sequence (default: `None`). @@ -39,7 +34,6 @@ def __init__( "http://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref50/uniref50.fasta.gz", root, "md5:e638c63230d13ad5e2098115b9cb5d8f", - index=index, transform=transform, target_transform=target_transform, ) diff --git a/src/beignet/datasets/_uniref90_dataset.py b/src/beignet/datasets/_uniref90_dataset.py index 7b400b7ee6..e9be653c63 100644 --- a/src/beignet/datasets/_uniref90_dataset.py +++ b/src/beignet/datasets/_uniref90_dataset.py @@ -39,7 +39,6 @@ def __init__( "http://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref90/uniref90.fasta.gz", root, "md5:6161bad4d7506365aee882fd5ff9c833", - index=index, transform=transform, target_transform=target_transform, )