From 91f0d3f2215f8c6c04f4b883fef87775b6eaf143 Mon Sep 17 00:00:00 2001 From: Johannes Linder Date: Fri, 21 Jul 2023 16:41:21 -0700 Subject: [PATCH] New GPU-friendly gradient function in seqnn.py; Initial commit of various Borzoi benchmarking- and large-scale attribution scripts. --- basenji/seqnn.py | 264 +++++++- bin/borzoi_bench_ipaqtl_folds.py | 479 ++++++++++++++ bin/borzoi_bench_paqtl_folds.py | 479 ++++++++++++++ bin/borzoi_bench_trip_folds.py | 202 ++++++ bin/borzoi_satg_gene_gpu.py | 688 +++++++++++++++++++++ bin/borzoi_satg_gene_gpu_focused_ism.py | 673 ++++++++++++++++++++ bin/borzoi_satg_polya_gpu.py | 487 +++++++++++++++ bin/borzoi_satg_splice_gpu.py | 499 +++++++++++++++ bin/borzoi_sed_ipaqtl_cov.py | 773 +++++++++++++++++++++++ bin/borzoi_sed_paqtl_cov.py | 789 ++++++++++++++++++++++++ bin/borzoi_test_apa_folds_polaydb.py | 181 ++++++ bin/borzoi_test_apa_polaydb.py | 301 +++++++++ bin/borzoi_trip.py | 275 +++++++++ 13 files changed, 6086 insertions(+), 4 deletions(-) create mode 100644 bin/borzoi_bench_ipaqtl_folds.py create mode 100644 bin/borzoi_bench_paqtl_folds.py create mode 100644 bin/borzoi_bench_trip_folds.py create mode 100644 bin/borzoi_satg_gene_gpu.py create mode 100644 bin/borzoi_satg_gene_gpu_focused_ism.py create mode 100644 bin/borzoi_satg_polya_gpu.py create mode 100644 bin/borzoi_satg_splice_gpu.py create mode 100644 bin/borzoi_sed_ipaqtl_cov.py create mode 100644 bin/borzoi_sed_paqtl_cov.py create mode 100644 bin/borzoi_test_apa_folds_polaydb.py create mode 100644 bin/borzoi_test_apa_polaydb.py create mode 100644 bin/borzoi_trip.py diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 608f1edf..f9bb8ec0 100644 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -381,9 +381,265 @@ def get_conv_weights(self, conv_layer_i=0): weights = np.transpose(weights, [2,1,0]) return weights + def gradients(self, seq_1hot, head_i=None, target_slice=None, pos_slice=None, pos_mask=None, pos_slice_denom=None, pos_mask_denom=None, chunk_size=None, batch_size=1, track_scale=1., track_transform=1., clip_soft=None, pseudo_count=0., no_transform=False, use_mean=False, use_ratio=False, use_logodds=False, subtract_avg=True, input_gate=True, smooth_grad=False, n_samples=5, sample_prob=0.875, dtype='float16'): + """ Compute input gradients for sequences (GPU-friendly). """ + + # start time + t0 = time.time() + + # choose model + if self.ensemble is not None: + model = self.ensemble + elif head_i is not None: + model = self.models[head_i] + else: + model = self.model + + # verify tensor shape(s) + seq_1hot = seq_1hot.astype('float32') + target_slice = np.array(target_slice).astype('int32') + pos_slice = np.array(pos_slice).astype('int32') + + # convert constants to tf tensors + track_scale = tf.constant(track_scale, dtype=tf.float32) + track_transform = tf.constant(track_transform, dtype=tf.float32) + if clip_soft is not None : + clip_soft = tf.constant(clip_soft, dtype=tf.float32) + pseudo_count = tf.constant(pseudo_count, dtype=tf.float32) + + if pos_mask is not None : + pos_mask = np.array(pos_mask).astype('float32') + + if use_ratio and pos_slice_denom is not None : + pos_slice_denom = np.array(pos_slice_denom).astype('int32') + + if pos_mask_denom is not None : + pos_mask_denom = np.array(pos_mask_denom).astype('float32') + + if len(seq_1hot.shape) < 3: + seq_1hot = seq_1hot[None, ...] + + if len(target_slice.shape) < 2: + target_slice = target_slice[None, ...] + + if len(pos_slice.shape) < 2: + pos_slice = pos_slice[None, ...] + + if pos_mask is not None and len(pos_mask.shape) < 2: + pos_mask = pos_mask[None, ...] + + if use_ratio and pos_slice_denom is not None and len(pos_slice_denom.shape) < 2: + pos_slice_denom = pos_slice_denom[None, ...] + + if pos_mask_denom is not None and len(pos_mask_denom.shape) < 2: + pos_mask_denom = pos_mask_denom[None, ...] + + # chunk parameters + num_chunks = 1 + if chunk_size is None : + chunk_size = seq_1hot.shape[0] + else : + num_chunks = int(np.ceil(seq_1hot.shape[0] / chunk_size)) + + # loop over chunks + grad_chunks = [] + for ci in range(num_chunks) : + + # collect chunk + seq_1hot_chunk = seq_1hot[ci * chunk_size:(ci+1) * chunk_size, ...] + target_slice_chunk = target_slice[ci * chunk_size:(ci+1) * chunk_size, ...] + pos_slice_chunk = pos_slice[ci * chunk_size:(ci+1) * chunk_size, ...] + + pos_mask_chunk = None + if pos_mask is not None : + pos_mask_chunk = pos_mask[ci * chunk_size:(ci+1) * chunk_size, ...] + + pos_slice_denom_chunk = None + pos_mask_denom_chunk = None + if use_ratio and pos_slice_denom is not None : + pos_slice_denom_chunk = pos_slice_denom[ci * chunk_size:(ci+1) * chunk_size, ...] + + if pos_mask_denom is not None : + pos_mask_denom_chunk = pos_mask_denom[ci * chunk_size:(ci+1) * chunk_size, ...] + + actual_chunk_size = seq_1hot_chunk.shape[0] + + # sample noisy (discrete) perturbations of the input pattern chunk + if smooth_grad : + seq_1hot_chunk_corrupted = np.repeat(np.copy(seq_1hot_chunk), n_samples, axis=0) + + for example_ix in range(seq_1hot_chunk.shape[0]) : + for sample_ix in range(n_samples) : + corrupt_index = np.nonzero(np.random.rand(seq_1hot_chunk.shape[1]) >= sample_prob)[0] + + rand_nt_index = np.random.choice([0, 1, 2, 3], size=(corrupt_index.shape[0],)) + + seq_1hot_chunk_corrupted[example_ix * n_samples + sample_ix, corrupt_index, :] = 0. + seq_1hot_chunk_corrupted[example_ix * n_samples + sample_ix, corrupt_index, rand_nt_index] = 1. + + seq_1hot_chunk = seq_1hot_chunk_corrupted + target_slice_chunk = np.repeat(np.copy(target_slice_chunk), n_samples, axis=0) + pos_slice_chunk = np.repeat(np.copy(pos_slice_chunk), n_samples, axis=0) + + if pos_mask is not None : + pos_mask_chunk = np.repeat(np.copy(pos_mask_chunk), n_samples, axis=0) + + if use_ratio and pos_slice_denom is not None : + pos_slice_denom_chunk = np.repeat(np.copy(pos_slice_denom_chunk), n_samples, axis=0) + + if pos_mask_denom is not None : + pos_mask_denom_chunk = np.repeat(np.copy(pos_mask_denom_chunk), n_samples, axis=0) + + # convert to tf tensors + seq_1hot_chunk = tf.convert_to_tensor(seq_1hot_chunk, dtype=tf.float32) + target_slice_chunk = tf.convert_to_tensor(target_slice_chunk, dtype=tf.int32) + pos_slice_chunk = tf.convert_to_tensor(pos_slice_chunk, dtype=tf.int32) + + if pos_mask is not None : + pos_mask_chunk = tf.convert_to_tensor(pos_mask_chunk, dtype=tf.float32) + + if use_ratio and pos_slice_denom is not None : + pos_slice_denom_chunk = tf.convert_to_tensor(pos_slice_denom_chunk, dtype=tf.int32) + + if pos_mask_denom is not None : + pos_mask_denom_chunk = tf.convert_to_tensor(pos_mask_denom_chunk, dtype=tf.float32) + + # batching parameters + num_batches = int(np.ceil(actual_chunk_size * (n_samples if smooth_grad else 1) / batch_size)) + + # loop over batches + grad_batches = [] + for bi in range(num_batches) : + + # collect batch + seq_1hot_batch = seq_1hot_chunk[bi * batch_size:(bi+1) * batch_size, ...] + target_slice_batch = target_slice_chunk[bi * batch_size:(bi+1) * batch_size, ...] + pos_slice_batch = pos_slice_chunk[bi * batch_size:(bi+1) * batch_size, ...] + + pos_mask_batch = None + if pos_mask is not None : + pos_mask_batch = pos_mask_chunk[bi * batch_size:(bi+1) * batch_size, ...] + + pos_slice_denom_batch = None + pos_mask_denom_batch = None + if use_ratio and pos_slice_denom is not None : + pos_slice_denom_batch = pos_slice_denom_chunk[bi * batch_size:(bi+1) * batch_size, ...] + + if pos_mask_denom is not None : + pos_mask_denom_batch = pos_mask_denom_chunk[bi * batch_size:(bi+1) * batch_size, ...] + + grad_batch = self.gradients_func(model, seq_1hot_batch, target_slice_batch, pos_slice_batch, pos_mask_batch, pos_slice_denom_batch, pos_mask_denom_batch, track_scale, track_transform, clip_soft, pseudo_count, no_transform, use_mean, use_ratio, use_logodds, subtract_avg, input_gate).numpy().astype(dtype) + + grad_batches.append(grad_batch) + + # concat gradient batches + grads = np.concatenate(grad_batches, axis=0) + + # aggregate noisy gradient perturbations + if smooth_grad : + grads_smoothed = np.zeros((grads.shape[0] // n_samples, grads.shape[1], grads.shape[2]), dtype='float32') + + for example_ix in range(grads_smoothed.shape[0]) : + for sample_ix in range(n_samples) : + grads_smoothed[example_ix, ...] += grads[example_ix * n_samples + sample_ix, ...] + + grads = grads_smoothed / float(n_samples) + grads = grads.astype(dtype) + + grad_chunks.append(grads) + + # collect garbage + gc.collect() + + # concat gradient chunks + grads = np.concatenate(grad_chunks, axis=0) + + # aggregate and broadcast to original input pattern + if input_gate : + grads = np.sum(grads, axis=-1, keepdims=True) * seq_1hot + + print('Completed gradient computation in %ds' % (time.time()-t0)) + + return grads + + @tf.function + def gradients_func(self, model, seq_1hot, target_slice, pos_slice, pos_mask=None, pos_slice_denom=None, pos_mask_denom=True, track_scale=1., track_transform=1., clip_soft=None, pseudo_count=0., no_transform=False, use_mean=False, use_ratio=False, use_logodds=False, subtract_avg=True, input_gate=True): + + with tf.GradientTape() as tape: + tape.watch(seq_1hot) + + # predict + preds = tf.gather(model(seq_1hot, training=False), target_slice, axis=-1, batch_dims=1) + + if not no_transform : + + # undo scale + preds = preds / track_scale + + # undo soft_clip + if clip_soft is not None : + preds = tf.where(preds > clip_soft, (preds - clip_soft)**2 + clip_soft, preds) + + # undo sqrt + preds = preds**(1. / track_transform) + + # aggregate over tracks (average) + preds = tf.reduce_mean(preds, axis=-1) + + # slice specified positions + preds_slice = tf.gather(preds, pos_slice, axis=-1, batch_dims=1) + if pos_mask is not None : + preds_slice = preds_slice * pos_mask + + # slice denominator positions + if use_ratio and pos_slice_denom is not None: + preds_slice_denom = tf.gather(preds, pos_slice_denom, axis=-1, batch_dims=1) + if pos_mask_denom is not None : + preds_slice_denom = preds_slice_denom * pos_mask_denom + + # aggregate over positions + if not use_mean : + preds_agg = tf.reduce_sum(preds_slice, axis=-1) + if use_ratio and pos_slice_denom is not None: + preds_agg_denom = tf.reduce_sum(preds_slice_denom, axis=-1) + else : + if pos_mask is not None : + preds_agg = tf.reduce_sum(preds_slice, axis=-1) / tf.reduce_sum(pos_mask, axis=-1) + else : + preds_agg = tf.reduce_mean(preds_slice, axis=-1) + + if use_ratio and pos_slice_denom is not None: + if pos_mask_denom is not None : + preds_agg_denom = tf.reduce_sum(preds_slice_denom, axis=-1) / tf.reduce_sum(pos_mask_denom, axis=-1) + else : + preds_agg_denom = tf.reduce_mean(preds_slice_denom, axis=-1) + + # compute final statistic to take gradient of + if no_transform : + score_ratios = preds_agg + elif not use_ratio : + score_ratios = tf.math.log(preds_agg + pseudo_count + 1e-6) + else : + if not use_logodds : + score_ratios = tf.math.log((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count) + 1e-6) + else : + score_ratios = tf.math.log(((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count)) / (1. - ((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count))) + 1e-6) + + # compute gradient + grads = tape.gradient(score_ratios, seq_1hot) + + # zero mean each position + if subtract_avg : + grads = grads - tf.reduce_mean(grads, axis=-1, keepdims=True) + + # multiply by input + if input_gate : + grads = grads * seq_1hot + + return grads - def gradients(self, seq_1hot, head_i=None, pos_slice=None, batch_size=8, dtype='float16'): - """ Compute input gradients sequence. """ + def gradients_orig(self, seq_1hot, head_i=None, pos_slice=None, batch_size=2, dtype='float16'): + """ Compute input gradients sequence (original version of code). """ # choose model if self.ensemble is not None: model = self.ensemble @@ -443,7 +699,7 @@ def gradients(self, seq_1hot, head_i=None, pos_slice=None, batch_size=8, dtype=' # grads_batch = grads_batch - tf.reduce_mean(grads_batch, axis=-2, keepdims=True) - grads_batch = self.gradients_func(model_batch, seq_1hot, pos_slice) + grads_batch = self.gradients_func_orig(model_batch, seq_1hot, pos_slice) print('Batch gradient computation in %ds' % (time.time()-t0)) # convert numpy dtype @@ -459,7 +715,7 @@ def gradients(self, seq_1hot, head_i=None, pos_slice=None, batch_size=8, dtype=' return grads @tf.function - def gradients_func(self, model, seq_1hot, pos_slice): + def gradients_func_orig(self, model, seq_1hot, pos_slice): with tf.GradientTape() as tape: tape.watch(seq_1hot) diff --git a/bin/borzoi_bench_ipaqtl_folds.py b/bin/borzoi_bench_ipaqtl_folds.py new file mode 100644 index 00000000..c540c2ff --- /dev/null +++ b/bin/borzoi_bench_ipaqtl_folds.py @@ -0,0 +1,479 @@ +#!/usr/bin/env python +# Copyright 2019 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser, OptionGroup +import glob +import h5py +import json +import pdb +import os +import shutil +import sys + +import numpy as np +import pandas as pd + +import slurm +#import util + +from basenji_test_folds import stat_tests + +""" +borzoi_bench_ipaqtl_folds.py + +Benchmark Basenji model replicates on GTEx ipaQTL classification task. +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + + # sed + sed_options = OptionGroup(parser, 'borzoi_sed_ipaqtl_cov.py options') + sed_options.add_option('-f', dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]') + sed_options.add_option('-g', dest='genes_gtf', + default='%s/genes/gencode41/gencode41_basic_nort.gtf' % os.environ['HG38'], + help='GTF for gene definition [Default %default]') + sed_options.add_option('--apafile', dest='apa_file', + default='polyadb_human_v3.csv.gz') + sed_options.add_option('-o',dest='out_dir', + default='ipaqtl', + help='Output directory for tables and plots [Default: %default]') + sed_options.add_option('-p', dest='processes', + default=None, type='int', + help='Number of processes, passed by multi script') + sed_options.add_option('--pseudo', dest='cov_pseudo', + default=50, type='float', + help='Coverage pseudocount [Default: %default]') + sed_options.add_option('--cov', dest='cov_min', + default=100, type='float', + help='Coverage pseudocount [Default: %default]') + sed_options.add_option('--rc', dest='rc', + default=False, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + sed_options.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + sed_options.add_option('--stats', dest='sed_stats', + default='REF,ALT', + help='Comma-separated list of stats to save. [Default: %default]') + sed_options.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + parser.add_option_group(sed_options) + + # classify + class_options = OptionGroup(parser, 'basenji_bench_classify.py options') + class_options.add_option('--msl', dest='msl', + default=1, type='int', + help='Random forest min_samples_leaf [Default: %default]') + parser.add_option_group(class_options) + + # cross-fold + fold_options = OptionGroup(parser, 'cross-fold options') + fold_options.add_option('-c', dest='crosses', + default=1, type='int', + help='Number of cross-fold rounds [Default:%default]') + fold_options.add_option('-d', dest='data_head', + default=None, type='int', + help='Index for dataset/head [Default: %default]') + fold_options.add_option('-e', dest='conda_env', + default='tf210', + help='Anaconda environment [Default: %default]') + fold_options.add_option('--name', dest='name', + default='ipaqtl', help='SLURM name prefix [Default: %default]') + fold_options.add_option('--max_proc', dest='max_proc', + default=None, type='int', + help='Maximum concurrent processes [Default: %default]') + fold_options.add_option('-q', dest='queue', + default='geforce', + help='SLURM queue on which to run the jobs [Default: %default]') + fold_options.add_option('-r', dest='restart', + default=False, action='store_true', + help='Restart a partially completed job [Default: %default]') + fold_options.add_option('--vcf', dest='vcf_dir', + default='/home/jlinder/seqnn/data/qtl_cat/ipaqtl_pip90ea') + parser.add_option_group(fold_options) + + (options, args) = parser.parse_args() + + if len(args) != 2: + print(len(args)) + print(args) + parser.error('Must provide parameters file and cross-fold directory') + else: + params_file = args[0] + exp_dir = args[1] + + ####################################################### + # prep work + + # count folds + num_folds = 0 + fold0_dir = '%s/f%dc0' % (exp_dir, num_folds) + model_file = '%s/train/model_best.h5' % fold0_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (fold0_dir, options.data_head) + while os.path.isfile(model_file): + num_folds += 1 + fold0_dir = '%s/f%dc0' % (exp_dir, num_folds) + model_file = '%s/train/model_best.h5' % fold0_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (fold0_dir, options.data_head) + print('Found %d folds' % num_folds) + if num_folds == 0: + exit(1) + + sed_stats = options.sed_stats.split(',') + + # merge study/tissue variants + mpos_vcf_file = '%s/pos_merge.vcf' % options.vcf_dir + mneg_vcf_file = '%s/neg_merge.vcf' % options.vcf_dir + + ################################################################ + # SNP scores + + # command base + cmd_base = '. /home/drk/anaconda3/etc/profile.d/conda.sh;' + cmd_base += ' conda activate %s;' % options.conda_env + cmd_base += ' echo $HOSTNAME;' + + jobs = [] + + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + name = '%s-f%dc%d' % (options.name, fi, ci) + + # update output directory + it_out_dir = '%s/%s' % (it_dir, options.out_dir) + os.makedirs(it_out_dir, exist_ok=True) + + model_file = '%s/train/model_best.h5' % it_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (it_dir, options.data_head) + + cmd_fold = '%s time borzoi_sed_ipaqtl_cov.py %s %s' % (cmd_base, params_file, model_file) + + # positive job + job_out_dir = '%s/merge_pos' % it_out_dir + if not options.restart or not os.path.isfile('%s/sed.h5'%job_out_dir): + cmd_job = '%s %s' % (cmd_fold, mpos_vcf_file) + cmd_job += ' %s' % options_string(options, sed_options, job_out_dir) + j = slurm.Job(cmd_job, '%s_pos' % name, + '%s.out'%job_out_dir, '%s.err'%job_out_dir, + queue=options.queue, gpu=1, + mem=30000, time='7-0:0:0') + jobs.append(j) + + # negative job + job_out_dir = '%s/merge_neg' % it_out_dir + if not options.restart or not os.path.isfile('%s/sed.h5'%job_out_dir): + cmd_job = '%s %s' % (cmd_fold, mneg_vcf_file) + cmd_job += ' %s' % options_string(options, sed_options, job_out_dir) + j = slurm.Job(cmd_job, '%s_neg' % name, + '%s.out'%job_out_dir, '%s.err'%job_out_dir, + queue=options.queue, gpu=1, + mem=30000, time='7-0:0:0') + jobs.append(j) + + slurm.multi_run(jobs, max_proc=options.max_proc, verbose=True, + launch_sleep=10, update_sleep=60) + + ################################################################ + # split study/tissue variants + + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + it_out_dir = '%s/%s' % (it_dir, options.out_dir) + + # split positives + split_sed(it_out_dir, 'pos', options.vcf_dir, sed_stats) + + # split negatives + split_sed(it_out_dir, 'neg', options.vcf_dir, sed_stats) + + ################################################################ + # ensemble + + ensemble_dir = '%s/ensemble' % exp_dir + if not os.path.isdir(ensemble_dir): + os.mkdir(ensemble_dir) + + sqtl_dir = '%s/%s' % (ensemble_dir, options.out_dir) + if not os.path.isdir(sqtl_dir): + os.mkdir(sqtl_dir) + + for pos_vcf in glob.glob('%s/*_pos.vcf' % options.vcf_dir): + neg_vcf = pos_vcf.replace('_pos.','_neg.') + pos_base = os.path.splitext(os.path.split(pos_vcf)[1])[0] + neg_base = os.path.splitext(os.path.split(neg_vcf)[1])[0] + + # collect SED files + sed_pos_files = [] + sed_neg_files = [] + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + it_out_dir = '%s/%s' % (it_dir, options.out_dir) + + sed_pos_file = '%s/%s/sed.h5' % (it_out_dir, pos_base) + sed_pos_files.append(sed_pos_file) + + sed_neg_file = '%s/%s/sed.h5' % (it_out_dir, neg_base) + sed_neg_files.append(sed_neg_file) + + # ensemble + ens_pos_dir = '%s/%s' % (sqtl_dir, pos_base) + os.makedirs(ens_pos_dir, exist_ok=True) + ens_pos_file = '%s/sed.h5' % (ens_pos_dir) + if not os.path.isfile(ens_pos_file): + ensemble_sed_h5(ens_pos_file, sed_pos_files, sed_stats) + + ens_neg_dir = '%s/%s' % (sqtl_dir, neg_base) + os.makedirs(ens_neg_dir, exist_ok=True) + ens_neg_file = '%s/sed.h5' % (ens_neg_dir) + if not os.path.isfile(ens_neg_file): + ensemble_sed_h5(ens_neg_file, sed_neg_files, sed_stats) + + ################################################################ + # fit classifiers + + cmd_base = 'basenji_bench_classify.py -i 100 -p 2 -r 44 -s --stat COVR' + cmd_base += ' --msl %d' % options.msl + + jobs = [] + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + it_out_dir = '%s/%s' % (it_dir, options.out_dir) + + for sqtl_pos_vcf in glob.glob('%s/*_pos.vcf' % options.vcf_dir): + tissue = os.path.splitext(os.path.split(sqtl_pos_vcf)[1])[0][:-4] + sed_pos = '%s/%s_pos/sed.h5' % (it_out_dir, tissue) + sed_neg = '%s/%s_neg/sed.h5' % (it_out_dir, tissue) + class_out_dir = '%s/%s_class' % (it_out_dir, tissue) + + if not options.restart or not os.path.isfile('%s/stats.txt' % class_out_dir): + cmd_class = '%s -o %s %s %s' % (cmd_base, class_out_dir, sed_pos, sed_neg) + j = slurm.Job(cmd_class, tissue, + '%s.out'%class_out_dir, '%s.err'%class_out_dir, + queue='standard', cpu=2, + mem=22000, time='1-0:0:0') + jobs.append(j) + + # ensemble + for sqtl_pos_vcf in glob.glob('%s/*_pos.vcf' % options.vcf_dir): + tissue = os.path.splitext(os.path.split(sqtl_pos_vcf)[1])[0][:-4] + sed_pos = '%s/%s_pos/sed.h5' % (sqtl_dir, tissue) + sed_neg = '%s/%s_neg/sed.h5' % (sqtl_dir, tissue) + class_out_dir = '%s/%s_class' % (sqtl_dir, tissue) + + if not options.restart or not os.path.isfile('%s/stats.txt' % class_out_dir): + cmd_class = '%s -o %s %s %s' % (cmd_base, class_out_dir, sed_pos, sed_neg) + j = slurm.Job(cmd_class, tissue, + '%s.out'%class_out_dir, '%s.err'%class_out_dir, + queue='standard', cpu=2, + mem=22000, time='1-0:0:0') + jobs.append(j) + + slurm.multi_run(jobs, verbose=True) + + +def complete_h5(h5_file, sed_stats): + if os.path.isfile(h5_file): + try: + with h5py.File(h5_file, 'r') as h5_open: + for ss in sed_stats: + sed = h5_open[ss][:] + if (sed != 0).sum() == 0: + return False + return True + except: + return False + else: + return False + + +def ensemble_sed_h5(ensemble_h5_file, scores_files, sed_stats): + # open ensemble + ensemble_h5 = h5py.File(ensemble_h5_file, 'w') + + # transfer base + sed_shapes = {} + with h5py.File(scores_files[0], 'r') as scores0_h5: + for key in scores0_h5.keys(): + if key not in sed_stats: + ensemble_h5.create_dataset(key, data=scores0_h5[key]) + else: + sed_shapes[key] = scores0_h5[key].shape + + # average stats + num_folds = len(scores_files) + for si, sed_stat in enumerate(sed_stats): + # initialize ensemble array + sed_values = np.zeros(shape=sed_shapes[sed_stat], dtype='float32') + + # read and add folds + for scores_file in scores_files: + with h5py.File(scores_file, 'r') as scores_h5: + sed_values += scores_h5[sed_stat][:].astype('float32') + + # normalize and downcast + sed_values /= num_folds + sed_values = sed_values.astype('float16') + + # save + ensemble_h5.create_dataset(sed_stat, data=sed_values) + + ensemble_h5.close() + + +def options_string(options, group_options, rep_dir): + options_str = '' + + for opt in group_options.option_list: + opt_str = opt.get_opt_string() + opt_value = options.__dict__[opt.dest] + + # wrap askeriks in "" + if type(opt_value) == str and opt_value.find('*') != -1: + opt_value = '"%s"' % opt_value + + # no value for bools + elif type(opt_value) == bool: + if not opt_value: + opt_str = '' + opt_value = '' + + # skip Nones + elif opt_value is None: + opt_str = '' + opt_value = '' + + # modify + elif opt.dest == 'out_dir': + opt_value = rep_dir + + options_str += ' %s %s' % (opt_str, opt_value) + + return options_str + + +def split_sed(it_out_dir, posneg, vcf_dir, sed_stats): + """Split merged VCF predictions in HDF5 into tissue-specific + predictions in HDF5, aggregating statistics over genes.""" + + merge_h5_file = '%s/merge_%s/sed.h5' % (it_out_dir, posneg) + merge_h5 = h5py.File(merge_h5_file, 'r') + + # hash snp indexes + snp_i = {} + for i in range(merge_h5['snp'].shape[0]): + snp = merge_h5['snp'][i].decode('UTF-8') + snp_i.setdefault(snp,[]).append(i) + + # for each tissue VCF + vcf_glob = '%s/*_%s.vcf' % (vcf_dir, posneg) + for tissue_vcf_file in glob.glob(vcf_glob): + tissue_label = tissue_vcf_file.split('/')[-1] + tissue_label = tissue_label.replace('_pos.vcf','') + tissue_label = tissue_label.replace('_neg.vcf','') + + # initialize HDF5 arrays + sed_si = [] + sed_snp = [] + sed_chr = [] + sed_pos = [] + sed_ref = [] + sed_alt = [] + sed_scores = {} + for ss in sed_stats: + sed_scores[ss] = [] + + # fill HDF5 arrays with ordered SNPs + for line in open(tissue_vcf_file): + if not line.startswith('#'): + snp = line.split()[2] + i0 = snp_i[snp][0] + sed_si.append(merge_h5['si'][i0]) + sed_snp.append(merge_h5['snp'][i0]) + sed_chr.append(merge_h5['chr'][i0]) + sed_pos.append(merge_h5['pos'][i0]) + sed_ref.append(merge_h5['ref_allele'][i0]) + sed_alt.append(merge_h5['alt_allele'][i0]) + + for ss in sed_stats: + # take max over each gene + # (may not be appropriate for all stats!) + sed_scores_si = np.array([merge_h5[ss][i] for i in snp_i[snp]]) + sed_scores[ss].append(sed_scores_si.max(axis=0)) + + # write tissue HDF5 + tissue_dir = '%s/%s_%s' % (it_out_dir, tissue_label, posneg) + os.makedirs(tissue_dir, exist_ok=True) + with h5py.File('%s/sed.h5' % tissue_dir, 'w') as tissue_h5: + + # write SNP indexes + tissue_h5.create_dataset('si', + data=np.array(sed_si, dtype='uint32')) + + # write genes + # tissue_h5.create_dataset('gene', + # data=np.array(sed_gene, 'S')) + + # write SNPs + tissue_h5.create_dataset('snp', + data=np.array(sed_snp, 'S')) + + # write SNP chr + tissue_h5.create_dataset('chr', + data=np.array(sed_chr, 'S')) + + # write SNP pos + tissue_h5.create_dataset('pos', + data=np.array(sed_pos, dtype='uint32')) + + # write ref allele + tissue_h5.create_dataset('ref_allele', + data=np.array(sed_ref, dtype='S')) + + # write alt allele + tissue_h5.create_dataset('alt_allele', + data=np.array(sed_alt, dtype='S')) + + # write targets + tissue_h5.create_dataset('target_ids', data=merge_h5['target_ids']) + tissue_h5.create_dataset('target_labels', data=merge_h5['target_labels']) + + # write sed stats + for ss in sed_stats: + tissue_h5.create_dataset(ss, + data=np.array(sed_scores[ss], dtype='float16')) + + merge_h5.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/bin/borzoi_bench_paqtl_folds.py b/bin/borzoi_bench_paqtl_folds.py new file mode 100644 index 00000000..330503cc --- /dev/null +++ b/bin/borzoi_bench_paqtl_folds.py @@ -0,0 +1,479 @@ +#!/usr/bin/env python +# Copyright 2019 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser, OptionGroup +import glob +import h5py +import json +import pdb +import os +import shutil +import sys + +import numpy as np +import pandas as pd + +import slurm +#import util + +from basenji_test_folds import stat_tests + +""" +borzoi_bench_paqtl_folds.py + +Benchmark Basenji model replicates on GTEx paQTL classification task. +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + + # sed + sed_options = OptionGroup(parser, 'borzoi_sed_paqtl_cov.py options') + sed_options.add_option('-f', dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]') + sed_options.add_option('-g', dest='genes_gtf', + default='%s/genes/gencode41/gencode41_basic_nort.gtf' % os.environ['HG38'], + help='GTF for gene definition [Default %default]') + sed_options.add_option('--apafile', dest='apa_file', + default='polyadb_human_v3.csv.gz') + sed_options.add_option('-o',dest='out_dir', + default='paqtl', + help='Output directory for tables and plots [Default: %default]') + sed_options.add_option('-p', dest='processes', + default=None, type='int', + help='Number of processes, passed by multi script') + sed_options.add_option('--pseudo', dest='cov_pseudo', + default=50, type='float', + help='Coverage pseudocount [Default: %default]') + sed_options.add_option('--cov', dest='cov_min', + default=100, type='float', + help='Coverage pseudocount [Default: %default]') + sed_options.add_option('--rc', dest='rc', + default=False, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + sed_options.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + sed_options.add_option('--stats', dest='sed_stats', + default='REF,ALT', + help='Comma-separated list of stats to save. [Default: %default]') + sed_options.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + parser.add_option_group(sed_options) + + # classify + class_options = OptionGroup(parser, 'basenji_bench_classify.py options') + class_options.add_option('--msl', dest='msl', + default=1, type='int', + help='Random forest min_samples_leaf [Default: %default]') + parser.add_option_group(class_options) + + # cross-fold + fold_options = OptionGroup(parser, 'cross-fold options') + fold_options.add_option('-c', dest='crosses', + default=1, type='int', + help='Number of cross-fold rounds [Default:%default]') + fold_options.add_option('-d', dest='data_head', + default=None, type='int', + help='Index for dataset/head [Default: %default]') + fold_options.add_option('-e', dest='conda_env', + default='tf210', + help='Anaconda environment [Default: %default]') + fold_options.add_option('--name', dest='name', + default='paqtl', help='SLURM name prefix [Default: %default]') + fold_options.add_option('--max_proc', dest='max_proc', + default=None, type='int', + help='Maximum concurrent processes [Default: %default]') + fold_options.add_option('-q', dest='queue', + default='geforce', + help='SLURM queue on which to run the jobs [Default: %default]') + fold_options.add_option('-r', dest='restart', + default=False, action='store_true', + help='Restart a partially completed job [Default: %default]') + fold_options.add_option('--vcf', dest='vcf_dir', + default='/home/jlinder/seqnn/data/qtl_cat/paqtl_pip90ea') + parser.add_option_group(fold_options) + + (options, args) = parser.parse_args() + + if len(args) != 2: + print(len(args)) + print(args) + parser.error('Must provide parameters file and cross-fold directory') + else: + params_file = args[0] + exp_dir = args[1] + + ####################################################### + # prep work + + # count folds + num_folds = 0 + fold0_dir = '%s/f%dc0' % (exp_dir, num_folds) + model_file = '%s/train/model_best.h5' % fold0_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (fold0_dir, options.data_head) + while os.path.isfile(model_file): + num_folds += 1 + fold0_dir = '%s/f%dc0' % (exp_dir, num_folds) + model_file = '%s/train/model_best.h5' % fold0_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (fold0_dir, options.data_head) + print('Found %d folds' % num_folds) + if num_folds == 0: + exit(1) + + sed_stats = options.sed_stats.split(',') + + # merge study/tissue variants + mpos_vcf_file = '%s/pos_merge.vcf' % options.vcf_dir + mneg_vcf_file = '%s/neg_merge.vcf' % options.vcf_dir + + ################################################################ + # SNP scores + + # command base + cmd_base = '. /home/drk/anaconda3/etc/profile.d/conda.sh;' + cmd_base += ' conda activate %s;' % options.conda_env + cmd_base += ' echo $HOSTNAME;' + + jobs = [] + + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + name = '%s-f%dc%d' % (options.name, fi, ci) + + # update output directory + it_out_dir = '%s/%s' % (it_dir, options.out_dir) + os.makedirs(it_out_dir, exist_ok=True) + + model_file = '%s/train/model_best.h5' % it_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (it_dir, options.data_head) + + cmd_fold = '%s time borzoi_sed_paqtl_cov.py %s %s' % (cmd_base, params_file, model_file) + + # positive job + job_out_dir = '%s/merge_pos' % it_out_dir + if not options.restart or not os.path.isfile('%s/sed.h5'%job_out_dir): + cmd_job = '%s %s' % (cmd_fold, mpos_vcf_file) + cmd_job += ' %s' % options_string(options, sed_options, job_out_dir) + j = slurm.Job(cmd_job, '%s_pos' % name, + '%s.out'%job_out_dir, '%s.err'%job_out_dir, + queue=options.queue, gpu=1, + mem=30000, time='7-0:0:0') + jobs.append(j) + + # negative job + job_out_dir = '%s/merge_neg' % it_out_dir + if not options.restart or not os.path.isfile('%s/sed.h5'%job_out_dir): + cmd_job = '%s %s' % (cmd_fold, mneg_vcf_file) + cmd_job += ' %s' % options_string(options, sed_options, job_out_dir) + j = slurm.Job(cmd_job, '%s_neg' % name, + '%s.out'%job_out_dir, '%s.err'%job_out_dir, + queue=options.queue, gpu=1, + mem=30000, time='7-0:0:0') + jobs.append(j) + + slurm.multi_run(jobs, max_proc=options.max_proc, verbose=True, + launch_sleep=10, update_sleep=60) + + ################################################################ + # split study/tissue variants + + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + it_out_dir = '%s/%s' % (it_dir, options.out_dir) + + # split positives + split_sed(it_out_dir, 'pos', options.vcf_dir, sed_stats) + + # split negatives + split_sed(it_out_dir, 'neg', options.vcf_dir, sed_stats) + + ################################################################ + # ensemble + + ensemble_dir = '%s/ensemble' % exp_dir + if not os.path.isdir(ensemble_dir): + os.mkdir(ensemble_dir) + + sqtl_dir = '%s/%s' % (ensemble_dir, options.out_dir) + if not os.path.isdir(sqtl_dir): + os.mkdir(sqtl_dir) + + for pos_vcf in glob.glob('%s/*_pos.vcf' % options.vcf_dir): + neg_vcf = pos_vcf.replace('_pos.','_neg.') + pos_base = os.path.splitext(os.path.split(pos_vcf)[1])[0] + neg_base = os.path.splitext(os.path.split(neg_vcf)[1])[0] + + # collect SED files + sed_pos_files = [] + sed_neg_files = [] + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + it_out_dir = '%s/%s' % (it_dir, options.out_dir) + + sed_pos_file = '%s/%s/sed.h5' % (it_out_dir, pos_base) + sed_pos_files.append(sed_pos_file) + + sed_neg_file = '%s/%s/sed.h5' % (it_out_dir, neg_base) + sed_neg_files.append(sed_neg_file) + + # ensemble + ens_pos_dir = '%s/%s' % (sqtl_dir, pos_base) + os.makedirs(ens_pos_dir, exist_ok=True) + ens_pos_file = '%s/sed.h5' % (ens_pos_dir) + if not os.path.isfile(ens_pos_file): + ensemble_sed_h5(ens_pos_file, sed_pos_files, sed_stats) + + ens_neg_dir = '%s/%s' % (sqtl_dir, neg_base) + os.makedirs(ens_neg_dir, exist_ok=True) + ens_neg_file = '%s/sed.h5' % (ens_neg_dir) + if not os.path.isfile(ens_neg_file): + ensemble_sed_h5(ens_neg_file, sed_neg_files, sed_stats) + + ################################################################ + # fit classifiers + + cmd_base = 'basenji_bench_classify.py -i 100 -p 2 -r 44 -s --stat COVR' + cmd_base += ' --msl %d' % options.msl + + jobs = [] + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + it_out_dir = '%s/%s' % (it_dir, options.out_dir) + + for sqtl_pos_vcf in glob.glob('%s/*_pos.vcf' % options.vcf_dir): + tissue = os.path.splitext(os.path.split(sqtl_pos_vcf)[1])[0][:-4] + sed_pos = '%s/%s_pos/sed.h5' % (it_out_dir, tissue) + sed_neg = '%s/%s_neg/sed.h5' % (it_out_dir, tissue) + class_out_dir = '%s/%s_class' % (it_out_dir, tissue) + + if not options.restart or not os.path.isfile('%s/stats.txt' % class_out_dir): + cmd_class = '%s -o %s %s %s' % (cmd_base, class_out_dir, sed_pos, sed_neg) + j = slurm.Job(cmd_class, tissue, + '%s.out'%class_out_dir, '%s.err'%class_out_dir, + queue='standard', cpu=2, + mem=22000, time='1-0:0:0') + jobs.append(j) + + # ensemble + for sqtl_pos_vcf in glob.glob('%s/*_pos.vcf' % options.vcf_dir): + tissue = os.path.splitext(os.path.split(sqtl_pos_vcf)[1])[0][:-4] + sed_pos = '%s/%s_pos/sed.h5' % (sqtl_dir, tissue) + sed_neg = '%s/%s_neg/sed.h5' % (sqtl_dir, tissue) + class_out_dir = '%s/%s_class' % (sqtl_dir, tissue) + + if not options.restart or not os.path.isfile('%s/stats.txt' % class_out_dir): + cmd_class = '%s -o %s %s %s' % (cmd_base, class_out_dir, sed_pos, sed_neg) + j = slurm.Job(cmd_class, tissue, + '%s.out'%class_out_dir, '%s.err'%class_out_dir, + queue='standard', cpu=2, + mem=22000, time='1-0:0:0') + jobs.append(j) + + slurm.multi_run(jobs, verbose=True) + + +def complete_h5(h5_file, sed_stats): + if os.path.isfile(h5_file): + try: + with h5py.File(h5_file, 'r') as h5_open: + for ss in sed_stats: + sed = h5_open[ss][:] + if (sed != 0).sum() == 0: + return False + return True + except: + return False + else: + return False + + +def ensemble_sed_h5(ensemble_h5_file, scores_files, sed_stats): + # open ensemble + ensemble_h5 = h5py.File(ensemble_h5_file, 'w') + + # transfer base + sed_shapes = {} + with h5py.File(scores_files[0], 'r') as scores0_h5: + for key in scores0_h5.keys(): + if key not in sed_stats: + ensemble_h5.create_dataset(key, data=scores0_h5[key]) + else: + sed_shapes[key] = scores0_h5[key].shape + + # average stats + num_folds = len(scores_files) + for si, sed_stat in enumerate(sed_stats): + # initialize ensemble array + sed_values = np.zeros(shape=sed_shapes[sed_stat], dtype='float32') + + # read and add folds + for scores_file in scores_files: + with h5py.File(scores_file, 'r') as scores_h5: + sed_values += scores_h5[sed_stat][:].astype('float32') + + # normalize and downcast + sed_values /= num_folds + sed_values = sed_values.astype('float16') + + # save + ensemble_h5.create_dataset(sed_stat, data=sed_values) + + ensemble_h5.close() + + +def options_string(options, group_options, rep_dir): + options_str = '' + + for opt in group_options.option_list: + opt_str = opt.get_opt_string() + opt_value = options.__dict__[opt.dest] + + # wrap askeriks in "" + if type(opt_value) == str and opt_value.find('*') != -1: + opt_value = '"%s"' % opt_value + + # no value for bools + elif type(opt_value) == bool: + if not opt_value: + opt_str = '' + opt_value = '' + + # skip Nones + elif opt_value is None: + opt_str = '' + opt_value = '' + + # modify + elif opt.dest == 'out_dir': + opt_value = rep_dir + + options_str += ' %s %s' % (opt_str, opt_value) + + return options_str + + +def split_sed(it_out_dir, posneg, vcf_dir, sed_stats): + """Split merged VCF predictions in HDF5 into tissue-specific + predictions in HDF5, aggregating statistics over genes.""" + + merge_h5_file = '%s/merge_%s/sed.h5' % (it_out_dir, posneg) + merge_h5 = h5py.File(merge_h5_file, 'r') + + # hash snp indexes + snp_i = {} + for i in range(merge_h5['snp'].shape[0]): + snp = merge_h5['snp'][i].decode('UTF-8') + snp_i.setdefault(snp,[]).append(i) + + # for each tissue VCF + vcf_glob = '%s/*_%s.vcf' % (vcf_dir, posneg) + for tissue_vcf_file in glob.glob(vcf_glob): + tissue_label = tissue_vcf_file.split('/')[-1] + tissue_label = tissue_label.replace('_pos.vcf','') + tissue_label = tissue_label.replace('_neg.vcf','') + + # initialize HDF5 arrays + sed_si = [] + sed_snp = [] + sed_chr = [] + sed_pos = [] + sed_ref = [] + sed_alt = [] + sed_scores = {} + for ss in sed_stats: + sed_scores[ss] = [] + + # fill HDF5 arrays with ordered SNPs + for line in open(tissue_vcf_file): + if not line.startswith('#'): + snp = line.split()[2] + i0 = snp_i[snp][0] + sed_si.append(merge_h5['si'][i0]) + sed_snp.append(merge_h5['snp'][i0]) + sed_chr.append(merge_h5['chr'][i0]) + sed_pos.append(merge_h5['pos'][i0]) + sed_ref.append(merge_h5['ref_allele'][i0]) + sed_alt.append(merge_h5['alt_allele'][i0]) + + for ss in sed_stats: + # take max over each gene + # (may not be appropriate for all stats!) + sed_scores_si = np.array([merge_h5[ss][i] for i in snp_i[snp]]) + sed_scores[ss].append(sed_scores_si.max(axis=0)) + + # write tissue HDF5 + tissue_dir = '%s/%s_%s' % (it_out_dir, tissue_label, posneg) + os.makedirs(tissue_dir, exist_ok=True) + with h5py.File('%s/sed.h5' % tissue_dir, 'w') as tissue_h5: + + # write SNP indexes + tissue_h5.create_dataset('si', + data=np.array(sed_si, dtype='uint32')) + + # write genes + # tissue_h5.create_dataset('gene', + # data=np.array(sed_gene, 'S')) + + # write SNPs + tissue_h5.create_dataset('snp', + data=np.array(sed_snp, 'S')) + + # write SNP chr + tissue_h5.create_dataset('chr', + data=np.array(sed_chr, 'S')) + + # write SNP pos + tissue_h5.create_dataset('pos', + data=np.array(sed_pos, dtype='uint32')) + + # write ref allele + tissue_h5.create_dataset('ref_allele', + data=np.array(sed_ref, dtype='S')) + + # write alt allele + tissue_h5.create_dataset('alt_allele', + data=np.array(sed_alt, dtype='S')) + + # write targets + tissue_h5.create_dataset('target_ids', data=merge_h5['target_ids']) + tissue_h5.create_dataset('target_labels', data=merge_h5['target_labels']) + + # write sed stats + for ss in sed_stats: + tissue_h5.create_dataset(ss, + data=np.array(sed_scores[ss], dtype='float16')) + + merge_h5.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/bin/borzoi_bench_trip_folds.py b/bin/borzoi_bench_trip_folds.py new file mode 100644 index 00000000..2c24e638 --- /dev/null +++ b/bin/borzoi_bench_trip_folds.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python +# Copyright 2019 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser, OptionGroup +import glob +import h5py +import json +import pdb +import os +import shutil +import sys + +import numpy as np +import pandas as pd + +import slurm +#import util + +from basenji_test_folds import stat_tests + +""" +borzoi_borzoi_trip_folds.py + +Benchmark Basenji model replicates on TRIP prediction task. +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + + # trip + trip_options = OptionGroup(parser, 'borzoi_trip.py options') + trip_options.add_option('-f', dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]') + trip_options.add_option('-o',dest='out_dir', + default='trip', + help='Output directory for tables and plots [Default: %default]') + trip_options.add_option('--site', dest='site', + default=False, action='store_true', + help='Return the insertion site without the promoter [Default: %default]') + trip_options.add_option('--reporter', dest='reporter', + default=False, action='store_true', + help='Insert the flanking piggyback reporter with the promoter [Default: %default]') + trip_options.add_option('--reporter_bare', dest='reporter_bare', + default=False, action='store_true', + help='Insert the flanking piggyback reporter with the promoter (no terminal repeats) [Default: %default]') + trip_options.add_option('--rc', dest='rc', + default=False, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + trip_options.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + trip_options.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + parser.add_option_group(trip_options) + + # cross-fold + fold_options = OptionGroup(parser, 'cross-fold options') + fold_options.add_option('-c', dest='crosses', + default=1, type='int', + help='Number of cross-fold rounds [Default:%default]') + fold_options.add_option('-d', dest='data_head', + default=None, type='int', + help='Index for dataset/head [Default: %default]') + fold_options.add_option('-e', dest='conda_env', + default='tf210', + help='Anaconda environment [Default: %default]') + fold_options.add_option('--name', dest='name', + default='trip', help='SLURM name prefix [Default: %default]') + fold_options.add_option('--max_proc', dest='max_proc', + default=None, type='int', + help='Maximum concurrent processes [Default: %default]') + fold_options.add_option('-q', dest='queue', + default='geforce', + help='SLURM queue on which to run the jobs [Default: %default]') + fold_options.add_option('-r', dest='restart', + default=False, action='store_true', + help='Restart a partially completed job [Default: %default]') + parser.add_option_group(fold_options) + + (options, args) = parser.parse_args() + + if len(args) != 4: + print(len(args)) + print(args) + parser.error('Must provide parameters file, cross-fold directory, TRIP promoter sequences, and TRIP insertion sites') + else: + params_file = args[0] + exp_dir = args[1] + promoters_file = args[2] + insertions_file = args[3] + + ####################################################### + # prep work + + # count folds + num_folds = 0 + fold0_dir = '%s/f%dc0' % (exp_dir, num_folds) + model_file = '%s/train/model_best.h5' % fold0_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (fold0_dir, options.data_head) + while os.path.isfile(model_file): + num_folds += 1 + fold0_dir = '%s/f%dc0' % (exp_dir, num_folds) + model_file = '%s/train/model_best.h5' % fold0_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (fold0_dir, options.data_head) + print('Found %d folds' % num_folds) + if num_folds == 0: + exit(1) + + ################################################################ + # TRIP prediction jobs + + # command base + cmd_base = '. /home/drk/anaconda3/etc/profile.d/conda.sh;' + cmd_base += ' conda activate %s;' % options.conda_env + cmd_base += ' echo $HOSTNAME;' + + jobs = [] + + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + name = '%s-f%dc%d' % (options.name, fi, ci) + + # update output directory + it_out_dir = '%s/%s' % (it_dir, options.out_dir) + os.makedirs(it_out_dir, exist_ok=True) + + model_file = '%s/train/model_best.h5' % it_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (it_dir, options.data_head) + + cmd_fold = '%s time borzoi_trip.py %s %s %s %s' % (cmd_base, params_file, model_file, promoters_file, insertions_file) + + # TRIP job + job_out_dir = it_out_dir + if not options.restart or not os.path.isfile('%s/preds.h5'%job_out_dir): + cmd_job = cmd_fold + cmd_job += ' %s' % options_string(options, trip_options, job_out_dir) + j = slurm.Job(cmd_job, name, + '%s.out'%job_out_dir, '%s.err'%job_out_dir, + queue=options.queue, gpu=1, + mem=60000, time='7-0:0:0') + jobs.append(j) + + slurm.multi_run(jobs, max_proc=options.max_proc, verbose=True, + launch_sleep=10, update_sleep=60) + +def options_string(options, group_options, rep_dir): + options_str = '' + + for opt in group_options.option_list: + opt_str = opt.get_opt_string() + opt_value = options.__dict__[opt.dest] + + # wrap askeriks in "" + if type(opt_value) == str and opt_value.find('*') != -1: + opt_value = '"%s"' % opt_value + + # no value for bools + elif type(opt_value) == bool: + if not opt_value: + opt_str = '' + opt_value = '' + + # skip Nones + elif opt_value is None: + opt_str = '' + opt_value = '' + + # modify + elif opt.dest == 'out_dir': + opt_value = rep_dir + + options_str += ' %s %s' % (opt_str, opt_value) + + return options_str + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/bin/borzoi_satg_gene_gpu.py b/bin/borzoi_satg_gene_gpu.py new file mode 100644 index 00000000..bc4c9d42 --- /dev/null +++ b/bin/borzoi_satg_gene_gpu.py @@ -0,0 +1,688 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function + +from optparse import OptionParser + +import gc +import json +import os +import pdb +import pickle +from queue import Queue +import random +import sys +from threading import Thread +import time + +import h5py +import numpy as np +import pandas as pd +import pysam +import tensorflow as tf + +from basenji import dna_io +from basenji import gene as bgene +from basenji import seqnn +from borzoi_sed import targets_prep_strand + +''' +borzoi_satg_gene_gpu.py + +Perform a gradient saliency analysis for genes specified in a GTF file (GPU-friendly). +''' + +# tf code for predicting raw sum-of-expression counts on GPU +@tf.function +def _count_func(model, seq_1hot, target_slice, pos_slice, pos_mask=None, track_scale=1., track_transform=1., clip_soft=None, use_mean=False) : + + # predict + preds = tf.gather(model(seq_1hot, training=False), target_slice, axis=-1, batch_dims=1) + + # undo scale + preds = preds / track_scale + + # undo soft_clip + if clip_soft is not None : + preds = tf.where(preds > clip_soft, (preds - clip_soft)**2 + clip_soft, preds) + + # undo sqrt + preds = preds**(1. / track_transform) + + # aggregate over tracks (average) + preds = tf.reduce_mean(preds, axis=-1) + + # slice specified positions + preds_slice = tf.gather(preds, pos_slice, axis=-1, batch_dims=1) + if pos_mask is not None : + preds_slice = preds_slice * pos_mask + + # aggregate over positions + if not use_mean : + preds_agg = tf.reduce_sum(preds_slice, axis=-1) + else : + if pos_mask is not None : + preds_agg = tf.reduce_sum(preds_slice, axis=-1) / tf.reduce_sum(pos_mask, axis=-1) + else : + preds_agg = tf.reduce_mean(preds_slice, axis=-1) + + return preds_agg + +# code for getting model predictions from a tensor of input sequence patterns +def predict_counts(seqnn_model, seq_1hot, head_i=None, target_slice=None, pos_slice=None, pos_mask=None, chunk_size=None, batch_size=1, track_scale=1., track_transform=1., clip_soft=None, use_mean=False, dtype='float32'): + + # start time + t0 = time.time() + + # choose model + if seqnn_model.ensemble is not None: + model = seqnn_model.ensemble + elif head_i is not None: + model = seqnn_model.models[head_i] + else: + model = seqnn_model.model + + # verify tensor shape(s) + seq_1hot = seq_1hot.astype('float32') + target_slice = np.array(target_slice).astype('int32') + pos_slice = np.array(pos_slice).astype('int32') + + # convert constants to tf tensors + track_scale = tf.constant(track_scale, dtype=tf.float32) + track_transform = tf.constant(track_transform, dtype=tf.float32) + if clip_soft is not None : + clip_soft = tf.constant(clip_soft, dtype=tf.float32) + + if pos_mask is not None : + pos_mask = np.array(pos_mask).astype('float32') + + if len(seq_1hot.shape) < 3: + seq_1hot = seq_1hot[None, ...] + + if len(target_slice.shape) < 2: + target_slice = target_slice[None, ...] + + if len(pos_slice.shape) < 2: + pos_slice = pos_slice[None, ...] + + if pos_mask is not None and len(pos_mask.shape) < 2: + pos_mask = pos_mask[None, ...] + + # chunk parameters + num_chunks = 1 + if chunk_size is None : + chunk_size = seq_1hot.shape[0] + else : + num_chunks = int(np.ceil(seq_1hot.shape[0] / chunk_size)) + + # loop over chunks + pred_chunks = [] + for ci in range(num_chunks) : + + # collect chunk + seq_1hot_chunk = seq_1hot[ci * chunk_size:(ci+1) * chunk_size, ...] + target_slice_chunk = target_slice[ci * chunk_size:(ci+1) * chunk_size, ...] + pos_slice_chunk = pos_slice[ci * chunk_size:(ci+1) * chunk_size, ...] + + pos_mask_chunk = None + if pos_mask is not None : + pos_mask_chunk = pos_mask[ci * chunk_size:(ci+1) * chunk_size, ...] + + actual_chunk_size = seq_1hot_chunk.shape[0] + + # convert to tf tensors + seq_1hot_chunk = tf.convert_to_tensor(seq_1hot_chunk, dtype=tf.float32) + target_slice_chunk = tf.convert_to_tensor(target_slice_chunk, dtype=tf.int32) + pos_slice_chunk = tf.convert_to_tensor(pos_slice_chunk, dtype=tf.int32) + + if pos_mask is not None : + pos_mask_chunk = tf.convert_to_tensor(pos_mask_chunk, dtype=tf.float32) + + # batching parameters + num_batches = int(np.ceil(actual_chunk_size / batch_size)) + + # loop over batches + pred_batches = [] + for bi in range(num_batches) : + + # collect batch + seq_1hot_batch = seq_1hot_chunk[bi * batch_size:(bi+1) * batch_size, ...] + target_slice_batch = target_slice_chunk[bi * batch_size:(bi+1) * batch_size, ...] + pos_slice_batch = pos_slice_chunk[bi * batch_size:(bi+1) * batch_size, ...] + + pos_mask_batch = None + if pos_mask is not None : + pos_mask_batch = pos_mask_chunk[bi * batch_size:(bi+1) * batch_size, ...] + + pred_batch = _count_func(model, seq_1hot_batch, target_slice_batch, pos_slice_batch, pos_mask_batch, track_scale, track_transform, clip_soft, use_mean).numpy().astype(dtype) + + pred_batches.append(pred_batch) + + # concat predicted batches + preds = np.concatenate(pred_batches, axis=0) + + pred_chunks.append(preds) + + # collect garbage + gc.collect() + + # concat predicted chunks + preds = np.concatenate(pred_chunks, axis=0) + + print('Made predictions in %ds' % (time.time()-t0)) + + return preds + + +################################################################################ +# main +# ############################################################################### +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('--fa', dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]') + parser.add_option('-o', dest='out_dir', + default='satg_out', help='Output directory [Default: %default]') + parser.add_option('--rc', dest='rc', + default=0, type='int', + help='Ensemble forward and reverse complement predictions [Default: %default]') + parser.add_option('-f', dest='folds', + default='0', type='str', + help='Model folds to use in ensemble [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--span', dest='span', + default=0, type='int', + help='Aggregate entire gene span [Default: %default]') + parser.add_option('--smoothgrad', dest='smooth_grad', + default=0, type='int', + help='Run smoothgrad [Default: %default]') + parser.add_option('--samples', dest='n_samples', + default=5, type='int', + help='Number of smoothgrad samples [Default: %default]') + parser.add_option('--sampleprob', dest='sample_prob', + default=0.875, type='float', + help='Probability of not mutating a position in smoothgrad [Default: %default]') + parser.add_option('--clip_soft', dest='clip_soft', + default=None, type='float', + help='Model clip_soft setting [Default: %default]') + parser.add_option('--no_transform', dest='no_transform', + default=0, type='int', + help='Run gradients with no inverse transforms [Default: %default]') + parser.add_option('--get_preds', dest='get_preds', + default=0, type='int', + help='Store scalar predictions in addition to their gradients [Default: %default]') + parser.add_option('--pseudo_qtl', dest='pseudo_qtl', + default=None, type='float', + help='Quantile of predicted scalars to choose as pseudo count [Default: %default]') + parser.add_option('--pseudo_tissue', dest='pseudo_tissue', + default=None, type='str', + help='Tissue to filter genes on when calculating pseudo count [Default: %default]') + parser.add_option('--gene_file', dest='gene_file', + default=None, type='str', + help='Csv-file of gene metadata [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + (options, args) = parser.parse_args() + + if len(args) == 3: + # single worker + params_file = args[0] + model_folder = args[1] + genes_gtf_file = args[2] + else: + parser.error('Must provide parameter file, model folder and GTF file') + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.folds = [int(fold) for fold in options.folds.split(',')] + options.shifts = [int(shift) for shift in options.shifts.split(',')] + + ################################################################# + # read parameters and targets + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params['model'] + params_train = params['train'] + seq_len = params_model['seq_length'] + + if options.targets_file is None: + parser.error('Must provide targets table to properly handle strands.') + else: + targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0) + + # prep strand + orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) + targets_strand_pair = np.array([orig_new_index[ti] for ti in targets_df.strand_pair]) + targets_strand_df = targets_prep_strand(targets_df) + num_targets = 1 + + #Load gene dataframe and select tissue + tissue_genes = None + if options.gene_file is not None and options.pseudo_tissue is not None : + gene_df = pd.read_csv(options.gene_file, sep='\t') + gene_df = gene_df.query("tissue == '" + str(options.pseudo_tissue) + "'").copy().reset_index(drop=True) + gene_df = gene_df.drop(columns=['Unnamed: 0']) + + #Get list of gene for tissue + tissue_genes = gene_df['gene_base'].values.tolist() + + print("len(tissue_genes) = " + str(len(tissue_genes))) + + ################################################################# + # load first model fold to get parameters + + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_folder + "/f0c0/model0_best.h5", 0, by_name=False) + seqnn_model.build_slice(targets_df.index, False) + # seqnn_model.build_ensemble(options.rc, options.shifts) + + model_stride = seqnn_model.model_strides[0] + model_crop = seqnn_model.target_crops[0] + target_length = seqnn_model.target_lengths[0] + + ################################################################# + # read genes + + # parse GTF + transcriptome = bgene.Transcriptome(genes_gtf_file) + + # order valid genes + genome_open = pysam.Fastafile(options.genome_fasta) + gene_list = sorted(transcriptome.genes.keys()) + num_genes = len(gene_list) + + ################################################################# + # setup output + + min_start = -model_stride*model_crop + + # choose gene sequences + genes_chr = [] + genes_start = [] + genes_end = [] + genes_strand = [] + for gene_id in gene_list: + gene = transcriptome.genes[gene_id] + genes_chr.append(gene.chrom) + genes_strand.append(gene.strand) + + gene_midpoint = gene.midpoint() + gene_start = max(min_start, gene_midpoint - seq_len//2) + gene_end = gene_start + seq_len + genes_start.append(gene_start) + genes_end.append(gene_end) + + ################################################################# + # predict scores, write output + + buffer_size = 1024 + + print("clip_soft = " + str(options.clip_soft)) + + print("n genes = " + str(len(genes_chr))) + + # loop over folds + for fold_ix in options.folds : + print("-- Fold = " + str(fold_ix) + " --") + + # (re-)initialize HDF5 + scores_h5_file = '%s/scores_f%dc0.h5' % (options.out_dir, fold_ix) + if os.path.isfile(scores_h5_file): + os.remove(scores_h5_file) + scores_h5 = h5py.File(scores_h5_file, 'w') + scores_h5.create_dataset('seqs', dtype='bool', + shape=(num_genes, seq_len, 4)) + scores_h5.create_dataset('grads', dtype='float16', + shape=(num_genes, seq_len, 4, num_targets)) + if options.get_preds == 1 : + scores_h5.create_dataset('preds', dtype='float32', + shape=(num_genes, num_targets)) + scores_h5.create_dataset('gene', data=np.array(gene_list, dtype='S')) + scores_h5.create_dataset('chr', data=np.array(genes_chr, dtype='S')) + scores_h5.create_dataset('start', data=np.array(genes_start)) + scores_h5.create_dataset('end', data=np.array(genes_end)) + scores_h5.create_dataset('strand', data=np.array(genes_strand, dtype='S')) + + # load model fold + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_folder + "/f" + str(fold_ix) + "c0/model0_best.h5", 0, by_name=False) + seqnn_model.build_slice(targets_df.index, False) + + track_scale = targets_df.iloc[0]['scale'] + track_transform = 3. / 4. + + # optionally get (and store) scalar predictions before computing their gradients + if options.get_preds == 1 : + print(" - (prediction) - ", flush=True) + + for shift in options.shifts : + print('Processing shift %d' % shift, flush=True) + + for rev_comp in ([False, True] if options.rc == 1 else [False]) : + + if options.rc == 1 : + print('Fwd/rev = %s' % ('fwd' if not rev_comp else 'rev'), flush=True) + + seq_1hots = [] + gene_slices = [] + gene_targets = [] + + for gi, gene_id in enumerate(gene_list): + + if gi % 500 == 0 : + print('Processing %d, %s' % (gi, gene_id), flush=True) + + gene = transcriptome.genes[gene_id] + + # make sequence + seq_1hot = make_seq_1hot(genome_open, genes_chr[gi], genes_start[gi], genes_end[gi], seq_len) + seq_1hot = dna_io.hot1_augment(seq_1hot, shift=shift) + + # determine output sequence start + seq_out_start = genes_start[gi] + model_stride*model_crop + seq_out_len = model_stride*target_length + + # determine output positions + gene_slice = gene.output_slice(seq_out_start, seq_out_len, model_stride, options.span == 1) + + if rev_comp: + seq_1hot = dna_io.hot1_rc(seq_1hot) + gene_slice = target_length - gene_slice - 1 + + # slice relevant strand targets + if genes_strand[gi] == '+': + gene_strand_mask = (targets_df.strand != '-') if not rev_comp else (targets_df.strand != '+') + else: + gene_strand_mask = (targets_df.strand != '+') if not rev_comp else (targets_df.strand != '-') + + gene_target = np.array(targets_df.index[gene_strand_mask].values) + + # accumulate data tensors + seq_1hots.append(seq_1hot[None, ...]) + gene_slices.append(gene_slice[None, ...]) + gene_targets.append(gene_target[None, ...]) + + if gi == len(gene_list) - 1 or len(seq_1hots) >= buffer_size : + + # concat sequences + seq_1hots = np.concatenate(seq_1hots, axis=0) + + # pad gene slices to same length (mark valid positions in mask tensor) + max_slice_len = int(np.max([gene_slice.shape[1] for gene_slice in gene_slices])) + + gene_masks = np.zeros((len(gene_slices), max_slice_len), dtype='float32') + gene_slices_padded = np.zeros((len(gene_slices), max_slice_len), dtype='int32') + for gii, gene_slice in enumerate(gene_slices) : + for j in range(gene_slice.shape[1]) : + gene_masks[gii, j] = 1. + gene_slices_padded[gii, j] = gene_slice[0, j] + + gene_slices = gene_slices_padded + + # concat gene-specific targets + gene_targets = np.concatenate(gene_targets, axis=0) + + # batch call count predictions + preds = predict_counts( + seqnn_model, + seq_1hots, + head_i=0, + target_slice=gene_targets, + pos_slice=gene_slices, + pos_mask=gene_masks, + chunk_size=buffer_size, + batch_size=1, + track_scale=track_scale, + track_transform=track_transform, + clip_soft=options.clip_soft, + use_mean=False, + dtype='float32' + ) + + # save predictions + for gii, gene_slice in enumerate(gene_slices) : + h5_gi = (gi // buffer_size) * buffer_size + gii + + # write to HDF5 + scores_h5['preds'][h5_gi, :] += (preds[gii] / float(len(options.shifts))) + + #clear sequence buffer + seq_1hots = [] + gene_slices = [] + gene_targets = [] + + # collect garbage + gc.collect() + + # optionally set pseudo count from predictions + pseudo_count = 0. + if options.pseudo_qtl is not None : + gene_preds = scores_h5['preds'][:] + + # filter on tissue + tissue_preds = None + + if tissue_genes is not None : + tissue_set = set(tissue_genes) + + # get subset of genes and predictions belonging to the pseudo count tissue + tissue_preds = [] + for gi, gene_id in enumerate(gene_list) : + if gene_id.split(".")[0] in tissue_set : + tissue_preds.append(gene_preds[gi, 0]) + + tissue_preds = np.array(tissue_preds, dtype='float32') + else : + tissue_preds = np.array(gene_preds[:, 0], dtype='float32') + + print("tissue_preds.shape[0] = " + str(tissue_preds.shape[0])) + + print("np.min(tissue_preds) = " + str(np.min(tissue_preds))) + print("np.max(tissue_preds) = " + str(np.max(tissue_preds))) + + # set pseudo count based on quantile of predictions + pseudo_count = np.quantile(tissue_preds, q=options.pseudo_qtl) + + print("") + print("pseudo_count = " + str(round(pseudo_count, 6))) + + # compute gradients + print(" - (gradients) - ", flush=True) + + for shift in options.shifts : + print('Processing shift %d' % shift, flush=True) + + for rev_comp in ([False, True] if options.rc == 1 else [False]) : + + if options.rc == 1 : + print('Fwd/rev = %s' % ('fwd' if not rev_comp else 'rev'), flush=True) + + seq_1hots = [] + gene_slices = [] + gene_targets = [] + + for gi, gene_id in enumerate(gene_list): + + if gi % 500 == 0 : + print('Processing %d, %s' % (gi, gene_id), flush=True) + + gene = transcriptome.genes[gene_id] + + # make sequence + seq_1hot = make_seq_1hot(genome_open, genes_chr[gi], genes_start[gi], genes_end[gi], seq_len) + seq_1hot = dna_io.hot1_augment(seq_1hot, shift=shift) + + # determine output sequence start + seq_out_start = genes_start[gi] + model_stride*model_crop + seq_out_len = model_stride*target_length + + # determine output positions + gene_slice = gene.output_slice(seq_out_start, seq_out_len, model_stride, options.span == 1) + + if rev_comp: + seq_1hot = dna_io.hot1_rc(seq_1hot) + gene_slice = target_length - gene_slice - 1 + + # slice relevant strand targets + if genes_strand[gi] == '+': + gene_strand_mask = (targets_df.strand != '-') if not rev_comp else (targets_df.strand != '+') + else: + gene_strand_mask = (targets_df.strand != '+') if not rev_comp else (targets_df.strand != '-') + + gene_target = np.array(targets_df.index[gene_strand_mask].values) + + # accumulate data tensors + seq_1hots.append(seq_1hot[None, ...]) + gene_slices.append(gene_slice[None, ...]) + gene_targets.append(gene_target[None, ...]) + + if gi == len(gene_list) - 1 or len(seq_1hots) >= buffer_size : + + # concat sequences + seq_1hots = np.concatenate(seq_1hots, axis=0) + + # pad gene slices to same length (mark valid positions in mask tensor) + max_slice_len = int(np.max([gene_slice.shape[1] for gene_slice in gene_slices])) + + gene_masks = np.zeros((len(gene_slices), max_slice_len), dtype='float32') + gene_slices_padded = np.zeros((len(gene_slices), max_slice_len), dtype='int32') + for gii, gene_slice in enumerate(gene_slices) : + for j in range(gene_slice.shape[1]) : + gene_masks[gii, j] = 1. + gene_slices_padded[gii, j] = gene_slice[0, j] + + gene_slices = gene_slices_padded + + # concat gene-specific targets + gene_targets = np.concatenate(gene_targets, axis=0) + + # batch call gradient computation + grads = seqnn_model.gradients( + seq_1hots, + head_i=0, + target_slice=gene_targets, + pos_slice=gene_slices, + pos_mask=gene_masks, + chunk_size=buffer_size if options.smooth_grad != 1 else buffer_size // options.n_samples, + batch_size=1, + track_scale=track_scale, + track_transform=track_transform, + clip_soft=options.clip_soft, + pseudo_count=pseudo_count, + no_transform=options.no_transform == 1, + use_mean=False, + use_ratio=False, + use_logodds=False, + subtract_avg=True, + input_gate=False, + smooth_grad=options.smooth_grad == 1, + n_samples=options.n_samples, + sample_prob=options.sample_prob, + dtype='float16' + ) + + # undo augmentations and save gradients + for gii, gene_slice in enumerate(gene_slices) : + grad = unaugment_grads(grads[gii, :, :, None], fwdrc=(not rev_comp), shift=shift) + + h5_gi = (gi // buffer_size) * buffer_size + gii + + # write to HDF5 + scores_h5['grads'][h5_gi] += grad + + #clear sequence buffer + seq_1hots = [] + gene_slices = [] + gene_targets = [] + + # collect garbage + gc.collect() + + # save sequences and normalize gradients by total size of ensemble + for gi, gene_id in enumerate(gene_list): + + # re-make original sequence + seq_1hot = make_seq_1hot(genome_open, genes_chr[gi], genes_start[gi], genes_end[gi], seq_len) + + # write to HDF5 + scores_h5['seqs'][gi] = seq_1hot + scores_h5['grads'][gi] /= float((len(options.shifts) * (2 if options.rc == 1 else 1))) + + # collect garbage + gc.collect() + + # close files + genome_open.close() + scores_h5.close() + + +def unaugment_grads(grads, fwdrc=False, shift=0): + """ Undo sequence augmentation.""" + # reverse complement + if not fwdrc: + # reverse + grads = grads[::-1, :, :] + + # swap A and T + grads[:, [0, 3], :] = grads[:, [3, 0], :] + + # swap C and G + grads[:, [1, 2], :] = grads[:, [2, 1], :] + + # undo shift + if shift < 0: + # shift sequence right + grads[-shift:, :, :] = grads[:shift, :, :] + + # fill in left unknowns + grads[:-shift, :, :] = 0 + + elif shift > 0: + # shift sequence left + grads[:-shift, :, :] = grads[shift:, :, :] + + # fill in right unknowns + grads[-shift:, :, :] = 0 + + return grads + + +def make_seq_1hot(genome_open, chrm, start, end, seq_len): + if start < 0: + seq_dna = 'N'*(-start) + genome_open.fetch(chrm, 0, end) + else: + seq_dna = genome_open.fetch(chrm, start, end) + + # extend to full length + if len(seq_dna) < seq_len: + seq_dna += 'N'*(seq_len-len(seq_dna)) + + seq_1hot = dna_io.dna_1hot(seq_dna) + return seq_1hot + +################################################################################ +# __main__ +# ############################################################################### +if __name__ == '__main__': + main() diff --git a/bin/borzoi_satg_gene_gpu_focused_ism.py b/bin/borzoi_satg_gene_gpu_focused_ism.py new file mode 100644 index 00000000..0f34b27a --- /dev/null +++ b/bin/borzoi_satg_gene_gpu_focused_ism.py @@ -0,0 +1,673 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function + +from optparse import OptionParser + +import gc +import json +import os +import pdb +import pickle +from queue import Queue +import random +import sys +from threading import Thread +import time + +import h5py +import numpy as np +import pandas as pd +import pysam +import tensorflow as tf + +from basenji import dna_io +from basenji import gene as bgene +from basenji import seqnn +from borzoi_sed import targets_prep_strand + +from scipy.ndimage import gaussian_filter1d + +''' +borzoi_satg_gene_gpu_focused_ism.py + +Perform an ISM analysis for genes specified in a GTF file, targeting high-saliency regions based on gradient scores. +''' + +# tf code for computing ISM scores on GPU +@tf.function +def _score_func(model, seq_1hot, target_slice, pos_slice, pos_mask=None, pos_slice_denom=None, pos_mask_denom=True, track_scale=1., track_transform=1., clip_soft=None, pseudo_count=0., no_transform=False, aggregate_tracks=None, use_mean=False, use_ratio=False, use_logodds=False) : + + # predict + preds = tf.gather(model(seq_1hot, training=False), target_slice, axis=-1, batch_dims=1) + + if not no_transform : + + # undo scale + preds = preds / track_scale + + # undo soft_clip + if clip_soft is not None : + preds = tf.where(preds > clip_soft, (preds - clip_soft)**2 + clip_soft, preds) + + # undo sqrt + preds = preds**(1. / track_transform) + + if aggregate_tracks is not None : + preds = tf.reduce_mean(tf.reshape(preds, (preds.shape[0], preds.shape[1], preds.shape[2] // aggregate_tracks, aggregate_tracks)), axis=-1) + + # slice specified positions + preds_slice = tf.gather(preds, pos_slice, axis=1, batch_dims=1) + if pos_mask is not None : + preds_slice = preds_slice * pos_mask + + # slice denominator positions + if use_ratio and pos_slice_denom is not None: + preds_slice_denom = tf.gather(preds, pos_slice_denom, axis=1, batch_dims=1) + if pos_mask_denom is not None : + preds_slice_denom = preds_slice_denom * pos_mask_denom + + # aggregate over positions + if not use_mean : + preds_agg = tf.reduce_sum(preds_slice, axis=1) + if use_ratio and pos_slice_denom is not None: + preds_agg_denom = tf.reduce_sum(preds_slice_denom, axis=1) + else : + if pos_mask is not None : + preds_agg = tf.reduce_sum(preds_slice, axis=1) / tf.reduce_sum(pos_mask, axis=1) + else : + preds_agg = tf.reduce_mean(preds_slice, axis=1) + + if use_ratio and pos_slice_denom is not None: + if pos_mask_denom is not None : + preds_agg_denom = tf.reduce_sum(preds_slice_denom, axis=1) / tf.reduce_sum(pos_mask_denom, axis=1) + else : + preds_agg_denom = tf.reduce_mean(preds_slice_denom, axis=1) + + # compute final statistic + if no_transform : + score_ratios = preds_agg + elif not use_ratio : + score_ratios = tf.math.log(preds_agg + pseudo_count + 1e-6) + else : + if not use_logodds : + score_ratios = tf.math.log((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count) + 1e-6) + else : + score_ratios = tf.math.log(((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count)) / (1. - ((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count))) + 1e-6) + + return score_ratios + +def get_ism(seqnn_model, seq_1hot_wt, ism_start=0, ism_end=524288, head_i=None, target_slice=None, pos_slice=None, pos_mask=None, pos_slice_denom=None, pos_mask_denom=None, track_scale=1., track_transform=1., clip_soft=None, pseudo_count=0., no_transform=False, aggregate_tracks=None, use_mean=False, use_ratio=False, use_logodds=False, bases=[0, 1, 2, 3]) : + + # choose model + if seqnn_model.ensemble is not None: + model = seqnn_model.ensemble + elif head_i is not None: + model = seqnn_model.models[head_i] + else: + model = seqnn_model.model + + # verify tensor shape(s) + seq_1hot_wt = seq_1hot_wt.astype('float32') + target_slice = np.array(target_slice).astype('int32') + pos_slice = np.array(pos_slice).astype('int32') + + # convert constants to tf tensors + track_scale = tf.constant(track_scale, dtype=tf.float32) + track_transform = tf.constant(track_transform, dtype=tf.float32) + if clip_soft is not None : + clip_soft = tf.constant(clip_soft, dtype=tf.float32) + pseudo_count = tf.constant(pseudo_count, dtype=tf.float32) + + if pos_mask is not None : + pos_mask = np.array(pos_mask).astype('float32') + + if use_ratio and pos_slice_denom is not None : + pos_slice_denom = np.array(pos_slice_denom).astype('int32') + + if pos_mask_denom is not None : + pos_mask_denom = np.array(pos_mask_denom).astype('float32') + + if len(seq_1hot_wt.shape) < 3: + seq_1hot_wt = seq_1hot_wt[None, ...] + + if len(target_slice.shape) < 2: + target_slice = target_slice[None, ...] + + if len(pos_slice.shape) < 2: + pos_slice = pos_slice[None, ...] + + if pos_mask is not None and len(pos_mask.shape) < 2: + pos_mask = pos_mask[None, ...] + + if use_ratio and pos_slice_denom is not None and len(pos_slice_denom.shape) < 2: + pos_slice_denom = pos_slice_denom[None, ...] + + if pos_mask_denom is not None and len(pos_mask_denom.shape) < 2: + pos_mask_denom = pos_mask_denom[None, ...] + + # convert to tf tensors + seq_1hot_wt_tf = tf.convert_to_tensor(seq_1hot_wt, dtype=tf.float32) + target_slice = tf.convert_to_tensor(target_slice, dtype=tf.int32) + pos_slice = tf.convert_to_tensor(pos_slice, dtype=tf.int32) + + if pos_mask is not None : + pos_mask = tf.convert_to_tensor(pos_mask, dtype=tf.float32) + + if use_ratio and pos_slice_denom is not None : + pos_slice_denom = tf.convert_to_tensor(pos_slice_denom, dtype=tf.int32) + + if pos_mask_denom is not None : + pos_mask_denom = tf.convert_to_tensor(pos_mask_denom, dtype=tf.float32) + + # allocate ism result tensor + pred_ism = np.zeros((524288, 4, target_slice.shape[1] // (aggregate_tracks if aggregate_tracks is not None else 1))) + + # get wt pred + score_wt = _score_func(model, seq_1hot_wt_tf, target_slice, pos_slice, pos_mask, pos_slice_denom, pos_mask_denom, track_scale, track_transform, clip_soft, pseudo_count, no_transform, aggregate_tracks, use_mean, use_ratio, use_logodds).numpy() + + for j in range(ism_start, ism_end) : + for b in bases : + if seq_1hot_wt[0, j, b] != 1. : + seq_1hot_mut = np.copy(seq_1hot_wt) + seq_1hot_mut[0, j, :] = 0. + seq_1hot_mut[0, j, b] = 1. + + # convert to tf tensor + seq_1hot_mut_tf = tf.convert_to_tensor(seq_1hot_mut, dtype=tf.float32) + + # get mut pred + score_mut = _score_func(model, seq_1hot_mut_tf, target_slice, pos_slice, pos_mask, pos_slice_denom, pos_mask_denom, track_scale, track_transform, clip_soft, pseudo_count, no_transform, aggregate_tracks, use_mean, use_ratio, use_logodds).numpy() + + pred_ism[j, b, :] = score_wt - score_mut + + pred_ism = np.tile(np.mean(pred_ism, axis=1, keepdims=True), (1, 4, 1)) * seq_1hot_wt[0, ..., None] + + return pred_ism + + +################################################################################ +# main +# ############################################################################### +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('--fa', dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]') + parser.add_option('-o', dest='out_dir', + default='satg_out', help='Output directory [Default: %default]') + parser.add_option('--rc', dest='rc', + default=0, type='int', + help='Ensemble forward and reverse complement predictions [Default: %default]') + parser.add_option('-f', dest='folds', + default='0', type='str', + help='Model folds to use in ensemble [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--span', dest='span', + default=0, type='int', + help='Aggregate entire gene span [Default: %default]') + parser.add_option('--clip_soft', dest='clip_soft', + default=None, type='float', + help='Model clip_soft setting [Default: %default]') + parser.add_option('--no_transform', dest='no_transform', + default=0, type='int', + help='Run gradients with no inverse transforms [Default: %default]') + parser.add_option('--pseudo_qtl', dest='pseudo_qtl', + default=None, type='float', + help='Quantile of predicted scalars to choose as pseudo count [Default: %default]') + parser.add_option('--aggregate_tracks', dest='aggregate_tracks', + default=None, type='int', + help='Run gradients with no inverse transforms [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + parser.add_option('--tissue_files', dest='tissue_files', + default=None, type='str', + help='Comma-separated list of files containing saliency scores (h5 format).') + parser.add_option('--tissues', dest='tissues', + default=None, type='str', + help='Comma-separated list of tissue names.') + parser.add_option('--tissue', dest='tissue', + default=None, type='str', + help='Tissue name to filter on in gene_file.') + parser.add_option('--main_tissue_ix', dest='main_tissue_ix', + default=0, type='int', + help='Main tissue index.') + parser.add_option('--ism_size', dest='ism_size', + default=192, type='int', + help='Length of sequence window to run ISM across.') + parser.add_option('--gene_file', dest='gene_file', + default=None, type='str', + help='Csv-file of gene metadata.') + parser.add_option('--max_n_genes', dest='max_n_genes', + default=10, type='int', + help='Maximum number of genes in the GTF to compute ISMs for [Default: %default]') + parser.add_option('--gaussian_sigma', dest='gaussian_sigma', + default=8, type='int', + help='Sigma value for 1D gaussian smoothing filter [Default: %default]') + parser.add_option('--min_padding', dest='min_padding', + default=65536, type='int', + help='Minimum crop to apply to scores before searching for smoothed maximum [Default: %default]') + + (options, args) = parser.parse_args() + + if len(args) == 3: + # single worker + params_file = args[0] + model_folder = args[1] + genes_gtf_file = args[2] + else: + parser.error('Must provide parameter file, model folder and GTF file') + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.folds = [int(fold) for fold in options.folds.split(',')] + options.shifts = [int(shift) for shift in options.shifts.split(',')] + + options.tissue_files = [tissue for tissue in options.tissue_files.split(",")] + options.tissues = [tissue for tissue in options.tissues.split(",")] + + ################################################################# + # read parameters and targets + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params['model'] + params_train = params['train'] + seq_len = params_model['seq_length'] + + if options.targets_file is None: + parser.error('Must provide targets table to properly handle strands.') + else: + targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0) + + # prep strand + orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) + targets_strand_pair = np.array([orig_new_index[ti] for ti in targets_df.strand_pair]) + targets_strand_df = targets_prep_strand(targets_df) + num_targets = len(targets_strand_df) + + # specify relative target indices + targets_df['row_index'] = np.arange(len(targets_df), dtype='int32') + + ################################################################# + # load first model fold to get parameters + + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_folder + "/f0c0/model0_best.h5", 0, by_name=False) + seqnn_model.build_slice(targets_df.index, False) + # seqnn_model.build_ensemble(options.rc, options.shifts) + + model_stride = seqnn_model.model_strides[0] + model_crop = seqnn_model.target_crops[0] + target_length = seqnn_model.target_lengths[0] + + ################################################################# + # read genes + + # parse GTF + transcriptome = bgene.Transcriptome(genes_gtf_file) + + # order valid genes + genome_open = pysam.Fastafile(options.genome_fasta) + gene_list = sorted(transcriptome.genes.keys()) + num_genes = len(gene_list) + + #Make copy of unfiltered gene list + gene_list_all = gene_list.copy() + + ################################################################# + # load tissue gene list + + #Load gene dataframe and select tissue + gene_df = pd.read_csv(options.gene_file, sep='\t') + gene_df = gene_df.query("tissue == '" + str(options.tissue) + "'").copy().reset_index(drop=True) + gene_df = gene_df.drop(columns=['Unnamed: 0']) + + print("len(gene_df) = " + str(len(gene_df))) + + #Truncate by maximum number of genes + gene_df = gene_df.iloc[:options.max_n_genes].copy().reset_index(drop=True) + + #Get list of genes for tissue + tissue_genes = gene_df['gene_base'].values.tolist() + + #print("len(tissue_genes) = " + str(len(tissue_genes))) + + #Filter transcriptome gene list + gene_list = [gene for gene in gene_list if gene.split(".")[0] in set(tissue_genes)] + num_genes = len(gene_list) + + print("num_genes = " + str(num_genes)) + + ################################################################# + # load h5 scores + + seqs = None + strands = None + chrs = None + starts = None + ends = None + genes = None + all_scores = [] + pseudo_counts = [] + + for scores_h5_file, scores_h5_tissue in zip(options.tissue_files, options.tissues) : + + print("Reading '" + scores_h5_file + "'") + + with h5py.File(scores_h5_file, 'r') as score_file: + + #Get scores and onehots + scores = score_file['grads'][()][..., 0] + seqs = score_file['seqs'][()] + + #Get auxiliary information + strands = score_file['strand'][()] + strands = np.array([strands[j].decode() for j in range(strands.shape[0])]) + + chrs = score_file['chr'][()] + chrs = np.array([chrs[j].decode() for j in range(chrs.shape[0])]) + + starts = np.array(score_file['start'][()]) + ends = np.array(score_file['end'][()]) + + genes = score_file['gene'][()] + genes = np.array([genes[j].decode() for j in range(genes.shape[0])]) #.split(".")[0] + + gene_dict = {gene : gene_i for gene_i, gene in enumerate(genes.tolist())} + + #Get index of rows to keep + keep_index = [] + for gene in gene_list : + keep_index.append(gene_dict[gene]) + + #Optionally compute pseudo-counts + if options.aggregate_tracks is not None and options.pseudo_qtl is not None : + + #Load gene dataframe and select active tissue + gene_df_all = pd.read_csv(options.gene_file, sep='\t') + gene_df_all = gene_df_all.query("tissue == '" + str(scores_h5_tissue) + "'").copy().reset_index(drop=True) + gene_df_all = gene_df_all.drop(columns=['Unnamed: 0']) + + #Get list of genes for active tissue + tissue_genes_all = gene_df_all['gene_base'].values.tolist() + + #Filter transcriptome gene list + gene_list_tissue = [gene for gene in gene_list_all if gene.split(".")[0] in set(tissue_genes_all)] + num_genes_tissue = len(gene_list_tissue) + + print(" - num_genes_tissue = " + str(num_genes_tissue)) + + #Get index of genes beloning to active tissue + gene_index = [] + for gene in gene_list_tissue : + gene_index.append(gene_dict[gene]) + + #Compute pseudo-count + pseudo_count = np.quantile(np.array(score_file['preds'][()][gene_index, 0]), q=options.pseudo_qtl) + pseudo_counts.append(pseudo_count) + + #Filter/sub-select data + scores = scores[keep_index, ...] + seqs = seqs[keep_index, ...] + strands = strands[keep_index] + chrs = chrs[keep_index] + starts = starts[keep_index] + ends = ends[keep_index] + genes = genes[keep_index] + + #Append input-gated scores + all_scores.append((scores * seqs)[None, ...]) + + #Collect garbage + gc.collect() + + #Collect final scores + scores = np.concatenate(all_scores, axis=0) + + print("scores.shape = " + str(scores.shape)) + + #Collect pseudo-counts + pseudo_count = 0. + if options.aggregate_tracks is not None and options.pseudo_qtl is not None : + pseudo_count = np.array(pseudo_counts, dtype='float32')[None, :] + + print("pseudo_count = " + str(np.round(pseudo_count, 2))) + else : + print("pseudo_count = " + str(round(pseudo_count, 2))) + + ################################################################# + # setup output + + # choose gene sequences + genes_chr = chrs.tolist() + genes_start = starts.tolist() + genes_end = ends.tolist() + genes_strand = strands.tolist() + + ################################################################# + # calculate ism start and end positions per gene + + print("main_tissue_ix = " + str(options.main_tissue_ix)) + + genes_ism_start = [] + genes_ism_end = [] + for gi in range(len(gene_list)) : + score_2 = scores[options.main_tissue_ix, gi, ...] + score_1 = np.mean(scores[np.arange(scores.shape[0]) != options.main_tissue_ix, gi, ...], axis=0) + + diff_score = np.sum(score_2 - score_1, axis=-1) + + #Apply gaussian filter + diff_score = gaussian_filter1d(diff_score.astype('float32'), sigma=options.gaussian_sigma, truncate=2).astype('float16') + + max_pos = np.argmax(diff_score[options.min_padding:-options.min_padding]) + options.min_padding + + genes_ism_start.append(max_pos - options.ism_size // 2) + genes_ism_end.append(max_pos + options.ism_size // 2) + + ################################################################# + # predict ISM scores, write output + + print("clip_soft = " + str(options.clip_soft)) + + print("n genes = " + str(len(genes_chr))) + + # loop over folds + for fold_ix in options.folds : + print("-- Fold = " + str(fold_ix) + " --") + + # (re-)initialize HDF5 + scores_h5_file = '%s/ism_f%dc0.h5' % (options.out_dir, fold_ix) + if os.path.isfile(scores_h5_file): + os.remove(scores_h5_file) + scores_h5 = h5py.File(scores_h5_file, 'w') + scores_h5.create_dataset('seqs', dtype='bool', + shape=(num_genes, options.ism_size, 4)) + scores_h5.create_dataset('isms', dtype='float16', + shape=(num_genes, options.ism_size, 4, num_targets // (options.aggregate_tracks if options.aggregate_tracks is not None else 1))) + scores_h5.create_dataset('gene', data=np.array(gene_list, dtype='S')) + scores_h5.create_dataset('chr', data=np.array(genes_chr, dtype='S')) + scores_h5.create_dataset('start', data=np.array(genes_start)) + scores_h5.create_dataset('end', data=np.array(genes_end)) + scores_h5.create_dataset('ism_start', data=np.array(genes_ism_start)) + scores_h5.create_dataset('ism_end', data=np.array(genes_ism_end)) + scores_h5.create_dataset('strand', data=np.array(genes_strand, dtype='S')) + + # load model fold + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_folder + "/f" + str(fold_ix) + "c0/model0_best.h5", 0, by_name=False) + seqnn_model.build_slice(targets_df.index, False) + + track_scale = targets_df.iloc[0]['scale'] + track_transform = 3. / 4. + + for shift in options.shifts : + print('Processing shift %d' % shift, flush=True) + + for rev_comp in ([False, True] if options.rc == 1 else [False]) : + + if options.rc == 1 : + print('Fwd/rev = %s' % ('fwd' if not rev_comp else 'rev'), flush=True) + + seq_1hots = [] + gene_slices = [] + gene_targets = [] + + for gi, gene_id in enumerate(gene_list): + + if gi % 50 == 0 : + print('Processing %d, %s' % (gi, gene_id), flush=True) + + gene = transcriptome.genes[gene_id] + + # make sequence + seq_1hot = make_seq_1hot(genome_open, genes_chr[gi], genes_start[gi], genes_end[gi], seq_len) + seq_1hot = dna_io.hot1_augment(seq_1hot, shift=shift) + + # determine output sequence start + seq_out_start = genes_start[gi] + model_stride*model_crop + seq_out_len = model_stride*target_length + + # determine output positions + gene_slice = gene.output_slice(seq_out_start, seq_out_len, model_stride, options.span == 1) + + # determine ism window + gene_ism_start = genes_ism_start[gi] + gene_ism_end = genes_ism_end[gi] + + if rev_comp: + seq_1hot = dna_io.hot1_rc(seq_1hot) + gene_slice = target_length - gene_slice - 1 + + gene_ism_start = seq_len - genes_ism_end[gi] - 1 + gene_ism_end = seq_len - genes_ism_start[gi] - 1 + + # slice relevant strand targets + if genes_strand[gi] == '+': + gene_strand_mask = (targets_df.strand != '-') if not rev_comp else (targets_df.strand != '+') + else: + gene_strand_mask = (targets_df.strand != '+') if not rev_comp else (targets_df.strand != '-') + + gene_target = np.array(targets_df.index[gene_strand_mask].values) + + # broadcast to singleton batch + seq_1hot = seq_1hot[None, ...] + gene_slice = gene_slice[None, ...] + gene_target = gene_target[None, ...] + + # ism computation + ism = get_ism( + seqnn_model, + seq_1hot, + gene_ism_start, + gene_ism_end, + head_i=0, + target_slice=gene_target, + pos_slice=gene_slice, + track_scale=track_scale, + track_transform=track_transform, + clip_soft=options.clip_soft, + pseudo_count=pseudo_count, + no_transform=options.no_transform == 1, + aggregate_tracks=options.aggregate_tracks, + use_mean=False, + use_ratio=False, + use_logodds=False, + ) + + # undo augmentations and save ism + ism = unaugment_grads(ism, fwdrc=(not rev_comp), shift=shift) + + # write to HDF5 + scores_h5['isms'][gi] += ism[genes_ism_start[gi]:genes_ism_end[gi], ...] + + # collect garbage + gc.collect() + + # save sequences and normalize isms by total size of ensemble + for gi, gene_id in enumerate(gene_list): + + # re-make original sequence + seq_1hot = make_seq_1hot(genome_open, genes_chr[gi], genes_start[gi], genes_end[gi], seq_len) + + # write to HDF5 + scores_h5['seqs'][gi] = seq_1hot[genes_ism_start[gi]:genes_ism_end[gi], ...] + scores_h5['isms'][gi] /= float((len(options.shifts) * (2 if options.rc == 1 else 1))) + + # collect garbage + gc.collect() + + # close files + genome_open.close() + scores_h5.close() + + +def unaugment_grads(grads, fwdrc=False, shift=0): + """ Undo sequence augmentation.""" + # reverse complement + if not fwdrc: + # reverse + grads = grads[::-1, :, :] + + # swap A and T + grads[:, [0, 3], :] = grads[:, [3, 0], :] + + # swap C and G + grads[:, [1, 2], :] = grads[:, [2, 1], :] + + # undo shift + if shift < 0: + # shift sequence right + grads[-shift:, :, :] = grads[:shift, :, :] + + # fill in left unknowns + grads[:-shift, :, :] = 0 + + elif shift > 0: + # shift sequence left + grads[:-shift, :, :] = grads[shift:, :, :] + + # fill in right unknowns + grads[-shift:, :, :] = 0 + + return grads + + +def make_seq_1hot(genome_open, chrm, start, end, seq_len): + if start < 0: + seq_dna = 'N'*(-start) + genome_open.fetch(chrm, 0, end) + else: + seq_dna = genome_open.fetch(chrm, start, end) + + # extend to full length + if len(seq_dna) < seq_len: + seq_dna += 'N'*(seq_len-len(seq_dna)) + + seq_1hot = dna_io.dna_1hot(seq_dna) + return seq_1hot + +################################################################################ +# __main__ +# ############################################################################### +if __name__ == '__main__': + main() diff --git a/bin/borzoi_satg_polya_gpu.py b/bin/borzoi_satg_polya_gpu.py new file mode 100644 index 00000000..a306a2c4 --- /dev/null +++ b/bin/borzoi_satg_polya_gpu.py @@ -0,0 +1,487 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function + +from optparse import OptionParser + +import gc +import json +import os +import pdb +import pickle +from queue import Queue +import random +import sys +from threading import Thread +import time + +import h5py +import numpy as np +import pandas as pd +import pysam +import tensorflow as tf + +from basenji import dna_io +from basenji import gene as bgene +from basenji import seqnn +from borzoi_sed import targets_prep_strand + +''' +borzoi_satg_polya_gpu.py + +Perform a gradient saliency analysis for genes specified in a GTF file (polyadenylation-centric). +''' + +################################################################################ +# main +# ############################################################################### +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('--fa', dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]') + parser.add_option('-o', dest='out_dir', + default='satg_out', help='Output directory [Default: %default]') + parser.add_option('--rc', dest='rc', + default=0, type='int', + help='Ensemble forward and reverse complement predictions [Default: %default]') + parser.add_option('-f', dest='folds', + default='0', type='str', + help='Model folds to use in ensemble [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--span', dest='span', + default=0, type='int', + help='Aggregate entire gene span [Default: %default]') + parser.add_option('--smoothgrad', dest='smooth_grad', + default=0, type='int', + help='Run smoothgrad [Default: %default]') + parser.add_option('--samples', dest='n_samples', + default=5, type='int', + help='Number of smoothgrad samples [Default: %default]') + parser.add_option('--sampleprob', dest='sample_prob', + default=0.875, type='float', + help='Probability of not mutating a position in smoothgrad [Default: %default]') + parser.add_option('--clip_soft', dest='clip_soft', + default=None, type='float', + help='Model clip_soft setting [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + parser.add_option('-a', dest='apa_file', + default='%s/genes/polyadb/polyadb_human_v3.csv.gz' % os.environ['HG38'], + help='Polyadenylation site annotation [Default: %default]') + (options, args) = parser.parse_args() + + if len(args) == 3: + # single worker + params_file = args[0] + model_folder = args[1] + genes_gtf_file = args[2] + else: + parser.error('Must provide parameter file, model folder and GTF file') + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.folds = [int(fold) for fold in options.folds.split(',')] + options.shifts = [int(shift) for shift in options.shifts.split(',')] + + ################################################################# + # read parameters and targets + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params['model'] + params_train = params['train'] + seq_len = params_model['seq_length'] + + if options.targets_file is None: + parser.error('Must provide targets table to properly handle strands.') + else: + targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0) + + # prep strand + orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) + targets_strand_pair = np.array([orig_new_index[ti] for ti in targets_df.strand_pair]) + targets_strand_df = targets_prep_strand(targets_df) + num_targets = 1 + + ################################################################# + # load first model fold to get parameters + + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_folder + "/f0c0/model0_best.h5", 0, by_name=False) + seqnn_model.build_slice(targets_df.index, False) + # seqnn_model.build_ensemble(options.rc, options.shifts) + + model_stride = seqnn_model.model_strides[0] + model_crop = seqnn_model.target_crops[0] + target_length = seqnn_model.target_lengths[0] + + ################################################################# + # read genes + + # parse GTF + transcriptome = bgene.Transcriptome(genes_gtf_file) + + # order valid genes + genome_open = pysam.Fastafile(options.genome_fasta) + gene_list = sorted(transcriptome.genes.keys()) + num_genes = len(gene_list) + + ################################################################# + # setup output + + min_start = -model_stride*model_crop + + # choose gene sequences + genes_chr = [] + genes_start = [] + genes_end = [] + genes_strand = [] + for gene_id in gene_list: + gene = transcriptome.genes[gene_id] + genes_chr.append(gene.chrom) + genes_strand.append(gene.strand) + + gene_midpoint = gene.midpoint() + gene_start = max(min_start, gene_midpoint - seq_len//2) + gene_end = gene_start + seq_len + genes_start.append(gene_start) + genes_end.append(gene_end) + + ####################################################### + # make APA BED (PolyADB) + + apa_df = pd.read_csv(options.apa_file, sep='\t', compression='gzip') + + # filter for 3' UTR polyA sites only + apa_df = apa_df.query("site_type == '3\\' most exon'").copy().reset_index(drop=True) + + #Remove non-contiguos sites, starting from distal-most site + + apa_df = apa_df.sort_values(by=['site_num'], ascending=False).copy().reset_index(drop=True) + + gene_dict = {} + keep_index = [] + for i, [_, row] in enumerate(apa_df.iterrows()) : + + if row['gene'] not in gene_dict : + gene_dict[row['gene']] = row['site_num'] + keep_index.append(i) + else : + if row['site_num'] == gene_dict[row['gene']] - 1 : + gene_dict[row['gene']] = row['site_num'] + keep_index.append(i) + + apa_df = apa_df.iloc[keep_index].copy().reset_index(drop=True) + + apa_df = apa_df.sort_values(by=['gene', 'site_num'], ascending=True).copy().reset_index(drop=True) + + apa_df['start_hg38'] = apa_df['position_hg38'] + apa_df['end_hg38'] = apa_df['position_hg38'] + 1 + + apa_df = apa_df.rename(columns={'chrom' : 'Chromosome', 'start_hg38' : 'Start', 'end_hg38' : 'End', 'position_hg38' : 'cut_mode', 'strand' : 'pas_strand'}) + + apa_df = apa_df[['Chromosome', 'Start', 'End', 'pas_id', 'pas_strand', 'gene', 'site_num']] + + ################################################################# + # predict scores, write output + + buffer_size = 1024 + pas_ext = 50 + + print("clip_soft = " + str(options.clip_soft)) + + print("n genes = " + str(len(genes_chr))) + + # loop over folds + for fold_ix in options.folds : + print("-- Fold = " + str(fold_ix) + " --") + + # (re-)initialize HDF5 + scores_h5_file = '%s/scores_f%dc0.h5' % (options.out_dir, fold_ix) + if os.path.isfile(scores_h5_file): + os.remove(scores_h5_file) + scores_h5 = h5py.File(scores_h5_file, 'w') + scores_h5.create_dataset('seqs', dtype='bool', + shape=(num_genes, seq_len, 4)) + scores_h5.create_dataset('grads', dtype='float16', + shape=(num_genes, seq_len, 4, num_targets)) + scores_h5.create_dataset('gene', data=np.array(gene_list, dtype='S')) + scores_h5.create_dataset('chr', data=np.array(genes_chr, dtype='S')) + scores_h5.create_dataset('start', data=np.array(genes_start)) + scores_h5.create_dataset('end', data=np.array(genes_end)) + scores_h5.create_dataset('strand', data=np.array(genes_strand, dtype='S')) + + # load model fold + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_folder + "/f" + str(fold_ix) + "c0/model0_best.h5", 0, by_name=False) + seqnn_model.build_slice(targets_df.index, False) + + track_scale = targets_df.iloc[0]['scale'] + track_transform = 3. / 4. + + for shift in options.shifts : + print('Processing shift %d' % shift, flush=True) + + for rev_comp in ([False, True] if options.rc == 1 else [False]) : + + if options.rc == 1 : + print('Fwd/rev = %s' % ('fwd' if not rev_comp else 'rev'), flush=True) + + seq_1hots = [] + gene_slices = [] + gene_slices_denom = [] + gene_targets = [] + + for gi, gene_id in enumerate(gene_list): + + if gi % 500 == 0 : + print('Processing %d, %s' % (gi, gene_id), flush=True) + + gene = transcriptome.genes[gene_id] + + # make sequence + seq_1hot = make_seq_1hot(genome_open, genes_chr[gi], genes_start[gi], genes_end[gi], seq_len) + seq_1hot = dna_io.hot1_augment(seq_1hot, shift=shift) + + # get apa dataframe + gene_apa_df = apa_df.query("Chromosome == '" + genes_chr[gi] + "' and ((End > " + str(genes_start[gi]-pas_ext) + " and End <= " + str(genes_end[gi]+pas_ext) + ") or (Start < " + str(genes_end[gi]+pas_ext) + " and Start >= " + str(genes_start[gi]-pas_ext) + ")) and pas_strand == '" + str(genes_strand[gi]) + "'").sort_values(by=['gene', 'site_num'], ascending=True) + + gene_slice = None + gene_slice_denom = None + + if len(gene_apa_df) > 0 : + # get distal-most PAS position + pas_start = gene_apa_df.iloc[-1]['Start'] + pas_end = gene_apa_df.iloc[-1]['End'] + pas_strand = gene_apa_df.iloc[-1]['pas_strand'] + + # determine output sequence start + seq_out_start = genes_start[gi] + model_stride*model_crop + + # get relative pas positions + pas_seq_start = max(0, pas_start - seq_out_start) + pas_seq_end = max(0, pas_end - seq_out_start) + + # determine output positions + + # upstream coverage (before PAS) + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_end = int(np.round(pas_seq_start / model_stride)) + 1 + bin_start = bin_end - 3 - 1 + else : + bin_start = int(np.round(pas_seq_end / model_stride)) + bin_end = bin_start + 3 + 1 + + # clip right boundaries + bin_max = int((seq_len - 2.*model_stride*model_crop)/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + gene_slice = np.arange(bin_start, bin_end) + + # downstream coverage (after PAS) + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_start = int(np.round(pas_seq_end / model_stride)) + 1 + bin_end = bin_start + 3 + 1 + 1 + else : + bin_end = int(np.round(pas_seq_start / model_stride)) + 1 - 1 + bin_start = bin_end - 3 - 1 - 1 + + # clip right boundaries + bin_max = int((seq_len - 2.*model_stride*model_crop)/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + gene_slice_denom = np.arange(bin_start, bin_end) + + else : + gene_slice = np.array([0]) + gene_slice_denom = np.array([0]) + + if gene_slice.shape[0] == 0 or gene_slice_denom.shape[0] == 0 : + gene_slice = np.array([0]) + gene_slice_denom = np.array([0]) + + if rev_comp: + seq_1hot = dna_io.hot1_rc(seq_1hot) + gene_slice = target_length - gene_slice - 1 + gene_slice_denom = target_length - gene_slice_denom - 1 + + # slice relevant strand targets + if genes_strand[gi] == '+': + gene_strand_mask = (targets_df.strand != '-') if not rev_comp else (targets_df.strand != '+') + else: + gene_strand_mask = (targets_df.strand != '+') if not rev_comp else (targets_df.strand != '-') + + gene_target = np.array(targets_df.index[gene_strand_mask].values) + + # accumulate data tensors + seq_1hots.append(seq_1hot[None, ...]) + gene_slices.append(gene_slice[None, ...]) + gene_slices_denom.append(gene_slice_denom[None, ...]) + gene_targets.append(gene_target[None, ...]) + + if gi == len(gene_list) - 1 or len(seq_1hots) >= buffer_size : + + # concat sequences + seq_1hots = np.concatenate(seq_1hots, axis=0) + + # pad gene slices to same length (mark valid positions in mask tensor) + max_slice_len = int(np.max([gene_slice.shape[1] for gene_slice in gene_slices])) + max_slice_denom_len = int(np.max([gene_slice_denom.shape[1] for gene_slice_denom in gene_slices_denom])) + + gene_masks = np.zeros((len(gene_slices), max_slice_len), dtype='float32') + gene_slices_padded = np.zeros((len(gene_slices), max_slice_len), dtype='int32') + for gii, gene_slice in enumerate(gene_slices) : + for j in range(gene_slice.shape[1]) : + gene_masks[gii, j] = 1. + gene_slices_padded[gii, j] = gene_slice[0, j] + + gene_slices = gene_slices_padded + + gene_masks_denom = np.zeros((len(gene_slices_denom), max_slice_denom_len), dtype='float32') + gene_slices_denom_padded = np.zeros((len(gene_slices_denom), max_slice_denom_len), dtype='int32') + for gii, gene_slice_denom in enumerate(gene_slices_denom) : + for j in range(gene_slice_denom.shape[1]) : + gene_masks_denom[gii, j] = 1. + gene_slices_denom_padded[gii, j] = gene_slice_denom[0, j] + + gene_slices_denom = gene_slices_denom_padded + + # concat gene-specific targets + gene_targets = np.concatenate(gene_targets, axis=0) + + # batch call gradient computation + grads = seqnn_model.gradients( + seq_1hots, + head_i=0, + target_slice=gene_targets, + pos_slice=gene_slices, + pos_mask=gene_masks, + pos_slice_denom=gene_slices_denom, + pos_mask_denom=gene_masks_denom, + chunk_size=buffer_size if options.smooth_grad != 1 else buffer_size // options.n_samples, + batch_size=1, + track_scale=track_scale, + track_transform=track_transform, + clip_soft=options.clip_soft, + use_mean=True, + use_ratio=True, + use_logodds=False, + subtract_avg=True, + input_gate=False, + smooth_grad=options.smooth_grad == 1, + n_samples=options.n_samples, + sample_prob=options.sample_prob, + dtype='float16' + ) + + # undo augmentations and save gradients + for gii, gene_slice in enumerate(gene_slices) : + grad = unaugment_grads(grads[gii, :, :, None], fwdrc=(not rev_comp), shift=shift) + + h5_gi = (gi // buffer_size) * buffer_size + gii + + # write to HDF5 + scores_h5['grads'][h5_gi] += grad + + #clear sequence buffer + seq_1hots = [] + gene_slices = [] + gene_slices_denom = [] + gene_targets = [] + + # collect garbage + gc.collect() + + # save sequences and normalize gradients by total size of ensemble + for gi, gene_id in enumerate(gene_list): + + # re-make original sequence + seq_1hot = make_seq_1hot(genome_open, genes_chr[gi], genes_start[gi], genes_end[gi], seq_len) + + # write to HDF5 + scores_h5['seqs'][gi] = seq_1hot + scores_h5['grads'][gi] /= float((len(options.shifts) * (2 if options.rc == 1 else 1))) + + # collect garbage + gc.collect() + + # close files + genome_open.close() + scores_h5.close() + + +def unaugment_grads(grads, fwdrc=False, shift=0): + """ Undo sequence augmentation.""" + # reverse complement + if not fwdrc: + # reverse + grads = grads[::-1, :, :] + + # swap A and T + grads[:, [0, 3], :] = grads[:, [3, 0], :] + + # swap C and G + grads[:, [1, 2], :] = grads[:, [2, 1], :] + + # undo shift + if shift < 0: + # shift sequence right + grads[-shift:, :, :] = grads[:shift, :, :] + + # fill in left unknowns + grads[:-shift, :, :] = 0 + + elif shift > 0: + # shift sequence left + grads[:-shift, :, :] = grads[shift:, :, :] + + # fill in right unknowns + grads[-shift:, :, :] = 0 + + return grads + + +def make_seq_1hot(genome_open, chrm, start, end, seq_len): + if start < 0: + seq_dna = 'N'*(-start) + genome_open.fetch(chrm, 0, end) + else: + seq_dna = genome_open.fetch(chrm, start, end) + + # extend to full length + if len(seq_dna) < seq_len: + seq_dna += 'N'*(seq_len-len(seq_dna)) + + seq_1hot = dna_io.dna_1hot(seq_dna) + return seq_1hot + +################################################################################ +# __main__ +# ############################################################################### +if __name__ == '__main__': + main() diff --git a/bin/borzoi_satg_splice_gpu.py b/bin/borzoi_satg_splice_gpu.py new file mode 100644 index 00000000..52e366ed --- /dev/null +++ b/bin/borzoi_satg_splice_gpu.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function + +from optparse import OptionParser + +import gc +import json +import os +import pdb +import pickle +from queue import Queue +import random +import sys +from threading import Thread +import time + +import h5py +import numpy as np +import pandas as pd +import pysam +import tensorflow as tf + +from basenji import dna_io +from basenji import gene as bgene +from basenji import seqnn +from borzoi_sed import targets_prep_strand + +''' +borzoi_satg_splice_gpu.py + +Perform a gradient saliency analysis for genes specified in a GTF file (splice-centric). +''' + +################################################################################ +# main +# ############################################################################### +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('--fa', dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]') + parser.add_option('-o', dest='out_dir', + default='satg_out', help='Output directory [Default: %default]') + parser.add_option('--rc', dest='rc', + default=0, type='int', + help='Ensemble forward and reverse complement predictions [Default: %default]') + parser.add_option('-f', dest='folds', + default='0', type='str', + help='Model folds to use in ensemble [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--span', dest='span', + default=0, type='int', + help='Aggregate entire gene span [Default: %default]') + parser.add_option('--smoothgrad', dest='smooth_grad', + default=0, type='int', + help='Run smoothgrad [Default: %default]') + parser.add_option('--samples', dest='n_samples', + default=5, type='int', + help='Number of smoothgrad samples [Default: %default]') + parser.add_option('--sampleprob', dest='sample_prob', + default=0.875, type='float', + help='Probability of not mutating a position in smoothgrad [Default: %default]') + parser.add_option('--clip_soft', dest='clip_soft', + default=None, type='float', + help='Model clip_soft setting [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + parser.add_option('-s', dest='splice_gff', + default='%s/genes/gencode41/gencode41_basic_protein_splice.gff' % os.environ['HG38'], + help='Splice site annotation [Default: %default]') + (options, args) = parser.parse_args() + + if len(args) == 3: + # single worker + params_file = args[0] + model_folder = args[1] + genes_gtf_file = args[2] + else: + parser.error('Must provide parameter file, model folder and GTF file') + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.folds = [int(fold) for fold in options.folds.split(',')] + options.shifts = [int(shift) for shift in options.shifts.split(',')] + + ################################################################# + # read parameters and targets + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params['model'] + params_train = params['train'] + seq_len = params_model['seq_length'] + + if options.targets_file is None: + parser.error('Must provide targets table to properly handle strands.') + else: + targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0) + + # prep strand + orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) + targets_strand_pair = np.array([orig_new_index[ti] for ti in targets_df.strand_pair]) + targets_strand_df = targets_prep_strand(targets_df) + num_targets = 1 + + ################################################################# + # load first model fold to get parameters + + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_folder + "/f0c0/model0_best.h5", 0, by_name=False) + seqnn_model.build_slice(targets_df.index, False) + # seqnn_model.build_ensemble(options.rc, options.shifts) + + model_stride = seqnn_model.model_strides[0] + model_crop = seqnn_model.target_crops[0] + target_length = seqnn_model.target_lengths[0] + + ################################################################# + # read genes + + # parse GTF + transcriptome = bgene.Transcriptome(genes_gtf_file) + + # order valid genes + genome_open = pysam.Fastafile(options.genome_fasta) + gene_list = sorted(transcriptome.genes.keys()) + num_genes = len(gene_list) + + ################################################################# + # setup output + + min_start = -model_stride*model_crop + + # choose gene sequences + genes_chr = [] + genes_start = [] + genes_end = [] + genes_strand = [] + for gene_id in gene_list: + gene = transcriptome.genes[gene_id] + genes_chr.append(gene.chrom) + genes_strand.append(gene.strand) + + gene_midpoint = gene.midpoint() + gene_start = max(min_start, gene_midpoint - seq_len//2) + gene_end = gene_start + seq_len + genes_start.append(gene_start) + genes_end.append(gene_end) + + ####################################################### + # load splice site annotation + + splice_df = pd.read_csv(options.splice_gff, sep='\t', names=['Chromosome', 'havana_str', 'feature', 'Start', 'End', 'feat1', 'Strand', 'feat2', 'id_str'], usecols=['Chromosome', 'Start', 'End', 'feature', 'feat1', 'Strand'])[['Chromosome', 'Start', 'End', 'feature', 'feat1', 'Strand']].drop_duplicates(subset=['Chromosome', 'Start', 'Strand'], keep='first').copy().reset_index(drop=True) + + ################################################################# + # predict scores, write output + + buffer_size = 1024 + + print("clip_soft = " + str(options.clip_soft)) + + print("n genes = " + str(len(genes_chr))) + + # loop over folds + for fold_ix in options.folds : + print("-- Fold = " + str(fold_ix) + " --") + + # (re-)initialize HDF5 + scores_h5_file = '%s/scores_f%dc0.h5' % (options.out_dir, fold_ix) + if os.path.isfile(scores_h5_file): + os.remove(scores_h5_file) + scores_h5 = h5py.File(scores_h5_file, 'w') + scores_h5.create_dataset('seqs', dtype='bool', + shape=(num_genes, seq_len, 4)) + scores_h5.create_dataset('grads', dtype='float16', + shape=(num_genes, seq_len, 4, num_targets)) + scores_h5.create_dataset('gene', data=np.array(gene_list, dtype='S')) + scores_h5.create_dataset('chr', data=np.array(genes_chr, dtype='S')) + scores_h5.create_dataset('start', data=np.array(genes_start)) + scores_h5.create_dataset('end', data=np.array(genes_end)) + scores_h5.create_dataset('strand', data=np.array(genes_strand, dtype='S')) + + # load model fold + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_folder + "/f" + str(fold_ix) + "c0/model0_best.h5", 0, by_name=False) + seqnn_model.build_slice(targets_df.index, False) + + track_scale = targets_df.iloc[0]['scale'] + track_transform = 3. / 4. + + for shift in options.shifts : + print('Processing shift %d' % shift, flush=True) + + for rev_comp in ([False, True] if options.rc == 1 else [False]) : + + if options.rc == 1 : + print('Fwd/rev = %s' % ('fwd' if not rev_comp else 'rev'), flush=True) + + seq_1hots = [] + gene_slices = [] + gene_slices_denom = [] + gene_targets = [] + + for gi, gene_id in enumerate(gene_list): + + if gi % 500 == 0 : + print('Processing %d, %s' % (gi, gene_id), flush=True) + + gene = transcriptome.genes[gene_id] + + # make sequence + seq_1hot = make_seq_1hot(genome_open, genes_chr[gi], genes_start[gi], genes_end[gi], seq_len) + seq_1hot = dna_io.hot1_augment(seq_1hot, shift=shift) + + # get splice dataframe + gene_splice_df = splice_df.query("Chromosome == '" + genes_chr[gi] + "' and ((End > " + str(genes_start[gi]) + " and End <= " + str(genes_end[gi]) + ") or (Start < " + str(genes_end[gi]) + " and Start >= " + str(genes_start[gi]) + ")) and Strand == '" + str(genes_strand[gi]) + "'").sort_values(by=['Chromosome', 'Start'], ascending=True) + + gene_slice = None + gene_slice_denom = None + + if len(gene_splice_df) > 0 : + + # get random splice junction (donor or acceptor) + rand_ix = np.random.randint(len(gene_splice_df)) + + # get splice junction position + splice_start = gene_splice_df.iloc[rand_ix]['Start'] + splice_end = gene_splice_df.iloc[rand_ix]['End'] + splice_strand = gene_splice_df.iloc[rand_ix]['Strand'] + donor_or_acceptor = gene_splice_df.iloc[rand_ix]['feature'] + + # determine output sequence start + seq_out_start = genes_start[gi] + model_stride*model_crop + + # get relative splice positions + splice_seq_start = max(0, splice_start - seq_out_start) + splice_seq_end = max(0, splice_end - seq_out_start) + + # determine output positions + + if donor_or_acceptor == 'donor' : + + # upstream coverage (before donor) + bin_start = None + bin_end = None + if splice_strand == '+' : + bin_end = int(np.round(splice_seq_start / model_stride)) + 1 + bin_start = bin_end - 3 + else : + bin_start = int(np.round(splice_seq_end / model_stride)) + bin_end = bin_start + 3 + + # clip right boundaries + bin_max = int((seq_len - 2.*model_stride*model_crop)/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + gene_slice = np.arange(bin_start, bin_end) + + # downstream coverage (after donor) + bin_start = None + bin_end = None + if splice_strand == '+' : + bin_start = int(np.round(splice_seq_end / model_stride)) + 1 + bin_end = bin_start + 3 + else : + bin_end = int(np.round(splice_seq_start / model_stride)) + bin_start = bin_end - 3 + + # clip right boundaries + bin_max = int((seq_len - 2.*model_stride*model_crop)/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + gene_slice_denom = np.arange(bin_start, bin_end) + + elif donor_or_acceptor == 'acceptor' : + + # downstream coverage (after acceptor) + bin_start = None + bin_end = None + if splice_strand == '+' : + bin_start = int(np.round(splice_seq_end / model_stride)) + bin_end = bin_start + 3 + else : + bin_end = int(np.round(splice_seq_start / model_stride)) + 1 + bin_start = bin_end - 3 + + # clip right boundaries + bin_max = int((seq_len - 2.*model_stride*model_crop)/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + gene_slice = np.arange(bin_start, bin_end) + + # upstream coverage (before acceptor) + bin_start = None + bin_end = None + if splice_strand == '+' : + bin_end = int(np.round(splice_seq_start / model_stride)) + bin_start = bin_end - 3 + else : + bin_start = int(np.round(splice_seq_end / model_stride)) + 1 + bin_end = bin_start + 3 + + # clip right boundaries + bin_max = int((seq_len - 2.*model_stride*model_crop)/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + gene_slice_denom = np.arange(bin_start, bin_end) + + else : + gene_slice = np.array([0]) + gene_slice_denom = np.array([0]) + + if gene_slice.shape[0] == 0 or gene_slice_denom.shape[0] == 0 : + gene_slice = np.array([0]) + gene_slice_denom = np.array([0]) + + if rev_comp: + seq_1hot = dna_io.hot1_rc(seq_1hot) + gene_slice = target_length - gene_slice - 1 + gene_slice_denom = target_length - gene_slice_denom - 1 + + # slice relevant strand targets + if genes_strand[gi] == '+': + gene_strand_mask = (targets_df.strand != '-') if not rev_comp else (targets_df.strand != '+') + else: + gene_strand_mask = (targets_df.strand != '+') if not rev_comp else (targets_df.strand != '-') + + gene_target = np.array(targets_df.index[gene_strand_mask].values) + + # accumulate data tensors + seq_1hots.append(seq_1hot[None, ...]) + gene_slices.append(gene_slice[None, ...]) + gene_slices_denom.append(gene_slice_denom[None, ...]) + gene_targets.append(gene_target[None, ...]) + + if gi == len(gene_list) - 1 or len(seq_1hots) >= buffer_size : + + # concat sequences + seq_1hots = np.concatenate(seq_1hots, axis=0) + + # pad gene slices to same length (mark valid positions in mask tensor) + max_slice_len = int(np.max([gene_slice.shape[1] for gene_slice in gene_slices])) + max_slice_denom_len = int(np.max([gene_slice_denom.shape[1] for gene_slice_denom in gene_slices_denom])) + + gene_masks = np.zeros((len(gene_slices), max_slice_len), dtype='float32') + gene_slices_padded = np.zeros((len(gene_slices), max_slice_len), dtype='int32') + for gii, gene_slice in enumerate(gene_slices) : + for j in range(gene_slice.shape[1]) : + gene_masks[gii, j] = 1. + gene_slices_padded[gii, j] = gene_slice[0, j] + + gene_slices = gene_slices_padded + + gene_masks_denom = np.zeros((len(gene_slices_denom), max_slice_denom_len), dtype='float32') + gene_slices_denom_padded = np.zeros((len(gene_slices_denom), max_slice_denom_len), dtype='int32') + for gii, gene_slice_denom in enumerate(gene_slices_denom) : + for j in range(gene_slice_denom.shape[1]) : + gene_masks_denom[gii, j] = 1. + gene_slices_denom_padded[gii, j] = gene_slice_denom[0, j] + + gene_slices_denom = gene_slices_denom_padded + + # concat gene-specific targets + gene_targets = np.concatenate(gene_targets, axis=0) + + # batch call gradient computation + grads = seqnn_model.gradients( + seq_1hots, + head_i=0, + target_slice=gene_targets, + pos_slice=gene_slices, + pos_mask=gene_masks, + pos_slice_denom=gene_slices_denom, + pos_mask_denom=gene_masks_denom, + chunk_size=buffer_size if options.smooth_grad != 1 else buffer_size // options.n_samples, + batch_size=1, + track_scale=track_scale, + track_transform=track_transform, + clip_soft=options.clip_soft, + use_mean=True, + use_ratio=True, + use_logodds=False, + subtract_avg=True, + input_gate=False, + smooth_grad=options.smooth_grad == 1, + n_samples=options.n_samples, + sample_prob=options.sample_prob, + dtype='float16' + ) + + # undo augmentations and save gradients + for gii, gene_slice in enumerate(gene_slices) : + grad = unaugment_grads(grads[gii, :, :, None], fwdrc=(not rev_comp), shift=shift) + + h5_gi = (gi // buffer_size) * buffer_size + gii + + # write to HDF5 + scores_h5['grads'][h5_gi] += grad + + #clear sequence buffer + seq_1hots = [] + gene_slices = [] + gene_slices_denom = [] + gene_targets = [] + + # collect garbage + gc.collect() + + # save sequences and normalize gradients by total size of ensemble + for gi, gene_id in enumerate(gene_list): + + # re-make original sequence + seq_1hot = make_seq_1hot(genome_open, genes_chr[gi], genes_start[gi], genes_end[gi], seq_len) + + # write to HDF5 + scores_h5['seqs'][gi] = seq_1hot + scores_h5['grads'][gi] /= float((len(options.shifts) * (2 if options.rc == 1 else 1))) + + # collect garbage + gc.collect() + + # close files + genome_open.close() + scores_h5.close() + + +def unaugment_grads(grads, fwdrc=False, shift=0): + """ Undo sequence augmentation.""" + # reverse complement + if not fwdrc: + # reverse + grads = grads[::-1, :, :] + + # swap A and T + grads[:, [0, 3], :] = grads[:, [3, 0], :] + + # swap C and G + grads[:, [1, 2], :] = grads[:, [2, 1], :] + + # undo shift + if shift < 0: + # shift sequence right + grads[-shift:, :, :] = grads[:shift, :, :] + + # fill in left unknowns + grads[:-shift, :, :] = 0 + + elif shift > 0: + # shift sequence left + grads[:-shift, :, :] = grads[shift:, :, :] + + # fill in right unknowns + grads[-shift:, :, :] = 0 + + return grads + + +def make_seq_1hot(genome_open, chrm, start, end, seq_len): + if start < 0: + seq_dna = 'N'*(-start) + genome_open.fetch(chrm, 0, end) + else: + seq_dna = genome_open.fetch(chrm, start, end) + + # extend to full length + if len(seq_dna) < seq_len: + seq_dna += 'N'*(seq_len-len(seq_dna)) + + seq_1hot = dna_io.dna_1hot(seq_dna) + return seq_1hot + +################################################################################ +# __main__ +# ############################################################################### +if __name__ == '__main__': + main() diff --git a/bin/borzoi_sed_ipaqtl_cov.py b/bin/borzoi_sed_ipaqtl_cov.py new file mode 100644 index 00000000..8ec4dd7e --- /dev/null +++ b/bin/borzoi_sed_ipaqtl_cov.py @@ -0,0 +1,773 @@ +#!/usr/bin/env python +# Copyright 2022 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function + +from optparse import OptionParser +from collections import OrderedDict +import json +import pickle +import os +import pdb +import sys +import time + +import h5py +import numpy as np +import pandas as pd +import pybedtools +import pysam +from scipy.special import rel_entr +import tensorflow as tf + +from basenji import dna_io +from basenji import gene as bgene +from basenji import seqnn +from basenji import stream +from basenji import vcf as bvcf + +''' +borzoi_sed_ipaqtl_cov.py + +Compute SNP COVerage Ratio (COVR) scores for SNPs in a VCF file, +relative to intronic polyadenylation sites in an annotation file. +''' + +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('-f', dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]') + parser.add_option('-g', dest='genes_gtf', + default='%s/genes/gencode41/gencode41_basic_nort.gtf' % os.environ['HG38'], + help='GTF for gene definition [Default %default]') + parser.add_option('--apafile', dest='apa_file', + default='polyadb_human_v3.csv.gz') + parser.add_option('-o',dest='out_dir', + default='sed', + help='Output directory for tables and plots [Default: %default]') + parser.add_option('-p', dest='processes', + default=None, type='int', + help='Number of processes, passed by multi script') + parser.add_option('--pseudo', dest='cov_pseudo', + default=50, type='float', + help='Coverage pseudocount [Default: %default]') + parser.add_option('--cov', dest='cov_min', + default=100, type='float', + help='Min coverage [Default: %default]') + parser.add_option('--paext', dest='pas_ext', + default=50, type='float', + help='Extension in bp past gene span annotation [Default: %default]') + parser.add_option('--rc', dest='rc', + default=True, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--stats', dest='sed_stats', + default='COVR,SCOVR', + help='Comma-separated list of stats to save. [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + (options, args) = parser.parse_args() + + if len(args) == 3: + # single worker + params_file = args[0] + model_file = args[1] + vcf_file = args[2] + + elif len(args) == 4: + # multi separate + options_pkl_file = args[0] + params_file = args[1] + model_file = args[2] + vcf_file = args[3] + + # save out dir + out_dir = options.out_dir + + # load options + options_pkl = open(options_pkl_file, 'rb') + options = pickle.load(options_pkl) + options_pkl.close() + + # update output directory + options.out_dir = out_dir + + elif len(args) == 5: + # multi worker + options_pkl_file = args[0] + params_file = args[1] + model_file = args[2] + vcf_file = args[3] + worker_index = int(args[4]) + + # load options + options_pkl = open(options_pkl_file, 'rb') + options = pickle.load(options_pkl) + options_pkl.close() + + # update output directory + options.out_dir = '%s/job%d' % (options.out_dir, worker_index) + + else: + parser.error('Must provide parameters/model, VCF, and genes GTF') + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.shifts = [int(shift) for shift in options.shifts.split(',')] + options.sed_stats = options.sed_stats.split(',') + + ################################################################# + # read parameters and targets + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params['model'] + params_train = params['train'] + seq_len = params_model['seq_length'] + + if options.targets_file is None: + parser.error('Must provide targets table to properly handle strands.') + else: + targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0) + + # prep strand + targets_strand_df = targets_prep_strand(targets_df) + + ################################################################# + # setup model + + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_file) + seqnn_model.build_slice(targets_df.index) + seqnn_model.build_ensemble(options.rc, options.shifts) + + model_stride = seqnn_model.model_strides[0] + out_seq_len = seqnn_model.target_lengths[0]*model_stride + + ################################################################# + # read SNPs / genes + + # filter for worker SNPs + if options.processes is not None: + # determine boundaries + num_snps = bvcf.vcf_count(vcf_file) + worker_bounds = np.linspace(0, num_snps, options.processes+1, dtype='int') + + # read SNPs form VCF + snps = bvcf.vcf_snps(vcf_file, start_i=worker_bounds[worker_index], + end_i=worker_bounds[worker_index+1]) + + else: + # read SNPs form VCF + snps = bvcf.vcf_snps(vcf_file) + + # read genes + transcriptome = bgene.Transcriptome(options.genes_gtf) + gene_strand = {} + for gene_id, gene in transcriptome.genes.items(): + gene_strand[gene_id] = gene.strand + + ####################################################### + # make APA BED (PolyADB) + + apa_df = pd.read_csv(options.apa_file, sep='\t', compression='gzip') + + # filter for intronic or 3' UTR polyA sites only + apa_df = apa_df.query("site_type == '3\\' most exon' or site_type == 'Intron'").copy().reset_index(drop=True) + apa_df = apa_df.sort_values(by=['gene', 'site_num'], ascending=True).copy().reset_index(drop=True) + + print("n intron sites = " + str(len(apa_df.query("site_type == 'Intron'"))), flush=True) + print("n utr3 sites = " + str(len(apa_df.query("site_type == '3\\' most exon'"))), flush=True) + + apa_df['start_hg38'] = apa_df['position_hg38'] + apa_df['end_hg38'] = apa_df['position_hg38'] + 1 + + apa_df = apa_df.rename(columns={'chrom' : 'Chromosome', 'start_hg38' : 'Start', 'end_hg38' : 'End', 'position_hg38' : 'cut_mode', 'strand' : 'pas_strand'}) + + apa_df = apa_df[['Chromosome', 'Start', 'End', 'pas_id', 'pas_strand', 'gene', 'site_num']] + + # map SNP sequences to gene / polyA signal positions + snpseq_gene_slice, snpseq_apa_slice = map_snpseq_apa(snps, out_seq_len, transcriptome, apa_df, model_stride, options.sed_stats, options.pas_ext) + + # remove SNPs w/o genes + num_snps_pre = len(snps) + snp_gene_mask = np.array([len(sgs) > 0 for sgs in snpseq_gene_slice]) + snps = [snps[si] for si in range(num_snps_pre) if snp_gene_mask[si]] + snpseq_gene_slice = [snpseq_gene_slice[si] for si in range(num_snps_pre) if snp_gene_mask[si]] + snpseq_apa_slice = [snpseq_apa_slice[si] for si in range(num_snps_pre) if snp_gene_mask[si]] + num_snps = len(snps) + + # create SNP seq generator + genome_open = pysam.Fastafile(options.genome_fasta) + + def snp_gen(): + for snp in snps: + # get SNP sequences + snp_1hot_list = bvcf.snp_seq1(snp, seq_len, genome_open) + for snp_1hot in snp_1hot_list: + yield snp_1hot + + ################################################################# + # setup output + + sed_out = initialize_output_h5(options.out_dir, options.sed_stats, snps, snpseq_gene_slice, targets_strand_df, out_name='sed') + sed_out_apa = initialize_output_h5(options.out_dir, ['REF','ALT'], snps, snpseq_apa_slice, targets_strand_df, out_name='sed_pas') + + ################################################################# + # predict SNP scores, write output + + # initialize predictions stream + preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(), params_train['batch_size']) + + # predictions index + pi = 0 + + # SNP/gene index + xi_gene = 0 + + # SNP/pas index + xi_pas = 0 + + # for each SNP sequence + for si in range(num_snps): + # get predictions + ref_preds = preds_stream[pi] + pi += 1 + alt_preds = preds_stream[pi] + pi += 1 + + # undo scale + ref_preds /= np.expand_dims(targets_df.scale, axis=0) + alt_preds /= np.expand_dims(targets_df.scale, axis=0) + + # undo sqrt + ref_preds = ref_preds**(4/3) + alt_preds = alt_preds**(4/3) + + # for each overlapping gene + for gene_id, gene_slice_dup in snpseq_gene_slice[si]['bins'].items(): + + # remove duplicate bin coordinates (artifact of PASs that are within <32bp) + gene_slice = [] + gene_slice_dict = {} + for gslpas in gene_slice_dup: + gslpas_key = str(gslpas[0]) + "_" + str(gslpas[1]) + + if gslpas_key not in gene_slice_dict : + gene_slice_dict[gslpas_key] = True + gene_slice.append(gslpas) + + # slice gene positions + ref_preds_gene = np.concatenate([np.sum(ref_preds[gene_slice_start:gene_slice_end, :], axis=0)[None, :] for [gene_slice_start, gene_slice_end] in gene_slice], axis=0) + alt_preds_gene = np.concatenate([np.sum(alt_preds[gene_slice_start:gene_slice_end, :], axis=0)[None, :] for [gene_slice_start, gene_slice_end] in gene_slice], axis=0) + + if gene_strand[gene_id] == '+': + gene_strand_mask = (targets_df.strand != '-') + else: + gene_strand_mask = (targets_df.strand != '+') + + ref_preds_gene = ref_preds_gene[...,gene_strand_mask] + alt_preds_gene = alt_preds_gene[...,gene_strand_mask] + + # write scores to HDF + write_snp(ref_preds_gene, alt_preds_gene, sed_out, xi_gene, options.sed_stats, options.cov_pseudo, options.cov_min) + + xi_gene += 1 + + # for each overlapping PAS + for pas_id, pas_slice in snpseq_apa_slice[si]['bins'].items(): + if len(pas_slice) > len(set(pas_slice)): + print('WARNING: %d %s has overlapping bins' % (si,pas_id)) + eprint('WARNING: %d %s has overlapping bins' % (si,pas_id)) + + # slice pas positions + ref_preds_pas = ref_preds[pas_slice] + alt_preds_pas = alt_preds[pas_slice] + + # slice relevant strand targets + if '+' in pas_id: + pas_strand_mask = (targets_df.strand != '-') + else: + pas_strand_mask = (targets_df.strand != '+') + + ref_preds_pas = ref_preds_pas[...,pas_strand_mask] + alt_preds_pas = alt_preds_pas[...,pas_strand_mask] + + # write scores to HDF + write_snp(ref_preds_pas, alt_preds_pas, sed_out_apa, xi_pas, ['REF','ALT'], options.cov_pseudo, options.cov_min) + + xi_pas += 1 + + # close genome + genome_open.close() + + ################################################### + # compute SAD distributions across variants + + # write_pct(sed_out, options.sed_stats) + sed_out.close() + sed_out_apa.close() + +def map_snpseq_apa(snps, seq_len, transcriptome, apa_df, model_stride, sed_stats, pas_ext): + """Intersect SNP sequences with genes and polyA sites, constructing a list + mapping sequence indexes to dictionaries of gene_ids or pas_ids.""" + + # make gene BEDtool + genes_bedt = transcriptome.bedtool_span() + + # make SNP sequence BEDtool + snpseq_bedt = make_snpseq_bedt(snps, seq_len) + + # map SNPs to genes and polyA sites + snpseq_gene_slice = [] + snpseq_apa_slice = [] + for snp in snps: + snpseq_gene_slice.append({ 'bins' : OrderedDict(), 'distances' : OrderedDict() }) + snpseq_apa_slice.append({ 'bins' : OrderedDict(), 'distances' : OrderedDict() }) + + for i1, overlap in enumerate(genes_bedt.intersect(snpseq_bedt, wo=True)): + gene_id = overlap[3] + gene_chrom = overlap[0] + gene_start = int(overlap[1]) + gene_end = int(overlap[2]) + seq_start = int(overlap[7]) + seq_end = int(overlap[8]) + si = int(overlap[9]) + + snp_pos = snps[si].pos + + # get apa dataframe + gene_apa_df = apa_df.query("Chromosome == '" + gene_chrom + "' and ((End > " + str(gene_start-pas_ext) + " and End <= " + str(gene_end+pas_ext) + ") or (Start < " + str(gene_end+pas_ext) + " and Start >= " + str(gene_start-pas_ext) + "))").sort_values(by=['gene', 'site_num'], ascending=True) + + # make sure 80% of all polyA signals are contained within the sequence input window + if len(gene_apa_df) <= 0 or np.mean((gene_apa_df['Start'] >= seq_start).values) < 0.8 or np.mean((gene_apa_df['End'] < seq_end).values) < 0.8: + continue + + # adjust for left overhang + seq_len_chop = seq_end - seq_start + seq_start -= (seq_len - seq_len_chop) + + for _, apa_row in gene_apa_df.iterrows(): + pas_id = apa_row['pas_id'] + pas_start = apa_row['Start'] + pas_end = apa_row['End'] + pas_strand = apa_row['pas_strand'] + + pas_distance = int(np.abs(pas_start - snp_pos)) + + if 'PROP3' in sed_stats and pas_id in snpseq_apa_slice[si]['bins']: + continue + elif pas_id + '_up' in snpseq_apa_slice[si]['bins'] or pas_id + '_dn' in snpseq_apa_slice[si]['bins']: + continue + + # clip left boundaries + pas_seq_start = max(0, pas_start - seq_start) + pas_seq_end = max(0, pas_end - seq_start) + + # accumulate list of pas-snp distances + snpseq_gene_slice[si]['distances'].setdefault(gene_id,[]).append(pas_distance) + if 'PROP3' in sed_stats: + snpseq_apa_slice[si]['distances'].setdefault(pas_id,[]).append(pas_distance) + elif 'COVR3' in sed_stats or 'COVR3WIDE' in sed_stats: + snpseq_apa_slice[si]['distances'].setdefault(pas_id + '_up',[]).append(pas_distance) + else : + snpseq_apa_slice[si]['distances'].setdefault(pas_id + '_up',[]).append(pas_distance) + snpseq_apa_slice[si]['distances'].setdefault(pas_id + '_dn',[]).append(pas_distance) + + if 'PROP3' in sed_stats: + # coverage (overlapping PAS) + bin_end = int(np.round(pas_seq_start / model_stride)) + 3 + bin_start = bin_end - 5 + + # clip right boundaries + bin_max = int(seq_len/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + if bin_end - bin_start > 0: + # save gene bin positions + snpseq_gene_slice[si]['bins'].setdefault(gene_id,[]).append([bin_start, bin_end]) + snpseq_apa_slice[si]['bins'].setdefault(pas_id,[]).extend(range(bin_start, bin_end)) + + elif 'COVR3' in sed_stats: + # upstream coverage (before PAS) + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_end = int(np.round(pas_seq_start / model_stride)) + 1 + bin_start = bin_end - 4 - 1 + else : + bin_start = int(np.round(pas_seq_end / model_stride)) + bin_end = bin_start + 4 + 1 + + # clip right boundaries + bin_max = int(seq_len/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + if bin_end - bin_start > 0: + # save gene bin positions + snpseq_gene_slice[si]['bins'].setdefault(gene_id,[]).append([bin_start, bin_end]) + snpseq_apa_slice[si]['bins'].setdefault(pas_id + '_up',[]).extend(range(bin_start, bin_end)) + + elif 'COVR3WIDE' in sed_stats: + # upstream coverage (before PAS); wider + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_end = int(np.round(pas_seq_start / model_stride)) + 1 + bin_start = bin_end - 9 - 1 + else : + bin_start = int(np.round(pas_seq_end / model_stride)) + bin_end = bin_start + 9 + 1 + + # clip right boundaries + bin_max = int(seq_len/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + if bin_end - bin_start > 0: + # save gene bin positions + snpseq_gene_slice[si]['bins'].setdefault(gene_id,[]).append([bin_start, bin_end]) + snpseq_apa_slice[si]['bins'].setdefault(pas_id + '_up',[]).extend(range(bin_start, bin_end)) + + else: # default (COVR) + # upstream coverage (before PAS) + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_end = int(np.round(pas_seq_start / model_stride)) + 1 + bin_start = bin_end - 3 - 1 + else : + bin_start = int(np.round(pas_seq_end / model_stride)) + bin_end = bin_start + 3 + 1 + + # clip right boundaries + bin_max = int(seq_len/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + if bin_end - bin_start > 0: + # save gene bin positions + snpseq_gene_slice[si]['bins'].setdefault(gene_id,[]).append([bin_start, bin_end]) + snpseq_apa_slice[si]['bins'].setdefault(pas_id + '_up',[]).extend(range(bin_start, bin_end)) + + # downstream coverage (after PAS) + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_start = int(np.round(pas_seq_end / model_stride)) + bin_end = bin_start + 3 + 1 + else : + bin_end = int(np.round(pas_seq_start / model_stride)) + 1 + bin_start = bin_end - 3 - 1 + + # clip right boundaries + bin_max = int(seq_len/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + if bin_end - bin_start > 0: + # save gene bin positions + snpseq_gene_slice[si]['bins'].setdefault(gene_id,[]).append([bin_start, bin_end]) + snpseq_apa_slice[si]['bins'].setdefault(pas_id + '_dn',[]).extend(range(bin_start, bin_end)) + + return snpseq_gene_slice, snpseq_apa_slice + +def initialize_output_h5(out_dir, sed_stats, snps, snpseq_gene_slice, targets_df, out_name='sed'): + """Initialize an output HDF5 file for SAD stats.""" + + sed_out = h5py.File('%s/%s.h5' % (out_dir, out_name), 'w') + + # collect identifier tuples + snp_indexes = [] + gene_ids = [] + distances = [] + ns = [] + snp_ids = [] + snp_a1 = [] + snp_a2 = [] + snp_flips = [] + for si, gene_slice in enumerate(snpseq_gene_slice): + snp_genes = list(gene_slice['bins'].keys()) + gene_ids += snp_genes + distances += [int(np.min(gene_slice['distances'][snp_gene])) for snp_gene in snp_genes] + ns += [len(gene_slice['distances'][snp_gene]) for snp_gene in snp_genes] + snp_indexes += [si]*len(snp_genes) + num_scores = len(snp_indexes) + + # write SNP indexes + snp_indexes = np.array(snp_indexes) + sed_out.create_dataset('si', data=snp_indexes) + + # write genes + gene_ids = np.array(gene_ids, 'S') + sed_out.create_dataset('gene', data=gene_ids) + + # write distances + distances = np.array(distances, 'int32') + sed_out.create_dataset('distance', data=distances) + + # write number of sites + ns = np.array(ns, 'int32') + sed_out.create_dataset('n', data=ns) + + # write SNPs + snp_ids = np.array([snp.rsid for snp in snps], 'S') + sed_out.create_dataset('snp', data=snp_ids) + + # write SNP chr + snp_chr = np.array([snp.chr for snp in snps], 'S') + sed_out.create_dataset('chr', data=snp_chr) + + # write SNP pos + snp_pos = np.array([snp.pos for snp in snps], dtype='uint32') + sed_out.create_dataset('pos', data=snp_pos) + + # check flips + snp_flips = [snp.flipped for snp in snps] + + # write SNP reference allele + snp_refs = [] + snp_alts = [] + for snp in snps: + if snp.flipped: + snp_refs.append(snp.alt_alleles[0]) + snp_alts.append(snp.ref_allele) + else: + snp_refs.append(snp.ref_allele) + snp_alts.append(snp.alt_alleles[0]) + snp_refs = np.array(snp_refs, 'S') + snp_alts = np.array(snp_alts, 'S') + sed_out.create_dataset('ref_allele', data=snp_refs) + sed_out.create_dataset('alt_allele', data=snp_alts) + + # write targets + sed_out.create_dataset('target_ids', data=np.array(targets_df.identifier, 'S')) + sed_out.create_dataset('target_labels', data=np.array(targets_df.description, 'S')) + + # initialize SED stats + num_targets = targets_df.shape[0] + for sed_stat in sed_stats: + sed_out.create_dataset(sed_stat, + shape=(num_scores, num_targets), + dtype='float16') + + return sed_out + + +def make_snpseq_bedt(snps, seq_len): + """Make a BedTool object for all SNP sequences.""" + num_snps = len(snps) + left_len = seq_len // 2 + right_len = seq_len // 2 + + snpseq_bed_lines = [] + for si in range(num_snps): + snpseq_start = max(0, snps[si].pos - left_len) + snpseq_end = snps[si].pos + right_len + snpseq_end += max(0, len(snps[si].ref_allele) - snps[si].longest_alt()) + snpseq_bed_lines.append('%s %d %d %d' % (snps[si].chr, snpseq_start, snpseq_end, si)) + + snpseq_bedt = pybedtools.BedTool('\n'.join(snpseq_bed_lines), from_string=True) + return snpseq_bedt + + +def targets_prep_strand(targets_df): + # attach strand + targets_strand = [] + for _, target in targets_df.iterrows(): + if target.strand_pair == target.name: + targets_strand.append('.') + else: + targets_strand.append(target.identifier[-1]) + targets_df['strand'] = targets_strand + + # collapse stranded + strand_mask = (targets_df.strand != '-') + targets_strand_df = targets_df[strand_mask] + + return targets_strand_df + + +def write_pct(sed_out, sed_stats): + """Compute percentile values for each target and write to HDF5.""" + + # define percentiles + d_fine = 0.001 + d_coarse = 0.01 + percentiles_neg = np.arange(d_fine, 0.1, d_fine) + percentiles_base = np.arange(0.1, 0.9, d_coarse) + percentiles_pos = np.arange(0.9, 1, d_fine) + + percentiles = np.concatenate([percentiles_neg, percentiles_base, percentiles_pos]) + sed_out.create_dataset('percentiles', data=percentiles) + pct_len = len(percentiles) + + for sad_stat in sed_stats: + if sad_stat not in ['REF','ALT']: + sad_stat_pct = '%s_pct' % sad_stat + + # compute + sad_pct = np.percentile(sed_out[sad_stat], 100*percentiles, axis=0).T + sad_pct = sad_pct.astype('float16') + + # save + sed_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16') + + +def write_snp(ref_preds, alt_preds, sed_out, xi, sed_stats, cov_pseudo, cov_min): + """Write SNP predictions to HDF, assuming the length dimension has + been maintained.""" + + # ref/alt_preds is L x T + ref_preds = ref_preds.astype('float64') + alt_preds = alt_preds.astype('float64') + seq_len, num_targets = ref_preds.shape + + # sum across bins + ref_preds_sum = ref_preds.sum(axis=0) + alt_preds_sum = alt_preds.sum(axis=0) + + # compare reference to alternative via mean downstream/upstream coverage ratios + if 'COVR' in sed_stats: + + cov_vec = (alt_preds + cov_pseudo) / (ref_preds + cov_pseudo) + + if np.sum(np.mean(ref_preds, axis=1) > cov_min) >= 1 : + cov_vec = cov_vec[np.mean(ref_preds, axis=1) > cov_min, :] + + cov_vec = np.concatenate([ + np.ones((1, cov_vec.shape[1])), + cov_vec, + np.ones((1, cov_vec.shape[1])), + ], axis=0) + + max_scores = np.zeros(cov_vec.shape[1]) + for j in range(1, cov_vec.shape[0]) : + avg_up = np.mean(cov_vec[:j, :], axis=0) + avg_dn = np.mean(cov_vec[j:, :], axis=0) + + scores = np.abs(np.log2(avg_dn / avg_up)) + + if np.mean(scores) > np.mean(max_scores) : + max_scores = scores + + sed_out['COVR'][xi] = max_scores.astype('float16') + + # compare reference to alternative via mean downstream/upstream coverage ratios (signed) + if 'SCOVR' in sed_stats: + + cov_vec = (alt_preds + cov_pseudo) / (ref_preds + cov_pseudo) + + if np.sum(np.mean(ref_preds, axis=1) > cov_min) >= 1 : + cov_vec = cov_vec[np.mean(ref_preds, axis=1) > cov_min, :] + + cov_vec = np.concatenate([ + np.ones((1, cov_vec.shape[1])), + cov_vec, + np.ones((1, cov_vec.shape[1])), + ], axis=0) + + max_scores = np.zeros(cov_vec.shape[1]) + max_scores_s = np.zeros(cov_vec.shape[1]) + for j in range(1, cov_vec.shape[0]) : + avg_up = np.mean(cov_vec[:j, :], axis=0) + avg_dn = np.mean(cov_vec[j:, :], axis=0) + + scores = np.abs(np.log2(avg_dn / avg_up)) + scores_s = np.log2(avg_dn / avg_up) + + if np.mean(scores) > np.mean(max_scores) : + max_scores = scores + max_scores_s = scores_s + + sed_out['SCOVR'][xi] = max_scores_s.astype('float16') + + # compare reference to alternative via mean downstream/upstream proportion ratios (PAS-seq) + if 'PROP3' in sed_stats or 'COVR3' in sed_stats or 'COVR3WIDE' in sed_stats: + + prop_vec_ref = (ref_preds + cov_pseudo) / np.sum(ref_preds + cov_pseudo, axis=0)[None, :] + prop_vec_alt = (alt_preds + cov_pseudo) / np.sum(alt_preds + cov_pseudo, axis=0)[None, :] + + max_scores = np.zeros(prop_vec_ref.shape[1]) + for j in range(1, prop_vec_ref.shape[0]) : + + dist_usage_ref = np.sum(prop_vec_ref[j:, :], axis=0) + dist_usage_alt = np.sum(prop_vec_alt[j:, :], axis=0) + + scores = np.abs(np.log2(dist_usage_alt / (1. - dist_usage_alt)) - np.log2(dist_usage_ref / (1. - dist_usage_ref))) + + if np.mean(scores) > np.mean(max_scores) : + max_scores = scores + + if 'PROP3' in sed_stats: + sed_out['PROP3'][xi] = max_scores.astype('float16') + if 'COVR3' in sed_stats: + sed_out['COVR3'][xi] = max_scores.astype('float16') + if 'COVR3WIDE' in sed_stats: + sed_out['COVR3WIDE'][xi] = max_scores.astype('float16') + + # compare reference to alternative via mean downstream/upstream proportion ratios (PAS-seq; signed) + if 'SPROP3' in sed_stats or 'SCOVR3' in sed_stats or 'SCOVR3WIDE' in sed_stats: + + prop_vec_ref = (ref_preds + cov_pseudo) / np.sum(ref_preds + cov_pseudo, axis=0)[None, :] + prop_vec_alt = (alt_preds + cov_pseudo) / np.sum(alt_preds + cov_pseudo, axis=0)[None, :] + + max_scores = np.zeros(prop_vec_ref.shape[1]) + max_scores_s = np.zeros(prop_vec_ref.shape[1]) + for j in range(1, prop_vec_ref.shape[0]) : + + dist_usage_ref = np.sum(prop_vec_ref[j:, :], axis=0) + dist_usage_alt = np.sum(prop_vec_alt[j:, :], axis=0) + + scores = np.abs(np.log2(dist_usage_alt / (1. - dist_usage_alt)) - np.log2(dist_usage_ref / (1. - dist_usage_ref))) + scores_s = np.log2(dist_usage_alt / (1. - dist_usage_alt)) - np.log2(dist_usage_ref / (1. - dist_usage_ref)) + + if np.mean(scores) > np.mean(max_scores) : + max_scores = scores + max_scores_s = scores_s + + if 'SPROP3' in sed_stats: + sed_out['SPROP3'][xi] = max_scores_s.astype('float16') + if 'SCOVR3' in sed_stats: + sed_out['SCOVR3'][xi] = max_scores_s.astype('float16') + if 'SCOVR3WIDE' in sed_stats: + sed_out['SCOVR3WIDE'][xi] = max_scores_s.astype('float16') + + # predictions + if 'REF' in sed_stats: + sed_out['REF'][xi] = ref_preds_sum.astype('float16') + if 'ALT' in sed_stats: + sed_out['ALT'][xi] = alt_preds_sum.astype('float16') + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/bin/borzoi_sed_paqtl_cov.py b/bin/borzoi_sed_paqtl_cov.py new file mode 100644 index 00000000..4313dd44 --- /dev/null +++ b/bin/borzoi_sed_paqtl_cov.py @@ -0,0 +1,789 @@ +#!/usr/bin/env python +# Copyright 2022 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function + +from optparse import OptionParser +from collections import OrderedDict +import json +import pickle +import os +import pdb +import sys +import time + +import h5py +import numpy as np +import pandas as pd +import pybedtools +import pysam +from scipy.special import rel_entr +import tensorflow as tf + +from basenji import dna_io +from basenji import gene as bgene +from basenji import seqnn +from basenji import stream +from basenji import vcf as bvcf + +''' +borzoi_sed_paqtl_cov.py + +Compute SNP COVerage Ratio (COVR) scores for SNPs in a VCF file, +relative to 3' UTR polyadenylation sites in an annotation file. +''' + +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('-f', dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]') + parser.add_option('-g', dest='genes_gtf', + default='%s/genes/gencode41/gencode41_basic_nort.gtf' % os.environ['HG38'], + help='GTF for gene definition [Default %default]') + parser.add_option('--apafile', dest='apa_file', + default='polyadb_human_v3.csv.gz') + parser.add_option('-o',dest='out_dir', + default='sed', + help='Output directory for tables and plots [Default: %default]') + parser.add_option('-p', dest='processes', + default=None, type='int', + help='Number of processes, passed by multi script') + parser.add_option('--pseudo', dest='cov_pseudo', + default=50, type='float', + help='Coverage pseudocount [Default: %default]') + parser.add_option('--cov', dest='cov_min', + default=100, type='float', + help='Min coverage [Default: %default]') + parser.add_option('--paext', dest='pas_ext', + default=50, type='float', + help='Extension in bp past gene span annotation [Default: %default]') + parser.add_option('--rc', dest='rc', + default=True, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--stats', dest='sed_stats', + default='COVR,SCOVR', + help='Comma-separated list of stats to save. [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + (options, args) = parser.parse_args() + + if len(args) == 3: + # single worker + params_file = args[0] + model_file = args[1] + vcf_file = args[2] + + elif len(args) == 4: + # multi separate + options_pkl_file = args[0] + params_file = args[1] + model_file = args[2] + vcf_file = args[3] + + # save out dir + out_dir = options.out_dir + + # load options + options_pkl = open(options_pkl_file, 'rb') + options = pickle.load(options_pkl) + options_pkl.close() + + # update output directory + options.out_dir = out_dir + + elif len(args) == 5: + # multi worker + options_pkl_file = args[0] + params_file = args[1] + model_file = args[2] + vcf_file = args[3] + worker_index = int(args[4]) + + # load options + options_pkl = open(options_pkl_file, 'rb') + options = pickle.load(options_pkl) + options_pkl.close() + + # update output directory + options.out_dir = '%s/job%d' % (options.out_dir, worker_index) + + else: + parser.error('Must provide parameters/model, VCF, and genes GTF') + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.shifts = [int(shift) for shift in options.shifts.split(',')] + options.sed_stats = options.sed_stats.split(',') + + ################################################################# + # read parameters and targets + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params['model'] + params_train = params['train'] + seq_len = params_model['seq_length'] + + if options.targets_file is None: + parser.error('Must provide targets table to properly handle strands.') + else: + targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0) + + # prep strand + targets_strand_df = targets_prep_strand(targets_df) + + ################################################################# + # setup model + + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_file) + seqnn_model.build_slice(targets_df.index) + seqnn_model.build_ensemble(options.rc, options.shifts) + + model_stride = seqnn_model.model_strides[0] + out_seq_len = seqnn_model.target_lengths[0]*model_stride + + ################################################################# + # read SNPs / genes + + # filter for worker SNPs + if options.processes is not None: + # determine boundaries + num_snps = bvcf.vcf_count(vcf_file) + worker_bounds = np.linspace(0, num_snps, options.processes+1, dtype='int') + + # read SNPs form VCF + snps = bvcf.vcf_snps(vcf_file, start_i=worker_bounds[worker_index], + end_i=worker_bounds[worker_index+1]) + + else: + # read SNPs form VCF + snps = bvcf.vcf_snps(vcf_file) + + # read genes + transcriptome = bgene.Transcriptome(options.genes_gtf) + gene_strand = {} + for gene_id, gene in transcriptome.genes.items(): + gene_strand[gene_id] = gene.strand + + ####################################################### + # make APA BED (PolyADB) + + apa_df = pd.read_csv(options.apa_file, sep='\t', compression='gzip') + + # filter for 3' UTR polyA sites only + apa_df = apa_df.query("site_type == '3\\' most exon'").copy().reset_index(drop=True) + + #Remove non-contiguos sites, starting from distal-most site + + apa_df = apa_df.sort_values(by=['site_num'], ascending=False).copy().reset_index(drop=True) + + gene_dict = {} + keep_index = [] + for i, [_, row] in enumerate(apa_df.iterrows()) : + + if row['gene'] not in gene_dict : + gene_dict[row['gene']] = row['site_num'] + keep_index.append(i) + else : + if row['site_num'] == gene_dict[row['gene']] - 1 : + gene_dict[row['gene']] = row['site_num'] + keep_index.append(i) + + apa_df = apa_df.iloc[keep_index].copy().reset_index(drop=True) + + apa_df = apa_df.sort_values(by=['gene', 'site_num'], ascending=True).copy().reset_index(drop=True) + + apa_df['start_hg38'] = apa_df['position_hg38'] + apa_df['end_hg38'] = apa_df['position_hg38'] + 1 + + apa_df = apa_df.rename(columns={'chrom' : 'Chromosome', 'start_hg38' : 'Start', 'end_hg38' : 'End', 'position_hg38' : 'cut_mode', 'strand' : 'pas_strand'}) + + apa_df = apa_df[['Chromosome', 'Start', 'End', 'pas_id', 'pas_strand', 'gene', 'site_num']] + + # map SNP sequences to gene / polyA signal positions + snpseq_gene_slice, snpseq_apa_slice = map_snpseq_apa(snps, out_seq_len, transcriptome, apa_df, model_stride, options.sed_stats, options.pas_ext) + + # remove SNPs w/o genes + num_snps_pre = len(snps) + snp_gene_mask = np.array([len(sgs) > 0 for sgs in snpseq_gene_slice]) + snps = [snps[si] for si in range(num_snps_pre) if snp_gene_mask[si]] + snpseq_gene_slice = [snpseq_gene_slice[si] for si in range(num_snps_pre) if snp_gene_mask[si]] + snpseq_apa_slice = [snpseq_apa_slice[si] for si in range(num_snps_pre) if snp_gene_mask[si]] + num_snps = len(snps) + + # create SNP seq generator + genome_open = pysam.Fastafile(options.genome_fasta) + + def snp_gen(): + for snp in snps: + # get SNP sequences + snp_1hot_list = bvcf.snp_seq1(snp, seq_len, genome_open) + for snp_1hot in snp_1hot_list: + yield snp_1hot + + ################################################################# + # setup output + + sed_out = initialize_output_h5(options.out_dir, options.sed_stats, snps, snpseq_gene_slice, targets_strand_df, out_name='sed') + sed_out_apa = initialize_output_h5(options.out_dir, ['REF','ALT'], snps, snpseq_apa_slice, targets_strand_df, out_name='sed_pas') + + ################################################################# + # predict SNP scores, write output + + # initialize predictions stream + preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(), params_train['batch_size']) + + # predictions index + pi = 0 + + # SNP/gene index + xi_gene = 0 + + # SNP/pas index + xi_pas = 0 + + # for each SNP sequence + for si in range(num_snps): + # get predictions + ref_preds = preds_stream[pi] + pi += 1 + alt_preds = preds_stream[pi] + pi += 1 + + # undo scale + ref_preds /= np.expand_dims(targets_df.scale, axis=0) + alt_preds /= np.expand_dims(targets_df.scale, axis=0) + + # undo sqrt + ref_preds = ref_preds**(4/3) + alt_preds = alt_preds**(4/3) + + # for each overlapping gene + for gene_id, gene_slice_dup in snpseq_gene_slice[si]['bins'].items(): + + # remove duplicate bin coordinates (artifact of PASs that are within <32bp) + gene_slice = [] + gene_slice_dict = {} + for gslpas in gene_slice_dup: + gslpas_key = str(gslpas[0]) + "_" + str(gslpas[1]) + + if gslpas_key not in gene_slice_dict : + gene_slice_dict[gslpas_key] = True + gene_slice.append(gslpas) + + # slice gene positions + ref_preds_gene = np.concatenate([np.sum(ref_preds[gene_slice_start:gene_slice_end, :], axis=0)[None, :] for [gene_slice_start, gene_slice_end] in gene_slice], axis=0) + alt_preds_gene = np.concatenate([np.sum(alt_preds[gene_slice_start:gene_slice_end, :], axis=0)[None, :] for [gene_slice_start, gene_slice_end] in gene_slice], axis=0) + + if gene_strand[gene_id] == '+': + gene_strand_mask = (targets_df.strand != '-') + else: + gene_strand_mask = (targets_df.strand != '+') + + ref_preds_gene = ref_preds_gene[...,gene_strand_mask] + alt_preds_gene = alt_preds_gene[...,gene_strand_mask] + + # write scores to HDF + write_snp(ref_preds_gene, alt_preds_gene, sed_out, xi_gene, options.sed_stats, options.cov_pseudo, options.cov_min) + + xi_gene += 1 + + # for each overlapping PAS + for pas_id, pas_slice in snpseq_apa_slice[si]['bins'].items(): + if len(pas_slice) > len(set(pas_slice)): + print('WARNING: %d %s has overlapping bins' % (si,pas_id)) + eprint('WARNING: %d %s has overlapping bins' % (si,pas_id)) + + # slice pas positions + ref_preds_pas = ref_preds[pas_slice] + alt_preds_pas = alt_preds[pas_slice] + + # slice relevant strand targets + if '+' in pas_id: + pas_strand_mask = (targets_df.strand != '-') + else: + pas_strand_mask = (targets_df.strand != '+') + + ref_preds_pas = ref_preds_pas[...,pas_strand_mask] + alt_preds_pas = alt_preds_pas[...,pas_strand_mask] + + # write scores to HDF + write_snp(ref_preds_pas, alt_preds_pas, sed_out_apa, xi_pas, ['REF','ALT'], options.cov_pseudo, options.cov_min) + + xi_pas += 1 + + # close genome + genome_open.close() + + ################################################### + # compute SAD distributions across variants + + # write_pct(sed_out, options.sed_stats) + sed_out.close() + sed_out_apa.close() + +def map_snpseq_apa(snps, seq_len, transcriptome, apa_df, model_stride, sed_stats, pas_ext): + """Intersect SNP sequences with genes and polyA sites, constructing a list + mapping sequence indexes to dictionaries of gene_ids or pas_ids.""" + + # make gene BEDtool + genes_bedt = transcriptome.bedtool_span() + + # make SNP sequence BEDtool + snpseq_bedt = make_snpseq_bedt(snps, seq_len) + + # map SNPs to genes and polyA sites + snpseq_gene_slice = [] + snpseq_apa_slice = [] + for snp in snps: + snpseq_gene_slice.append({ 'bins' : OrderedDict(), 'distances' : OrderedDict() }) + snpseq_apa_slice.append({ 'bins' : OrderedDict(), 'distances' : OrderedDict() }) + + for i1, overlap in enumerate(genes_bedt.intersect(snpseq_bedt, wo=True)): + gene_id = overlap[3] + gene_chrom = overlap[0] + gene_start = int(overlap[1]) + gene_end = int(overlap[2]) + seq_start = int(overlap[7]) + seq_end = int(overlap[8]) + si = int(overlap[9]) + + snp_pos = snps[si].pos + + # get apa dataframe + gene_apa_df = apa_df.query("Chromosome == '" + gene_chrom + "' and ((End > " + str(gene_start-pas_ext) + " and End <= " + str(gene_end+pas_ext) + ") or (Start < " + str(gene_end+pas_ext) + " and Start >= " + str(gene_start-pas_ext) + "))").sort_values(by=['gene', 'site_num'], ascending=True) + + # make sure 80% of all polyA signals are contained within the sequence input window + if len(gene_apa_df) <= 0 or np.mean((gene_apa_df['Start'] >= seq_start).values) < 0.8 or np.mean((gene_apa_df['End'] < seq_end).values) < 0.8: + continue + + # adjust for left overhang + seq_len_chop = seq_end - seq_start + seq_start -= (seq_len - seq_len_chop) + + for _, apa_row in gene_apa_df.iterrows(): + pas_id = apa_row['pas_id'] + pas_start = apa_row['Start'] + pas_end = apa_row['End'] + pas_strand = apa_row['pas_strand'] + + pas_distance = int(np.abs(pas_start - snp_pos)) + + if 'PROP3' in sed_stats and pas_id in snpseq_apa_slice[si]['bins']: + continue + elif pas_id + '_up' in snpseq_apa_slice[si]['bins'] or pas_id + '_dn' in snpseq_apa_slice[si]['bins']: + continue + + # clip left boundaries + pas_seq_start = max(0, pas_start - seq_start) + pas_seq_end = max(0, pas_end - seq_start) + + # accumulate list of pas-snp distances + snpseq_gene_slice[si]['distances'].setdefault(gene_id,[]).append(pas_distance) + if 'PROP3' in sed_stats: + snpseq_apa_slice[si]['distances'].setdefault(pas_id,[]).append(pas_distance) + elif 'COVR3' in sed_stats or 'COVR3WIDE' in sed_stats: + snpseq_apa_slice[si]['distances'].setdefault(pas_id + '_up',[]).append(pas_distance) + else : + snpseq_apa_slice[si]['distances'].setdefault(pas_id + '_up',[]).append(pas_distance) + snpseq_apa_slice[si]['distances'].setdefault(pas_id + '_dn',[]).append(pas_distance) + + if 'PROP3' in sed_stats: + # coverage (overlapping PAS) + bin_end = int(np.round(pas_seq_start / model_stride)) + 3 + bin_start = bin_end - 5 + + # clip right boundaries + bin_max = int(seq_len/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + if bin_end - bin_start > 0: + # save gene bin positions + snpseq_gene_slice[si]['bins'].setdefault(gene_id,[]).append([bin_start, bin_end]) + snpseq_apa_slice[si]['bins'].setdefault(pas_id,[]).extend(range(bin_start, bin_end)) + + elif 'COVR3' in sed_stats: + # upstream coverage (before PAS) + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_end = int(np.round(pas_seq_start / model_stride)) + 1 + bin_start = bin_end - 4 - 1 + else : + bin_start = int(np.round(pas_seq_end / model_stride)) + bin_end = bin_start + 4 + 1 + + # clip right boundaries + bin_max = int(seq_len/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + if bin_end - bin_start > 0: + # save gene bin positions + snpseq_gene_slice[si]['bins'].setdefault(gene_id,[]).append([bin_start, bin_end]) + snpseq_apa_slice[si]['bins'].setdefault(pas_id + '_up',[]).extend(range(bin_start, bin_end)) + + elif 'COVR3WIDE' in sed_stats: + # upstream coverage (before PAS); wider + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_end = int(np.round(pas_seq_start / model_stride)) + 1 + bin_start = bin_end - 9 - 1 + else : + bin_start = int(np.round(pas_seq_end / model_stride)) + bin_end = bin_start + 9 + 1 + + # clip right boundaries + bin_max = int(seq_len/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + if bin_end - bin_start > 0: + # save gene bin positions + snpseq_gene_slice[si]['bins'].setdefault(gene_id,[]).append([bin_start, bin_end]) + snpseq_apa_slice[si]['bins'].setdefault(pas_id + '_up',[]).extend(range(bin_start, bin_end)) + + else: # default (COVR) + # upstream coverage (before PAS) + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_end = int(np.round(pas_seq_start / model_stride)) + 1 + bin_start = bin_end - 3 - 1 + else : + bin_start = int(np.round(pas_seq_end / model_stride)) + bin_end = bin_start + 3 + 1 + + # clip right boundaries + bin_max = int(seq_len/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + if bin_end - bin_start > 0: + # save gene bin positions + snpseq_gene_slice[si]['bins'].setdefault(gene_id,[]).append([bin_start, bin_end]) + snpseq_apa_slice[si]['bins'].setdefault(pas_id + '_up',[]).extend(range(bin_start, bin_end)) + + # downstream coverage (after PAS) + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_start = int(np.round(pas_seq_end / model_stride)) + bin_end = bin_start + 3 + 1 + else : + bin_end = int(np.round(pas_seq_start / model_stride)) + 1 + bin_start = bin_end - 3 - 1 + + # clip right boundaries + bin_max = int(seq_len/model_stride) + bin_start = max(min(bin_start, bin_max), 0) + bin_end = max(min(bin_end, bin_max), 0) + + if bin_end - bin_start > 0: + # save gene bin positions + snpseq_gene_slice[si]['bins'].setdefault(gene_id,[]).append([bin_start, bin_end]) + snpseq_apa_slice[si]['bins'].setdefault(pas_id + '_dn',[]).extend(range(bin_start, bin_end)) + + return snpseq_gene_slice, snpseq_apa_slice + +def initialize_output_h5(out_dir, sed_stats, snps, snpseq_gene_slice, targets_df, out_name='sed'): + """Initialize an output HDF5 file for SAD stats.""" + + sed_out = h5py.File('%s/%s.h5' % (out_dir, out_name), 'w') + + # collect identifier tuples + snp_indexes = [] + gene_ids = [] + distances = [] + ns = [] + snp_ids = [] + snp_a1 = [] + snp_a2 = [] + snp_flips = [] + for si, gene_slice in enumerate(snpseq_gene_slice): + snp_genes = list(gene_slice['bins'].keys()) + gene_ids += snp_genes + distances += [int(np.min(gene_slice['distances'][snp_gene])) for snp_gene in snp_genes] + ns += [len(gene_slice['distances'][snp_gene]) for snp_gene in snp_genes] + snp_indexes += [si]*len(snp_genes) + num_scores = len(snp_indexes) + + # write SNP indexes + snp_indexes = np.array(snp_indexes) + sed_out.create_dataset('si', data=snp_indexes) + + # write genes + gene_ids = np.array(gene_ids, 'S') + sed_out.create_dataset('gene', data=gene_ids) + + # write distances + distances = np.array(distances, 'int32') + sed_out.create_dataset('distance', data=distances) + + # write number of sites + ns = np.array(ns, 'int32') + sed_out.create_dataset('n', data=ns) + + # write SNPs + snp_ids = np.array([snp.rsid for snp in snps], 'S') + sed_out.create_dataset('snp', data=snp_ids) + + # write SNP chr + snp_chr = np.array([snp.chr for snp in snps], 'S') + sed_out.create_dataset('chr', data=snp_chr) + + # write SNP pos + snp_pos = np.array([snp.pos for snp in snps], dtype='uint32') + sed_out.create_dataset('pos', data=snp_pos) + + # check flips + snp_flips = [snp.flipped for snp in snps] + + # write SNP reference allele + snp_refs = [] + snp_alts = [] + for snp in snps: + if snp.flipped: + snp_refs.append(snp.alt_alleles[0]) + snp_alts.append(snp.ref_allele) + else: + snp_refs.append(snp.ref_allele) + snp_alts.append(snp.alt_alleles[0]) + snp_refs = np.array(snp_refs, 'S') + snp_alts = np.array(snp_alts, 'S') + sed_out.create_dataset('ref_allele', data=snp_refs) + sed_out.create_dataset('alt_allele', data=snp_alts) + + # write targets + sed_out.create_dataset('target_ids', data=np.array(targets_df.identifier, 'S')) + sed_out.create_dataset('target_labels', data=np.array(targets_df.description, 'S')) + + # initialize SED stats + num_targets = targets_df.shape[0] + for sed_stat in sed_stats: + sed_out.create_dataset(sed_stat, + shape=(num_scores, num_targets), + dtype='float16') + + return sed_out + + +def make_snpseq_bedt(snps, seq_len): + """Make a BedTool object for all SNP sequences.""" + num_snps = len(snps) + left_len = seq_len // 2 + right_len = seq_len // 2 + + snpseq_bed_lines = [] + for si in range(num_snps): + snpseq_start = max(0, snps[si].pos - left_len) + snpseq_end = snps[si].pos + right_len + snpseq_end += max(0, len(snps[si].ref_allele) - snps[si].longest_alt()) + snpseq_bed_lines.append('%s %d %d %d' % (snps[si].chr, snpseq_start, snpseq_end, si)) + + snpseq_bedt = pybedtools.BedTool('\n'.join(snpseq_bed_lines), from_string=True) + return snpseq_bedt + + +def targets_prep_strand(targets_df): + # attach strand + targets_strand = [] + for _, target in targets_df.iterrows(): + if target.strand_pair == target.name: + targets_strand.append('.') + else: + targets_strand.append(target.identifier[-1]) + targets_df['strand'] = targets_strand + + # collapse stranded + strand_mask = (targets_df.strand != '-') + targets_strand_df = targets_df[strand_mask] + + return targets_strand_df + + +def write_pct(sed_out, sed_stats): + """Compute percentile values for each target and write to HDF5.""" + + # define percentiles + d_fine = 0.001 + d_coarse = 0.01 + percentiles_neg = np.arange(d_fine, 0.1, d_fine) + percentiles_base = np.arange(0.1, 0.9, d_coarse) + percentiles_pos = np.arange(0.9, 1, d_fine) + + percentiles = np.concatenate([percentiles_neg, percentiles_base, percentiles_pos]) + sed_out.create_dataset('percentiles', data=percentiles) + pct_len = len(percentiles) + + for sad_stat in sed_stats: + if sad_stat not in ['REF','ALT']: + sad_stat_pct = '%s_pct' % sad_stat + + # compute + sad_pct = np.percentile(sed_out[sad_stat], 100*percentiles, axis=0).T + sad_pct = sad_pct.astype('float16') + + # save + sed_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16') + + +def write_snp(ref_preds, alt_preds, sed_out, xi, sed_stats, cov_pseudo, cov_min): + """Write SNP predictions to HDF, assuming the length dimension has + been maintained.""" + + # ref/alt_preds is L x T + ref_preds = ref_preds.astype('float64') + alt_preds = alt_preds.astype('float64') + seq_len, num_targets = ref_preds.shape + + # sum across bins + ref_preds_sum = ref_preds.sum(axis=0) + alt_preds_sum = alt_preds.sum(axis=0) + + # compare reference to alternative via mean downstream/upstream coverage ratios + if 'COVR' in sed_stats: + + cov_vec = (alt_preds + cov_pseudo) / (ref_preds + cov_pseudo) + + if np.sum(np.mean(ref_preds, axis=1) > cov_min) >= 1 : + cov_vec = cov_vec[np.mean(ref_preds, axis=1) > cov_min, :] + + cov_vec = np.concatenate([ + np.ones((1, cov_vec.shape[1])), + cov_vec, + np.ones((1, cov_vec.shape[1])), + ], axis=0) + + max_scores = np.zeros(cov_vec.shape[1]) + for j in range(1, cov_vec.shape[0]) : + avg_up = np.mean(cov_vec[:j, :], axis=0) + avg_dn = np.mean(cov_vec[j:, :], axis=0) + + scores = np.abs(np.log2(avg_dn / avg_up)) + + if np.mean(scores) > np.mean(max_scores) : + max_scores = scores + + sed_out['COVR'][xi] = max_scores.astype('float16') + + # compare reference to alternative via mean downstream/upstream coverage ratios (signed) + if 'SCOVR' in sed_stats: + + cov_vec = (alt_preds + cov_pseudo) / (ref_preds + cov_pseudo) + + if np.sum(np.mean(ref_preds, axis=1) > cov_min) >= 1 : + cov_vec = cov_vec[np.mean(ref_preds, axis=1) > cov_min, :] + + cov_vec = np.concatenate([ + np.ones((1, cov_vec.shape[1])), + cov_vec, + np.ones((1, cov_vec.shape[1])), + ], axis=0) + + max_scores = np.zeros(cov_vec.shape[1]) + max_scores_s = np.zeros(cov_vec.shape[1]) + for j in range(1, cov_vec.shape[0]) : + avg_up = np.mean(cov_vec[:j, :], axis=0) + avg_dn = np.mean(cov_vec[j:, :], axis=0) + + scores = np.abs(np.log2(avg_dn / avg_up)) + scores_s = np.log2(avg_dn / avg_up) + + if np.mean(scores) > np.mean(max_scores) : + max_scores = scores + max_scores_s = scores_s + + sed_out['SCOVR'][xi] = max_scores_s.astype('float16') + + # compare reference to alternative via mean downstream/upstream proportion ratios (PAS-seq) + if 'PROP3' in sed_stats or 'COVR3' in sed_stats or 'COVR3WIDE' in sed_stats: + + prop_vec_ref = (ref_preds + cov_pseudo) / np.sum(ref_preds + cov_pseudo, axis=0)[None, :] + prop_vec_alt = (alt_preds + cov_pseudo) / np.sum(alt_preds + cov_pseudo, axis=0)[None, :] + + max_scores = np.zeros(prop_vec_ref.shape[1]) + for j in range(1, prop_vec_ref.shape[0]) : + + dist_usage_ref = np.sum(prop_vec_ref[j:, :], axis=0) + dist_usage_alt = np.sum(prop_vec_alt[j:, :], axis=0) + + scores = np.abs(np.log2(dist_usage_alt / (1. - dist_usage_alt)) - np.log2(dist_usage_ref / (1. - dist_usage_ref))) + + if np.mean(scores) > np.mean(max_scores) : + max_scores = scores + + if 'PROP3' in sed_stats: + sed_out['PROP3'][xi] = max_scores.astype('float16') + if 'COVR3' in sed_stats: + sed_out['COVR3'][xi] = max_scores.astype('float16') + if 'COVR3WIDE' in sed_stats: + sed_out['COVR3WIDE'][xi] = max_scores.astype('float16') + + # compare reference to alternative via mean downstream/upstream proportion ratios (PAS-seq; signed) + if 'SPROP3' in sed_stats or 'SCOVR3' in sed_stats or 'SCOVR3WIDE' in sed_stats: + + prop_vec_ref = (ref_preds + cov_pseudo) / np.sum(ref_preds + cov_pseudo, axis=0)[None, :] + prop_vec_alt = (alt_preds + cov_pseudo) / np.sum(alt_preds + cov_pseudo, axis=0)[None, :] + + max_scores = np.zeros(prop_vec_ref.shape[1]) + max_scores_s = np.zeros(prop_vec_ref.shape[1]) + for j in range(1, prop_vec_ref.shape[0]) : + + dist_usage_ref = np.sum(prop_vec_ref[j:, :], axis=0) + dist_usage_alt = np.sum(prop_vec_alt[j:, :], axis=0) + + scores = np.abs(np.log2(dist_usage_alt / (1. - dist_usage_alt)) - np.log2(dist_usage_ref / (1. - dist_usage_ref))) + scores_s = np.log2(dist_usage_alt / (1. - dist_usage_alt)) - np.log2(dist_usage_ref / (1. - dist_usage_ref)) + + if np.mean(scores) > np.mean(max_scores) : + max_scores = scores + max_scores_s = scores_s + + if 'SPROP3' in sed_stats: + sed_out['SPROP3'][xi] = max_scores_s.astype('float16') + if 'SCOVR3' in sed_stats: + sed_out['SCOVR3'][xi] = max_scores_s.astype('float16') + if 'SCOVR3WIDE' in sed_stats: + sed_out['SCOVR3WIDE'][xi] = max_scores_s.astype('float16') + + # predictions + if 'REF' in sed_stats: + sed_out['REF'][xi] = ref_preds_sum.astype('float16') + if 'ALT' in sed_stats: + sed_out['ALT'][xi] = alt_preds_sum.astype('float16') + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/bin/borzoi_test_apa_folds_polaydb.py b/bin/borzoi_test_apa_folds_polaydb.py new file mode 100644 index 00000000..805fd295 --- /dev/null +++ b/bin/borzoi_test_apa_folds_polaydb.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python +# Copyright 2019 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser, OptionGroup +import glob +import json +import os +import pdb +import sys + +from natsort import natsorted +import numpy as np +import pandas as pd +from scipy.stats import wilcoxon, ttest_rel +import matplotlib.pyplot as plt +import seaborn as sns + +import slurm + +""" +borzoi_test_apa_folds_polaydb.py + +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ...' + parser = OptionParser(usage) + parser.add_option('-a', '--alt', dest='alternative', + default='two-sided', help='Statistical test alternative [Default: %default]') + parser.add_option('-c', dest='crosses', + default=1, type='int', + help='Number of cross-fold rounds [Default:%default]') + parser.add_option('-d', dest='dataset_i', + default=None, type='int', + help='Dataset index [Default:%default]') + parser.add_option('--d_ref', dest='dataset_ref_i', + default=None, type='int', + help='Reference Dataset index [Default:%default]') + parser.add_option('-e', dest='conda_env', + default='tf210', + help='Anaconda environment [Default: %default]') + parser.add_option('-f', dest='fold_subset', + default=None, type='int', + help='Run a subset of folds [Default:%default]') + parser.add_option('-g', dest='apa_file', + default='polyadb_human_v3.csv.gz') + parser.add_option('--label_exp', dest='label_exp', + default='Experiment', help='Experiment label [Default: %default]') + parser.add_option('--label_ref', dest='label_ref', + default='Reference', help='Reference label [Default: %default]') + parser.add_option('-m', dest='metric', + default='pearsonr', help='Train/test metric [Default: Pearsonr or AUPRC]') + parser.add_option('--name', dest='name', + default='teste', help='SLURM name prefix [Default: %default]') + parser.add_option('-o', dest='exp_dir', + default=None, help='Output experiment directory [Default: %default]') + parser.add_option('-p', dest='out_stem', + default=None, help='Output plot stem [Default: %default]') + parser.add_option('-q', dest='queue', + default='geforce') + parser.add_option('-r', dest='ref_dir', + default=None, help='Reference directory for statistical tests') + parser.add_option('--rc', dest='rc', + default=False, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--status', dest='status', + default=False, action='store_true', + help='Update metric status; do not run jobs [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + (options, args) = parser.parse_args() + + if len(args) < 2: + parser.error('Must provide parameters file and data directory') + else: + params_file = args[0] + data_dirs = [os.path.abspath(arg) for arg in args[1:]] + + # using -o for required argument for compatibility with the training script + assert(options.exp_dir is not None) + + # read data parameters + data_stats_file = '%s/statistics.json' % data_dirs[0] + with open(data_stats_file) as data_stats_open: + data_stats = json.load(data_stats_open) + + if options.dataset_i is None: + head_i = 0 + else: + head_i = options.dataset_i + + # count folds + num_folds = len([dkey for dkey in data_stats if dkey.startswith('fold')]) + + # subset folds + if options.fold_subset is not None: + num_folds = min(options.fold_subset, num_folds) + + if options.queue == 'standard': + num_cpu = 4 + num_gpu = 0 + else: + num_cpu = 2 + num_gpu = 1 + + ################################################################ + # test best + ################################################################ + jobs = [] + + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (options.exp_dir, fi, ci) + + if options.dataset_i is None: + out_dir = '%s/teste' % it_dir + model_file = '%s/train/model_best.h5' % it_dir + else: + out_dir = '%s/teste%d' % (it_dir, options.dataset_i) + model_file = '%s/train/model%d_best.h5' % (it_dir, options.dataset_i) + + # check if done + acc_file = '%s/acc.txt' % out_dir + if os.path.isfile(acc_file): + # print('%s already generated.' % acc_file) + pass + else: + # basenji test + cmd = '. /home/drk/anaconda3/etc/profile.d/conda.sh;' + cmd += ' conda activate %s;' % options.conda_env + cmd += ' time borzoi_test_apa_polaydb.py' + cmd += ' --head %d' % head_i + cmd += ' -o %s' % out_dir + if options.rc: + cmd += ' --rc' + if options.shifts: + cmd += ' --shifts %s' % options.shifts + if options.targets_file is not None: + cmd += ' -t %s' % options.targets_file + cmd += ' %s' % params_file + cmd += ' %s' % model_file + cmd += ' %s/data%d' % (it_dir, head_i) + cmd += ' %s' % options.apa_file + + name = '%s-f%dc%d' % (options.name, fi, ci) + j = slurm.Job(cmd, + name=name, + out_file='%s.out'%out_dir, + err_file='%s.err'%out_dir, + queue=options.queue, + cpu=num_cpu, gpu=num_gpu, + mem=45000, + time='2-00:00:00') + jobs.append(j) + + slurm.multi_run(jobs, verbose=True) + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/bin/borzoi_test_apa_polaydb.py b/bin/borzoi_test_apa_polaydb.py new file mode 100644 index 00000000..b0810edf --- /dev/null +++ b/bin/borzoi_test_apa_polaydb.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python +# Copyright 2021 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from optparse import OptionParser +import gc +import json +import pdb +import os +import time +import sys + +import h5py +#from intervaltree import IntervalTree +import numpy as np +import pandas as pd +import pyranges as pr +from scipy.stats import pearsonr +from sklearn.metrics import explained_variance_score +import tensorflow as tf +#from tqdm import tqdm + +from basenji import bed +from basenji import dataset +from basenji import seqnn +from basenji import trainer +#import pygene +#from qnorm import quantile_normalize + +''' +borzoi_test_apa_polaydb.py + +Measure accuracy at polyadenylation-level. +''' + +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('--head', dest='head_i', + default=0, type='int', + help='Parameters head [Default: %default]') + parser.add_option('-o', dest='out_dir', + default='teste_out', + help='Output directory for predictions [Default: %default]') + parser.add_option('--rc', dest='rc', + default=False, action='store_true', + help='Average the fwd and rc predictions [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + parser.add_option('--split', dest='split_label', + default='test', + help='Dataset split label for eg TFR pattern [Default: %default]') + parser.add_option('--tfr', dest='tfr_pattern', + default=None, + help='TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]') + (options, args) = parser.parse_args() + + if len(args) != 4: + parser.error('Must provide parameters, model, data directory, and APA annotation') + else: + params_file = args[0] + model_file = args[1] + data_dir = args[2] + apa_file = args[3] + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + # parse shifts to integers + options.shifts = [int(shift) for shift in options.shifts.split(',')] + + ####################################################### + # inputs + + # read targets + if options.targets_file is None: + options.targets_file = '%s/targets.txt' % data_dir + targets_df = pd.read_csv(options.targets_file, index_col=0, sep='\t') + + # attach strand + targets_strand = [] + for ti, identifier in enumerate(targets_df.identifier): + if targets_df.index[ti] == targets_df.strand_pair.iloc[ti]: + targets_strand.append('.') + else: + targets_strand.append(identifier[-1]) + targets_df['strand'] = targets_strand + + # collapse stranded + strand_mask = (targets_df.strand != '-') + targets_strand_df = targets_df[strand_mask] + + # count targets + num_targets = targets_df.shape[0] + num_targets_strand = targets_strand_df.shape[0] + + # save sqrt'd tracks + sqrt_mask = np.array([ss.find('sqrt') != -1 for ss in targets_strand_df.sum_stat]) + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params['model'] + params_train = params['train'] + + # set strand pairs + params_model['strand_pair'] = [np.array(targets_df.strand_pair)] + + # construct eval data + eval_data = dataset.SeqDataset(data_dir, + split_label=options.split_label, + batch_size=params_train['batch_size'], + mode='eval', + tfr_pattern=options.tfr_pattern) + + # initialize model + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_file, options.head_i) + seqnn_model.build_slice(targets_df.index) + seqnn_model.build_ensemble(options.rc, options.shifts) + + ####################################################### + # sequence intervals + + # read data parameters + with open('%s/statistics.json'%data_dir) as data_open: + data_stats = json.load(data_open) + crop_bp = data_stats['crop_bp'] + pool_width = data_stats['pool_width'] + + # read sequence positions + seqs_df = pd.read_csv('%s/sequences.bed'%data_dir, sep='\t', + names=['Chromosome','Start','End','Name']) + seqs_df = seqs_df[seqs_df.Name == options.split_label] + seqs_pr = pr.PyRanges(seqs_df) + + ####################################################### + # make APA BED (PolyADB) + + apa_df = pd.read_csv(apa_file, sep='\t', compression='gzip') + + # filter for 3' UTR polyA sites only + apa_df = apa_df.query("site_type == '3\\' most exon'").copy().reset_index(drop=True) + + eprint("len(apa_df) = " + str(len(apa_df))) + print("len(apa_df) = " + str(len(apa_df))) + + apa_df['start_hg38'] = apa_df['position_hg38'] + apa_df['end_hg38'] = apa_df['position_hg38'] + 1 + + apa_df = apa_df.rename(columns={'chrom' : 'Chromosome', 'start_hg38' : 'Start', 'end_hg38' : 'End', 'position_hg38' : 'cut_mode', 'strand' : 'pas_strand'}) + + apa_pr = pr.PyRanges(apa_df[['Chromosome', 'Start', 'End', 'pas_id', 'cut_mode', 'pas_strand']]) + + ####################################################### + # intersect APA sites w/ preds, targets + + # intersect seqs, APA sites + seqs_apa_pr = seqs_pr.join(apa_pr) + + eprint("len(seqs_apa_pr.df) = " + str(len(seqs_apa_pr.df))) + print("len(seqs_apa_pr.df) = " + str(len(seqs_apa_pr.df))) + + # hash preds/targets by pas_id + apa_preds_dict = {} + apa_targets_dict = {} + + si = 0 + for x, y in eval_data.dataset: + # predict only if gene overlaps + yh = None + y = y.numpy()[...,targets_df.index] + + t0 = time.time() + eprint('Sequence %d...' % si) + print('Sequence %d...' % si, end='') + for bsi in range(x.shape[0]): + seq = seqs_df.iloc[si+bsi] + + cseqs_apa_df = seqs_apa_pr[seq.Chromosome].df + if cseqs_apa_df.shape[0] == 0: + # empty. no apa sites on this chromosome + seq_apa_df = cseqs_apa_df + else: + seq_apa_df = cseqs_apa_df[cseqs_apa_df.Start == seq.Start] + + for _, seq_apa in seq_apa_df.iterrows(): + pas_id = seq_apa.pas_id + pas_start = seq_apa.Start_b + pas_end = seq_apa.End_b + seq_start = seq_apa.Start + cut_mode = seq_apa.cut_mode + pas_strand = seq_apa.pas_strand + + # clip boundaries + pas_seq_start = max(0, pas_start - seq_start) + pas_seq_end = max(0, pas_end - seq_start) + cut_seq_mode = max(0, cut_mode - seq_start) + + # requires >50% overlap + + bin_start = None + bin_end = None + if pas_strand == '+' : + bin_end = int(np.round(pas_seq_start / pool_width)) + 1 + bin_start = bin_end - 3 - 1 + else : + bin_start = int(np.round(pas_seq_end / pool_width)) + bin_end = bin_start + 3 + 1 + + # predict + if yh is None: + yh = seqnn_model(x) + + # slice gene region + yhb = yh[bsi,bin_start:bin_end].astype('float16') + yb = y[bsi,bin_start:bin_end].astype('float16') + + if len(yb) > 0: + apa_preds_dict.setdefault(pas_id,[]).append(yhb) + apa_targets_dict.setdefault(pas_id,[]).append(yb) + else: + eprint("(Warning: len(yb) <= 0)") + + # advance sequence table index + si += x.shape[0] + eprint('DONE in %ds.' % (time.time()-t0)) + print('DONE in %ds.' % (time.time()-t0)) + + eprint("len(apa_preds_dict) = " + str(len(apa_preds_dict))) + + if si % 128 == 0: + gc.collect() + + + ####################################################### + # aggregate pA bin values into arrays + + apa_targets = [] + apa_preds = [] + pas_ids = np.array(sorted(apa_targets_dict.keys())) + + for pas_id in pas_ids: + apa_preds_gi = np.concatenate(apa_preds_dict[pas_id], axis=0).astype('float32') + apa_targets_gi = np.concatenate(apa_targets_dict[pas_id], axis=0).astype('float32') + + # undo scale + apa_preds_gi /= np.expand_dims(targets_strand_df.scale, axis=0) + apa_targets_gi /= np.expand_dims(targets_strand_df.scale, axis=0) + + # undo sqrt + apa_preds_gi[:,sqrt_mask] = apa_preds_gi[:,sqrt_mask]**(4/3) + apa_targets_gi[:,sqrt_mask] = apa_targets_gi[:,sqrt_mask]**(4/3) + + # mean coverage + apa_preds_gi = apa_preds_gi.mean(axis=0) + apa_targets_gi = apa_targets_gi.mean(axis=0) + + apa_preds.append(apa_preds_gi) + apa_targets.append(apa_targets_gi) + + apa_targets = np.array(apa_targets) + apa_preds = np.array(apa_preds) + + # TEMP + np.save('%s/apa_targets_polyadb.npy' % options.out_dir, apa_targets) + np.save('%s/apa_preds_polyadb.npy' % options.out_dir, apa_preds) + + # save values + apa_targets_df = pd.DataFrame(apa_targets, index=pas_ids) + apa_targets_df.to_csv('%s/apa_targets_polyadb.tsv.gz' % options.out_dir, sep='\t') #, index=False + apa_preds_df = pd.DataFrame(apa_preds, index=pas_ids) + apa_preds_df.to_csv('%s/apa_preds_polyadb.tsv.gz' % options.out_dir, sep='\t') + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/bin/borzoi_trip.py b/bin/borzoi_trip.py new file mode 100644 index 00000000..14638597 --- /dev/null +++ b/bin/borzoi_trip.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python +# Copyright 2022 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function +from optparse import OptionParser +import os +import random +import gc + +import h5py +import json +import numpy as np +import pandas as pd +import pysam +import tensorflow as tf + +from basenji import dna_io +from basenji import seqnn +from basenji import stream + +''' +borzoi_trip.py + +Predict insertions from TRIP assay. +''' + +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('-f', dest='fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]') + parser.add_option('--site', dest='site', + default=False, action='store_true', + help='Return the insertion site without the promoter [Default: %default]') + parser.add_option('--reporter', dest='reporter', + default=False, action='store_true', + help='Insert the flanking piggyback reporter with the promoter [Default: %default]') + parser.add_option('--reporter_bare', dest='reporter_bare', + default=False, action='store_true', + help='Insert the flanking piggyback reporter with the promoter (no terminal repeats) [Default: %default]') + parser.add_option('-o', dest='out_dir', + default='trip', + help='Output directory [Default: %default]') + parser.add_option('--rc', dest='rc', + default=True, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + (options, args) = parser.parse_args() + + if len(args) != 4: + parser.error('Must provide parameters, model, TRIP promoter sequences, and TRIP insertion sites') + else: + params_file = args[0] + model_file = args[1] + promoters_file = args[2] + insertions_file = args[3] + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.shifts = [int(shift) for shift in options.shifts.split(',')] + + ################################################################# + # read parameters and targets + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params['model'] + params_train = params['train'] + + if options.targets_file is None: + parser.error('Must provide targets table to properly handle strands.') + else: + targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0) + + ################################################################# + # setup model + + target_slice = np.array(targets_df.index.values, dtype='int32') + strand_pair = np.array(targets_df.strand_pair.values, dtype='int32') + + # create local index of strand_pair (relative to sliced targets) + target_slice_dict = {ix : i for i, ix in enumerate(target_slice.tolist())} + slice_pair = np.array([target_slice_dict[ix] for ix in strand_pair.tolist()], dtype='int32') + + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_file) + seqnn_model.build_slice(targets_df.index) + seqnn_model.strand_pair.append(slice_pair) + seqnn_model.build_ensemble(options.rc, options.shifts) + + ################################################################ + # promoters + + # read promoter info + promoters_df = pd.read_excel(promoters_file) + + # genome fasta + fasta_open = pysam.Fastafile(options.fasta) + + # define piggyback reporter sequence + reporter_left = 'TTAACCCTAGAAAGATAGTCTGCGTAAAATTGACGCATGCATTCTTGAAATATTGCTCTCTCTTTCTAAATAGCGCGAATCCGTCGCTGTGCATTTAGGACATCTCAGTCGCCGCTTGGAGCTCCCGTGAGGCGTGCTTGTCAATGCGGTAAGTGTCACTGATTTTGAACTATAACGACCGCGTGAGTCAAAATGACGCATGATTATCTTTTACGTGACTTTTAAGATTTAACTCATACGATAATTATATTGTTATTTCATGTTCTACTTACGTGATAACTTATTATATATATATTTTCTTGTTATAGATATCAACTAGAATGCTAGCATGGGCCCATCTCGAGGATCCACCGGTCTAGAAAGCTTAGGCCTCCAAGG' + + reporter_right = 'ATGGTGAGCAAGGGCGAGGAGCTGTTCACCGGGGTGGTGCCCATCCTGGTCGAGCTGGACGGCGACGTAAACGGCCACAAGTTCAGCGTGTCCGGCGAGGGCGAGGGCGATGCCACCTACGGCAAGCTGACCCTGAAGTTCATCTGCACCACCGGCAAGCTGCCCGTGCCCTGGCCCACCCTCGTGACCACCCTGACCTACGGCGTGCAGTGCTTCAGCCGCTACCCCGACCACATGAAGCAGCACGACTTCTTCAAGTCCGCCATGCCCGAAGGCTACGTCCAGGAGCGCACCATCTTCTTCAAGGACGACGGCAACTACAAGACCCGCGCCGAGGTGAAGTTCGAGGGCGACACCCTGGTGAACCGCATCGAGCTGAAGGGCATCGACTTCAAGGAGGACGGCAACATCCTGGGGCACAAGCTGGAGTACAACTACAACAGCCACAACGTCTATATCATGGCCGACAAGCAGAAGAACGGCATCAAGGTGAACTTCAAGATCCGCCACAACATCGAGGACGGCAGCGTGCAGCTCGCCGACCACTACCAGCAGAACACCCCCATCGGCGACGGCCCCGTGCTGCTGCCCGACAACCACTACCTGAGCACCCAGTCCGCCCTGAGCAAAGACCCCAACGAGAAGCGCGATCACATGGTCCTGCTGGAGTTCGTGACCGCCGCCGGGATCACTCTCGGCATGGACGAGCTGTACAAGTAAGAATTCGCGGCCGCATACGATTTAGGTGACACTGCAGATCATATGACAATTGTGGCCGGCCCTTGTGACTGGGAAAACCCTGGCGTAAATAAAATACGAAATGACTAGTTAAAAGTTTTGTTACTTTATAGAAGAAATTTTGAGTTTTTGTTTTTTTTTAATAAATAAATAAACATAAATAAATTGTTTGTTGAATTTATTATTAGTATGTAAGTGTAAATATAATAAAACTTAATATCTATTCAAATTAATAAATAAACCTCGATATACAGACCGATAAAACACATGCGTCAATTTTACGCATGATTATCTTTAACGTACGTCACAATATGATTATCTTTCTAGGGTTAA' + + reporter_right_bare = 'ATGGTGAGCAAGGGCGAGGAGCTGTTCACCGGGGTGGTGCCCATCCTGGTCGAGCTGGACGGCGACGTAAACGGCCACAAGTTCAGCGTGTCCGGCGAGGGCGAGGGCGATGCCACCTACGGCAAGCTGACCCTGAAGTTCATCTGCACCACCGGCAAGCTGCCCGTGCCCTGGCCCACCCTCGTGACCACCCTGACCTACGGCGTGCAGTGCTTCAGCCGCTACCCCGACCACATGAAGCAGCACGACTTCTTCAAGTCCGCCATGCCCGAAGGCTACGTCCAGGAGCGCACCATCTTCTTCAAGGACGACGGCAACTACAAGACCCGCGCCGAGGTGAAGTTCGAGGGCGACACCCTGGTGAACCGCATCGAGCTGAAGGGCATCGACTTCAAGGAGGACGGCAACATCCTGGGGCACAAGCTGGAGTACAACTACAACAGCCACAACGTCTATATCATGGCCGACAAGCAGAAGAACGGCATCAAGGTGAACTTCAAGATCCGCCACAACATCGAGGACGGCAGCGTGCAGCTCGCCGACCACTACCAGCAGAACACCCCCATCGGCGACGGCCCCGTGCTGCTGCCCGACAACCACTACCTGAGCACCCAGTCCGCCCTGAGCAAAGACCCCAACGAGAAGCGCGATCACATGGTCCTGCTGGAGTTCGTGACCGCCGCCGGGATCACTCTCGGCATGGACGAGCTGTACAAGTAAGAATTCGCGGCCGCATACGATTTAGGTGACACTGCAGATCATATGACAATTGTGGCCGGCCCTTGTGACTGGGAAAACCCTGGCGTAAATAAAATACGAAATGACTAGTTTAATGTTTGTTTTCTTATA' + + # read promoter sequence + promoter_seq1 = {} + for pi in range(promoters_df.shape[0]): + promoter_chr, promoter_range = promoters_df.iloc[pi].Region.split(':') + promoter_start, promoter_end = promoter_range.split('-') + promoter_start, promoter_end = int(promoter_start), int(promoter_end) + + promoter_dna = fasta_open.fetch(promoter_chr, promoter_start, promoter_end) + promoter_1hot = dna_io.dna_1hot(promoter_dna) + if promoters_df.iloc[pi].Strand == '-': + promoter_1hot = dna_io.hot1_rc(promoter_1hot) + + # optionally insert full piggyback reporter + if options.reporter : + if pi == 0 : + print("Using full reporter construct, " + reporter_left[:5] + "...", flush=True) + + reporter_left_1hot = dna_io.dna_1hot(reporter_left) + reporter_right_1hot = dna_io.dna_1hot(reporter_right) + + promoter_1hot = np.concatenate([reporter_left_1hot, promoter_1hot, reporter_right_1hot], axis=0) + elif options.reporter_bare : + if pi == 0 : + print("Using bare reporter construct, " + reporter_right_bare[:5] + "...", flush=True) + + reporter_right_bare_1hot = dna_io.dna_1hot(reporter_right_bare) + + promoter_1hot = np.concatenate([promoter_1hot, reporter_right_bare_1hot], axis=0) + + promoter_seq1[promoters_df.iloc[pi].Gene] = promoter_1hot + + ################################################################ + # insertions + + # read insertion info + insertions_df = pd.read_csv(insertions_file, sep='\t') + + # construct sequence generator + def insertion_seqs(): + for ii in range(insertions_df.shape[0]): + chrm = insertions_df.iloc[ii].seqname + pos = insertions_df.iloc[ii].position + strand = insertions_df.iloc[ii].strand + + if options.site: + flank_len = params_model['seq_length'] + flank_start = pos - flank_len//2 + flank_end = flank_start + flank_len + + # left flank + if flank_start < 0: + flank_dna = ''.join(random.choices('ACGT', k=-flank_start)) + flank_start = 0 + else: + flank_dna = '' + + # fetch DNA + flank_dna += fasta_open.fetch(chrm, flank_start, flank_end) + + # right flank + if len(flank_dna) < flank_len: + over_len = flank_len - len(flank_dna) + flank_dna += ''.join(random.choices('ACGT', k=over_len)) + + # 1 hot + insertion_1hot = dna_io.dna_1hot(flank_dna) + + else: + promoter = insertions_df.iloc[ii].promoter + promoter_1hot = promoter_seq1[promoter] + + # reverse complement + if strand == '-': + promoter_1hot = dna_io.hot1_rc(promoter_1hot) + + # get flanking sequence + flank_len = params_model['seq_length'] - promoter_1hot.shape[0] + flank_start = pos - flank_len//2 + flank_end = flank_start + flank_len + + # left flank + if flank_start < 0: + flank_dna_left = ''.join(random.choices('ACGT', k=-flank_start)) + flank_start = 0 + else: + flank_dna_left = '' + flank_dna_left += fasta_open.fetch(chrm, flank_start, pos) + flank_1hot_left = dna_io.dna_1hot(flank_dna_left) + + flank_dna_right = fasta_open.fetch(chrm, pos, flank_end) + if len(flank_dna_right) < flank_end-pos: + over_len = flank_end - pos - len(flank_dna_right) + flank_dna_right += ''.join(random.choices('ACGT', k=over_len)) + flank_1hot_right = dna_io.dna_1hot(flank_dna_right) + + # combine insertion sequence + insertion_1hot = np.concatenate([flank_1hot_left, promoter_1hot, flank_1hot_right], axis=0) + # insertion_1hot = np.expand_dims(insertion_1hot, axis=0) + + # orient promoters forward + if strand == '-': + insertion_1hot = dna_io.hot1_rc(insertion_1hot) + + assert(insertion_1hot.shape[0] == params_model['seq_length']) + + yield insertion_1hot + + # initialize prediction stream + pred_stream = stream.PredStreamGen(seqnn_model, insertion_seqs(), 1) + + ################################################################ + # predict + + # initialize h5 + preds_h5 = h5py.File('%s/preds.h5' % options.out_dir, 'w') + preds_h5.create_dataset('preds', dtype='float16', shape=(len(insertions_df), seqnn_model.target_lengths[0], len(targets_df))) + + # predictions index + pi = 0 + + # collect garbage after some amount of iterations + collect_every = 256 + + # predict for all sequences + for pi in range(insertions_df.shape[0]) : + + # get predictions + preds = pred_stream[pi] + + preds_h5['preds'][pi, ...] = preds[:, :].astype('float16') + + # collect garbage after a number of predictions + if pi % collect_every == 0 : + gc.collect() + + # save h5 + preds_h5.close() + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main()