Skip to content

Commit

Permalink
step11 + stats + threading
Browse files Browse the repository at this point in the history
  • Loading branch information
naumenko-sa committed Jan 31, 2022
1 parent ecb5233 commit 31405e7
Showing 1 changed file with 83 additions and 20 deletions.
103 changes: 83 additions & 20 deletions hivdrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
from Bio import AlignIO
from Bio.Align import AlignInfo

import threading, queue

queue = queue.Queue()
mutex = threading.Lock()

hivdrm_work_dir = "_hivdrm_tmp"
s5_prefix = "s5_demultiplex"
s7_prefix = "s7_blast_result"
Expand Down Expand Up @@ -126,8 +131,12 @@ def s4_trim_left_right(file_fa, bp_left = 4, bp_right = 4):
def s5_demultiplex_samples(file_fa, barcodes_csv):
""" demultiplex samples based on dual barcodes """
print("Step5 - demultiplex samples")
# index sample name by barcode
samples_barcode = {}
result_file = "step5.stats.csv"
# demultiplex stats
demultiplex_stats = {}
demultiplex_stats["unmatched"] = 0
result_file = os.path.join(hivdrm_work_dir, s5_prefix, "step5.demultiplex_stats.tsv")
control_file = os.path.join(hivdrm_work_dir, "step5.done")

# still populate samples even the demultiplex work is done
Expand Down Expand Up @@ -155,6 +164,7 @@ def s5_demultiplex_samples(file_fa, barcodes_csv):
combined_barcode = f_barcode + s.reverse_complement()
samples_barcode[combined_barcode] = sample_name
samples.append(sample_name)
demultiplex_stats[sample_name] = 0

os.makedirs(os.path.join(hivdrm_work_dir, s5_prefix), exist_ok = True)

Expand All @@ -168,13 +178,24 @@ def s5_demultiplex_samples(file_fa, barcodes_csv):
for record in SeqIO.parse(fa_in, "fasta-2line"):
barcode = str(record.seq[0:barcode_f_len] + record.seq[-barcode_r_len:])
if barcode in samples_barcode:
SeqIO.write(record, filedata[samples_barcode[barcode]], "fasta-2line")
sample_name = samples_barcode[barcode]
SeqIO.write(record, filedata[sample_name], "fasta-2line")
demultiplex_stats[sample_name] += 1
else:
SeqIO.write(record, unmatched, "fasta-2line")
demultiplex_stats["unmatched"] += 1

for f in filedata.values():
f.close()
unmatched.close()
# write demultiplex stats
total = sum(demultiplex_stats.values())
with open(result_file, "wt") as fout:
fout.write("sample\tbarcodes\tpct\n")
for sample_name in demultiplex_stats:
sample_stats = demultiplex_stats[sample_name]
sample_pct = int(round(sample_stats / total, 2) * 100)
fout.write(f"{sample_name}\t{sample_stats}\t{sample_pct}\n")
touch(control_file)
touch(result_file)
return result_file
Expand Down Expand Up @@ -202,7 +223,7 @@ def s6_create_blast_db(reference_fasta):
touch(control_file)
return blastdb_path

def s7_blastn_xml(qry, base):
def s7_blastn_xml(qry, base, threads):
""" run blastn with xml output qry and base are absolute paths """
print("Step7 ... :" + qry)
os.makedirs(os.path.join(hivdrm_work_dir, s7_prefix), exist_ok = True)
Expand All @@ -213,7 +234,7 @@ def s7_blastn_xml(qry, base):
result_path = os.path.realpath(os.path.join(hivdrm_work_dir, s7_prefix, result_file))
if os.path.exists(result_path):
return result_file
cmd = (f"blastn -num_threads 1 "
cmd = (f"blastn -num_threads {threads} "
f"-query {qry} "
f"-db {base} "
f"-out {result_path} " \
Expand All @@ -223,15 +244,15 @@ def s7_blastn_xml(qry, base):
subprocess.check_call(cmd, shell = True)
return result_file

def s7_blast_all(barcodes_csv, reference_path):
def s7_blast_all(barcodes_csv, reference_path, threads):
with open(barcodes_csv, "rt") as csvfile:
csvreader = csv.reader(csvfile)
# skip header
next(csvfile, None)
for row in csvreader:
sample_fa = row[0] + ".fa"
qry_path = os.path.realpath(os.path.join(hivdrm_work_dir, s5_prefix, sample_fa))
s7_blastn_xml(qry_path, reference_path)
s7_blastn_xml(qry_path, reference_path, threads)

def hamming_dist(s1, s2):
return sum(1 for (a, b) in zip(s1, s2) if a != b)
Expand Down Expand Up @@ -275,7 +296,19 @@ def get_consensus(family):
return s_consensus

def s8_make_consensus(sample, min_family_size = 5, max_family_size = 20):
mutex.acquire()
print(f"Step8: consensus for {sample}")
mutex.release()

os.makedirs(os.path.join(hivdrm_work_dir, s8_prefix), exist_ok = True)

consensus_fasta = sample + ".consensus.fasta"
result_file = os.path.join(hivdrm_work_dir, s8_prefix, consensus_fasta)
# making consensus is long if has been created for a sample, return
if os.path.exists(result_file):
#os.remove(result_file)
return

sample_xml = sample + ".xml"
blast_xml = os.path.realpath(os.path.join(hivdrm_work_dir, s7_prefix, sample_xml))
result_handle = open(blast_xml)
Expand Down Expand Up @@ -317,22 +350,18 @@ def s8_make_consensus(sample, min_family_size = 5, max_family_size = 20):
pos = ref.find("-")
ref = ref[0:pos] + ref[pos+1:]
seq = seq[0:pos] + seq[pos+1:]

# use exact match
if umi in families:
families[umi].append(seq)
else:
families[umi] = [ seq ]

if i%1000 == 1:
print(f"Total families: {len(families)}")
print(f"Blast records processed: {i}")
#if i > 10000:
# break
if i%10000 == 1:
mutex.acquire()
print(f"Step8: {sample} : STotal families: {len(families)}")
print(f"Step8: {sample} : Blast records processed: {i}")
mutex.release()
results = {}
consensus_fasta = sample + ".consensus.fasta"
result_file = os.path.join(hivdrm_work_dir, s8_prefix, consensus_fasta)
if os.path.exists(result_file):
os.remove(result_file)

for umi in families:
filename = umi + ".fasta"
Expand All @@ -358,12 +387,41 @@ def s8_make_consensus(sample, min_family_size = 5, max_family_size = 20):
# f.write(seq + "\n")
result_handle.close()

def s8_make_consensus_all():
def s8_make_consensus_worker():
while True:
sample = queue.get()
if sample is None:
break
s8_make_consensus(sample)
queue.task_done()

def s8_make_consensus_all(n_threads):
control_file = os.path.join(hivdrm_work_dir, "step8.done")
if os.path.exists(control_file):
return

print("Step8: start workers")

threads = []
for i in range(n_threads):
t = threading.Thread(target = s8_make_consensus_worker)
t.start()
threads.append(t)

for s in samples:
s8_make_consensus(s)
queue.put(s)

queue.join()

mutex.acquire()
print("Step8: stop workers")
mutex.release()

for i in threads:
queue.put(None)
for t in threads:
t.join()

touch(control_file)

def s9_write_simple_mutations():
Expand Down Expand Up @@ -497,6 +555,10 @@ def s11_freq_table_all():
s11_dir = os.path.join(hivdrm_work_dir, s11_prefix)
tsvs = os.listdir(s11_dir)
with pd.ExcelWriter("freq.xlsx", engine = "openpyxl") as writer:
step5_result = os.path.join(hivdrm_work_dir, s5_prefix, "step5.demultiplex_stats.tsv")
if os.path.exists(step5_result):
df = pd.read_csv(step5_result, sep = "\t")
df.to_excel(writer, sheet_name = "step5_demultiplex_stats", index = False, header = True)
for tsv_file in sorted(tsvs):
tsv_path = os.path.join(s11_dir, tsv_file)
if tsv_file.endswith(".all_alleles.freq.tsv"):
Expand All @@ -511,6 +573,7 @@ def get_args(description):
parser = argparse.ArgumentParser(description = description, usage = "%(prog)s [options]")
parser.add_argument("--barcodes", required = True, help = "barcodes.csv")
parser.add_argument("--reference", required = True, help = "reference.fasta")
parser.add_argument("--threads", required = False, help = "N threads", default = 1, type = int)
parser.add_argument("fastq_files", nargs = 2, help = "f1.fq.gz f2.fq.gz")
args = parser.parse_args()
return args
Expand All @@ -527,8 +590,8 @@ def get_args(description):
step4_out = s4_trim_left_right(step3_out)
step5_out = s5_demultiplex_samples(step4_out, args.barcodes)
s6_out_fasta_ref = s6_create_blast_db(args.reference)
s7_blast_all(args.barcodes, s6_out_fasta_ref)
s8_make_consensus_all()
s7_blast_all(args.barcodes, s6_out_fasta_ref, args.threads)
s8_make_consensus_all(args.threads)
s9_write_simple_mutations()
s9_sierrapy_all()
s10_drlink_all()
Expand Down

0 comments on commit 31405e7

Please sign in to comment.