Skip to content

Commit

Permalink
untransform
Browse files Browse the repository at this point in the history
  • Loading branch information
davek44 committed Jun 21, 2023
1 parent 19edc2c commit 5871257
Showing 1 changed file with 59 additions and 51 deletions.
110 changes: 59 additions & 51 deletions bin/borzoi_sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pdb
import sys
import time
from tqdm import tqdm

import h5py
import numpy as np
Expand All @@ -32,12 +33,11 @@
from scipy.special import rel_entr
import tensorflow as tf

from basenji import dna_io
from basenji import gene as bgene
from basenji import seqnn
from basenji import stream
from basenji import vcf as bvcf

from basenji_sad import untransform_preds, untransform_preds1
'''
borzoi_sed.py
Expand Down Expand Up @@ -81,6 +81,8 @@ def main():
parser.add_option('-t', dest='targets_file',
default=None, type='str',
help='File specifying target indexes and labels in table format')
parser.add_option('-u', dest='untransform_old',
default=False, action='store_true')
(options, args) = parser.parse_args()

if len(args) == 3:
Expand Down Expand Up @@ -194,17 +196,6 @@ def main():
snpseq_gene_slice = [snpseq_gene_slice[si] for si in range(num_snps_pre) if snp_gene_mask[si]]
num_snps = len(snps)

# create SNP seq generator
genome_open = pysam.Fastafile(options.genome_fasta)

def snp_gen():
for snp in snps:
# get SNP sequences
snp_1hot_list = bvcf.snp_seq1(snp, seq_len, genome_open)
for snp_1hot in snp_1hot_list:
yield snp_1hot


#################################################################
# setup output

Expand All @@ -214,22 +205,34 @@ def snp_gen():
#################################################################
# predict SNP scores, write output

# initialize predictions stream
preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(), params_train['batch_size'])

# predictions index
pi = 0
# create SNP seq generator
genome_open = pysam.Fastafile(options.genome_fasta)

# SNP/gene index
xi = 0

# for each SNP sequence
for si in range(num_snps):
for si, snp in tqdm(enumerate(snps), total=len(snps)):
# get SNP sequences
snp_1hot_list = bvcf.snp_seq1(snp, seq_len, genome_open)
snps_1hot = np.array(snp_1hot_list)

# get predictions
ref_preds = preds_stream[pi]
pi += 1
alt_preds = preds_stream[pi]
pi += 1
if params_train['batch_size'] == 1:
ref_preds = seqnn_model(snps_1hot[:1])[0]
alt_preds = seqnn_model(snps_1hot[1:])[0]
else:
snp_preds = seqnn_model(snps_1hot)
ref_preds, alt_preds = snp_preds[0], snp_preds[1]

# untransform predictions
if options.targets_file is not None:
if options.untransform_old:
ref_preds = untransform_preds1(ref_preds, targets_df)
alt_preds = untransform_preds1(alt_preds, targets_df)
else:
ref_preds = untransform_preds(ref_preds, targets_df)
alt_preds = untransform_preds(alt_preds, targets_df)

if options.bedgraph:
write_bedgraph_snp(snps[si], ref_preds, alt_preds, options.out_dir, model_stride)
Expand Down Expand Up @@ -271,6 +274,10 @@ def snp_gen():
sed_out.close()


def clip_float(x, dtype=np.float16):
return np.clip(x, np.finfo(dtype).min, np.finfo(dtype).max)


def initialize_output_h5(out_dir: str, sed_stats, snps, snpseq_gene_slice, targets_df):
"""Initialize an output HDF5 file for SAD stats.
Expand Down Expand Up @@ -543,40 +550,41 @@ def write_snp(ref_preds, alt_preds, sed_out, xi: int, sed_stats, pseudocounts):
been maintained.
Args:
ref_preds (np.ndarray): Reference predictions.
alt_preds (np.ndarray): Alternate predictions.
ref_preds (np.ndarray): Reference predictions, (gene length x tasks)
alt_preds (np.ndarray): Alternate predictions, (gene length x tasks)
sed_out (h5py.File): HDF5 output file.
xi (int): SNP index.
sed_stats (list): SED statistics to compute.
pseudocounts (np.ndarray): Target pseudocounts for safe logs.
"""

# ref/alt_preds is L x T
ref_preds = ref_preds.astype('float64')
alt_preds = alt_preds.astype('float64')
seq_len, num_targets = ref_preds.shape

# sum across bins
ref_preds_sum = ref_preds.sum(axis=0)
alt_preds_sum = alt_preds.sum(axis=0)

# compare reference to alternative via mean subtraction
# difference of sums
if 'SED' in sed_stats:
sad = alt_preds_sum - ref_preds_sum
sed_out['SED'][xi] = sad.astype('float16')

# compare reference to alternative via mean log division
if 'SEDR' in sed_stats:
sar = np.log2(alt_preds_sum + pseudocounts) \
- np.log2(ref_preds_sum + pseudocounts)
sed_out['SEDR'][xi] = sar.astype('float16')

# compare geometric means
if 'SER' in sed_stats:
sar_vec = np.log2(alt_preds + pseudocounts) \
- np.log2(ref_preds + pseudocounts)
geo_sad = sar_vec.sum(axis=0)
sed_out['SER'][xi] = geo_sad.astype('float16')
sed = alt_preds_sum - ref_preds_sum
sed_out['SED'][xi] = clip_float(sed).astype('float16')
if 'logSED' in sed_stats:
log_sed = np.log2(alt_preds_sum + 1) \
- np.log2(ref_preds_sum + 1)
sed_out['logSED'][xi] = log_sed.astype('float16')

# difference L1 norm
if 'D1' in sed_stats:
diff_abs = np.abs(ref_preds - alt_preds)
diff_norm1 = diff_abs.sum(axis=0)
sed_out['D1'][xi] = clip_float(diff_norm1).astype('float16')

# difference L2 norm
if 'D2' in sed_stats:
diff2 = np.power(ref_preds - alt_preds, 2)
diff_norm2 = np.sqrt(diff2.sum(axis=0))
sed_out['D2'][xi] = clip_float(diff_norm2).astype('float16')

# normalized scores
ref_preds_norm = ref_preds + pseudocounts
Expand All @@ -585,16 +593,16 @@ def write_snp(ref_preds, alt_preds, sed_out, xi: int, sed_stats, pseudocounts):
alt_preds_norm /= alt_preds_norm.sum(axis=0)

# compare normalized squared difference
if 'D2' in sed_stats:
diff_norm2 = np.power(ref_preds_norm - alt_preds_norm, 2)
diff_norm2 = diff_norm2.sum(axis=0)
sed_out['D2'][xi] = diff_norm2.astype('float16')
if 'nD2' in sed_stats:
ndiff2 = np.power(ref_preds_norm - alt_preds_norm, 2)
ndiff_norm2 = np.sqrt(ndiff2.sum(axis=0))
sed_out['nD2'][xi] = ndiff_norm2.astype('float16')

# compare normalized abs max
if 'D0' in sed_stats:
diff_norm0 = np.abs(ref_preds_norm - alt_preds_norm)
diff_norm0 = diff_norm0.max(axis=0)
sed_out['D0'][xi] = diff_norm0.astype('float16')
if 'nDi' in sed_stats:
ndiff_abs = np.abs(ref_preds_norm - alt_preds_norm)
ndiff_normi = ndiff_abs.max(axis=0)
sed_out['nDi'][xi] = ndiff_normi.astype('float16')

# compare normalized JS
if 'JS' in sed_stats:
Expand Down

0 comments on commit 5871257

Please sign in to comment.