From 5871257bad594fabedeab972088b45ca648fb0de Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 21 Jun 2023 16:14:58 -0700 Subject: [PATCH] untransform --- bin/borzoi_sed.py | 110 +++++++++++++++++++++++++--------------------- 1 file changed, 59 insertions(+), 51 deletions(-) diff --git a/bin/borzoi_sed.py b/bin/borzoi_sed.py index 964d4174..e50d294d 100755 --- a/bin/borzoi_sed.py +++ b/bin/borzoi_sed.py @@ -23,6 +23,7 @@ import pdb import sys import time +from tqdm import tqdm import h5py import numpy as np @@ -32,12 +33,11 @@ from scipy.special import rel_entr import tensorflow as tf -from basenji import dna_io from basenji import gene as bgene from basenji import seqnn from basenji import stream from basenji import vcf as bvcf - +from basenji_sad import untransform_preds, untransform_preds1 ''' borzoi_sed.py @@ -81,6 +81,8 @@ def main(): parser.add_option('-t', dest='targets_file', default=None, type='str', help='File specifying target indexes and labels in table format') + parser.add_option('-u', dest='untransform_old', + default=False, action='store_true') (options, args) = parser.parse_args() if len(args) == 3: @@ -194,17 +196,6 @@ def main(): snpseq_gene_slice = [snpseq_gene_slice[si] for si in range(num_snps_pre) if snp_gene_mask[si]] num_snps = len(snps) - # create SNP seq generator - genome_open = pysam.Fastafile(options.genome_fasta) - - def snp_gen(): - for snp in snps: - # get SNP sequences - snp_1hot_list = bvcf.snp_seq1(snp, seq_len, genome_open) - for snp_1hot in snp_1hot_list: - yield snp_1hot - - ################################################################# # setup output @@ -214,22 +205,34 @@ def snp_gen(): ################################################################# # predict SNP scores, write output - # initialize predictions stream - preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(), params_train['batch_size']) - - # predictions index - pi = 0 + # create SNP seq generator + genome_open = pysam.Fastafile(options.genome_fasta) # SNP/gene index xi = 0 # for each SNP sequence - for si in range(num_snps): + for si, snp in tqdm(enumerate(snps), total=len(snps)): + # get SNP sequences + snp_1hot_list = bvcf.snp_seq1(snp, seq_len, genome_open) + snps_1hot = np.array(snp_1hot_list) + # get predictions - ref_preds = preds_stream[pi] - pi += 1 - alt_preds = preds_stream[pi] - pi += 1 + if params_train['batch_size'] == 1: + ref_preds = seqnn_model(snps_1hot[:1])[0] + alt_preds = seqnn_model(snps_1hot[1:])[0] + else: + snp_preds = seqnn_model(snps_1hot) + ref_preds, alt_preds = snp_preds[0], snp_preds[1] + + # untransform predictions + if options.targets_file is not None: + if options.untransform_old: + ref_preds = untransform_preds1(ref_preds, targets_df) + alt_preds = untransform_preds1(alt_preds, targets_df) + else: + ref_preds = untransform_preds(ref_preds, targets_df) + alt_preds = untransform_preds(alt_preds, targets_df) if options.bedgraph: write_bedgraph_snp(snps[si], ref_preds, alt_preds, options.out_dir, model_stride) @@ -271,6 +274,10 @@ def snp_gen(): sed_out.close() +def clip_float(x, dtype=np.float16): + return np.clip(x, np.finfo(dtype).min, np.finfo(dtype).max) + + def initialize_output_h5(out_dir: str, sed_stats, snps, snpseq_gene_slice, targets_df): """Initialize an output HDF5 file for SAD stats. @@ -543,8 +550,8 @@ def write_snp(ref_preds, alt_preds, sed_out, xi: int, sed_stats, pseudocounts): been maintained. Args: - ref_preds (np.ndarray): Reference predictions. - alt_preds (np.ndarray): Alternate predictions. + ref_preds (np.ndarray): Reference predictions, (gene length x tasks) + alt_preds (np.ndarray): Alternate predictions, (gene length x tasks) sed_out (h5py.File): HDF5 output file. xi (int): SNP index. sed_stats (list): SED statistics to compute. @@ -552,31 +559,32 @@ def write_snp(ref_preds, alt_preds, sed_out, xi: int, sed_stats, pseudocounts): """ # ref/alt_preds is L x T - ref_preds = ref_preds.astype('float64') - alt_preds = alt_preds.astype('float64') seq_len, num_targets = ref_preds.shape # sum across bins ref_preds_sum = ref_preds.sum(axis=0) alt_preds_sum = alt_preds.sum(axis=0) - # compare reference to alternative via mean subtraction + # difference of sums if 'SED' in sed_stats: - sad = alt_preds_sum - ref_preds_sum - sed_out['SED'][xi] = sad.astype('float16') - - # compare reference to alternative via mean log division - if 'SEDR' in sed_stats: - sar = np.log2(alt_preds_sum + pseudocounts) \ - - np.log2(ref_preds_sum + pseudocounts) - sed_out['SEDR'][xi] = sar.astype('float16') - - # compare geometric means - if 'SER' in sed_stats: - sar_vec = np.log2(alt_preds + pseudocounts) \ - - np.log2(ref_preds + pseudocounts) - geo_sad = sar_vec.sum(axis=0) - sed_out['SER'][xi] = geo_sad.astype('float16') + sed = alt_preds_sum - ref_preds_sum + sed_out['SED'][xi] = clip_float(sed).astype('float16') + if 'logSED' in sed_stats: + log_sed = np.log2(alt_preds_sum + 1) \ + - np.log2(ref_preds_sum + 1) + sed_out['logSED'][xi] = log_sed.astype('float16') + + # difference L1 norm + if 'D1' in sed_stats: + diff_abs = np.abs(ref_preds - alt_preds) + diff_norm1 = diff_abs.sum(axis=0) + sed_out['D1'][xi] = clip_float(diff_norm1).astype('float16') + + # difference L2 norm + if 'D2' in sed_stats: + diff2 = np.power(ref_preds - alt_preds, 2) + diff_norm2 = np.sqrt(diff2.sum(axis=0)) + sed_out['D2'][xi] = clip_float(diff_norm2).astype('float16') # normalized scores ref_preds_norm = ref_preds + pseudocounts @@ -585,16 +593,16 @@ def write_snp(ref_preds, alt_preds, sed_out, xi: int, sed_stats, pseudocounts): alt_preds_norm /= alt_preds_norm.sum(axis=0) # compare normalized squared difference - if 'D2' in sed_stats: - diff_norm2 = np.power(ref_preds_norm - alt_preds_norm, 2) - diff_norm2 = diff_norm2.sum(axis=0) - sed_out['D2'][xi] = diff_norm2.astype('float16') + if 'nD2' in sed_stats: + ndiff2 = np.power(ref_preds_norm - alt_preds_norm, 2) + ndiff_norm2 = np.sqrt(ndiff2.sum(axis=0)) + sed_out['nD2'][xi] = ndiff_norm2.astype('float16') # compare normalized abs max - if 'D0' in sed_stats: - diff_norm0 = np.abs(ref_preds_norm - alt_preds_norm) - diff_norm0 = diff_norm0.max(axis=0) - sed_out['D0'][xi] = diff_norm0.astype('float16') + if 'nDi' in sed_stats: + ndiff_abs = np.abs(ref_preds_norm - alt_preds_norm) + ndiff_normi = ndiff_abs.max(axis=0) + sed_out['nDi'][xi] = ndiff_normi.astype('float16') # compare normalized JS if 'JS' in sed_stats: