Skip to content

Commit

Permalink
Merge pull request #27 from JianjunJin/new_search_likelihood
Browse files Browse the repository at this point in the history
estimate_min_align_counts: depth_factor updated
  • Loading branch information
JianjunJin authored May 3, 2024
2 parents 68cb204 + a23ea82 commit 4be0af9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
9 changes: 8 additions & 1 deletion traversome/GraphAlignConflicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@ def __init__(
self.n_balls = None
self.max_load = None

def run(self):
def detect(self):
"""
Executes the algorithm to find conflicts in the graph and determine the maximum load.
Returns:
conflict_n (list): A list of vertex numbers with conflicts.
max_loads (list): A list of maximum loads for each vertex with conflicts.
"""
v_window_conflicts = self._find_vertex_window_wise_conflicts()
all_conflicts = []
for conflicts in v_window_conflicts.values():
Expand Down
2 changes: 1 addition & 1 deletion traversome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""


__version__ = "0.1.4"
__version__ = "0.1.6"
__author__ = "JianJun Jin and Deren Eaton"


Expand Down
32 changes: 26 additions & 6 deletions traversome/traversome.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import Set, Union
from multiprocessing import Manager, Pool
import gc
import math
import numpy as np
# import time

Expand Down Expand Up @@ -221,6 +222,9 @@ def run(self):
min_identity=0.8 if min_alignment_identity_cutoff == "auto" else min_alignment_identity_cutoff,
build_records=False)
self.filter_alignment(alignment, min_alignment_len_cutoff, min_alignment_identity_cutoff)
if not alignment.raw_records:
logger.error("Insufficient alignment records remains after filtering!")
raise SystemExit(0)
if self.min_alignment_counts == "auto":
min_alignment_counts = self.estimate_min_align_counts(alignment=alignment)
logger.info(f"Setting minimum alignment counts to {min_alignment_counts}")
Expand Down Expand Up @@ -448,9 +452,23 @@ def filter_alignment(self, alignment, min_alignment_len_cutoff, min_alignment_id

def estimate_min_align_counts(self, alignment):
"""empirical function"""
lengths = [r.p_align_len for r in alignment.raw_records]
identities = [r.identity for r in alignment.raw_records]
mean_len = np.average(lengths)
mean_identity = np.average(identities, weights=lengths)
graph_len = sum(self.graph.vertex_info[v_].len for v_ in self.graph.vertex_info)
valid_bases = sum([r.p_align_len for r in alignment.raw_records])
return max(3, int(valid_bases / (graph_len * 21.)) + 1)
# logger.info(f"DEBUG - valid_bases: {valid_bases}")
# logger.info(f"DEBUG - graph_len: {graph_len}")
# logger.info(f"DEBUG - mean id: {mean_identity}")
# logger.info(f"DEBUG - mean len: {mean_len}")
# depth_factor = (np.log2(1 - mean_identity) * np.log2(mean_len) / 25.) ** 2
# length negatively correlates to the total number of reads theirfore, negatively correlates to expected correct same-path reads
# -np.log10(1- mean_identity)*10 is the phred quality score
# depth_factor = -np.log(1 - mean_identity) * mean_len**0.5 / 35
# logger.info("DEBUG - format: -np.log(1 - mean_identity) * mean_len**0.5 / 35")
depth_factor = -np.log(1 - mean_identity) * (3 + mean_len / 16000)
return max(3, math.ceil(valid_bases / (graph_len * depth_factor) + 1))

def do_subsampling(self):
"""
Expand Down Expand Up @@ -881,16 +899,18 @@ def detect_alignment_abnormals(self, alignment):
"""
logger.info("Detecting abnormal vertices in the graph alignment")
detect_conflict = GraphAlignConflicts(self.graph, alignment)
abnormal_vertices, max_loads = detect_conflict.run()
abnormal_vertices, max_loads = detect_conflict.detect()
if abnormal_vertices:
logger.info(f"Total number of bins: {detect_conflict.n_bins}")
logger.info(f"Total number of balls: {detect_conflict.n_balls}")
logger.info(f"Maximum load per bin: {detect_conflict.max_load}")
logger.info(f"Total number of windows: {detect_conflict.n_bins}")
logger.info(f"Total number of conflicts: {detect_conflict.n_balls}")
logger.info(f"Maximum conflicts per window: {detect_conflict.max_load}")
if len(abnormal_vertices) > 20:
info_line = f"Detected abnormal vertices (max=[{min(max_loads)}, {max(max_loads)}]): {', '.join([f'{v}(max={ld})' for v, ld in zip(abnormal_vertices[:20], max_loads[:20])])} ..."
else:
info_line = f"Detected abnormal vertices (max=[{min(max_loads)}, {max(max_loads)}]): {', '.join([f'{v}(max={ld})' for v, ld in zip(abnormal_vertices, max_loads)])}"
if self.kwargs.get("ignore_conflicts", False):
if max(max_loads) == 1:
pass
elif self.kwargs.get("ignore_conflicts", False):
logger.warning(info_line)
logger.warning(f"Ignoring conflicts and continue...")
else:
Expand Down

0 comments on commit 4be0af9

Please sign in to comment.