From 38c874cac76e4d55d707d99c221ce4c4d5e1f579 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Thu, 9 Nov 2023 14:16:07 -0800 Subject: [PATCH] D2 scores within shift before averaging; compensation shifts --- src/baskerville/snps.py | 160 ++++++++++++++++++++++++++-------------- src/baskerville/vcf.py | 4 + 2 files changed, 108 insertions(+), 56 deletions(-) diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index dc610e9..7286ebb 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -68,7 +68,10 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): seqnn_model.build_slice(targets_df.index) if sum_length: seqnn_model.build_sad() - seqnn_model.build_ensemble(options.rc, options.shifts) + seqnn_model.build_ensemble(options.rc) + + # shift outside seqnn + num_shifts = len(options.shifts) targets_length = seqnn_model.target_lengths[0] num_targets = seqnn_model.num_targets() @@ -128,7 +131,7 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): # setup output scores_out = initialize_output_h5( - options.out_dir, options.snp_stats, snps, targets_length, targets_strand_df + options.out_dir, options.snp_stats, snps, targets_length, targets_strand_df, num_shifts ) # SNP index @@ -136,46 +139,74 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): for sc in tqdm(snp_clusters): snp_1hot_list = sc.get_1hots(genome_open) + ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0) # predict reference - ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0) - ref_preds = seqnn_model(ref_1hot)[0] + ref_preds = [] + for shift in options.shifts: + ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift) + ref_preds_shift = seqnn_model(ref_1hot_shift)[0] - # untransform predictions - if options.targets_file is not None: - if options.untransform_old: - ref_preds = dataset.untransform_preds1(ref_preds, targets_df) - else: - ref_preds = dataset.untransform_preds(ref_preds, targets_df) + # untransform predictions + if options.targets_file is not None: + if options.untransform_old: + ref_preds_shift = dataset.untransform_preds1(ref_preds_shift, targets_df) + else: + ref_preds_shift = dataset.untransform_preds(ref_preds_shift, targets_df) + + # sum strand pairs + if strand_transform is not None: + ref_preds_shift = ref_preds_shift * strand_transform - # sum strand pairs - if strand_transform is not None: - ref_preds = ref_preds * strand_transform + # save shift prediction + ref_preds.append(ref_preds_shift) + ref_preds = np.array(ref_preds) + ai = 0 for alt_1hot in snp_1hot_list[1:]: alt_1hot = np.expand_dims(alt_1hot, axis=0) + # add compensation shifts for indels + indel_size = sc.snps[ai].indel_size() + if indel_size == 0: + alt_shifts = options.shifts + else: + # repeat reference predictions + ref_preds = np.repeat(ref_preds, 2, axis=0) + + # add compensation shifts + alt_shifts = [] + for shift in options.shifts: + alt_shifts.append(shift) + alt_shifts.append(shift - indel_size) + # predict alternate - alt_preds = seqnn_model(alt_1hot)[0] + alt_preds = [] + for shift in alt_shifts: + alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift) + alt_preds_shift = seqnn_model(alt_1hot_shift)[0] - # untransform predictions - if options.targets_file is not None: - if options.untransform_old: - alt_preds = dataset.untransform_preds1(alt_preds, targets_df) - else: - alt_preds = dataset.untransform_preds(alt_preds, targets_df) + # untransform predictions + if options.targets_file is not None: + if options.untransform_old: + alt_preds_shift = dataset.untransform_preds1(alt_preds_shift, targets_df) + else: + alt_preds_shift = dataset.untransform_preds(alt_preds_shift, targets_df) - # sum strand pairs - if strand_transform is not None: - alt_preds = alt_preds * strand_transform + # sum strand pairs + if strand_transform is not None: + alt_preds_shift = alt_preds_shift * strand_transform + + # save shift prediction + alt_preds.append(alt_preds_shift) # flip reference and alternate if snps[si].flipped: - rp_snp = alt_preds - ap_snp = ref_preds + rp_snp = np.array(alt_preds) + ap_snp = np.array(ref_preds) else: - rp_snp = ref_preds - ap_snp = alt_preds + rp_snp = np.array(ref_preds) + ap_snp = np.array(alt_preds) # write SNP if sum_length: @@ -222,7 +253,7 @@ def cluster_snps(snps, seq_len: int, center_pct: float): return snp_clusters -def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df): +def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df, num_shifts): """Initialize an output HDF5 file for SAD stats. Args: @@ -230,7 +261,8 @@ def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df): snp_stats [str]: List of SAD stats to compute. snps [SNP]: List of SNPs. targets_length (int): Targets' sequence length - targets_df (pd.DataFrame): Targets AataFrame. + targets_df (pd.DataFrame): Targets DataFrame. + num_shifts (int): Number of shifts. """ num_targets = targets_df.shape[0] @@ -275,7 +307,7 @@ def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df): for snp_stat in snp_stats: if snp_stat in ["REF", "ALT"]: scores_out.create_dataset( - snp_stat, shape=(num_snps, targets_length, num_targets), dtype="float16" + snp_stat, shape=(num_snps, num_shifts, targets_length, num_targets), dtype="float16" ) else: scores_out.create_dataset( @@ -407,7 +439,8 @@ def write_snp(ref_preds_sum, alt_preds_sum, scores_out, si, snp_stats): # compare reference to alternative via mean subtraction if "SAD" in snp_stats: sad = alt_preds_sum - ref_preds_sum - scores_out["SAD"][si, :] = sad.astype("float16") + sad = sad.mean(axis=0) + scores_out["SAD"][si] = sad.astype("float16") def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats): @@ -421,7 +454,7 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats): si (int): SNP index. snp_stats [str]: List of SAD stats to compute. """ - seq_length, num_targets = ref_preds.shape + num_shifts, seq_length, num_targets = ref_preds.shape # log/sqrt ref_preds_log = np.log2(ref_preds + 1) @@ -430,12 +463,12 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats): alt_preds_sqrt = np.sqrt(alt_preds) # sum across length - ref_preds_sum = ref_preds.sum(axis=0) - alt_preds_sum = alt_preds.sum(axis=0) - ref_preds_log_sum = ref_preds_log.sum(axis=0) - alt_preds_log_sum = alt_preds_log.sum(axis=0) - ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=0) - alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=0) + ref_preds_sum = ref_preds.sum(axis=(0,1)) + alt_preds_sum = alt_preds.sum(axis=(0,1)) + ref_preds_log_sum = ref_preds_log.sum(axis=(0,1)) + alt_preds_log_sum = alt_preds_log.sum(axis=(0,1)) + ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0,1)) + alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0,1)) # difference altref_diff = alt_preds - ref_preds @@ -461,53 +494,65 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats): # compare reference to alternative via max subtraction if "SAX" in snp_stats: - max_i = np.argmax(altref_adiff, axis=0) - sax = altref_diff[max_i, np.arange(num_targets)] + sax = [] + for s in range(num_shifts): + max_i = np.argmax(altref_adiff[s], axis=0) + sax.append(altref_diff[s, max_i, np.arange(num_targets)]) + sax = np.array(sax).mean(axis=0) scores_out["SAX"][si] = sax.astype("float16") # L1 norm of difference vector if "D1" in snp_stats: - sad_d1 = altref_adiff.sum(axis=0) + sad_d1 = altref_adiff.sum(axis=1) sad_d1 = np.clip(sad_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores_out["D1"][si] = sad_d1.astype("float16") + sad_d1 = sad_d1.mean(axis=0) + scores_out["D1"][si] = sad_d1.mean().astype("float16") if "logD1" in snp_stats: - log_d1 = altref_log_adiff.sum(axis=0) + log_d1 = altref_log_adiff.sum(axis=1) log_d1 = np.clip(log_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) + log_d1 = log_d1.mean(axis=0) scores_out["logD1"][si] = log_d1.astype("float16") if "sqrtD1" in snp_stats: - sqrt_d1 = altref_sqrt_adiff.sum(axis=0) + sqrt_d1 = altref_sqrt_adiff.sum(axis=1) sqrt_d1 = np.clip(sqrt_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) + sqrt_d1 = sqrt_d1.mean(axis=0) scores_out["sqrtD1"][si] = sqrt_d1.astype("float16") # L2 norm of difference vector if "D2" in snp_stats: altref_diff2 = np.power(altref_diff, 2) - sad_d2 = np.sqrt(altref_diff2.sum(axis=0)) + sad_d2 = np.sqrt(altref_diff2.sum(axis=1)) sad_d2 = np.clip(sad_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) + sad_d2 = sad_d2.mean(axis=0) scores_out["D2"][si] = sad_d2.astype("float16") if "logD2" in snp_stats: altref_log_diff2 = np.power(altref_log_diff, 2) - log_d2 = np.sqrt(altref_log_diff2.sum(axis=0)) + log_d2 = np.sqrt(altref_log_diff2.sum(axis=1)) log_d2 = np.clip(log_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) + log_d2 = log_d2.mean(axis=0) scores_out["logD2"][si] = log_d2.astype("float16") if "sqrtD2" in snp_stats: altref_sqrt_diff2 = np.power(altref_sqrt_diff, 2) - sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=0)) + sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=1)) sqrt_d2 = np.clip(sqrt_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) + sqrt_d2 = sqrt_d2.mean(axis=0) scores_out["sqrtD2"][si] = sqrt_d2.astype("float16") if "JS" in snp_stats: # normalized scores - pseudocounts = np.percentile(ref_preds, 25, axis=0) + pseudocounts = np.percentile(ref_preds, 25, axis=1) ref_preds_norm = ref_preds + pseudocounts - ref_preds_norm /= ref_preds_norm.sum(axis=0) + ref_preds_norm /= ref_preds_norm.sum(axis=1) alt_preds_norm = alt_preds + pseudocounts - alt_preds_norm /= alt_preds_norm.sum(axis=0) + alt_preds_norm /= alt_preds_norm.sum(axis=1) # compare normalized JS - ref_alt_entr = rel_entr(ref_preds_norm, alt_preds_norm).sum(axis=0) - alt_ref_entr = rel_entr(alt_preds_norm, ref_preds_norm).sum(axis=0) - js_dist = (ref_alt_entr + alt_ref_entr) / 2 + js_dist = [] + for s in range(num_shifts): + ref_alt_entr = rel_entr(ref_preds_norm[s], alt_preds_norm[s]).sum(axis=0) + alt_ref_entr = rel_entr(alt_preds_norm[s], ref_preds_norm[s]).sum(axis=0) + js_dist.append((ref_alt_entr + alt_ref_entr) / 2) + js_dist = np.mean(js_dist, axis=0) scores_out["JS"][si] = js_dist.astype("float16") if "logJS" in snp_stats: # normalized scores @@ -518,9 +563,12 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats): alt_preds_log_norm /= alt_preds_log_norm.sum(axis=0) # compare normalized JS - ref_alt_entr = rel_entr(ref_preds_log_norm, alt_preds_log_norm).sum(axis=0) - alt_ref_entr = rel_entr(alt_preds_log_norm, ref_preds_log_norm).sum(axis=0) - log_js_dist = (ref_alt_entr + alt_ref_entr) / 2 + log_js_dist = [] + for s in range(num_shifts): + ref_alt_entr = rel_entr(ref_preds_log_norm[s], alt_preds_log_norm[s]).sum(axis=0) + alt_ref_entr = rel_entr(alt_preds_log_norm[s], ref_preds_log_norm[s]).sum(axis=0) + log_js_dist.append((ref_alt_entr + alt_ref_entr) / 2) + log_js_dist = np.mean(log_js_dist, axis=0) scores_out["logJS"][si] = log_js_dist.astype("float16") # predictions diff --git a/src/baskerville/vcf.py b/src/baskerville/vcf.py index 520be53..7cca40f 100644 --- a/src/baskerville/vcf.py +++ b/src/baskerville/vcf.py @@ -697,6 +697,10 @@ def get_alleles(self): """Return a list of all alleles""" alleles = [self.ref_allele] + self.alt_alleles return alleles + + def indel_size(self): + """Return the size of the indel.""" + return len(self.alt_allele) - len(self.ref_allele) def longest_alt(self): """Return the longest alt allele."""