From d99e8b88bacce3b99b14413cf8c09bef16df7063 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Mon, 1 Apr 2024 16:19:44 -0700 Subject: [PATCH 1/8] dok is empirically faster --- src/baskerville/dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/baskerville/dataset.py b/src/baskerville/dataset.py index f061b97..e127e42 100644 --- a/src/baskerville/dataset.py +++ b/src/baskerville/dataset.py @@ -319,7 +319,7 @@ def make_strand_transform(targets_df, targets_strand_df): targets_strand_df (pd.DataFrame): Targets DataFrame, with strand pairs collapsed. Returns: - scipy.sparse.csr_matrix: Sparse matrix to sum strand pairs. + scipy.sparse.dok_matrix: Sparse matrix to sum strand pairs. """ # initialize sparse matrix @@ -336,7 +336,6 @@ def make_strand_transform(targets_df, targets_strand_df): if target.identifier[-1] == "-": sti += 1 ti += 1 - strand_transform = strand_transform.tocsr() return strand_transform From 4e38acdb1d63f9da3288e1a0bfd4c6ee889cb033 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Mon, 1 Apr 2024 16:20:40 -0700 Subject: [PATCH 2/8] strand transform scores instead of predictions --- src/baskerville/snps.py | 123 +++++++++++++++------------------------- 1 file changed, 47 insertions(+), 76 deletions(-) diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index cba0354..00e42b3 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -169,10 +169,6 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): ref_preds_shift, targets_df ) - # sum strand pairs - if strand_transform is not None: - ref_preds_shift = ref_preds_shift * strand_transform - # save shift prediction ref_preds.append(ref_preds_shift) ref_preds = np.array(ref_preds) @@ -213,10 +209,6 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): alt_preds_shift, targets_df ) - # sum strand pairs - if strand_transform is not None: - alt_preds_shift = alt_preds_shift * strand_transform - # save shift prediction alt_preds.append(alt_preds_shift) @@ -237,7 +229,7 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): write_snp(rp_snp, ap_snp, scores_out, si, options.snp_stats) else: # write_snp_len(rp_snp, ap_snp, scores_out, si, options.snp_stats) - scores = compute_scores(rp_snp, ap_snp, options.snp_stats) + scores = compute_scores(rp_snp, ap_snp, options.snp_stats, strand_transform) for snp_stat in options.snp_stats: scores_out[snp_stat][si] = scores[snp_stat] @@ -280,13 +272,14 @@ def cluster_snps(snps, seq_len: int, center_pct: float): return snp_clusters -def compute_scores(ref_preds, alt_preds, snp_stats): +def compute_scores(ref_preds, alt_preds, snp_stats, strand_transform): """Compute SNP scores from reference and alternative predictions. Args: ref_preds (np.array): Reference allele predictions. alt_preds (np.array): Alternative allele predictions. snp_stats [str]: List of SAD stats to compute. + strand_transform (scipy.sparse): Strand transform matrix. """ num_shifts, seq_length, num_targets = ref_preds.shape @@ -296,98 +289,75 @@ def compute_scores(ref_preds, alt_preds, snp_stats): ref_preds_sqrt = np.sqrt(ref_preds) alt_preds_sqrt = np.sqrt(alt_preds) - # sum across length - ref_preds_sum = ref_preds.sum(axis=(0, 1)) - alt_preds_sum = alt_preds.sum(axis=(0, 1)) - ref_preds_log_sum = ref_preds_log.sum(axis=(0, 1)) - alt_preds_log_sum = alt_preds_log.sum(axis=(0, 1)) - ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0, 1)) - alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0, 1)) + # sum across length, mean across shifts + ref_preds_sum = ref_preds.sum(axis=(0, 1)) / num_shifts + alt_preds_sum = alt_preds.sum(axis=(0, 1)) / num_shifts + ref_preds_log_sum = ref_preds_log.sum(axis=(0, 1)) / num_shifts + alt_preds_log_sum = alt_preds_log.sum(axis=(0, 1)) / num_shifts + ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0, 1)) / num_shifts + alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0, 1)) / num_shifts # difference altref_diff = alt_preds - ref_preds - altref_adiff = np.abs(altref_diff) altref_log_diff = alt_preds_log - ref_preds_log - altref_log_adiff = np.abs(altref_log_diff) altref_sqrt_diff = alt_preds_sqrt - ref_preds_sqrt - altref_sqrt_adiff = np.abs(altref_sqrt_diff) # initialize scores dict scores = {} + def strand_clip_save(key, score, d2=False): + if strand_transform is not None: + if d2: + score = np.power(score, 2) + score = score @ strand_transform + score = np.sqrt(score) + else: + score = score @ strand_transform + score = np.clip(score, np.finfo(np.float16).min, np.finfo(np.float16).max) + scores[key] = score.astype("float16") + # compare reference to alternative via sum subtraction if "SUM" in snp_stats: sad = alt_preds_sum - ref_preds_sum - sad = np.clip(sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["SUM"] = sad.astype("float16") + strand_clip_save("SUM", sad) if "logSUM" in snp_stats: log_sad = alt_preds_log_sum - ref_preds_log_sum - log_sad = np.clip(log_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["logSUM"] = log_sad.astype("float16") + strand_clip_save("logSUM", log_sad) if "sqrtSUM" in snp_stats: sqrt_sad = alt_preds_sqrt_sum - ref_preds_sqrt_sum - sqrt_sad = np.clip(sqrt_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["sqrtSUM"] = sqrt_sad.astype("float16") - - # TEMP during name change - if "SAD" in snp_stats: - sad = alt_preds_sum - ref_preds_sum - sad = np.clip(sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["SAD"] = sad.astype("float16") - if "logSAD" in snp_stats: - log_sad = alt_preds_log_sum - ref_preds_log_sum - log_sad = np.clip(log_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["logSAD"] = log_sad.astype("float16") - if "sqrtSAD" in snp_stats: - sqrt_sad = alt_preds_sqrt_sum - ref_preds_sqrt_sum - sqrt_sad = np.clip(sqrt_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["sqrtSAD"] = sqrt_sad.astype("sqrtSAD") + strand_clip_save("sqrtSUM", sqrt_sad) # compare reference to alternative via max subtraction if "SAX" in snp_stats: + altref_adiff = np.abs(altref_diff) sax = [] for s in range(num_shifts): max_i = np.argmax(altref_adiff[s], axis=0) sax.append(altref_diff[s, max_i, np.arange(num_targets)]) sax = np.array(sax).mean(axis=0) - scores["SAX"] = sax.astype("float16") + strand_clip_save("SAX", sax) # L1 norm of difference vector if "D1" in snp_stats: - sad_d1 = altref_adiff.sum(axis=1) - sad_d1 = np.clip(sad_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_d1 = sad_d1.mean(axis=0) - scores["D1"] = sad_d1.mean().astype("float16") + sad_d1 = np.linalg.norm(altref_diff, ord=1, axis=1) + strand_clip_save("D1", sad_d1.mean(axis=0)) if "logD1" in snp_stats: - log_d1 = altref_log_adiff.sum(axis=1) - log_d1 = np.clip(log_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - log_d1 = log_d1.mean(axis=0) - scores["logD1"] = log_d1.astype("float16") + log_d1 = np.linalg.norm(altref_log_diff, ord=1, axis=1) + strand_clip_save("logD1", log_d1.mean(axis=0)) if "sqrtD1" in snp_stats: - sqrt_d1 = altref_sqrt_adiff.sum(axis=1) - sqrt_d1 = np.clip(sqrt_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sqrt_d1 = sqrt_d1.mean(axis=0) - scores["sqrtD1"] = sqrt_d1.astype("float16") + sqrt_d1 = np.linalg.norm(altref_sqrt_diff, ord=1, axis=1) + strand_clip_save("sqrtD1", sqrt_d1.mean(axis=0)) # L2 norm of difference vector if "D2" in snp_stats: - altref_diff2 = np.power(altref_diff, 2) - sad_d2 = np.sqrt(altref_diff2.sum(axis=1)) - sad_d2 = np.clip(sad_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_d2 = sad_d2.mean(axis=0) - scores["D2"] = sad_d2.astype("float16") + sad_d2 = np.linalg.norm(altref_diff, ord=2, axis=1) + strand_clip_save("D2", sad_d2.mean(axis=0), d2=True) if "logD2" in snp_stats: - altref_log_diff2 = np.power(altref_log_diff, 2) - log_d2 = np.sqrt(altref_log_diff2.sum(axis=1)) - log_d2 = np.clip(log_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - log_d2 = log_d2.mean(axis=0) - scores["logD2"] = log_d2.astype("float16") + log_d2 = np.linalg.norm(altref_log_diff, ord=2, axis=1) + strand_clip_save("logD2", log_d2.mean(axis=0), d2=True) if "sqrtD2" in snp_stats: - altref_sqrt_diff2 = np.power(altref_sqrt_diff, 2) - sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=1)) - sqrt_d2 = np.clip(sqrt_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sqrt_d2 = sqrt_d2.mean(axis=0) - scores["sqrtD2"] = sqrt_d2.astype("float16") + sqrt_d2 = np.linalg.norm(altref_sqrt_diff, ord=2, axis=1) + strand_clip_save("sqrtD2", sqrt_d2.mean(axis=0), d2=True) if "JS" in snp_stats: # normalized scores @@ -404,7 +374,9 @@ def compute_scores(ref_preds, alt_preds, snp_stats): alt_ref_entr = rel_entr(alt_preds_norm[s], ref_preds_norm[s]).sum(axis=0) js_dist.append((ref_alt_entr + alt_ref_entr) / 2) js_dist = np.mean(js_dist, axis=0) - scores["JS"] = js_dist.astype("float16") + # handling strand this way is incorrect, but I'm punting for now + strand_clip_save("JS", js_dist) + if "logJS" in snp_stats: # normalized scores pseudocounts = np.percentile(ref_preds_log, 25, axis=0) @@ -416,15 +388,14 @@ def compute_scores(ref_preds, alt_preds, snp_stats): # compare normalized JS log_js_dist = [] for s in range(num_shifts): - ref_alt_entr = rel_entr(ref_preds_log_norm[s], alt_preds_log_norm[s]).sum( - axis=0 - ) - alt_ref_entr = rel_entr(alt_preds_log_norm[s], ref_preds_log_norm[s]).sum( - axis=0 - ) + rps = ref_preds_log_norm[s] + aps = alt_preds_log_norm[s] + ref_alt_entr = rel_entr(rps, aps).sum(axis=0) + alt_ref_entr = rel_entr(aps, rps).sum(axis=0) log_js_dist.append((ref_alt_entr + alt_ref_entr) / 2) log_js_dist = np.mean(log_js_dist, axis=0) - scores["logJS"] = log_js_dist.astype("float16") + # handling strand this way is incorrect, but I'm punting for now + strand_clip_save("logJS", log_js_dist) # predictions if "REF" in snp_stats: From c033a822c673b365b8adbcce946d81a5b461ed35 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Thu, 4 Apr 2024 14:50:39 -0700 Subject: [PATCH 3/8] concurrent futures --- src/baskerville/snps.py | 162 ++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 80 deletions(-) diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index 00e42b3..2d37d9a 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -1,3 +1,4 @@ +import concurrent import json import pdb import sys @@ -145,96 +146,97 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): num_shifts, ) + # CPU computation + def score_write(ref_preds, alt_preds, si): + scores = compute_scores(ref_preds, alt_preds, options.snp_stats, strand_transform) + for snp_stat in options.snp_stats: + scores_out[snp_stat][si] = scores[snp_stat] + + if options.untransform_old: + untransform = dataset.untransform_preds1 + else: + untransform = dataset.untransform_preds + # SNP index si = 0 - for sc in tqdm(snp_clusters): - snp_1hot_list = sc.get_1hots(genome_open) - ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0) - - # predict reference - ref_preds = [] - for shift in options.shifts: - ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift) - ref_preds_shift = seqnn_model(ref_1hot_shift)[0] - - # untransform predictions - if options.targets_file is not None: - if options.untransform_old: - ref_preds_shift = dataset.untransform_preds1( - ref_preds_shift, targets_df - ) - else: - ref_preds_shift = dataset.untransform_preds( - ref_preds_shift, targets_df - ) - - # save shift prediction - ref_preds.append(ref_preds_shift) - ref_preds = np.array(ref_preds) - - ai = 0 - for alt_1hot in snp_1hot_list[1:]: - alt_1hot = np.expand_dims(alt_1hot, axis=0) - - # add compensation shifts for indels - indel_size = sc.snps[ai].indel_size() - if indel_size == 0: - alt_shifts = options.shifts - else: - # repeat reference predictions, unless stitching - if not options.indel_stitch: - ref_preds = np.repeat(ref_preds, 2, axis=0) - - # add compensation shifts - alt_shifts = [] - for shift in options.shifts: - alt_shifts.append(shift) - alt_shifts.append(shift - indel_size) - - # predict alternate - alt_preds = [] - for shift in alt_shifts: - alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift) - alt_preds_shift = seqnn_model(alt_1hot_shift)[0] + with concurrent.futures.ThreadPoolExecutor() as executor: + for sc in tqdm(snp_clusters): + snp_1hot_list = sc.get_1hots(genome_open) + ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0) + + # predict reference + ref_preds = [] + for shift in options.shifts: + ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift) + ref_preds_shift = seqnn_model(ref_1hot_shift)[0] # untransform predictions - if options.targets_file is not None: - if options.untransform_old: - alt_preds_shift = dataset.untransform_preds1( - alt_preds_shift, targets_df - ) - else: - alt_preds_shift = dataset.untransform_preds( - alt_preds_shift, targets_df - ) + if options.targets_file is None: + ref_preds.append(ref_preds_shift) + else: + rpsf = executor.submit(untransform, ref_preds_shift, targets_df) + ref_preds.append(rpsf) - # save shift prediction - alt_preds.append(alt_preds_shift) + ai = 0 + for alt_1hot in snp_1hot_list[1:]: + alt_1hot = np.expand_dims(alt_1hot, axis=0) - # stitch indel compensation shifts - if indel_size != 0 and options.indel_stitch: - alt_preds = stitch_preds(alt_preds, options.shifts) + # add compensation shifts for indels + indel_size = sc.snps[ai].indel_size() + if indel_size == 0: + alt_shifts = options.shifts + else: + # repeat reference predictions, unless stitching + if not options.indel_stitch: + ref_preds = np.repeat(ref_preds, 2, axis=0) + + # add compensation shifts + alt_shifts = [] + for shift in options.shifts: + alt_shifts.append(shift) + alt_shifts.append(shift - indel_size) + + # predict alternate + alt_preds = [] + for shift in alt_shifts: + alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift) + alt_preds_shift = seqnn_model(alt_1hot_shift)[0] + + # untransform predictions + if options.targets_file is None: + alt_preds.append(alt_preds_shift) + else: + apsf = executor.submit(untransform, alt_preds_shift, targets_df) + alt_preds.append(apsf) - # flip reference and alternate - if snps[si].flipped: - rp_snp = np.array(alt_preds) - ap_snp = np.array(ref_preds) - else: - rp_snp = np.array(ref_preds) - ap_snp = np.array(alt_preds) + # result + if options.targets_file is not None: + # get result, only if not already gotten + if isinstance(ref_preds[0], concurrent.futures.Future): + ref_preds = [rpsf.result() for rpsf in ref_preds] + alt_preds = [apsf.result() for apsf in alt_preds] + + # stitch indel compensation shifts + if indel_size != 0 and options.indel_stitch: + alt_preds = stitch_preds(alt_preds, options.shifts) + + # flip reference and alternate + if snps[si].flipped: + rp_snp = np.array(alt_preds) + ap_snp = np.array(ref_preds) + else: + rp_snp = np.array(ref_preds) + ap_snp = np.array(alt_preds) - # write SNP - if sum_length: - write_snp(rp_snp, ap_snp, scores_out, si, options.snp_stats) - else: - # write_snp_len(rp_snp, ap_snp, scores_out, si, options.snp_stats) - scores = compute_scores(rp_snp, ap_snp, options.snp_stats, strand_transform) - for snp_stat in options.snp_stats: - scores_out[snp_stat][si] = scores[snp_stat] + # write SNP + if sum_length: + write_snp(rp_snp, ap_snp, scores_out, si, options.snp_stats) + else: + executor.submit(score_write, rp_snp, ap_snp, si) - # update SNP index - si += 1 + # update SNP index + si += 1 # close genome genome_open.close() From 528892ebdaff5f142c6f0dd141ab9e7be26d3799 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Thu, 4 Apr 2024 14:51:04 -0700 Subject: [PATCH 4/8] delete old score code --- src/baskerville/snps.py | 145 ---------------------------------------- 1 file changed, 145 deletions(-) diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index 2d37d9a..443d2ed 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -595,151 +595,6 @@ def write_snp(ref_preds_sum, alt_preds_sum, scores_out, si, snp_stats): scores_out["SAD"][si] = sad.astype("float16") -def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats): - """Write SNP predictions to HDF, assuming the length dimension has - been maintained. - - Args: - ref_preds (np.array): Reference allele predictions. - alt_preds (np.array): Alternative allele predictions. - scores_out (h5py.File): Output HDF5 file. - si (int): SNP index. - snp_stats [str]: List of SAD stats to compute. - """ - num_shifts, seq_length, num_targets = ref_preds.shape - - # log/sqrt - ref_preds_log = np.log2(ref_preds + 1) - alt_preds_log = np.log2(alt_preds + 1) - ref_preds_sqrt = np.sqrt(ref_preds) - alt_preds_sqrt = np.sqrt(alt_preds) - - # sum across length - ref_preds_sum = ref_preds.sum(axis=(0, 1)) - alt_preds_sum = alt_preds.sum(axis=(0, 1)) - ref_preds_log_sum = ref_preds_log.sum(axis=(0, 1)) - alt_preds_log_sum = alt_preds_log.sum(axis=(0, 1)) - ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0, 1)) - alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0, 1)) - - # difference - altref_diff = alt_preds - ref_preds - altref_adiff = np.abs(altref_diff) - altref_log_diff = alt_preds_log - ref_preds_log - altref_log_adiff = np.abs(altref_log_diff) - altref_sqrt_diff = alt_preds_sqrt - ref_preds_sqrt - altref_sqrt_adiff = np.abs(altref_sqrt_diff) - - # compare reference to alternative via sum subtraction - if "SAD" in snp_stats: - sad = alt_preds_sum - ref_preds_sum - sad = np.clip(sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores_out["SAD"][si] = sad.astype("float16") - if "logSAD" in snp_stats: - log_sad = alt_preds_log_sum - ref_preds_log_sum - log_sad = np.clip(log_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores_out["logSAD"][si] = log_sad.astype("float16") - if "sqrtSAD" in snp_stats: - sqrt_sad = alt_preds_sqrt_sum - ref_preds_sqrt_sum - sqrt_sad = np.clip(sqrt_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores_out["sqrtSAD"][si] = sqrt_sad.astype("float16") - - # compare reference to alternative via max subtraction - if "SAX" in snp_stats: - sax = [] - for s in range(num_shifts): - max_i = np.argmax(altref_adiff[s], axis=0) - sax.append(altref_diff[s, max_i, np.arange(num_targets)]) - sax = np.array(sax).mean(axis=0) - scores_out["SAX"][si] = sax.astype("float16") - - # L1 norm of difference vector - if "D1" in snp_stats: - sad_d1 = altref_adiff.sum(axis=1) - sad_d1 = np.clip(sad_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_d1 = sad_d1.mean(axis=0) - scores_out["D1"][si] = sad_d1.mean().astype("float16") - if "logD1" in snp_stats: - log_d1 = altref_log_adiff.sum(axis=1) - log_d1 = np.clip(log_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - log_d1 = log_d1.mean(axis=0) - scores_out["logD1"][si] = log_d1.astype("float16") - if "sqrtD1" in snp_stats: - sqrt_d1 = altref_sqrt_adiff.sum(axis=1) - sqrt_d1 = np.clip(sqrt_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sqrt_d1 = sqrt_d1.mean(axis=0) - scores_out["sqrtD1"][si] = sqrt_d1.astype("float16") - - # L2 norm of difference vector - if "D2" in snp_stats: - altref_diff2 = np.power(altref_diff, 2) - sad_d2 = np.sqrt(altref_diff2.sum(axis=1)) - sad_d2 = np.clip(sad_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_d2 = sad_d2.mean(axis=0) - scores_out["D2"][si] = sad_d2.astype("float16") - if "logD2" in snp_stats: - altref_log_diff2 = np.power(altref_log_diff, 2) - log_d2 = np.sqrt(altref_log_diff2.sum(axis=1)) - log_d2 = np.clip(log_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - log_d2 = log_d2.mean(axis=0) - scores_out["logD2"][si] = log_d2.astype("float16") - if "sqrtD2" in snp_stats: - altref_sqrt_diff2 = np.power(altref_sqrt_diff, 2) - sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=1)) - sqrt_d2 = np.clip(sqrt_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sqrt_d2 = sqrt_d2.mean(axis=0) - scores_out["sqrtD2"][si] = sqrt_d2.astype("float16") - - if "JS" in snp_stats: - # normalized scores - pseudocounts = np.percentile(ref_preds, 25, axis=1) - ref_preds_norm = ref_preds + pseudocounts - ref_preds_norm /= ref_preds_norm.sum(axis=1) - alt_preds_norm = alt_preds + pseudocounts - alt_preds_norm /= alt_preds_norm.sum(axis=1) - - # compare normalized JS - js_dist = [] - for s in range(num_shifts): - ref_alt_entr = rel_entr(ref_preds_norm[s], alt_preds_norm[s]).sum(axis=0) - alt_ref_entr = rel_entr(alt_preds_norm[s], ref_preds_norm[s]).sum(axis=0) - js_dist.append((ref_alt_entr + alt_ref_entr) / 2) - js_dist = np.mean(js_dist, axis=0) - scores_out["JS"][si] = js_dist.astype("float16") - if "logJS" in snp_stats: - # normalized scores - pseudocounts = np.percentile(ref_preds_log, 25, axis=0) - ref_preds_log_norm = ref_preds_log + pseudocounts - ref_preds_log_norm /= ref_preds_log_norm.sum(axis=0) - alt_preds_log_norm = alt_preds_log + pseudocounts - alt_preds_log_norm /= alt_preds_log_norm.sum(axis=0) - - # compare normalized JS - log_js_dist = [] - for s in range(num_shifts): - ref_alt_entr = rel_entr(ref_preds_log_norm[s], alt_preds_log_norm[s]).sum( - axis=0 - ) - alt_ref_entr = rel_entr(alt_preds_log_norm[s], ref_preds_log_norm[s]).sum( - axis=0 - ) - log_js_dist.append((ref_alt_entr + alt_ref_entr) / 2) - log_js_dist = np.mean(log_js_dist, axis=0) - scores_out["logJS"][si] = log_js_dist.astype("float16") - - # predictions - if "REF" in snp_stats: - ref_preds = np.clip( - ref_preds, np.finfo(np.float16).min, np.finfo(np.float16).max - ) - scores_out["REF"][si] = ref_preds.astype("float16") - if "ALT" in snp_stats: - alt_preds = np.clip( - alt_preds, np.finfo(np.float16).min, np.finfo(np.float16).max - ) - scores_out["ALT"][si] = alt_preds.astype("float16") - - class SNPCluster: def __init__(self): self.snps = [] From 0ca1681fc819dd178badf5c0becc72e8afc8d761 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Thu, 4 Apr 2024 14:53:06 -0700 Subject: [PATCH 5/8] black --- src/baskerville/snps.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index 443d2ed..085db8c 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -148,7 +148,9 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): # CPU computation def score_write(ref_preds, alt_preds, si): - scores = compute_scores(ref_preds, alt_preds, options.snp_stats, strand_transform) + scores = compute_scores( + ref_preds, alt_preds, options.snp_stats, strand_transform + ) for snp_stat in options.snp_stats: scores_out[snp_stat][si] = scores[snp_stat] From 6f0ee9b14a3f7e660faba37926353d826edda816 Mon Sep 17 00:00:00 2001 From: lruizcalico Date: Tue, 9 Apr 2024 12:55:18 -0700 Subject: [PATCH 6/8] add strand_transform to fix pytest error --- src/baskerville/scripts/hound_ism_snp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/baskerville/scripts/hound_ism_snp.py b/src/baskerville/scripts/hound_ism_snp.py index 12b8ba5..e8a8024 100755 --- a/src/baskerville/scripts/hound_ism_snp.py +++ b/src/baskerville/scripts/hound_ism_snp.py @@ -156,6 +156,7 @@ def main(): else: targets_strand_df = targets_df strand_transform = None + num_targets = targets_strand_df.shape[0] ################################################################# @@ -249,7 +250,7 @@ def main(): alt_preds = np.array(alt_preds) ism_scores = snps.compute_scores( - ref_preds, alt_preds, options.snp_stats + ref_preds, alt_preds, options.snp_stats, strand_transform ) for snp_stat in options.snp_stats: scores_h5[snp_stat][si, mi - mut_start, ni] = ism_scores[ From 0c8021bc119d5a3470e0c4e7712d89e9558ee1cc Mon Sep 17 00:00:00 2001 From: lruizcalico Date: Tue, 9 Apr 2024 13:05:59 -0700 Subject: [PATCH 7/8] add strand_transform to hound_ism_bed line 278 fix missing arg --- src/baskerville/scripts/hound_ism_bed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/baskerville/scripts/hound_ism_bed.py b/src/baskerville/scripts/hound_ism_bed.py index 7449c18..fea5f12 100755 --- a/src/baskerville/scripts/hound_ism_bed.py +++ b/src/baskerville/scripts/hound_ism_bed.py @@ -276,7 +276,7 @@ def main(): alt_preds = np.array(alt_preds) ism_scores = snps.compute_scores( - ref_preds, alt_preds, options.snp_stats + ref_preds, alt_preds, options.snp_stats, strand_transform ) for snp_stat in options.snp_stats: scores_h5[snp_stat][si, mi - mut_start, ni] = ism_scores[ From 95e0b090d4c9adfc8fc9bb583aee994da75d7568 Mon Sep 17 00:00:00 2001 From: lruizcalico Date: Tue, 9 Apr 2024 13:19:15 -0700 Subject: [PATCH 8/8] put None for strand_transform in hound_ism_bed hound_ism_snp to pass test --- src/baskerville/scripts/hound_ism_bed.py | 2 +- src/baskerville/scripts/hound_ism_snp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/baskerville/scripts/hound_ism_bed.py b/src/baskerville/scripts/hound_ism_bed.py index fea5f12..7d4668e 100755 --- a/src/baskerville/scripts/hound_ism_bed.py +++ b/src/baskerville/scripts/hound_ism_bed.py @@ -276,7 +276,7 @@ def main(): alt_preds = np.array(alt_preds) ism_scores = snps.compute_scores( - ref_preds, alt_preds, options.snp_stats, strand_transform + ref_preds, alt_preds, options.snp_stats, None ) for snp_stat in options.snp_stats: scores_h5[snp_stat][si, mi - mut_start, ni] = ism_scores[ diff --git a/src/baskerville/scripts/hound_ism_snp.py b/src/baskerville/scripts/hound_ism_snp.py index e8a8024..0318765 100755 --- a/src/baskerville/scripts/hound_ism_snp.py +++ b/src/baskerville/scripts/hound_ism_snp.py @@ -250,7 +250,7 @@ def main(): alt_preds = np.array(alt_preds) ism_scores = snps.compute_scores( - ref_preds, alt_preds, options.snp_stats, strand_transform + ref_preds, alt_preds, options.snp_stats, None ) for snp_stat in options.snp_stats: scores_h5[snp_stat][si, mi - mut_start, ni] = ism_scores[