Skip to content

Commit

Permalink
add types
Browse files Browse the repository at this point in the history
  • Loading branch information
jsstevenson committed May 27, 2024
1 parent 37c7dfe commit e77f3a5
Show file tree
Hide file tree
Showing 12 changed files with 251 additions and 161 deletions.
13 changes: 13 additions & 0 deletions docs/sequence_loading.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Storing New Sequences and Aliases
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


.. code-block:: python
sequence = metadata.target_sequence
# Add custom digest to SeqRepo for both Protein and DNA Sequence
psequence_id = f"SQ.{sha512t24u(sequence.encode('ascii'))}"
alias_dict_list = [{"namespace": "ga4gh", "alias": psequence_id}]
sr.sr.store(sequence, nsaliases=alias_dict_list)
11 changes: 7 additions & 4 deletions src/biocommons/seqrepo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

from .seqrepo import SeqRepo # noqa: F401

try:
__version__ = version(__package__)
except PackageNotFoundError: # pragma: no cover
# package is not installed
if __package__ is None:
__version__ = None
else:
try:
__version__ = version(__package__)
except PackageNotFoundError: # pragma: no cover
# package is not installed
__version__ = None
11 changes: 8 additions & 3 deletions src/biocommons/seqrepo/_internal/logging_support.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import logging
from typing import Optional


class DuplicateFilter:
"""
Filters away duplicate log messages.
Modified from https://stackoverflow.com/a/60462619/342839
"""

def __init__(self, logger=None):
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.log_keys = set()
self.logger = logger

def filter(self, record):
def filter(self, record: logging.LogRecord) -> bool:
log_key = (record.name, record.lineno, str(record.msg))
is_duplicate = log_key in self.log_keys
if not is_duplicate:
Expand All @@ -23,4 +27,5 @@ def __enter__(self):
self.logger.addFilter(self)

def __exit__(self, exc_type, exc_val, exc_tb):
self.logger.removeFilter(self)
if self.logger is not None:
self.logger.removeFilter(self)
9 changes: 5 additions & 4 deletions src/biocommons/seqrepo/_internal/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import copy
import datetime
from typing import Iterable, Iterator, Optional


def translate_db2api(namespace, alias):
def translate_db2api(namespace: str, alias: str) -> list[tuple[str, Optional[str]]]:
"""
>>> translate_db2api("VMC", "GS_1234")
[('sha512t24u', '1234'), ('ga4gh', 'SQ.1234')]
Expand All @@ -38,7 +39,7 @@ def translate_db2api(namespace, alias):
return []


def translate_api2db(namespace, alias):
def translate_api2db(namespace: str, alias: Optional[str]) -> list[tuple[str, Optional[str]]]:
"""
>>> translate_api2db("ga4gh", "SQ.1234")
[('VMC', 'GS_1234')]
Expand All @@ -55,14 +56,14 @@ def translate_api2db(namespace, alias):
return [
("VMC", "GS_" + alias if alias else None),
]
if namespace == "ga4gh":
if namespace == "ga4gh" and alias is not None:
return [
("VMC", "GS_" + alias[3:]),
]
return []


def translate_alias_records(aliases_itr):
def translate_alias_records(aliases_itr: Iterable[dict]) -> Iterator[dict]:
"""given an iterator of find_aliases results, return a stream with
translated records"""

Expand Down
69 changes: 39 additions & 30 deletions src/biocommons/seqrepo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import subprocess
import sys
import tempfile
from typing import Iterable, Iterator, Optional

import bioutils.assemblies
import bioutils.seqfetcher
Expand All @@ -50,7 +51,7 @@
_logger = logging.getLogger(__name__)


def _get_remote_instances(opts):
def _get_remote_instances(opts: argparse.Namespace) -> list[str]:
line_re = re.compile(r"d[-rwx]{9}\s+[\d,]+ \d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2} (.+)")
rsync_cmd = [
opts.rsync_exe,
Expand All @@ -64,21 +65,21 @@ def _get_remote_instances(opts):
return sorted(list(filter(instance_name_new_re.match, dirs)))


def _get_local_instances(opts):
def _get_local_instances(opts: argparse.Namespace) -> list[str]:
return sorted(list(filter(instance_name_re.match, os.listdir(opts.root_directory))))


def _latest_instance(opts):
def _latest_instance(opts: argparse.Namespace) -> Optional[str]:
instances = _get_local_instances(opts)
return instances[-1] if instances else None


def _latest_instance_path(opts):
def _latest_instance_path(opts: argparse.Namespace) -> Optional[str]:
li = _latest_instance(opts)
return os.path.join(opts.root_directory, li) if li else None


def parse_arguments():
def parse_arguments() -> argparse.Namespace:
epilog = (
f"seqrepo {__version__}"
"See https://github.com/biocommons/biocommons.seqrepo for more information"
Expand Down Expand Up @@ -297,7 +298,7 @@ def parse_arguments():
############################################################################


def add_assembly_names(opts):
def add_assembly_names(opts: argparse.Namespace) -> None:
"""add assembly names as aliases to existing sequences
Specifically, associate aliases like GRCh37.p9:1 with existing
Expand Down Expand Up @@ -382,7 +383,7 @@ def add_assembly_names(opts):
sr.commit()


def export(opts):
def export(opts: argparse.Namespace) -> None:
seqrepo_dir = os.path.join(opts.root_directory, opts.instance_name)
sr = SeqRepo(seqrepo_dir)

Expand All @@ -396,7 +397,7 @@ def alias_generator():
translate_ncbi_namespace=True,
)

def _rec_iterator():
def _rec_iterator_aliases():
"""yield (srec, [arec]) tuples to export"""
grouped_alias_iterator = itertools.groupby(
alias_generator(), key=lambda arec: (arec["seq_id"])
Expand All @@ -406,9 +407,11 @@ def _rec_iterator():
srec["seq"] = sr.sequences.fetch(seq_id)
yield srec, arecs

_rec_iterator = _rec_iterator_aliases

elif opts.namespace:

def _rec_iterator():
def _rec_iterator_namespace():
"""yield (srec, [arec]) tuples to export"""
alias_iterator = sr.aliases.find_aliases(
namespace=opts.namespace, translate_ncbi_namespace=True
Expand All @@ -421,11 +424,15 @@ def _rec_iterator():
srec["seq"] = sr.sequences.fetch(seq_id)
yield srec, arecs

_rec_iterator = _rec_iterator_namespace

else:

def _rec_iterator():
def _rec_iterator_sr():
yield from sr

_rec_iterator = _rec_iterator_sr

for srec, arecs in _rec_iterator():
nsad = _convert_alias_records_to_ns_dict(arecs)
aliases = [
Expand Down Expand Up @@ -453,7 +460,7 @@ def export_aliases(opts):
print("\t".join(nsaliases))


def fetch_load(opts):
def fetch_load(opts: argparse.Namespace) -> None:
disable_bar = _logger.getEffectiveLevel() < logging.WARNING

seqrepo_dir = os.path.join(opts.root_directory, opts.instance_name)
Expand All @@ -463,7 +470,7 @@ def fetch_load(opts):
for ac in ac_bar:
ac_bar.set_description(ac)
aliases_cur = sr.aliases.find_aliases(namespace=opts.namespace, alias=ac)
if aliases_cur.fetchone() is not None:
if aliases_cur:
_logger.info("{ac} already in {sr}".format(ac=ac, sr=sr))
continue
seq = bioutils.seqfetcher.fetch_seq(ac)
Expand All @@ -473,28 +480,28 @@ def fetch_load(opts):
sr.commit()


def init(opts):
def init(opts: argparse.Namespace) -> None:
seqrepo_dir = os.path.join(opts.root_directory, opts.instance_name)
if os.path.exists(seqrepo_dir) and len(os.listdir(seqrepo_dir)) > 0:
raise IOError("{seqrepo_dir} exists and is not empty".format(seqrepo_dir=seqrepo_dir))
sr = SeqRepo(seqrepo_dir, writeable=True) # noqa: F841


def list_local_instances(opts):
def list_local_instances(opts: argparse.Namespace) -> None:
instances = _get_local_instances(opts)
print("Local instances ({})".format(opts.root_directory))
for i in instances:
print(" " + i)


def list_remote_instances(opts):
def list_remote_instances(opts: argparse.Namespace) -> None:
instances = _get_remote_instances(opts)
print("Remote instances ({})".format(opts.remote_host))
for i in instances:
print(" " + i)


def load(opts):
def load(opts: argparse.Namespace) -> None:
# TODO: drop this test
if opts.namespace == "-":
raise RuntimeError("namespace == '-' is no longer supported")
Expand All @@ -519,8 +526,10 @@ def load(opts):
else:
fh = io.open(fn, mode="rt", encoding="ascii")
_logger.info("Opened " + fn)
seq_bar = tqdm.tqdm(FastaIter(fh), unit=" seqs", disable=disable_bar, leave=False)
for defline, seq in seq_bar:
seq_bar = tqdm.tqdm(
FastaIter(fh), unit=" seqs", disable=disable_bar, leave=False # type: ignore noqa: E501
)
for defline, seq in seq_bar: # type: ignore
n_seqs_seen += 1
seq_bar.set_description(
"sequences: {nsa}/{nss} added/seen; aliases: {naa} added".format(
Expand All @@ -535,7 +544,7 @@ def load(opts):
sr.commit()


def pull(opts):
def pull(opts: argparse.Namespace) -> None:
remote_instances = _get_remote_instances(opts)
if opts.instance_name:
instance_name = opts.instance_name
Expand Down Expand Up @@ -569,7 +578,7 @@ def pull(opts):
update_latest(opts, instance_name)


def show_status(opts):
def show_status(opts: argparse.Namespace) -> SeqRepo:
seqrepo_dir = os.path.join(opts.root_directory, opts.instance_name)
tot_size = sum(
os.path.getsize(os.path.join(dirpath, filename))
Expand All @@ -596,7 +605,7 @@ def show_status(opts):
return sr


def snapshot(opts):
def snapshot(opts: argparse.Namespace) -> None:
"""snapshot a seqrepo data directory by hardlinking sequence files,
copying sqlite databases, and remove write permissions from directories
Expand Down Expand Up @@ -675,7 +684,7 @@ def _drop_write(p):
os.chdir(wd)


def start_shell(opts):
def start_shell(opts: argparse.Namespace) -> None:
seqrepo_dir = os.path.join(opts.root_directory, opts.instance_name)
sr = SeqRepo(seqrepo_dir) # noqa: 682
import IPython
Expand All @@ -691,20 +700,20 @@ def start_shell(opts):
)


def upgrade(opts):
def upgrade(opts: argparse.Namespace) -> None:
seqrepo_dir = os.path.join(opts.root_directory, opts.instance_name)
sr = SeqRepo(seqrepo_dir, writeable=True)
print("upgraded to schema version {}".format(sr.seqinfo.schema_version()))
sr = SeqRepo(seqrepo_dir, writeable=False)
print("upgraded to schema version {}".format(sr.sequences.schema_version()))


def update_digests(opts):
def update_digests(opts: argparse.Namespace) -> None:
seqrepo_dir = os.path.join(opts.root_directory, opts.instance_name)
sr = SeqRepo(seqrepo_dir, writeable=True)
for srec in tqdm.tqdm(sr.sequences):
sr._update_digest_aliases(srec["seq_id"], srec["seq"])


def update_latest(opts, mri=None):
def update_latest(opts: argparse.Namespace, mri: Optional[str] = None) -> None:
if not mri:
instances = _get_local_instances(opts)
if not instances:
Expand All @@ -720,7 +729,7 @@ def update_latest(opts, mri=None):
_logger.info("Linked `latest` -> `{}`".format(mri))


def main():
def main() -> None:
opts = parse_arguments()

verbose_log_level = (
Expand All @@ -734,7 +743,7 @@ def main():
# INTERNAL


def _convert_alias_records_to_ns_dict(records):
def _convert_alias_records_to_ns_dict(records: Iterable[dict]) -> dict:
"""converts a set of alias db records to a dict like {ns: [aliases], ...}
aliases are lexicographicaly sorted
"""
Expand All @@ -745,7 +754,7 @@ def _convert_alias_records_to_ns_dict(records):
}


def _wrap_lines(seq, line_width):
def _wrap_lines(seq: str, line_width: int) -> Iterator[str]:
for i in range(0, len(seq), line_width):
yield seq[i : i + line_width]

Expand Down
Loading

0 comments on commit e77f3a5

Please sign in to comment.