Skip to content

Commit

Permalink
Merge pull request #38 from YeoLab/08012024_memprof_more_work_in_shards
Browse files Browse the repository at this point in the history
memory optimizations
  • Loading branch information
ekofman authored Aug 12, 2024
2 parents 22973ad + de24816 commit 9d7a419
Show file tree
Hide file tree
Showing 11 changed files with 13,520 additions and 5,816 deletions.
5,812 changes: 2,906 additions & 2,906 deletions examples/bulk_subset_AI/final_filtered_site_info.tsv

Large diffs are not rendered by default.

2,460 changes: 1,230 additions & 1,230 deletions examples/bulk_subset_CT/final_filtered_site_info.tsv

Large diffs are not rendered by default.

7,260 changes: 7,260 additions & 0 deletions examples/sc_subset_CT/final_filtered_site_info.tsv

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/test_sc_subset_CT.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#! /bin/bash

marine.py --bam_filepath $MARINE/examples/data/single_cell_CT.md.subset.bam --output_folder $MARINE/examples/sc_subset_CT --barcode_whitelist_file $MARINE/examples/data/sc_barcodes.tsv.gz --barcode_tag "CB" --num_intervals_per_contig 16 --strandedness 2 --contigs "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,X,Y"
$MARINE/marine.py --bam_filepath $MARINE/examples/data/single_cell_CT.md.subset.bam --output_folder $MARINE/examples/sc_subset_CT --barcode_whitelist_file $MARINE/examples/data/sc_barcodes.tsv.gz --barcode_tag "CB" --num_intervals_per_contig 16 --strandedness 2 --contigs "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,X,Y"
238 changes: 114 additions & 124 deletions marine.py

Large diffs are not rendered by default.

35 changes: 23 additions & 12 deletions src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import os, psutil


def run_edit_identifier(bampath, output_folder, strandedness, barcode_tag="CB", barcode_whitelist=None, contigs=[], num_intervals_per_contig=16, verbose=False, cores=64, min_read_quality = 0):
# Make subfolder in which to information about edits
edit_info_subfolder = '{}/edit_info'.format(output_folder)
Expand All @@ -42,24 +41,28 @@ def run_edit_identifier(bampath, output_folder, strandedness, barcode_tag="CB",

start_time = time.perf_counter()

all_counts_summary_dfs = []
overall_count_summary_dict = defaultdict(lambda:0)
counts_summary_dicts = []
multiprocessing.set_start_method('spawn')

#multiprocessing.set_start_method('spawn')
with get_context("spawn").Pool(processes=cores, maxtasksperchild=4) as p:
max_ = len(edit_finding_jobs)
with tqdm(total=max_) as pbar:
for _ in p.imap_unordered(find_edits_and_split_bams_wrapper, edit_finding_jobs):
# values returned within array _ are:
# ~~~~ contig, label, barcode_to_concatted_reads_pl, total_reads, counts_df, time_df, total_time
# So the line overall_label_to_list_of_contents[_[0]][_[1]] = _[2]
# is equivalent to overall_label_to_list_of_contents[contig][label] = barcode_to_concatted_reads_pl
pbar.update()

if barcode_tag:
# Only keep this information for single cell requirements
overall_label_to_list_of_contents[_[0]][_[1]] = _[2]

total_reads = _[3]
counts_summary_dict = _[4]
for k, v in counts_summary_dict.items():
overall_count_summary_dict[k] += v

counts_summary_df = _[4]
all_counts_summary_dfs.append(counts_summary_df)

total_time = time.perf_counter() - start_time

overall_total_reads += total_reads
Expand All @@ -68,7 +71,11 @@ def run_edit_identifier(bampath, output_folder, strandedness, barcode_tag="CB",

overall_time = time.perf_counter() - start_time

overall_count_summary_df = pd.DataFrame.from_dict(overall_count_summary_dict).sum(axis=1)
all_counts_summary_dfs_combined = pd.concat(all_counts_summary_dfs, axis=1)
#print(all_counts_summary_dfs_combined.index, all_counts_summary_dfs_combined.columns)

overall_count_summary_df = pd.DataFrame.from_dict(all_counts_summary_dfs_combined).sum(axis=1)
#print(overall_count_summary_df)

return overall_label_to_list_of_contents, results, overall_time, overall_total_reads, total_seconds_for_reads, overall_count_summary_df

Expand Down Expand Up @@ -236,6 +243,7 @@ def find_edits_and_split_bams(bampath, contig, split_index, start, end, output_f

import random


def find_edits_and_split_bams_wrapper(parameters):
try:
start_time = time.perf_counter()
Expand Down Expand Up @@ -292,17 +300,20 @@ def find_edits_and_split_bams_wrapper(parameters):



def run_coverage_calculator(edit_info_grouped_per_contig_combined, output_folder,
def run_coverage_calculator(edit_info_grouped_per_contig_combined,
output_folder,
barcode_tag='CB',
paired_end=False,
verbose=False,
processes=16
processes=16,
filters=None,
):
coverage_counting_job_params = get_job_params_for_coverage_for_edits_in_contig(
edit_info_grouped_per_contig_combined,
output_folder,
barcode_tag=barcode_tag,
paired_end=paired_end,
filters=filters,
verbose=verbose
)

Expand Down Expand Up @@ -331,12 +342,12 @@ def run_coverage_calculator(edit_info_grouped_per_contig_combined, output_folder


def get_job_params_for_coverage_for_edits_in_contig(edit_info_grouped_per_contig_combined, output_folder,
barcode_tag='CB', paired_end=False, verbose=False):
barcode_tag='CB', paired_end=False, filters=None, verbose=False):
job_params = []

for contig, edit_info in edit_info_grouped_per_contig_combined.items():

job_params.append([edit_info, contig, output_folder, barcode_tag, paired_end, verbose])
job_params.append([edit_info, contig, output_folder, barcode_tag, paired_end, filters, verbose])

return job_params

Expand Down
125 changes: 100 additions & 25 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import sys
from collections import OrderedDict, defaultdict


suffixes = {
'CB': [
"A-1", "C-1", "G-1", "T-1"
Expand Down Expand Up @@ -312,6 +311,7 @@ def get_bulk_coverage_at_pos(samfile_for_barcode, contig_bam, just_contig, pos,

return coverage_at_pos


def get_coverage_df(edit_info, contig, output_folder, barcode_tag='CB', paired_end=False,
verbose=False):

Expand Down Expand Up @@ -391,35 +391,111 @@ def get_coverage_df(edit_info, contig, output_folder, barcode_tag='CB', paired_e
return coverage_df


def get_coverage_wrapper(parameters):
edit_info, contig, output_folder, barcode_tag, paired_end, verbose = parameters
def filter_output_df(output_df, filters, output_filename):
filter_stats = {}
filter_stats['original'] = len(output_df)
if output_df.empty:
filter_stats['filtered'] = len(output_df)
output_df['coverage'] = []
output_df.to_csv(output_filename, sep='\t', header=False)
return filter_stats

filtered_output_df = output_df[
(output_df.dist_from_end >= filters.get('dist_from_end')) &
(output_df.base_quality >= filters.get('base_quality'))]

output_filename = '{}/coverage/{}.tsv'.format(output_folder, contig, header=False)
if os.path.exists(output_filename):
return output_filename

edit_info = edit_info.with_columns(
pl.concat_str(
coverage_per_unique_position_df = pd.DataFrame(filtered_output_df.groupby(
[
pl.col("barcode"),
pl.col("position")
],
separator=":",
).alias("barcode_position"))
"position_barcode"
]).coverage.max())

distinguishing_columns = [
"barcode",
"contig",
"position",
"ref",
"alt",
"read_id",
"strand",
"dist_from_end",
"base_quality",
"mapping_quality"
]

edit_info_df = edit_info.to_pandas()
edit_info_df.index = edit_info_df['barcode_position']

coverage_df = get_coverage_df(edit_info, contig, output_folder, barcode_tag=barcode_tag,
paired_end=paired_end, verbose=verbose)
all_edit_info_unique_position_df = filtered_output_df.drop_duplicates(distinguishing_columns)[distinguishing_columns]

all_edit_info_unique_position_df.index = all_edit_info_unique_position_df['position'].astype(str)\
+ '_' + all_edit_info_unique_position_df['barcode']

all_edit_info_unique_position_with_coverage_df = all_edit_info_unique_position_df.join(coverage_per_unique_position_df)

if filters.get('max_edits_per_read'):
#pretty_print("\t\tFiltering out reads with more than {} edits...".format(max_edits_per_read))
read_edits = all_edit_info_unique_position_with_coverage_df.groupby('read_id').count().sort_values('barcode')
all_edit_info_unique_position_with_coverage_df = all_edit_info_unique_position_with_coverage_df[all_edit_info_unique_position_with_coverage_df.read_id.isin(read_edits[read_edits['barcode'] <= max_edits_per_read].index)]

# Combine edit information with coverage information
edit_info_and_coverage_joined = edit_info_df.join(coverage_df, how='inner')

edit_info_and_coverage_joined['position_barcode'] = edit_info_and_coverage_joined['position'].astype(str) + '_' + edit_info_and_coverage_joined['barcode'].astype(str)
edit_info_and_coverage_joined.to_csv(output_filename, sep='\t', header=False)
distinguishing_columns = [
"barcode",
"contig",
"position",
"ref",
"alt",
"read_id",
"strand",
"mapping_quality",
"coverage"
]

all_edit_info_unique_position_with_coverage_df = all_edit_info_unique_position_with_coverage_df.drop_duplicates(
distinguishing_columns)[distinguishing_columns]

return output_filename
filter_stats['filtered'] = len(all_edit_info_unique_position_with_coverage_df)


all_edit_info_unique_position_with_coverage_df.to_csv(output_filename, sep='\t', header=False)

return filter_stats


def get_coverage_wrapper(parameters):
edit_info, contig, output_folder, barcode_tag, paired_end, filters, verbose = parameters

output_filename = '{}/coverage/{}.tsv'.format(output_folder, contig, header=False)
filtered_output_filename = '{}/coverage/{}_filtered.tsv'.format(output_folder, contig, header=False)

if os.path.exists(output_filename):
# filter
edit_info_and_coverage_joined = pd.read_csv(output_filename, sep='\t', names=[
'barcode', 'contig', 'position', 'ref', 'alt', 'read_id', 'strand',
'dist_from_end', 'base_quality', 'mapping_quality', 'barcode_position',
'coverage', 'source', 'position_barcode'], dtype={'base_quality': int, 'dist_from_end': int, 'contig': str})
else:
edit_info = edit_info.with_columns(
pl.concat_str(
[
pl.col("barcode"),
pl.col("position")
],
separator=":",
).alias("barcode_position"))

edit_info_df = edit_info.to_pandas()
edit_info_df.index = edit_info_df['barcode_position']

coverage_df = get_coverage_df(edit_info, contig, output_folder, barcode_tag=barcode_tag,
paired_end=paired_end, verbose=verbose)

# Combine edit i)nformation with coverage information
edit_info_and_coverage_joined = edit_info_df.join(coverage_df, how='inner')
edit_info_and_coverage_joined['position_barcode'] = edit_info_and_coverage_joined['position'].astype(str) + '_' + edit_info_and_coverage_joined['barcode'].astype(str)
edit_info_and_coverage_joined.to_csv(output_filename, sep='\t', header=False)

filter_stats = filter_output_df(edit_info_and_coverage_joined, filters, filtered_output_filename)
assert(os.path.exists(output_filename))
assert(os.path.exists(filtered_output_filename))
return filtered_output_filename, filter_stats



Expand Down Expand Up @@ -486,8 +562,7 @@ def write_reads_to_file(reads, bam_file_name, header_string, barcode_tag="BC"):

bam_handle.close()




def concat_and_write_bams(contig, df_dict, header_string, split_bams_folder, barcode_tag='CB', number_of_expected_bams=4, verbose=False):
job_params = []

Expand Down
Loading

0 comments on commit 9d7a419

Please sign in to comment.