From 26891d5bb6919e4f1bfde8796e22968028ad98e0 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Fri, 29 Jun 2018 18:01:47 -0700 Subject: [PATCH 01/71] failing with map_fn --- basenji/seqnn.py | 33 ++++++++++++++++++++++++--------- basenji/tfrecord_batcher.py | 21 ++++++++++++++------- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index c3f782f4..e62c9063 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -21,6 +21,7 @@ import numpy as np import tensorflow as tf +from basenji import augmentation from basenji import layers from basenji import params from basenji import seqnn_util @@ -43,6 +44,7 @@ def build(self, job, target_subset=None): def build_from_data_ops(self, job, data_ops, augment_rc=False, augment_shifts=[], + ensemble_rc=False, ensemble_shifts=[], target_subset=None): """Build training ops from input data ops.""" if not self.hparams_set: @@ -52,19 +54,32 @@ def build_from_data_ops(self, job, data_ops, self.inputs = data_ops['sequence'] self.targets_na = data_ops['na'] + # training data_ops w/ stochastic augmentation + data_ops_train, transform_repr_train = augmentation.augment_stochastic( + data_ops, augment_rc, augment_shifts) + + # eval data ops w/ deterministic augmentation + data_ops_eval, transform_repr_eval = augmentation.augment_deterministic( + data_ops, ensemble_rc, ensemble_shifts) + # training conditional self.is_training = tf.placeholder(tf.bool, name='is_training') - # active only via basenji_train_queues.py for TFRecords - if augment_rc or len(augment_shifts) > 0: - # augment data ops - data_ops_aug, _ = tfrecord_batcher.data_augmentation_from_data_ops( - data_ops, augment_rc, augment_shifts) - - # condition on training - data_ops = tf.cond(self.is_training, lambda: data_ops_aug, lambda: data_ops) + # condition on training + data_ops_list = tf.cond(self.is_training, + lambda: [data_ops_train], + lambda: data_ops_eval) + transform_repr_list = tf.cond(self.is_training, + lambda: [transform_repr_train], + lambda: transform_repr_eval) + + # compute representation for every input + def repr_i(i): + return transform_repr_list[i](self.build_representation(data_ops_list[i])) + seqs_repr_list = tf.map_fn(repr_i, tf.range(len(data_ops_list))) + seqs_repr = tf.reduce_mean(seqs_repr_list) + # seqs_repr = self.build_representation(data_ops, target_subset) - seqs_repr = self.build_representation(data_ops, target_subset) self.loss_op, self.loss_adhoc = self.build_loss(seqs_repr, data_ops, target_subset) self.build_optimizer(self.loss_op) diff --git a/basenji/tfrecord_batcher.py b/basenji/tfrecord_batcher.py index e8bf062b..0826d105 100755 --- a/basenji/tfrecord_batcher.py +++ b/basenji/tfrecord_batcher.py @@ -57,7 +57,7 @@ def _shift_left(_seq): return output # TODO(dbelanger) change inputs to be (features, labels) like for Estimator. -def rc_data_augmentation(dataset): +def rc_data_augmentation(dataset, stochastic=False): """Apply reverse complement to seq and flip label/na along the time axis. Args: @@ -71,13 +71,20 @@ def rc_data_augmentation(dataset): """ seq, label, na = [dataset[k] for k in ['sequence', 'label', 'na']] - do_flip = tf.random_uniform(shape=[]) > 0.5 - seq, label, na = tf.cond(do_flip, lambda: ops.reverse_complement_transform(seq, label, na), - lambda: (seq, label, na)) + if stochastic: + do_flip = tf.random_uniform(shape=[]) > 0.5 + seq, label, na = tf.cond(do_flip, lambda: ops.reverse_complement_transform(seq, label, na), + lambda: (seq, label, na)) - def process_predictions_fn(predictions): - return tf.cond(do_flip, lambda: tf.reverse(predictions, axis=[1]), - lambda: predictions) + def process_predictions_fn(predictions): + return tf.cond(do_flip, lambda: tf.reverse(predictions, axis=[1]), + lambda: predictions) + + else: + seq, label, na = ops.reverse_complement_transform(seq, label, na) + + def process_predictions_fn(predictions): + return lambda: tf.reverse(predictions, axis=[1]) transformed_dataset = {'sequence': seq, 'label': label, 'na': na} return transformed_dataset, process_predictions_fn From 41444afd3ae839278a2e9984b20e468e4ce72e6d Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 30 Jun 2018 13:11:56 -0700 Subject: [PATCH 02/71] cond refuses transform_fn --- basenji/seqnn.py | 24 +++++++++++++++--------- bin/basenji_train_queues.py | 4 +++- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index e62c9063..8ec54eab 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -57,9 +57,11 @@ def build_from_data_ops(self, job, data_ops, # training data_ops w/ stochastic augmentation data_ops_train, transform_repr_train = augmentation.augment_stochastic( data_ops, augment_rc, augment_shifts) + data_ops_train = [data_ops_train] + transform_repr_train = [transform_repr_train] # eval data ops w/ deterministic augmentation - data_ops_eval, transform_repr_eval = augmentation.augment_deterministic( + data_ops_eval, transform_repr_eval = augmentation.augment_deterministic_set( data_ops, ensemble_rc, ensemble_shifts) # training conditional @@ -67,16 +69,16 @@ def build_from_data_ops(self, job, data_ops, # condition on training data_ops_list = tf.cond(self.is_training, - lambda: [data_ops_train], - lambda: data_ops_eval) + lambda: data_ops_train, + lambda: data_ops_eval, strict=False) transform_repr_list = tf.cond(self.is_training, - lambda: [transform_repr_train], - lambda: transform_repr_eval) + lambda: transform_repr_train, + lambda: transform_repr_eval, strict=False) # compute representation for every input - def repr_i(i): - return transform_repr_list[i](self.build_representation(data_ops_list[i])) - seqs_repr_list = tf.map_fn(repr_i, tf.range(len(data_ops_list))) + map_elems = (data_ops_list, transform_repr_list) + build_rep = lambda me: self.build_representation(me[0], me[1], target_subset) + seqs_repr_list = tf.map_fn(build_rep, map_elems) # back_prop=False seqs_repr = tf.reduce_mean(seqs_repr_list) # seqs_repr = self.build_representation(data_ops, target_subset) @@ -126,7 +128,7 @@ def _make_conv_block_args(self, layer_index): 'name': 'conv-%d' % layer_index } - def build_representation(self, data_ops, target_subset): + def build_representation(self, data_ops, transform_preds_fn=lambda x: x, target_subset=None): """Construct per-location real-valued predictions.""" inputs = data_ops['sequence'] assert inputs is not None @@ -213,6 +215,10 @@ def build_representation(self, data_ops, target_subset): (self.hp.batch_size, -1, self.hp.num_targets, self.hp.target_classes)) + + # transform for reverse complement + final_repr = transform_preds_fn(final_repr) + return final_repr def build_optimizer(self, loss_op): diff --git a/bin/basenji_train_queues.py b/bin/basenji_train_queues.py index 0a29fb1b..b3ca245d 100755 --- a/bin/basenji_train_queues.py +++ b/bin/basenji_train_queues.py @@ -57,7 +57,9 @@ def run(params_file, train_file, test_file, train_epochs, train_epoch_batches, # initialize model model = seqnn.SeqNN() - model.build_from_data_ops(job, data_ops, FLAGS.augment_rc, augment_shifts) + model.build_from_data_ops(job, data_ops, + FLAGS.augment_rc, augment_shifts, + FLAGS.ensemble_rc, ensemble_shifts) # checkpoints saver = tf.train.Saver() From 0bb926d3c51deaedeb4107b779dff657e9566ee4 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 30 Jun 2018 20:14:50 -0700 Subject: [PATCH 03/71] another failed attempt --- basenji/seqnn.py | 24 +++---- basenji/tfrecord_batcher.py | 137 ------------------------------------ 2 files changed, 10 insertions(+), 151 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 8ec54eab..e4c199e4 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -55,13 +55,11 @@ def build_from_data_ops(self, job, data_ops, self.targets_na = data_ops['na'] # training data_ops w/ stochastic augmentation - data_ops_train, transform_repr_train = augmentation.augment_stochastic( + data_ops_train = augmentation.augment_stochastic( data_ops, augment_rc, augment_shifts) - data_ops_train = [data_ops_train] - transform_repr_train = [transform_repr_train] # eval data ops w/ deterministic augmentation - data_ops_eval, transform_repr_eval = augmentation.augment_deterministic_set( + data_ops_eval = augmentation.augment_deterministic_set( data_ops, ensemble_rc, ensemble_shifts) # training conditional @@ -69,16 +67,12 @@ def build_from_data_ops(self, job, data_ops, # condition on training data_ops_list = tf.cond(self.is_training, - lambda: data_ops_train, - lambda: data_ops_eval, strict=False) - transform_repr_list = tf.cond(self.is_training, - lambda: transform_repr_train, - lambda: transform_repr_eval, strict=False) + lambda: [data_ops_train], + lambda: data_ops_eval, strict=True) # compute representation for every input - map_elems = (data_ops_list, transform_repr_list) - build_rep = lambda me: self.build_representation(me[0], me[1], target_subset) - seqs_repr_list = tf.map_fn(build_rep, map_elems) # back_prop=False + build_rep = lambda do: self.build_representation(do, target_subset) + seqs_repr_list = tf.map_fn(build_rep, data_ops_list) # back_prop=False seqs_repr = tf.reduce_mean(seqs_repr_list) # seqs_repr = self.build_representation(data_ops, target_subset) @@ -128,7 +122,7 @@ def _make_conv_block_args(self, layer_index): 'name': 'conv-%d' % layer_index } - def build_representation(self, data_ops, transform_preds_fn=lambda x: x, target_subset=None): + def build_representation(self, data_ops, target_subset=None): """Construct per-location real-valued predictions.""" inputs = data_ops['sequence'] assert inputs is not None @@ -217,7 +211,9 @@ def build_representation(self, data_ops, transform_preds_fn=lambda x: x, target_ # transform for reverse complement - final_repr = transform_preds_fn(final_repr) + final_repr = tf.cond(data_ops['reverse_preds'], + lambda: tf.reverse(final_repr, axis=1), + lambda: final_repr) return final_repr diff --git a/basenji/tfrecord_batcher.py b/basenji/tfrecord_batcher.py index 0826d105..cd22a14b 100755 --- a/basenji/tfrecord_batcher.py +++ b/basenji/tfrecord_batcher.py @@ -28,143 +28,6 @@ # datasets. NUM_FILES_TO_PARALLEL_INTERLEAVE = 10 -def shift_sequence(seq, shift_amount, pad_value): - """Shift a sequence left or right by shift_amount. - Args: - seq: a [batch_size, sequence_length, sequence_depth] sequence to shift - shift_amount: the signed amount to shift (tf.int32 or int) - pad_value: value to fill the padding (primitive or scalar tf.Tensor) - """ - if seq.shape.ndims != 3: - raise ValueError('input sequence should be rank 3') - input_shape = seq.shape - - pad = pad_value * tf.ones_like(seq[:, 0:tf.abs(shift_amount), :]) - - def _shift_right(_seq): - sliced_seq = _seq[:, :-shift_amount:, :] - return tf.concat([pad, sliced_seq], axis=1) - - def _shift_left(_seq): - sliced_seq = _seq[:, -shift_amount:, :] - return tf.concat([sliced_seq, pad], axis=1) - - output = tf.cond( - tf.greater(shift_amount, 0), lambda: _shift_right(seq), - lambda: _shift_left(seq)) - - output.set_shape(input_shape) - return output - -# TODO(dbelanger) change inputs to be (features, labels) like for Estimator. -def rc_data_augmentation(dataset, stochastic=False): - """Apply reverse complement to seq and flip label/na along the time axis. - - Args: - dataset: dict with keys 'sequence,' 'label,' and 'na.' - Returns - transformed_dataset: augmented data - process_predictions_fn: callable to be applied to predictions - such that they are directly comparable to the input dataset['label'] - rather than transformed_dataset['label']. Here, it flips the prediction - along the time axis. - """ - seq, label, na = [dataset[k] for k in ['sequence', 'label', 'na']] - - if stochastic: - do_flip = tf.random_uniform(shape=[]) > 0.5 - seq, label, na = tf.cond(do_flip, lambda: ops.reverse_complement_transform(seq, label, na), - lambda: (seq, label, na)) - - def process_predictions_fn(predictions): - return tf.cond(do_flip, lambda: tf.reverse(predictions, axis=[1]), - lambda: predictions) - - else: - seq, label, na = ops.reverse_complement_transform(seq, label, na) - - def process_predictions_fn(predictions): - return lambda: tf.reverse(predictions, axis=[1]) - - transformed_dataset = {'sequence': seq, 'label': label, 'na': na} - return transformed_dataset, process_predictions_fn - - -def shift_sequence_augmentation(seq, shift_augment_offsets, pad_value): - """Shift seq by a random amount. Pad to maintain the input size. - - Args: - seq: input sequence of size [batch_size, length, depth] - shift_augment_offsets: list of int offsets to sample from. If `None` or - `[]`, then only "shift" by 0 (the identity). - pad_value: value to fill the padding with. - Returns: - shifted and padded sequence of size [batch_size, length, depth] - """ - # The value of the parameter shift_augment_offsets are the set of things to - # _augment_ the original data with, and we want to, in addition to including - # those augmentations, actually include the original data. - total_set_of_shifts = [] - if shift_augment_offsets: - total_set_of_shifts += shift_augment_offsets - if 0 not in total_set_of_shifts: - total_set_of_shifts.append(0) - - shift_index = tf.random_uniform( - shape=[], minval=0, maxval=len(total_set_of_shifts), dtype=tf.int64) - shift_value = tf.gather(tf.constant(total_set_of_shifts), shift_index) - - seq = tf.cond( - tf.not_equal(shift_value, 0), - lambda: shift_sequence(seq, shift_value, pad_value), lambda: seq) - - return seq - - -def apply_data_augmentation(input_ops, label_ops, augment_with_complement, - shift_augment_offsets): - """Apply data augmentation to input and label ops. - Args: - input_ops: dict containing input Tensors. - label_ops: dict containing label Tensors. - augment_with_complement: whether to do reverse complement augmentation. - shift_augment_offsets: offsets used for doing shift-based augmentation. - Can be `None` or `[]` to indicate no shift-augmentation. - - Returns: - transformed_inputs: inputs with augmentation applied. - transformed_labels: labels transformed in accordance with the augmentation. - process_predictions_fn: callable to be applied to predictions - such that they are directly comparable to the label_ops - rather than transformed_labels. - """ - data_ops = {} - data_ops.update(input_ops) - data_ops.update(label_ops) - - augmented_data_ops, process_predictions_fn = data_augmentation_from_data_ops( - data_ops, augment_with_complement, shift_augment_offsets) - return ({ - 'sequence': augmented_data_ops['sequence'] - }, {name: augmented_data_ops[name] - for name in ['label', 'na']}, process_predictions_fn) - -# TODO(dbelanger) switch to directly calling apply_data_augmentation -def data_augmentation_from_data_ops(data_ops, augment_with_complement, - shift_augment_offsets): - process_predictions_fn = None - - if shift_augment_offsets and len(shift_augment_offsets) > 1: - pad_value = 0.25 - data_ops['sequence'] = shift_sequence_augmentation( - data_ops['sequence'], shift_augment_offsets, pad_value) - - if augment_with_complement: - data_ops, process_predictions_fn = rc_data_augmentation(data_ops) - - return data_ops, process_predictions_fn - - def tfrecord_dataset(tfr_data_files_pattern, batch_size, seq_length, From 32a529c19ec90dda9984c6bd9601662e993fdfa4 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 30 Jun 2018 20:39:35 -0700 Subject: [PATCH 04/71] moving cond downstream but map_fn still fails --- basenji/seqnn.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index e4c199e4..a1612bb4 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -54,6 +54,9 @@ def build_from_data_ops(self, job, data_ops, self.inputs = data_ops['sequence'] self.targets_na = data_ops['na'] + # training conditional + self.is_training = tf.placeholder(tf.bool, name='is_training') + # training data_ops w/ stochastic augmentation data_ops_train = augmentation.augment_stochastic( data_ops, augment_rc, augment_shifts) @@ -62,19 +65,17 @@ def build_from_data_ops(self, job, data_ops, data_ops_eval = augmentation.augment_deterministic_set( data_ops, ensemble_rc, ensemble_shifts) - # training conditional - self.is_training = tf.placeholder(tf.bool, name='is_training') + # compute train representation + seqs_repr_train = self.build_representation(data_ops_train, target_subset) - # condition on training - data_ops_list = tf.cond(self.is_training, - lambda: [data_ops_train], - lambda: data_ops_eval, strict=True) + pdb.set_trace() - # compute representation for every input + # compute eval representation build_rep = lambda do: self.build_representation(do, target_subset) - seqs_repr_list = tf.map_fn(build_rep, data_ops_list) # back_prop=False - seqs_repr = tf.reduce_mean(seqs_repr_list) - # seqs_repr = self.build_representation(data_ops, target_subset) + seqs_repr_list = tf.map_fn(build_rep, data_ops_eval, dtype=seqs_repr_train.dtype) # back_prop=False + seqs_repr_eval = tf.reduce_mean(seqs_repr_list) + + seqs_repr = tf.cond(self.is_training, lambda: seqs_repr_train, lambda: seqs_repr_eval) self.loss_op, self.loss_adhoc = self.build_loss(seqs_repr, data_ops, target_subset) self.build_optimizer(self.loss_op) @@ -209,10 +210,9 @@ def build_representation(self, data_ops, target_subset=None): (self.hp.batch_size, -1, self.hp.num_targets, self.hp.target_classes)) - # transform for reverse complement final_repr = tf.cond(data_ops['reverse_preds'], - lambda: tf.reverse(final_repr, axis=1), + lambda: tf.reverse(final_repr, axis=[1]), lambda: final_repr) return final_repr From bf42d4e920639fda0e59a043f57439ff21dd6fde Mon Sep 17 00:00:00 2001 From: David Kelley Date: Tue, 3 Jul 2018 13:18:04 -0700 Subject: [PATCH 05/71] separate train and eval loss --- basenji/seqnn.py | 147 ++++++++++++++++++++++-------------------- basenji/seqnn_util.py | 18 +++--- 2 files changed, 87 insertions(+), 78 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index a1612bb4..3d604712 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -57,28 +57,44 @@ def build_from_data_ops(self, job, data_ops, # training conditional self.is_training = tf.placeholder(tf.bool, name='is_training') + ################################################## + # training + # training data_ops w/ stochastic augmentation data_ops_train = augmentation.augment_stochastic( data_ops, augment_rc, augment_shifts) + # compute train representation + seqs_repr_train = self.build_representation(data_ops_train['sequence'], + None, target_subset) + + # training losses + loss_returns = self.build_loss(seqs_repr_train, data_ops_train, target_subset) + self.loss_train, self.loss_train_targets, self.preds_train, self.targets_train = loss_returns + + # optimizer + self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + self.build_optimizer(self.loss_train) + + ################################################## + # eval + # eval data ops w/ deterministic augmentation data_ops_eval = augmentation.augment_deterministic_set( data_ops, ensemble_rc, ensemble_shifts) - - # compute train representation - seqs_repr_train = self.build_representation(data_ops_train, target_subset) - - pdb.set_trace() + data_seq_eval = tf.stack([do['sequence'] for do in data_ops_eval]) + data_rev_eval = tf.stack([do['reverse_preds'] for do in data_ops_eval]) # compute eval representation - build_rep = lambda do: self.build_representation(do, target_subset) - seqs_repr_list = tf.map_fn(build_rep, data_ops_eval, dtype=seqs_repr_train.dtype) # back_prop=False - seqs_repr_eval = tf.reduce_mean(seqs_repr_list) + map_elems_eval = (data_seq_eval, data_rev_eval) + build_rep = lambda do: self.build_representation(do[0], do[1], target_subset) + seqs_repr_list = tf.map_fn(build_rep, map_elems_eval, dtype=seqs_repr_train.dtype) # back_prop=False + seqs_repr_eval = tf.reduce_mean(seqs_repr_list, axis=0) - seqs_repr = tf.cond(self.is_training, lambda: seqs_repr_train, lambda: seqs_repr_eval) + # eval loss + loss_returns = self.build_loss(seqs_repr_eval, data_ops, target_subset) + self.loss_eval, self.loss_eval_targets, self.preds_eval, self.targets_eval = loss_returns - self.loss_op, self.loss_adhoc = self.build_loss(seqs_repr, data_ops, target_subset) - self.build_optimizer(self.loss_op) def make_placeholders(self): """Allocates placeholders to be used in place of input data ops.""" @@ -109,7 +125,7 @@ def make_placeholders(self): } return data - def _make_conv_block_args(self, layer_index): + def _make_conv_block_args(self, layer_index, layer_reprs): """Packages arguments to be used by layers.conv_block.""" return { 'conv_params': self.hp.cnn_params[layer_index], @@ -119,37 +135,39 @@ def _make_conv_block_args(self, layer_index): 'batch_renorm': self.hp.batch_renorm, 'batch_renorm_momentum': self.hp.batch_renorm_momentum, 'l2_scale': self.hp.cnn_l2_scale, - 'layer_reprs': self.layer_reprs, + 'layer_reprs': layer_reprs, 'name': 'conv-%d' % layer_index } - def build_representation(self, data_ops, target_subset=None): + def build_representation(self, inputs, reverse_preds=None, target_subset=None): """Construct per-location real-valued predictions.""" - inputs = data_ops['sequence'] assert inputs is not None - print('Targets pooled by %d to length %d' % (self.hp.target_pool, self.hp.seq_length // self.hp.target_pool)) ################################################### # convolution layers ################################################### - self.filter_weights = [] - self.layer_reprs = [inputs] + filter_weights = [] + layer_reprs = [inputs] seqs_repr = inputs for layer_index in range(self.hp.cnn_layers): - with tf.variable_scope('cnn%d' % layer_index): + with tf.variable_scope('cnn%d' % layer_index, reuse=tf.AUTO_REUSE): # convolution block - args_for_block = self._make_conv_block_args(layer_index) + args_for_block = self._make_conv_block_args(layer_index, layer_reprs) seqs_repr = layers.conv_block(seqs_repr=seqs_repr, **args_for_block) # save representation - self.layer_reprs.append(seqs_repr) + layer_reprs.append(seqs_repr) # final nonlinearity seqs_repr = tf.nn.relu(seqs_repr) + ################################################### + # slice out side buffer + ################################################### + # update batch buffer to reflect pooling seq_length = seqs_repr.shape[1].value pool_preds = self.hp.seq_length // seq_length @@ -158,26 +176,18 @@ def build_representation(self, data_ops, target_subset=None): ' by the CNN pooling %d') % (self.hp.batch_buffer, pool_preds) batch_buffer_pool = self.hp.batch_buffer // pool_preds - - ################################################### - # slice out side buffer - ################################################### - - # predictions + # slice out buffer seq_length = seqs_repr.shape[1] seqs_repr = seqs_repr[:, batch_buffer_pool: seq_length - batch_buffer_pool, :] - seq_length = seqs_repr.shape[1].value - self.preds_length = seq_length # save penultimate representation - self.penultimate_op = seqs_repr - + # self.penultimate_op = seqs_repr ################################################### # final layer ################################################### - with tf.variable_scope('final'): + with tf.variable_scope('final', reuse=tf.AUTO_REUSE): final_filters = self.hp.num_targets * self.hp.target_classes final_repr = tf.layers.dense( inputs=seqs_repr, @@ -211,9 +221,10 @@ def build_representation(self, data_ops, target_subset=None): self.hp.target_classes)) # transform for reverse complement - final_repr = tf.cond(data_ops['reverse_preds'], - lambda: tf.reverse(final_repr, axis=[1]), - lambda: final_repr) + if reverse_preds is not None: + final_repr = tf.cond(reverse_preds, + lambda: tf.reverse(final_repr, axis=[1]), + lambda: final_repr) return final_repr @@ -266,8 +277,6 @@ def build_optimizer(self, loss_op): self.step_op = self.opt.apply_gradients( self.gvs, global_step=self.global_step) - self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - # summary self.merged_summary = tf.summary.merge_all() @@ -278,7 +287,6 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): # targets tstart = self.hp.batch_buffer // self.hp.target_pool tend = (self.hp.seq_length - self.hp.batch_buffer) // self.hp.target_pool - self.target_length = tend - tstart targets = data_ops['label'] targets = tf.identity(targets[:, tstart:tend, :], name='targets_op') @@ -287,8 +295,8 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): targets = tf.gather(targets, target_subset, axis=2) # work-around for specifying my own predictions - self.preds_adhoc = tf.placeholder( - tf.float32, shape=seqs_repr.shape, name='preds-adhoc') + # self.preds_adhoc = tf.placeholder( + # tf.float32, shape=seqs_repr.shape, name='preds-adhoc') # float 32 exponential clip max # exp_max = np.floor(np.log(0.5*tf.float32.max)) @@ -296,17 +304,17 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): # choose link if self.hp.link in ['identity', 'linear']: - self.preds_op = tf.identity(seqs_repr, name='preds') + preds_op = tf.identity(seqs_repr, name='preds') elif self.hp.link == 'relu': - self.preds_op = tf.relu(seqs_repr, name='preds') + preds_op = tf.relu(seqs_repr, name='preds') elif self.hp.link == 'exp': seqs_repr_clip = tf.clip_by_value(seqs_repr, -exp_max, exp_max) - self.preds_op = tf.exp(seqs_repr_clip, name='preds') + preds_op = tf.exp(seqs_repr_clip, name='preds') elif self.hp.link == 'exp_linear': - self.preds_op = tf.where( + preds_op = tf.where( seqs_repr > 0, seqs_repr + 1, tf.exp(tf.clip_by_value(seqs_repr, -exp_max, exp_max)), @@ -314,7 +322,7 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): elif self.hp.link == 'softplus': seqs_repr_clip = tf.clip_by_value(seqs_repr, -exp_max, 10000) - self.preds_op = tf.nn.softplus(seqs_repr_clip, name='preds') + preds_op = tf.nn.softplus(seqs_repr_clip, name='preds') elif self.hp.link == 'softmax': # performed in the loss function, but saving probabilities @@ -326,33 +334,33 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): # clip if self.hp.target_clip is not None: - self.preds_op = tf.clip_by_value(self.preds_op, 0, self.hp.target_clip) + preds_op = tf.clip_by_value(preds_op, 0, self.hp.target_clip) targets = tf.clip_by_value(targets, 0, self.hp.target_clip) # sqrt if self.hp.target_sqrt: - self.preds_op = tf.sqrt(self.preds_op) + preds_op = tf.sqrt(preds_op) targets = tf.sqrt(targets) loss_op = None - loss_adhoc = None + # loss_adhoc = None # choose loss if self.hp.loss == 'gaussian': - loss_op = tf.squared_difference(self.preds_op, targets) - loss_adhoc = tf.squared_difference(self.preds_adhoc, targets) + loss_op = tf.squared_difference(preds_op, targets) + # loss_adhoc = tf.squared_difference(self.preds_adhoc, targets) elif self.hp.loss == 'poisson': loss_op = tf.nn.log_poisson_loss( - targets, tf.log(self.preds_op), compute_full_loss=True) - loss_adhoc = tf.nn.log_poisson_loss( - targets, tf.log(self.preds_adhoc), compute_full_loss=True) + targets, tf.log(preds_op), compute_full_loss=True) + # loss_adhoc = tf.nn.log_poisson_loss( + # targets, tf.log(self.preds_adhoc), compute_full_loss=True) elif self.hp.loss == 'cross_entropy': loss_op = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=(targets - 1), logits=self.preds_op) - loss_adhoc = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=(targets - 1), logits=self.preds_adhoc) + labels=(targets - 1), logits=preds_op) + # loss_adhoc = tf.nn.sparse_softmax_cross_entropy_with_logits( + # labels=(targets - 1), logits=self.preds_adhoc) else: raise ValueError('Cannot identify loss function %s' % self.hp.loss) @@ -361,29 +369,29 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): loss_op = tf.reduce_mean(loss_op, axis=[0, 1], name='target_loss') loss_op = tf.check_numerics(loss_op, 'Invalid loss', name='loss_check') - loss_adhoc = tf.reduce_mean( - loss_adhoc, axis=[0, 1], name='target_loss_adhoc') + # loss_adhoc = tf.reduce_mean( + # loss_adhoc, axis=[0, 1], name='target_loss_adhoc') tf.summary.histogram('target_loss', loss_op) for ti in np.linspace(0, self.hp.num_targets - 1, 10).astype('int'): tf.summary.scalar('loss_t%d' % ti, loss_op[ti]) - self.target_losses = loss_op - self.target_losses_adhoc = loss_adhoc + target_losses = loss_op + # self.target_losses_adhoc = loss_adhoc # fully reduce loss_op = tf.reduce_mean(loss_op, name='loss') - loss_adhoc = tf.reduce_mean(loss_adhoc, name='loss_adhoc') + # loss_adhoc = tf.reduce_mean(loss_adhoc, name='loss_adhoc') # add regularization terms reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) reg_sum = tf.reduce_sum(reg_losses) tf.summary.scalar('regularizers', reg_sum) loss_op += reg_sum - loss_adhoc += reg_sum + # loss_adhoc += reg_sum # track tf.summary.scalar('loss', loss_op) - self.targets_op = targets - return loss_op, loss_adhoc + + return loss_op, target_losses, preds_op, targets def set_mode(self, mode): @@ -438,12 +446,12 @@ def train_epoch(self, fd[self.targets_na] = NAb if no_steps: - run_returns = sess.run([self.merged_summary, self.loss_op] + \ + run_returns = sess.run([self.merged_summary, self.loss_train] + \ self.update_ops, feed_dict=fd) summary, loss_batch = run_returns[:2] else: run_returns = sess.run( - [self.merged_summary, self.loss_op, self.global_step, self.step_op] + self.update_ops, + [self.merged_summary, self.loss_train, self.global_step, self.step_op] + self.update_ops, feed_dict=fd) summary, loss_batch, global_step = run_returns[:3] @@ -482,10 +490,11 @@ def train_epoch_from_data_ops(self, data_available = True batch_num = 0 while data_available and (epoch_batches is None or batch_num < epoch_batches): + print(batch_num) try: - run_returns = sess.run( - [self.merged_summary, self.loss_op, self.global_step, self.step_op] + self.update_ops, - feed_dict=fd) + # update_ops won't run + run_ops = [self.merged_summary, self.loss_train, self.global_step, self.step_op] + self.update_ops + run_returns = sess.run(run_ops, feed_dict=fd) summary, loss_batch, global_step = run_returns[:3] # add summary diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 3d2c8110..76c06b73 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -27,7 +27,7 @@ def build_grads(self, layers=[0]): self.grad_ops = [] for ti in range(self.hp.num_targets): - grad_ti_op = tf.gradients(self.preds_op[:,:,ti], [self.layer_reprs[li] for li in self.grad_layers]) + grad_ti_op = tf.gradients(self.preds_eval[:,:,ti], [self.layer_reprs[li] for li in self.grad_layers]) self.grad_ops.append(grad_ti_op) @@ -60,7 +60,7 @@ def build_grads_genes(self, gene_seqs, layers=[0]): if pi in tss_pos: # build position-specific, target-specific gradient ops for ti in range(self.hp.num_targets): - grad_piti_op = tf.gradients(self.preds_op[:,pi,ti], + grad_piti_op = tf.gradients(self.preds_eval[:,pi,ti], [self.layer_reprs[li] for li in self.grad_layers]) self.grad_pos_ops[-1].append(grad_piti_op) @@ -312,7 +312,7 @@ def _gradients_ensemble(self, sess, fd, Xb, ensemble_fwdrc, ensemble_shifts, mc_ # prediction # predict - preds_ei, layer_reprs_ei = sess.run([self.preds_op, self.layer_reprs], feed_dict=fd) + preds_ei, layer_reprs_ei = sess.run([self.preds_eval, self.layer_reprs], feed_dict=fd) # reverse if ensemble_fwdrc[ei] is False: @@ -451,7 +451,7 @@ def gradients_genes(self, sess, batcher, gene_seqs): fd[self.inputs] = Xb # predict - reprs_batch, _ = sess.run([self.layer_reprs, self.preds_op], feed_dict=fd) + reprs_batch, _ = sess.run([self.layer_reprs, self.preds_eval], feed_dict=fd) # save representations for lii in range(len(self.grad_layers)): @@ -511,7 +511,7 @@ def hidden(self, sess, batcher, layers=None): # compute predictions layer_reprs_batch, preds_batch = sess.run( - [self.layer_reprs, self.preds_op], feed_dict=fd) + [self.layer_reprs, self.preds_eval], feed_dict=fd) # accumulate representationsmakes the number of members for self smaller and also for li in layers: @@ -599,7 +599,7 @@ def _predict_ensemble(self, if penultimate: preds_ei = sess.run(self.penultimate_op, feed_dict=fd) else: - preds_ei = sess.run(self.preds_op, feed_dict=fd) + preds_ei = sess.run(self.preds_eval, feed_dict=fd) # reverse if ensemble_fwdrc[ei] is False: @@ -866,8 +866,8 @@ def test_from_data_ops(self, sess, test_batches=None): while data_available and (test_batches is None or batch_num < test_batches): try: # make non-ensembled predictions - run_ops = [self.targets_op, self.preds_op, - self.loss_op, self.target_losses] + run_ops = [self.targets_eval, self.preds_eval, + self.loss_eval, self.loss_eval_targets] run_returns = sess.run(run_ops, feed_dict=fd) targets_batch, preds_batch, loss_batch, target_losses_batch = run_returns @@ -974,7 +974,7 @@ def test(self, # recompute loss w/ ensembled prediction fd[self.preds_adhoc] = preds_batch targets_batch, loss_batch, target_losses_batch = sess.run( - [self.targets_op, self.loss_adhoc, self.target_losses_adhoc], + [self.targets_train, self.loss_adhoc, self.target_losses_adhoc], feed_dict=fd) # accumulate predictions and targets From 76f3340afffa162bff98a924eb992af312e69616 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Tue, 3 Jul 2018 18:33:27 -0700 Subject: [PATCH 06/71] needs preds_length --- basenji/seqnn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 3d604712..e33c9372 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -95,6 +95,9 @@ def build_from_data_ops(self, job, data_ops, loss_returns = self.build_loss(seqs_repr_eval, data_ops, target_subset) self.loss_eval, self.loss_eval_targets, self.preds_eval, self.targets_eval = loss_returns + # helper variables + self.preds_length = self.preds_train.shape[1] + def make_placeholders(self): """Allocates placeholders to be used in place of input data ops.""" @@ -490,7 +493,6 @@ def train_epoch_from_data_ops(self, data_available = True batch_num = 0 while data_available and (epoch_batches is None or batch_num < epoch_batches): - print(batch_num) try: # update_ops won't run run_ops = [self.merged_summary, self.loss_train, self.global_step, self.step_op] + self.update_ops From 2e57fd6409e63b1e6e280d6d4f35c9638e40644a Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 7 Jul 2018 10:46:15 -0700 Subject: [PATCH 07/71] augmentation methods --- basenji/augmentation.py | 164 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 basenji/augmentation.py diff --git a/basenji/augmentation.py b/basenji/augmentation.py new file mode 100644 index 00000000..340c313e --- /dev/null +++ b/basenji/augmentation.py @@ -0,0 +1,164 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +import pdb +import tensorflow as tf + +from basenji import ops + +def shift_sequence(seq, shift_amount, pad_value=0.25): + """Shift a sequence left or right by shift_amount. + + Args: + seq: a [batch_size, sequence_length, sequence_depth] sequence to shift + shift_amount: the signed amount to shift (tf.int32 or int) + pad_value: value to fill the padding (primitive or scalar tf.Tensor) + """ + if seq.shape.ndims != 3: + raise ValueError('input sequence should be rank 3') + input_shape = seq.shape + + pad = pad_value * tf.ones_like(seq[:, 0:tf.abs(shift_amount), :]) + + def _shift_right(_seq): + sliced_seq = _seq[:, :-shift_amount:, :] + return tf.concat([pad, sliced_seq], axis=1) + + def _shift_left(_seq): + sliced_seq = _seq[:, -shift_amount:, :] + return tf.concat([sliced_seq, pad], axis=1) + + output = tf.cond( + tf.greater(shift_amount, 0), lambda: _shift_right(seq), + lambda: _shift_left(seq)) + + output.set_shape(input_shape) + return output + +def augment_deterministic_set(data_ops, augment_rc=False, augment_shifts=[0]): + """ + + Args: + data_ops: dict with keys 'sequence,' 'label,' and 'na.' + augment_rc: Boolean + augment_shifts: List of ints. + Returns + data_ops_list: + """ + augment_pairs = [] + for ashift in augment_shifts: + augment_pairs.append((False, ashift)) + if augment_rc: + augment_pairs.append((True, ashift)) + + data_ops_list = [] + for arc, ashift in augment_pairs: + data_ops_aug = augment_deterministic(data_ops, arc, ashift) + data_ops_list.append(data_ops_aug) + + return data_ops_list + + +def augment_deterministic(data_ops, augment_rc=False, augment_shift=0): + """Apply a deterministic augmentation, specified by the parameters. + + Args: + data_ops: dict with keys 'sequence,' 'label,' and 'na.' + augment_rc: Boolean + augment_shifts: Int + Returns + data_ops: augmented data + """ + if augment_shift != 0: + shift_amount = tf.constant(augment_shift, shape=(), dtype=tf.int64) + data_ops['sequence'] = shift_sequence(data_ops['sequence'], shift_amount) + + if augment_rc: + data_ops = augment_deterministic_rc(data_ops) + else: + data_ops['reverse_preds'] = tf.zeros((), dtype=tf.bool) + + return data_ops + + +def augment_deterministic_rc(data_ops): + """Apply a deterministic reverse complement augmentation. + + Args: + data_ops: dict with keys 'sequence,' 'label,' and 'na.' + Returns + data_ops_aug: augmented data ops + """ + seq, label, na = [data_ops[k] for k in ['sequence', 'label', 'na']] + seq, label, na = ops.reverse_complement_transform(seq, label, na) + reverse_preds = tf.ones((), dtype=tf.bool) + data_ops_aug = {'sequence': seq, 'label': label, 'na': na, 'reverse_preds':reverse_preds} + return data_ops_aug + + +def augment_stochastic_rc(data_ops): + """Apply a stochastic reverse complement augmentation. + + Args: + data_ops: dict with keys 'sequence,' 'label,' and 'na.' + Returns + data_ops_aug: augmented data + """ + seq, label, na = [data_ops[k] for k in ['sequence', 'label', 'na']] + reverse_preds = tf.random_uniform(shape=[]) > 0.5 + seq, label, na = tf.cond(reverse_preds, lambda: ops.reverse_complement_transform(seq, label, na), + lambda: (seq, label, na)) + data_ops_aug = {'sequence': seq, 'label': label, 'na': na, 'reverse_preds':reverse_preds} + return data_ops_aug + + +def augment_stochastic_shifts(seq, augment_shifts): + """Apply a stochastic shift augmentation. + + Args: + seq: input sequence of size [batch_size, length, depth] + augment_shifts: list of int offsets to sample from + Returns: + shifted and padded sequence of size [batch_size, length, depth] + """ + shift_index = tf.random_uniform(shape=[], minval=0, + maxval=len(augment_shifts), dtype=tf.int64) + shift_value = tf.gather(tf.constant(augment_shifts), shift_index) + + seq = tf.cond(tf.not_equal(shift_value, 0), + lambda: shift_sequence(seq, shift_value), + lambda: seq) + + return seq + + +def augment_stochastic(data_ops, augment_rc=False, augment_shifts=[]): + """Apply stochastic augmentations, + + Args: + data_ops: dict with keys 'sequence,' 'label,' and 'na.' + augment_rc: Boolean for whether to apply reverse complement augmentation. + augment_shifts: list of int offsets to sample shift augmentations. + Returns: + data_ops_aug: augmented data + """ + if augment_shifts: + data_ops['sequence'] = augment_stochastic_shifts(data_ops['sequence'], + augment_shifts) + + if augment_rc: + data_ops = augment_stochastic_rc(data_ops) + else: + data_ops['reverse_preds'] = tf.zeros((), dtype=tf.bool) + + return data_ops From 2797ccb149e924885156b644b394dece35fedf9b Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 7 Jul 2018 10:48:57 -0700 Subject: [PATCH 08/71] h5 in-graph augmentation --- basenji/seqnn.py | 97 ++++++++++++---- bin/basenji_train_h5.py | 216 ++++++++++++++++++++++++++++++++++++ bin/basenji_train_queues.py | 2 +- 3 files changed, 291 insertions(+), 24 deletions(-) create mode 100755 bin/basenji_train_h5.py diff --git a/basenji/seqnn.py b/basenji/seqnn.py index e33c9372..a6f8471f 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -33,14 +33,20 @@ def __init__(self): self.global_step = tf.train.get_or_create_global_step() self.hparams_set = False - def build(self, job, target_subset=None): + def build(self, job, augment_rc=False, augment_shifts=[], + ensemble_rc=False, ensemble_shifts=[], target_subset=None): """Build training ops that depend on placeholders.""" self.hp = params.make_hparams(job) self.hparams_set = True data_ops = self.make_placeholders() - self.build_from_data_ops(job, data_ops, target_subset=target_subset) + self.build_from_data_ops(job, data_ops, + augment_rc=augment_rc, + augment_shifts=augment_shifts, + ensemble_rc=ensemble_rc, + ensemble_shifts=ensemble_shifts, + target_subset=target_subset) def build_from_data_ops(self, job, data_ops, augment_rc=False, augment_shifts=[], @@ -50,9 +56,6 @@ def build_from_data_ops(self, job, data_ops, if not self.hparams_set: self.hp = params.make_hparams(job) self.hparams_set = True - self.targets = data_ops['label'] - self.inputs = data_ops['sequence'] - self.targets_na = data_ops['na'] # training conditional self.is_training = tf.placeholder(tf.bool, name='is_training') @@ -102,29 +105,26 @@ def build_from_data_ops(self, job, data_ops, def make_placeholders(self): """Allocates placeholders to be used in place of input data ops.""" # batches - self.inputs = tf.placeholder( + self.inputs_ph = tf.placeholder( tf.float32, shape=(self.hp.batch_size, self.hp.seq_length, self.hp.seq_depth), name='inputs') if self.hp.target_classes == 1: - self.targets = tf.placeholder( + self.targets_ph = tf.placeholder( tf.float32, shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool, self.hp.num_targets), name='targets') else: - self.targets = tf.placeholder( + self.targets_ph = tf.placeholder( tf.int32, shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool, self.hp.num_targets), name='targets') - self.targets_na = tf.placeholder( - tf.bool, shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool)) data = { - 'sequence': self.inputs, - 'label': self.targets, - 'na': self.targets_na + 'sequence': self.inputs_ph, + 'label': self.targets_ph } return data @@ -419,7 +419,7 @@ def set_mode(self, mode): return fd - def train_epoch(self, + def train_epoch_h5_manual(self, sess, batcher, fwdrc=True, @@ -427,7 +427,8 @@ def train_epoch(self, sum_writer=None, epoch_batches=None, no_steps=False): - """Execute one training epoch.""" + """Execute one training epoch, using HDF5 data + and manual augmentation.""" # initialize training loss train_loss = [] @@ -444,9 +445,8 @@ def train_epoch(self, epoch_batches is None or batch_num < epoch_batches): # update feed dict - fd[self.inputs] = Xb - fd[self.targets] = Yb - fd[self.targets_na] = NAb + fd[self.inputs_ph] = Xb + fd[self.targets_ph] = Yb if no_steps: run_returns = sess.run([self.merged_summary, self.loss_train] + \ @@ -477,11 +477,62 @@ def train_epoch(self, return np.mean(train_loss), global_step - def train_epoch_from_data_ops(self, - sess, - sum_writer=None, - epoch_batches=None): - """ Execute one training epoch """ + def train_epoch_h5(self, + sess, + batcher, + sum_writer=None, + epoch_batches=None, + no_steps=False): + """Execute one training epoch using HDF5 data, + and compute-graph augmentation""" + + # initialize training loss + train_loss = [] + global_step = 0 + + # setup feed dict + fd = self.set_mode('train') + + # get first batch + Xb, Yb, NAb, Nb = batcher.next() + + batch_num = 0 + while Xb is not None and Nb == self.hp.batch_size and ( + epoch_batches is None or batch_num < epoch_batches): + + # update feed dict + fd[self.inputs_ph] = Xb + fd[self.targets_ph] = Yb + + if no_steps: + run_returns = sess.run([self.merged_summary, self.loss_train] + \ + self.update_ops, feed_dict=fd) + summary, loss_batch = run_returns[:2] + else: + run_ops = [self.merged_summary, self.loss_train, self.global_step, self.step_op] + run_ops += self.update_ops + summary, loss_batch, global_step = sess.run(run_ops, feed_dict=fd) + + # add summary + if sum_writer is not None: + sum_writer.add_summary(summary, global_step) + + # accumulate loss + train_loss.append(loss_batch) + + # next batch + Xb, Yb, NAb, Nb = batcher.next(fwdrc, shift) + batch_num += 1 + + # reset training batcher if epoch considered all of the data + if epoch_batches is None: + batcher.reset() + + return np.mean(train_loss), global_step + + + def train_epoch_tfr(self, sess, sum_writer=None, epoch_batches=None): + """ Execute one training epoch, using TFRecords data. """ # initialize training loss train_loss = [] diff --git a/bin/basenji_train_h5.py b/bin/basenji_train_h5.py new file mode 100755 index 00000000..6e09f9e3 --- /dev/null +++ b/bin/basenji_train_h5.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function + +import sys +import time + +import h5py +import numpy as np +import tensorflow as tf + +from basenji import batcher +from basenji import params +from basenji import seqnn +from basenji import shared_flags + +FLAGS = tf.app.flags.FLAGS + +################################################################################ +# main +################################################################################ +def main(_): + np.random.seed(FLAGS.seed) + + run(params_file=FLAGS.params, + data_file=FLAGS.data, + train_epochs=FLAGS.train_epochs, + train_epoch_batches=FLAGS.train_epoch_batches, + test_epoch_batches=FLAGS.test_epoch_batches) + + +def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_batches): + + ####################################################### + # load data + ####################################################### + data_open = h5py.File(data_file) + + train_seqs = data_open['train_in'] + train_targets = data_open['train_out'] + train_na = None + if 'train_na' in data_open: + train_na = data_open['train_na'] + + valid_seqs = data_open['valid_in'] + valid_targets = data_open['valid_out'] + valid_na = None + if 'valid_na' in data_open: + valid_na = data_open['valid_na'] + + ####################################################### + # model parameters and placeholders + ####################################################### + job = params.read_job_params(params_file) + + job['seq_length'] = train_seqs.shape[1] + job['seq_depth'] = train_seqs.shape[2] + job['num_targets'] = train_targets.shape[2] + job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) + + augment_shifts = [int(shift) for shift in FLAGS.augment_shifts.split(',')] + ensemble_shifts = [int(shift) for shift in FLAGS.ensemble_shifts.split(',')] + + t0 = time.time() + model = seqnn.SeqNN() + model.build(job, augment_rc=FLAGS.augment_rc, augment_shifts=augment_shifts, + ensemble_rc=FLAGS.ensemble_rc, ensemble_shifts=ensemble_shifts) + print('Model building time %f' % (time.time() - t0)) + + # adjust for fourier + job['fourier'] = 'train_out_imag' in data_open + if job['fourier']: + train_targets_imag = data_open['train_out_imag'] + valid_targets_imag = data_open['valid_out_imag'] + + ####################################################### + # prepare batcher + ####################################################### + if job['fourier']: + batcher_train = batcher.BatcherF( + train_seqs, + train_targets, + train_targets_imag, + train_na, + model.hp.batch_size, + model.hp.target_pool, + shuffle=True) + batcher_valid = batcher.BatcherF(valid_seqs, valid_targets, + valid_targets_imag, valid_na, + model.batch_size, model.target_pool) + else: + batcher_train = batcher.Batcher( + train_seqs, + train_targets, + train_na, + model.hp.batch_size, + model.hp.target_pool, + shuffle=True) + batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na, + model.hp.batch_size, model.hp.target_pool) + print('Batcher initialized') + + ####################################################### + # train + ####################################################### + + # checkpoints + saver = tf.train.Saver() + + config = tf.ConfigProto() + if FLAGS.log_device_placement: + config.log_device_placement = True + with tf.Session(config=config) as sess: + t0 = time.time() + + # set seed + tf.set_random_seed(FLAGS.seed) + + if FLAGS.logdir: + train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train', sess.graph) + else: + train_writer = None + + if FLAGS.restart: + # load variables into session + saver.restore(sess, FLAGS.restart) + else: + # initialize variables + print('Initializing...') + sess.run(tf.global_variables_initializer()) + print('Initialization time %f' % (time.time() - t0)) + + train_loss = None + best_loss = None + early_stop_i = 0 + + epoch = 0 + while (train_epochs is None or epoch < train_epochs) and early_stop_i < FLAGS.early_stop: + t0 = time.time() + + # alternate forward and reverse batches + fwdrc = True + if FLAGS.augment_rc and epoch % 2 == 1: + fwdrc = False + + # cycle shifts + shift_i = epoch % len(augment_shifts) + + # train + train_loss, steps = model.train_epoch(sess, batcher_train, fwdrc=fwdrc, + shift=augment_shifts[shift_i], + sum_writer=train_writer, + epoch_batches=train_epoch_batches, + no_steps=FLAGS.no_steps) + + # validate + valid_acc = model.test(sess, batcher_valid, mc_n=FLAGS.ensemble_mc, + rc=FLAGS.ensemble_rc, shifts=ensemble_shifts, + test_batches=test_epoch_batches) + valid_loss = valid_acc.loss + valid_r2 = valid_acc.r2().mean() + del valid_acc + + best_str = '' + if best_loss is None or valid_loss < best_loss: + best_loss = valid_loss + best_str = ', best!' + early_stop_i = 0 + saver.save(sess, '%s/model_best.tf' % FLAGS.logdir) + else: + early_stop_i += 1 + + # measure time + et = time.time() - t0 + if et < 600: + time_str = '%3ds' % et + elif et < 6000: + time_str = '%3dm' % (et / 60) + else: + time_str = '%3.1fh' % (et / 3600) + + # print update + print( + 'Epoch: %3d, Steps: %7d, Train loss: %7.5f, Valid loss: %7.5f, Valid R2: %7.5f, Time: %s%s' + % (epoch + 1, steps, train_loss, valid_loss, valid_r2, time_str, best_str)) + sys.stdout.flush() + + if FLAGS.check_all: + saver.save(sess, '%s/model_check%d.tf' % (FLAGS.logdir, epoch)) + + # update epoch + epoch += 1 + + + if FLAGS.logdir: + train_writer.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + tf.app.run(main) diff --git a/bin/basenji_train_queues.py b/bin/basenji_train_queues.py index b3ca245d..50f4d5e5 100755 --- a/bin/basenji_train_queues.py +++ b/bin/basenji_train_queues.py @@ -95,7 +95,7 @@ def run(params_file, train_file, test_file, train_epochs, train_epoch_batches, # train epoch sess.run(training_init_op) - train_loss, steps = model.train_epoch_from_data_ops(sess, train_writer, train_epoch_batches) + train_loss, steps = model.train_epoch_tfr(sess, train_writer, train_epoch_batches) # test validation sess.run(test_init_op) From b66bc9c83fff1a68f6ac938d8d594c6bb00a1a0b Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 7 Jul 2018 12:56:28 -0700 Subject: [PATCH 09/71] tuning --- basenji/seqnn.py | 10 ++++++++-- bin/basenji_train_h5.py | 14 +++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index a6f8471f..27f34043 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -109,6 +109,7 @@ def make_placeholders(self): tf.float32, shape=(self.hp.batch_size, self.hp.seq_length, self.hp.seq_depth), name='inputs') + if self.hp.target_classes == 1: self.targets_ph = tf.placeholder( tf.float32, @@ -122,9 +123,14 @@ def make_placeholders(self): self.hp.num_targets), name='targets') + self.targets_na_ph = tf.placeholder(tf.bool, + shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool), + name='targets_na') + data = { 'sequence': self.inputs_ph, - 'label': self.targets_ph + 'label': self.targets_ph, + 'na': self.targets_na_ph } return data @@ -511,7 +517,7 @@ def train_epoch_h5(self, else: run_ops = [self.merged_summary, self.loss_train, self.global_step, self.step_op] run_ops += self.update_ops - summary, loss_batch, global_step = sess.run(run_ops, feed_dict=fd) + summary, loss_batch, global_step = sess.run(run_ops, feed_dict=fd)[:3] # add summary if sum_writer is not None: diff --git a/bin/basenji_train_h5.py b/bin/basenji_train_h5.py index 6e09f9e3..41f86789 100755 --- a/bin/basenji_train_h5.py +++ b/bin/basenji_train_h5.py @@ -77,7 +77,8 @@ def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_ba t0 = time.time() model = seqnn.SeqNN() model.build(job, augment_rc=FLAGS.augment_rc, augment_shifts=augment_shifts, - ensemble_rc=FLAGS.ensemble_rc, ensemble_shifts=ensemble_shifts) + ensemble_rc=FLAGS.ensemble_rc, ensemble_shifts=ensemble_shifts) + print('Model building time %f' % (time.time() - t0)) # adjust for fourier @@ -151,17 +152,8 @@ def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_ba while (train_epochs is None or epoch < train_epochs) and early_stop_i < FLAGS.early_stop: t0 = time.time() - # alternate forward and reverse batches - fwdrc = True - if FLAGS.augment_rc and epoch % 2 == 1: - fwdrc = False - - # cycle shifts - shift_i = epoch % len(augment_shifts) - # train - train_loss, steps = model.train_epoch(sess, batcher_train, fwdrc=fwdrc, - shift=augment_shifts[shift_i], + train_loss, steps = model.train_epoch_h5(sess, batcher_train, sum_writer=train_writer, epoch_batches=train_epoch_batches, no_steps=FLAGS.no_steps) From 257544b9eb4177b2e838ff309a5e0e90a2938ba4 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 7 Jul 2018 13:19:48 -0700 Subject: [PATCH 10/71] tran epoch bug --- basenji/seqnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 27f34043..7f1d1fc5 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -527,7 +527,7 @@ def train_epoch_h5(self, train_loss.append(loss_batch) # next batch - Xb, Yb, NAb, Nb = batcher.next(fwdrc, shift) + Xb, Yb, NAb, Nb = batcher.next() batch_num += 1 # reset training batcher if epoch considered all of the data From a013f9ec2d178ae58801e9c3688c0a09bbb57fda Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 7 Jul 2018 13:22:02 -0700 Subject: [PATCH 11/71] h5 ensembling in-graph --- basenji/seqnn_util.py | 89 +++++++++++++++++++++++++++++++------ bin/basenji_train_h5.py | 10 ++--- bin/basenji_train_queues.py | 2 +- 3 files changed, 80 insertions(+), 21 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 76c06b73..123ab054 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -835,7 +835,7 @@ def predict_genes(self, return tss_preds - def test_from_data_ops(self, sess, test_batches=None): + def test_tfr(self, sess, test_batches=None): """ Compute model accuracy on a test set, where data is loaded from a queue. Args: @@ -845,11 +845,6 @@ def test_from_data_ops(self, sess, test_batches=None): Returns: acc: Accuracy object """ - - # TODO(dbelanger) this ignores rc and shift ensembling for now. - # Accuracy will be slightly lower than if we had used this. - # The rc and shift data augmentation need to be pulled into the graph. - fd = self.set_mode('test') # initialize prediction and target arrays @@ -865,7 +860,7 @@ def test_from_data_ops(self, sess, test_batches=None): batch_num = 0 while data_available and (test_batches is None or batch_num < test_batches): try: - # make non-ensembled predictions + # make predictions run_ops = [self.targets_eval, self.preds_eval, self.loss_eval, self.loss_eval_targets] run_returns = sess.run(run_ops, feed_dict=fd) @@ -900,13 +895,79 @@ def test_from_data_ops(self, sess, test_batches=None): return acc - def test(self, - sess, - batcher, - rc=False, - shifts=[0], - mc_n=0, - test_batches=None): + def test_h5(self, sess, batcher, test_batches=None): + """ Compute model accuracy on a test set. + + Args: + sess: TensorFlow session + batcher: Batcher object to provide data + mc_n: Monte Carlo iterations per rc/shift. + test_batches: Number of test batches + + Returns: + acc: Accuracy object + """ + # setup feed dict + fd = self.set_mode('test') + + # initialize prediction and target arrays + preds = [] + targets = [] + targets_na = [] + + batch_losses = [] + batch_target_losses = [] + + # get first batch + batch_num = 0 + Xb, Yb, NAb, Nb = batcher.next() + + while Xb is not None and (test_batches is None or + batch_num < test_batches): + # make predictions + run_ops = [self.targets_eval, self.preds_eval, + self.loss_eval, self.loss_eval_targets] + run_returns = sess.run(run_ops, feed_dict=fd) + targets_batch, preds_batch, loss_batch, target_losses_batch = run_returns + + # accumulate predictions and targets + preds.append(preds_batch.astype('float16')) + targets.append(targets_batch.astype('float16')) + targets_na.append(np.zeros([preds_batch.shape[0], self.preds_length], dtype='bool')) + + # accumulate loss + batch_losses.append(loss_batch) + batch_target_losses.append(target_losses_batch) + + # next batch + batch_num += 1 + Xb, Yb, NAb, Nb = batcher.next() + + # reset batcher + batcher.reset() + + # construct arrays + targets = np.concatenate(targets, axis=0) + preds = np.concatenate(preds, axis=0) + targets_na = np.concatenate(targets_na, axis=0) + + # mean across batches + batch_losses = np.mean(batch_losses) + batch_target_losses = np.array(batch_target_losses).mean(axis=0) + + # instantiate accuracy object + acc = accuracy.Accuracy(targets, preds, targets_na, + batch_losses, batch_target_losses) + + return acc + + def test_h5_manual(self, + sess, + batcher, + rc=False, + shifts=[0], + mc_n=0, + test_batches=None): """ Compute model accuracy on a test set. Args: diff --git a/bin/basenji_train_h5.py b/bin/basenji_train_h5.py index 41f86789..875ee3af 100755 --- a/bin/basenji_train_h5.py +++ b/bin/basenji_train_h5.py @@ -154,14 +154,12 @@ def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_ba # train train_loss, steps = model.train_epoch_h5(sess, batcher_train, - sum_writer=train_writer, - epoch_batches=train_epoch_batches, - no_steps=FLAGS.no_steps) + sum_writer=train_writer, + epoch_batches=train_epoch_batches, + no_steps=FLAGS.no_steps) # validate - valid_acc = model.test(sess, batcher_valid, mc_n=FLAGS.ensemble_mc, - rc=FLAGS.ensemble_rc, shifts=ensemble_shifts, - test_batches=test_epoch_batches) + valid_acc = model.test_h5(sess, batcher_valid, test_batches=test_epoch_batches) valid_loss = valid_acc.loss valid_r2 = valid_acc.r2().mean() del valid_acc diff --git a/bin/basenji_train_queues.py b/bin/basenji_train_queues.py index 50f4d5e5..8ac51ec5 100755 --- a/bin/basenji_train_queues.py +++ b/bin/basenji_train_queues.py @@ -99,7 +99,7 @@ def run(params_file, train_file, test_file, train_epochs, train_epoch_batches, # test validation sess.run(test_init_op) - valid_acc = model.test_from_data_ops(sess, test_epoch_batches) + valid_acc = model.test_tfr(sess, test_epoch_batches) valid_loss = valid_acc.loss valid_r2 = valid_acc.r2().mean() del valid_acc From 18386bc4c1e44b4ff31a088961f0699ef65aca62 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sun, 8 Jul 2018 10:24:23 -0700 Subject: [PATCH 12/71] test feed dict --- basenji/seqnn_util.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 123ab054..071a13a7 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -924,6 +924,10 @@ def test_h5(self, sess, batcher, test_batches=None): while Xb is not None and (test_batches is None or batch_num < test_batches): + # update feed dict + fd[self.inputs_ph] = Xb + fd[self.targets_ph] = Yb + # make predictions run_ops = [self.targets_eval, self.preds_eval, self.loss_eval, self.loss_eval_targets] @@ -931,9 +935,9 @@ def test_h5(self, sess, batcher, test_batches=None): targets_batch, preds_batch, loss_batch, target_losses_batch = run_returns # accumulate predictions and targets - preds.append(preds_batch.astype('float16')) - targets.append(targets_batch.astype('float16')) - targets_na.append(np.zeros([preds_batch.shape[0], self.preds_length], dtype='bool')) + preds.append(preds_batch[:Nb,:,:].astype('float16')) + targets.append(targets_batch[:Nb,:,:].astype('float16')) + targets_na.append(np.zeros([Nb, self.preds_length], dtype='bool')) # accumulate loss batch_losses.append(loss_batch) @@ -1027,8 +1031,8 @@ def test_h5_manual(self, sess, fd, Xb, ensemble_fwdrc, ensemble_shifts, mc_n) # add target info - fd[self.targets] = Yb - fd[self.targets_na] = NAb + fd[self.targets_ph] = Yb + fd[self.targets_na_ph] = NAb targets_na.append(np.zeros([Nb, self.preds_length], dtype='bool')) From 9686338c6fbd35f761a29f95334ee98aef3de5ee Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 14:34:00 -0700 Subject: [PATCH 13/71] update placeholder --- basenji/seqnn_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 071a13a7..2df94556 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -300,7 +300,7 @@ def _gradients_ensemble(self, sess, fd, Xb, ensemble_fwdrc, ensemble_shifts, mc_ Xb_ensemble = hot1_augment(Xb, ensemble_fwdrc[ei], ensemble_shifts[ei]) # update feed dict - fd[self.inputs] = Xb_ensemble + fd[self.inputs_ph] = Xb_ensemble # for each monte carlo (or non-mc single) iteration for mi in range(mc_n): @@ -448,7 +448,7 @@ def gradients_genes(self, sess, batcher, gene_seqs): while Xb is not None: # update feed dict - fd[self.inputs] = Xb + fd[self.inputs_ph] = Xb # predict reprs_batch, _ = sess.run([self.layer_reprs, self.preds_eval], feed_dict=fd) @@ -507,7 +507,7 @@ def hidden(self, sess, batcher, layers=None): while Xb is not None: # update feed dict - fd[self.inputs] = Xb + fd[self.inputs_ph] = Xb # compute predictions layer_reprs_batch, preds_batch = sess.run( @@ -589,7 +589,7 @@ def _predict_ensemble(self, Xb_ensemble = hot1_augment(Xb, ensemble_fwdrc[ei], ensemble_shifts[ei]) # update feed dict - fd[self.inputs] = Xb_ensemble + fd[self.inputs_ph] = Xb_ensemble # for each monte carlo (or non-mc single) iteration for mi in range(mc_n): From cd9a8752b31dd7fc0414ae4e596ec2cb3ac38d1c Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 14:34:10 -0700 Subject: [PATCH 14/71] 0 shift defaults --- basenji/seqnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 7f1d1fc5..fbc0c413 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -33,8 +33,8 @@ def __init__(self): self.global_step = tf.train.get_or_create_global_step() self.hparams_set = False - def build(self, job, augment_rc=False, augment_shifts=[], - ensemble_rc=False, ensemble_shifts=[], target_subset=None): + def build(self, job, augment_rc=False, augment_shifts=[0], + ensemble_rc=False, ensemble_shifts=[0], target_subset=None): """Build training ops that depend on placeholders.""" self.hp = params.make_hparams(job) From 558f53f0a49aa8b2bc40f738d26af03aa3f7f943 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 14:34:38 -0700 Subject: [PATCH 15/71] testing --- bin/basenji_test.py | 4 +- bin/basenji_test_h5.py | 624 +++++++++++++++++++++++++++++++++++++++++ bin/basenji_testq.py | 6 +- 3 files changed, 630 insertions(+), 4 deletions(-) create mode 100755 bin/basenji_test_h5.py diff --git a/bin/basenji_test.py b/bin/basenji_test.py index 6f7ed4ae..1555d719 100755 --- a/bin/basenji_test.py +++ b/bin/basenji_test.py @@ -255,8 +255,8 @@ def main(): # test t0 = time.time() - test_acc = dr.test(sess, batcher_test, rc=options.rc, - shifts=options.shifts, mc_n=options.mc_n) + test_acc = dr.test_h5_manual(sess, batcher_test, rc=options.rc, + shifts=options.shifts, mc_n=options.mc_n) if options.save: np.save('%s/preds.npy' % options.out_dir, test_acc.preds) diff --git a/bin/basenji_test_h5.py b/bin/basenji_test_h5.py new file mode 100755 index 00000000..7dd64484 --- /dev/null +++ b/bin/basenji_test_h5.py @@ -0,0 +1,624 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import print_function +from optparse import OptionParser +import os +import random +import sys +import time + +import h5py +import joblib +import matplotlib +matplotlib.use('PDF') +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pyBigWig +from scipy.stats import spearmanr, poisson +import seaborn as sns +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, average_precision_score +import tensorflow as tf + +from basenji import batcher +from basenji import params +from basenji import plots +from basenji import seqnn + +""" +basenji_test.py + +Test the accuracy of a trained model. + +Notes + -This probably needs work for the pooled large sequence version. I tried to + update the "full" comparison, but it's not tested. The notion of peak calls + will need to completely change; we probably want to predict in each bin. +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option( + '--ai', + dest='accuracy_indexes', + help= + 'Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values' + ) + parser.add_option( + '--clip', + dest='target_clip', + default=None, + type='float', + help='Clip targets and predictions to a maximum value [Default: %default]' + ) + parser.add_option( + '-d', + dest='down_sample', + default=1, + type='int', + help= + 'Down sample test computation by taking uniformly spaced positions [Default: %default]' + ) + parser.add_option( + '-g', + dest='genome_file', + default='%s/tutorials/data/human.hg19.genome' % os.environ['BASENJIDIR'], + help='Chromosome length information [Default: %default]') + parser.add_option( + '--mc', + dest='mc_n', + default=0, + type='int', + help='Monte carlo test iterations [Default: %default]') + parser.add_option( + '--peak','--peaks', + dest='peaks', + default=False, + action='store_true', + help='Compute expensive peak accuracy [Default: %default]') + parser.add_option( + '-o', + dest='out_dir', + default='test_out', + help='Output directory for test statistics [Default: %default]') + parser.add_option( + '--rc', + dest='rc', + default=False, + action='store_true', + help= + 'Average the fwd and rc predictions [Default: %default]') + parser.add_option( + '--sample', + dest='sample_pct', + default=1, + type='float', + help='Sample percentage') + parser.add_option( + '--save', + dest='save', + default=False, + action='store_true') + parser.add_option( + '--shifts', + dest='shifts', + default='0', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option( + '-t', + dest='track_bed', + help='BED file describing regions so we can output BigWig tracks') + parser.add_option( + '--ti', + dest='track_indexes', + help='Comma-separated list of target indexes to output BigWig tracks') + parser.add_option( + '--train', + dest='train', + default=False, + action='store_true', + help='Process the training set [Default: %default]') + parser.add_option( + '-v', + dest='valid', + default=False, + action='store_true', + help='Process the validation set [Default: %default]') + parser.add_option( + '-w', + dest='pool_width', + default=1, + type='int', + help= + 'Max pool width for regressing nt predictions to predict peak calls [Default: %default]' + ) + (options, args) = parser.parse_args() + + if len(args) != 3: + parser.error('Must provide parameters, model, and test data HDF5') + else: + params_file = args[0] + model_file = args[1] + test_hdf5_file = args[2] + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.shifts = [int(shift) for shift in options.shifts.split(',')] + + ####################################################### + # load data + ####################################################### + data_open = h5py.File(test_hdf5_file) + + if options.train: + test_seqs = data_open['train_in'] + test_targets = data_open['train_out'] + if 'train_na' in data_open: + test_na = data_open['train_na'] + + elif options.valid: + test_seqs = data_open['valid_in'] + test_targets = data_open['valid_out'] + test_na = None + if 'valid_na' in data_open: + test_na = data_open['valid_na'] + + else: + test_seqs = data_open['test_in'] + test_targets = data_open['test_out'] + test_na = None + if 'test_na' in data_open: + test_na = data_open['test_na'] + + if options.sample_pct < 1: + sample_n = int(test_seqs.shape[0]*options.sample_pct) + print('Sampling %d sequences' % sample_n) + sample_indexes = sorted(np.random.choice(np.arange(test_seqs.shape[0]), + size=sample_n, replace=False)) + test_seqs = test_seqs[sample_indexes] + test_targets = test_targets[sample_indexes] + if test_na is not None: + test_na = test_na[sample_indexes] + + target_labels = [tl.decode('UTF-8') for tl in data_open['target_labels']] + + ####################################################### + # model parameters and placeholders + + job = params.read_job_params(params_file) + + job['seq_length'] = test_seqs.shape[1] + job['seq_depth'] = test_seqs.shape[2] + job['num_targets'] = test_targets.shape[2] + job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) + + t0 = time.time() + model = seqnn.SeqNN() + model.build(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts) + print('Model building time %ds' % (time.time() - t0)) + + # adjust for fourier + job['fourier'] = 'train_out_imag' in data_open + if job['fourier']: + test_targets_imag = data_open['test_out_imag'] + if options.valid: + test_targets_imag = data_open['valid_out_imag'] + + ####################################################### + # test + + # initialize batcher + if job['fourier']: + batcher_test = batcher.BatcherF(test_seqs, test_targets, + test_targets_imag, test_na, + model.hp.batch_size, model.hp.target_pool) + else: + batcher_test = batcher.Batcher(test_seqs, test_targets, test_na, + model.hp.batch_size, model.hp.target_pool) + + # initialize saver + saver = tf.train.Saver() + + with tf.Session() as sess: + # load variables into session + saver.restore(sess, model_file) + + # test + t0 = time.time() + test_acc = model.test_h5(sess, batcher_test) + + if options.save: + np.save('%s/preds.npy' % options.out_dir, test_acc.preds) + np.save('%s/targets.npy' % options.out_dir, test_acc.targets) + + test_preds = test_acc.preds + print('SeqNN test: %ds' % (time.time() - t0)) + + # compute stats + t0 = time.time() + test_r2 = test_acc.r2(clip=options.target_clip) + # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip) + test_pcor = test_acc.pearsonr(clip=options.target_clip) + test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip) + #test_scor = test_acc.spearmanr() # too slow; mostly driven by low values + print('Compute stats: %ds' % (time.time()-t0)) + + # print + print('Test Loss: %7.5f' % test_acc.loss) + print('Test R2: %7.5f' % test_r2.mean()) + # print('Test log R2: %7.5f' % test_log_r2.mean()) + print('Test PearsonR: %7.5f' % test_pcor.mean()) + print('Test log PearsonR: %7.5f' % test_log_pcor.mean()) + # print('Test SpearmanR: %7.5f' % test_scor.mean()) + + acc_out = open('%s/acc.txt' % options.out_dir, 'w') + for ti in range(len(test_r2)): + print( + '%4d %7.5f %.5f %.5f %.5f %s' % + (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti], + test_log_pcor[ti], target_labels[ti]), file=acc_out) + acc_out.close() + + # print normalization factors + target_means = test_preds.mean(axis=(0,1), dtype='float64') + target_means_median = np.median(target_means) + target_means /= target_means_median + norm_out = open('%s/normalization.txt' % options.out_dir, 'w') + print('\n'.join([str(tu) for tu in target_means]), file=norm_out) + norm_out.close() + + # clean up + del test_acc + + + ####################################################### + # peak call accuracy + + if options.peaks: + # sample every few bins to decrease correlations + ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) + ds_indexes_targets = ds_indexes_preds + (model.hp.batch_buffer // model.hp.target_pool) + + aurocs = [] + auprcs = [] + + peaks_out = open('%s/peaks.txt' % options.out_dir, 'w') + for ti in range(test_targets.shape[2]): + test_targets_ti = test_targets[:, :, ti] + + # subset and flatten + test_targets_ti_flat = test_targets_ti[:, ds_indexes_targets].flatten( + ).astype('float32') + test_preds_ti_flat = test_preds[:, ds_indexes_preds, ti].flatten().astype( + 'float32') + + # call peaks + test_targets_ti_lambda = np.mean(test_targets_ti_flat) + test_targets_pvals = 1 - poisson.cdf( + np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) + test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) + test_targets_peaks = test_targets_qvals < 0.01 + + if test_targets_peaks.sum() == 0: + aurocs.append(0.5) + auprcs.append(0) + + else: + # compute prediction accuracy + aurocs.append(roc_auc_score(test_targets_peaks, test_preds_ti_flat)) + auprcs.append( + average_precision_score(test_targets_peaks, test_preds_ti_flat)) + + print('%4d %6d %.5f %.5f' % (ti, test_targets_peaks.sum(), + aurocs[-1], auprcs[-1]), + file=peaks_out) + + peaks_out.close() + + print('Test AUROC: %7.5f' % np.mean(aurocs)) + print('Test AUPRC: %7.5f' % np.mean(auprcs)) + + ####################################################### + # BigWig tracks + + # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE + + # print bigwig tracks for visualization + if options.track_bed: + if options.genome_file is None: + parser.error('Must provide genome file in order to print valid BigWigs') + + if not os.path.isdir('%s/tracks' % options.out_dir): + os.mkdir('%s/tracks' % options.out_dir) + + track_indexes = range(test_preds.shape[2]) + if options.track_indexes: + track_indexes = [int(ti) for ti in options.track_indexes.split(',')] + + bed_set = 'test' + if options.valid: + bed_set = 'valid' + + for ti in track_indexes: + test_targets_ti = test_targets[:, :, ti] + + # make true targets bigwig + bw_file = '%s/tracks/t%d_true.bw' % (options.out_dir, ti) + bigwig_write( + bw_file, + test_targets_ti, + options.track_bed, + options.genome_file, + bed_set=bed_set) + + # make predictions bigwig + bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti) + bigwig_write( + bw_file, + test_preds[:, :, ti], + options.track_bed, + options.genome_file, + model.hp.batch_buffer, + bed_set=bed_set) + + # make NA bigwig + # bw_file = '%s/tracks/na.bw' % options.out_dir + # bigwig_write( + # bw_file, + # test_na, + # options.track_bed, + # options.genome_file, + # bed_set=bed_set) + + ####################################################### + # accuracy plots + + if options.accuracy_indexes is not None: + accuracy_indexes = [int(ti) for ti in options.accuracy_indexes.split(',')] + + if not os.path.isdir('%s/scatter' % options.out_dir): + os.mkdir('%s/scatter' % options.out_dir) + + if not os.path.isdir('%s/violin' % options.out_dir): + os.mkdir('%s/violin' % options.out_dir) + + if not os.path.isdir('%s/roc' % options.out_dir): + os.mkdir('%s/roc' % options.out_dir) + + if not os.path.isdir('%s/pr' % options.out_dir): + os.mkdir('%s/pr' % options.out_dir) + + for ti in accuracy_indexes: + test_targets_ti = test_targets[:, :, ti] + + ############################################ + # scatter + + # sample every few bins (adjust to plot the # points I want) + ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) + ds_indexes_targets = ds_indexes_preds + ( + model.hp.batch_buffer // model.hp.target_pool) + + # subset and flatten + test_targets_ti_flat = test_targets_ti[:, ds_indexes_targets].flatten( + ).astype('float32') + test_preds_ti_flat = test_preds[:, ds_indexes_preds, ti].flatten().astype( + 'float32') + + # take log2 + test_targets_ti_log = np.log2(test_targets_ti_flat + 1) + test_preds_ti_log = np.log2(test_preds_ti_flat + 1) + + # plot log2 + sns.set(font_scale=1.2, style='ticks') + out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti) + plots.regplot( + test_targets_ti_log, + test_preds_ti_log, + out_pdf, + poly_order=1, + alpha=0.3, + sample=500, + figsize=(6, 6), + x_label='log2 Experiment', + y_label='log2 Prediction', + table=True) + + ############################################ + # violin + + # call peaks + test_targets_ti_lambda = np.mean(test_targets_ti_flat) + test_targets_pvals = 1 - poisson.cdf( + np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) + test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) + test_targets_peaks = test_targets_qvals < 0.01 + test_targets_peaks_str = np.where(test_targets_peaks, 'Peak', + 'Background') + + # violin plot + sns.set(font_scale=1.3, style='ticks') + plt.figure() + df = pd.DataFrame({ + 'log2 Prediction': np.log2(test_preds_ti_flat + 1), + 'Experimental coverage status': test_targets_peaks_str + }) + ax = sns.violinplot( + x='Experimental coverage status', y='log2 Prediction', data=df) + ax.grid(True, linestyle=':') + plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti)) + plt.close() + + # ROC + plt.figure() + fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat) + auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat) + plt.plot( + [0, 1], [0, 1], c='black', linewidth=1, linestyle='--', alpha=0.7) + plt.plot(fpr, tpr, c='black') + ax = plt.gca() + ax.set_xlabel('False positive rate') + ax.set_ylabel('True positive rate') + ax.text( + 0.99, 0.02, 'AUROC %.3f' % auroc, + horizontalalignment='right') # , fontsize=14) + ax.grid(True, linestyle=':') + plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti)) + plt.close() + + # PR + plt.figure() + prec, recall, _ = precision_recall_curve(test_targets_peaks, + test_preds_ti_flat) + auprc = average_precision_score(test_targets_peaks, test_preds_ti_flat) + plt.axhline( + y=test_targets_peaks.mean(), + c='black', + linewidth=1, + linestyle='--', + alpha=0.7) + plt.plot(recall, prec, c='black') + ax = plt.gca() + ax.set_xlabel('Recall') + ax.set_ylabel('Precision') + ax.text( + 0.99, 0.95, 'AUPRC %.3f' % auprc, + horizontalalignment='right') # , fontsize=14) + ax.grid(True, linestyle=':') + plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti)) + plt.close() + + data_open.close() + + +def ben_hoch(p_values): + """ Convert the given p-values to q-values using Benjamini-Hochberg FDR. """ + m = len(p_values) + + # attach original indexes to p-values + p_k = [(p_values[k], k) for k in range(m)] + + # sort by p-value + p_k.sort() + + # compute q-value and attach original index to front + k_q = [(p_k[i][1], p_k[i][0] * m // (i + 1)) for i in range(m)] + + # re-sort by original index + k_q.sort() + + # drop original indexes + q_values = [k_q[k][1] for k in range(m)] + + return q_values + + +def bigwig_open(bw_file, genome_file): + """ Open the bigwig file for writing and write the header. """ + + bw_out = pyBigWig.open(bw_file, 'w') + + chrom_sizes = [] + for line in open(genome_file): + a = line.split() + chrom_sizes.append((a[0], int(a[1]))) + + bw_out.addHeader(chrom_sizes) + + return bw_out + + +def bigwig_write(bw_file, + signal_ti, + track_bed, + genome_file, + buffer=0, + bed_set='test'): + """ Write a signal track to a BigWig file over the regions + specified by track_bed. + + Args + bw_file: BigWig filename + signal_ti: Sequences X Length array for some target + track_bed: BED file specifying sequence coordinates + genome_file: Chromosome lengths file + buffer: Length skipped on each side of the region. + """ + + bw_out = bigwig_open(bw_file, genome_file) + + si = 0 + bw_hash = {} + + # set entries + for line in open(track_bed): + a = line.split() + if a[3] == bed_set: + chrom = a[0] + start = int(a[1]) + end = int(a[2]) + + preds_pool = (end - start - 2 * buffer) // signal_ti.shape[1] + + bw_start = start + buffer + for li in range(signal_ti.shape[1]): + bw_end = bw_start + preds_pool + bw_hash.setdefault((chrom,bw_start,bw_end),[]).append(signal_ti[si,li]) + bw_start = bw_end + + si += 1 + + # average duplicates + bw_entries = [] + for bw_key in bw_hash: + bw_signal = np.mean(bw_hash[bw_key]) + bwe = tuple(list(bw_key)+[bw_signal]) + bw_entries.append(bwe) + + # sort entries + bw_entries.sort() + + # add entries + for line in open(genome_file): + chrom = line.split()[0] + + bw_entries_chroms = [be[0] for be in bw_entries if be[0] == chrom] + bw_entries_starts = [be[1] for be in bw_entries if be[0] == chrom] + bw_entries_ends = [be[2] for be in bw_entries if be[0] == chrom] + bw_entries_values = [float(be[3]) for be in bw_entries if be[0] == chrom] + + if len(bw_entries_chroms) > 0: + bw_out.addEntries( + bw_entries_chroms, + bw_entries_starts, + ends=bw_entries_ends, + values=bw_entries_values) + + bw_out.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/bin/basenji_testq.py b/bin/basenji_testq.py index 5af5614d..578d7a19 100755 --- a/bin/basenji_testq.py +++ b/bin/basenji_testq.py @@ -120,7 +120,9 @@ def main(): # initialize model model = seqnn.SeqNN() - model.build_from_data_ops(job, data_ops) + model.build_from_data_ops(job, data_ops, + ensemble_rc=options.rc, + ensemble_shifts=options.shifts) # initialize saver saver = tf.train.Saver() @@ -136,7 +138,7 @@ def main(): # test t0 = time.time() sess.run(test_init_op) - test_acc = model.test_from_data_ops(sess) + test_acc = model.test_tfr(sess) test_preds = test_acc.preds print('SeqNN test: %ds' % (time.time() - t0)) From 03fad02b79792909a2f4c3ee9a840c07ff1f60af Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 15:07:00 -0700 Subject: [PATCH 16/71] target labels --- bin/basenji_testq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/basenji_testq.py b/bin/basenji_testq.py index 578d7a19..fe29057e 100755 --- a/bin/basenji_testq.py +++ b/bin/basenji_testq.py @@ -169,7 +169,7 @@ def main(): print( '%4d %7.5f %.5f %.5f %.5f %s' % (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti], - test_log_pcor[ti], target_labels[ti]), file=acc_out) + test_log_pcor[ti], targets_df.description.iloc[ti]), file=acc_out) acc_out.close() # print normalization factors From bcef7ab9bb41ccf064fd9f133fa267ba9278a35e Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 15:07:34 -0700 Subject: [PATCH 17/71] float64 loss mean --- basenji/seqnn_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 2df94556..b3c9bea9 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -886,8 +886,8 @@ def test_tfr(self, sess, test_batches=None): targets_na = np.concatenate(targets_na, axis=0) # mean across batches - batch_losses = np.mean(batch_losses) - batch_target_losses = np.array(batch_target_losses).mean(axis=0) + batch_losses = np.mean(batch_losses, dtype='float64') + batch_target_losses = np.array(batch_target_losses).mean(axis=0, dtype='float64') # instantiate accuracy object acc = accuracy.Accuracy(targets, preds, targets_na, @@ -956,8 +956,8 @@ def test_h5(self, sess, batcher, test_batches=None): targets_na = np.concatenate(targets_na, axis=0) # mean across batches - batch_losses = np.mean(batch_losses) - batch_target_losses = np.array(batch_target_losses).mean(axis=0) + batch_losses = np.mean(batch_losses, dtype='float64') + batch_target_losses = np.array(batch_target_losses).mean(axis=0, dtype='float64') # instantiate accuracy object acc = accuracy.Accuracy(targets, preds, targets_na, From 19e110bd443dca5d97cbc80fa28bc8a44587a0f5 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 15:09:02 -0700 Subject: [PATCH 18/71] default shift 0 --- basenji/seqnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index fbc0c413..d8c2f282 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -49,8 +49,8 @@ def build(self, job, augment_rc=False, augment_shifts=[0], target_subset=target_subset) def build_from_data_ops(self, job, data_ops, - augment_rc=False, augment_shifts=[], - ensemble_rc=False, ensemble_shifts=[], + augment_rc=False, augment_shifts=[0], + ensemble_rc=False, ensemble_shifts=[0], target_subset=None): """Build training ops from input data ops.""" if not self.hparams_set: From 23d2f70bf6186c0812a57083a5ac524d0e81fcff Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 15:10:18 -0700 Subject: [PATCH 19/71] no data open --- bin/basenji_testq.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bin/basenji_testq.py b/bin/basenji_testq.py index fe29057e..c6c83b99 100755 --- a/bin/basenji_testq.py +++ b/bin/basenji_testq.py @@ -400,8 +400,6 @@ def main(): plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti)) plt.close() - data_open.close() - def ben_hoch(p_values): """ Convert the given p-values to q-values using Benjamini-Hochberg FDR. """ From b4e5df581fdd7080a751e20b009b0ff73c186191 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Thu, 12 Jul 2018 06:39:57 -0700 Subject: [PATCH 20/71] create new data_ops dict --- basenji/augmentation.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/basenji/augmentation.py b/basenji/augmentation.py index 340c313e..7b7fb437 100644 --- a/basenji/augmentation.py +++ b/basenji/augmentation.py @@ -79,16 +79,21 @@ def augment_deterministic(data_ops, augment_rc=False, augment_shift=0): Returns data_ops: augmented data """ - if augment_shift != 0: + + data_ops_aug = {'label': data_ops['label'], 'na': data_ops['na']} + + if augment_shift == 0: + data_ops_aug['sequence'] = data_ops['sequence'] + else: shift_amount = tf.constant(augment_shift, shape=(), dtype=tf.int64) - data_ops['sequence'] = shift_sequence(data_ops['sequence'], shift_amount) + data_ops_aug['sequence'] = shift_sequence(data_ops['sequence'], shift_amount) if augment_rc: - data_ops = augment_deterministic_rc(data_ops) + data_ops_aug = augment_deterministic_rc(data_ops_aug) else: - data_ops['reverse_preds'] = tf.zeros((), dtype=tf.bool) + data_ops_aug['reverse_preds'] = tf.zeros((), dtype=tf.bool) - return data_ops + return data_ops_aug def augment_deterministic_rc(data_ops): From aaa46fa60a3e314e3e0dbd58d846af423b9795a6 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Thu, 12 Jul 2018 06:40:48 -0700 Subject: [PATCH 21/71] average predictions, not representations --- basenji/seqnn.py | 132 ++++++++++++++++++++++------------------------- 1 file changed, 62 insertions(+), 70 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index d8c2f282..f9db82df 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -68,12 +68,12 @@ def build_from_data_ops(self, job, data_ops, data_ops, augment_rc, augment_shifts) # compute train representation - seqs_repr_train = self.build_representation(data_ops_train['sequence'], - None, target_subset) + self.preds_train = self.build_predict(data_ops_train['sequence'], + None, target_subset) # training losses - loss_returns = self.build_loss(seqs_repr_train, data_ops_train, target_subset) - self.loss_train, self.loss_train_targets, self.preds_train, self.targets_train = loss_returns + loss_returns = self.build_loss(self.preds_train, data_ops_train['label'], target_subset) + self.loss_train, self.loss_train_targets, self.targets_train = loss_returns # optimizer self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) @@ -90,13 +90,13 @@ def build_from_data_ops(self, job, data_ops, # compute eval representation map_elems_eval = (data_seq_eval, data_rev_eval) - build_rep = lambda do: self.build_representation(do[0], do[1], target_subset) - seqs_repr_list = tf.map_fn(build_rep, map_elems_eval, dtype=seqs_repr_train.dtype) # back_prop=False - seqs_repr_eval = tf.reduce_mean(seqs_repr_list, axis=0) + build_rep = lambda do: self.build_predict(do[0], do[1], target_subset) + self.preds_ensemble = tf.map_fn(build_rep, map_elems_eval, dtype=self.preds_train.dtype) # back_prop=False + self.preds_eval = tf.reduce_mean(self.preds_ensemble, axis=0) # eval loss - loss_returns = self.build_loss(seqs_repr_eval, data_ops, target_subset) - self.loss_eval, self.loss_eval_targets, self.preds_eval, self.targets_eval = loss_returns + loss_returns = self.build_loss(self.preds_eval, data_ops['label'], target_subset) + self.loss_eval, self.loss_eval_targets, self.targets_eval = loss_returns # helper variables self.preds_length = self.preds_train.shape[1] @@ -148,7 +148,7 @@ def _make_conv_block_args(self, layer_index, layer_reprs): 'name': 'conv-%d' % layer_index } - def build_representation(self, inputs, reverse_preds=None, target_subset=None): + def build_predict(self, inputs, reverse_preds=None, target_subset=None): """Construct per-location real-valued predictions.""" assert inputs is not None print('Targets pooled by %d to length %d' % @@ -235,7 +235,52 @@ def build_representation(self, inputs, reverse_preds=None, target_subset=None): lambda: tf.reverse(final_repr, axis=[1]), lambda: final_repr) - return final_repr + ################################################### + # link function + ################################################### + + # work-around for specifying my own predictions + # self.preds_adhoc = tf.placeholder( + # tf.float32, shape=final_repr.shape, name='preds-adhoc') + + # float 32 exponential clip max + exp_max = 50 + + # choose link + if self.hp.link in ['identity', 'linear']: + predictions = tf.identity(final_repr, name='preds') + + elif self.hp.link == 'relu': + predictions = tf.relu(final_repr, name='preds') + + elif self.hp.link == 'exp': + final_repr_clip = tf.clip_by_value(final_repr, -exp_max, exp_max) + predictions = tf.exp(final_repr_clip, name='preds') + + elif self.hp.link == 'exp_linear': + predictions = tf.where( + final_repr > 0, + final_repr + 1, + tf.exp(tf.clip_by_value(final_repr, -exp_max, exp_max)), + name='preds') + + elif self.hp.link == 'softplus': + final_repr_clip = tf.clip_by_value(final_repr, -exp_max, 10000) + predictions = tf.nn.softplus(final_repr_clip, name='preds') + + else: + print('Unknown link function %s' % self.hp.link, file=sys.stderr) + exit(1) + + # clip + if self.hp.target_clip is not None: + predictions = tf.clip_by_value(predictions, 0, self.hp.target_clip) + + # sqrt + if self.hp.target_sqrt: + predictions = tf.sqrt(predictions) + + return predictions def build_optimizer(self, loss_op): """Construct optimization op that minimizes loss_op.""" @@ -290,86 +335,38 @@ def build_optimizer(self, loss_op): self.merged_summary = tf.summary.merge_all() - def build_loss(self, seqs_repr, data_ops, target_subset=None): + def build_loss(self, preds, targets, target_subset=None): """Convert per-location real-valued predictions to a loss.""" - # targets + # slice buffer tstart = self.hp.batch_buffer // self.hp.target_pool tend = (self.hp.seq_length - self.hp.batch_buffer) // self.hp.target_pool - - targets = data_ops['label'] targets = tf.identity(targets[:, tstart:tend, :], name='targets_op') if target_subset is not None: targets = tf.gather(targets, target_subset, axis=2) - # work-around for specifying my own predictions - # self.preds_adhoc = tf.placeholder( - # tf.float32, shape=seqs_repr.shape, name='preds-adhoc') - - # float 32 exponential clip max - # exp_max = np.floor(np.log(0.5*tf.float32.max)) - exp_max = 50 - - # choose link - if self.hp.link in ['identity', 'linear']: - preds_op = tf.identity(seqs_repr, name='preds') - - elif self.hp.link == 'relu': - preds_op = tf.relu(seqs_repr, name='preds') - - elif self.hp.link == 'exp': - seqs_repr_clip = tf.clip_by_value(seqs_repr, -exp_max, exp_max) - preds_op = tf.exp(seqs_repr_clip, name='preds') - - elif self.hp.link == 'exp_linear': - preds_op = tf.where( - seqs_repr > 0, - seqs_repr + 1, - tf.exp(tf.clip_by_value(seqs_repr, -exp_max, exp_max)), - name='preds') - - elif self.hp.link == 'softplus': - seqs_repr_clip = tf.clip_by_value(seqs_repr, -exp_max, 10000) - preds_op = tf.nn.softplus(seqs_repr_clip, name='preds') - - elif self.hp.link == 'softmax': - # performed in the loss function, but saving probabilities - self.preds_prob = tf.nn.softmax(seqs_repr, name='preds') - - else: - print('Unknown link function %s' % self.hp.link, file=sys.stderr) - exit(1) - # clip if self.hp.target_clip is not None: - preds_op = tf.clip_by_value(preds_op, 0, self.hp.target_clip) targets = tf.clip_by_value(targets, 0, self.hp.target_clip) # sqrt if self.hp.target_sqrt: - preds_op = tf.sqrt(preds_op) targets = tf.sqrt(targets) loss_op = None - # loss_adhoc = None # choose loss if self.hp.loss == 'gaussian': - loss_op = tf.squared_difference(preds_op, targets) - # loss_adhoc = tf.squared_difference(self.preds_adhoc, targets) + loss_op = tf.squared_difference(preds, targets) elif self.hp.loss == 'poisson': loss_op = tf.nn.log_poisson_loss( - targets, tf.log(preds_op), compute_full_loss=True) - # loss_adhoc = tf.nn.log_poisson_loss( - # targets, tf.log(self.preds_adhoc), compute_full_loss=True) + targets, tf.log(preds), compute_full_loss=True) elif self.hp.loss == 'cross_entropy': loss_op = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=(targets - 1), logits=preds_op) - # loss_adhoc = tf.nn.sparse_softmax_cross_entropy_with_logits( - # labels=(targets - 1), logits=self.preds_adhoc) + labels=(targets - 1), logits=preds) else: raise ValueError('Cannot identify loss function %s' % self.hp.loss) @@ -378,29 +375,24 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): loss_op = tf.reduce_mean(loss_op, axis=[0, 1], name='target_loss') loss_op = tf.check_numerics(loss_op, 'Invalid loss', name='loss_check') - # loss_adhoc = tf.reduce_mean( - # loss_adhoc, axis=[0, 1], name='target_loss_adhoc') tf.summary.histogram('target_loss', loss_op) for ti in np.linspace(0, self.hp.num_targets - 1, 10).astype('int'): tf.summary.scalar('loss_t%d' % ti, loss_op[ti]) target_losses = loss_op - # self.target_losses_adhoc = loss_adhoc # fully reduce loss_op = tf.reduce_mean(loss_op, name='loss') - # loss_adhoc = tf.reduce_mean(loss_adhoc, name='loss_adhoc') # add regularization terms reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) reg_sum = tf.reduce_sum(reg_losses) tf.summary.scalar('regularizers', reg_sum) loss_op += reg_sum - # loss_adhoc += reg_sum # track tf.summary.scalar('loss', loss_op) - return loss_op, target_losses, preds_op, targets + return loss_op, target_losses, targets def set_mode(self, mode): From cd77912075bde39b5f531d631140a767ea19a5d8 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Fri, 20 Jul 2018 16:36:02 -0700 Subject: [PATCH 22/71] rename build --- basenji/seqnn.py | 2 +- bin/basenji_sad.py | 2 +- bin/basenji_sat.py | 2 +- bin/basenji_sat_vcf.py | 2 +- bin/basenji_sed.py | 2 +- bin/basenji_test.py | 2 +- bin/basenji_test_genes.py | 2 +- bin/basenji_test_h5.py | 2 +- bin/basenji_test_reps.py | 2 +- bin/basenji_train.py | 2 +- bin/basenji_train_h5.py | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index f9db82df..d58ed141 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -33,7 +33,7 @@ def __init__(self): self.global_step = tf.train.get_or_create_global_step() self.hparams_set = False - def build(self, job, augment_rc=False, augment_shifts=[0], + def build_feed(self, job, augment_rc=False, augment_shifts=[0], ensemble_rc=False, ensemble_shifts=[0], target_subset=None): """Build training ops that depend on placeholders.""" diff --git a/bin/basenji_sad.py b/bin/basenji_sad.py index 366e4693..b7250f06 100755 --- a/bin/basenji_sad.py +++ b/bin/basenji_sad.py @@ -173,7 +173,7 @@ def main(): # build model t0 = time.time() model = basenji.seqnn.SeqNN() - model.build(job, target_subset=target_subset) + model.build_feed(job, target_subset=target_subset) print('Model building time %f' % (time.time() - t0), flush=True) if options.penultimate: diff --git a/bin/basenji_sat.py b/bin/basenji_sat.py index f5b3602f..4b27c0a8 100755 --- a/bin/basenji_sat.py +++ b/bin/basenji_sat.py @@ -189,7 +189,7 @@ def main(): t0 = time.time() dr = seqnn.SeqNN() - dr.build(job, target_subset=target_subset) + dr.build_feed(job, target_subset=target_subset) print('Model building time %f' % (time.time() - t0), flush=True) if options.batch_size is not None: diff --git a/bin/basenji_sat_vcf.py b/bin/basenji_sat_vcf.py index c0236b76..de8e610c 100755 --- a/bin/basenji_sat_vcf.py +++ b/bin/basenji_sat_vcf.py @@ -180,7 +180,7 @@ def main(): # build model dr = basenji.seqnn.SeqNN() - dr.build(job) + dr.build_feed(job) # initialize saver saver = tf.train.Saver() diff --git a/bin/basenji_sed.py b/bin/basenji_sed.py index 8e25e3eb..7956934c 100755 --- a/bin/basenji_sed.py +++ b/bin/basenji_sed.py @@ -230,7 +230,7 @@ def main(): # build model model = seqnn.SeqNN() - model.build(job, target_subset=target_subset) + model.build_feed(job, target_subset=target_subset) if options.penultimate: # labels become inappropriate diff --git a/bin/basenji_test.py b/bin/basenji_test.py index 1555d719..7750ea62 100755 --- a/bin/basenji_test.py +++ b/bin/basenji_test.py @@ -218,7 +218,7 @@ def main(): t0 = time.time() dr = seqnn.SeqNN() - dr.build(job) + dr.build_feed(job) print('Model building time %ds' % (time.time() - t0)) # adjust for fourier diff --git a/bin/basenji_test_genes.py b/bin/basenji_test_genes.py index b78d688a..c72eb0ef 100755 --- a/bin/basenji_test_genes.py +++ b/bin/basenji_test_genes.py @@ -181,7 +181,7 @@ def main(): # build model model = seqnn.SeqNN() - model.build(job) + model.build_feed(job) if options.batch_size is not None: model.hp.batch_size = options.batch_size diff --git a/bin/basenji_test_h5.py b/bin/basenji_test_h5.py index 7dd64484..56810cdb 100755 --- a/bin/basenji_test_h5.py +++ b/bin/basenji_test_h5.py @@ -214,7 +214,7 @@ def main(): t0 = time.time() model = seqnn.SeqNN() - model.build(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts) + model.build_feed(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts) print('Model building time %ds' % (time.time() - t0)) # adjust for fourier diff --git a/bin/basenji_test_reps.py b/bin/basenji_test_reps.py index ea7ce82e..3a180e05 100755 --- a/bin/basenji_test_reps.py +++ b/bin/basenji_test_reps.py @@ -136,7 +136,7 @@ def main(): job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) dr = basenji.seqnn.SeqNN() - dr.build(job) + dr.build_feed(job) # adjust for fourier job['fourier'] = 'train_out_imag' in data_open diff --git a/bin/basenji_train.py b/bin/basenji_train.py index f8154e23..6a540d49 100755 --- a/bin/basenji_train.py +++ b/bin/basenji_train.py @@ -73,7 +73,7 @@ def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_ba t0 = time.time() model = seqnn.SeqNN() - model.build(job) + model.build_feed(job) print('Model building time %f' % (time.time() - t0)) # adjust for fourier diff --git a/bin/basenji_train_h5.py b/bin/basenji_train_h5.py index 875ee3af..b6cbb87f 100755 --- a/bin/basenji_train_h5.py +++ b/bin/basenji_train_h5.py @@ -76,7 +76,7 @@ def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_ba t0 = time.time() model = seqnn.SeqNN() - model.build(job, augment_rc=FLAGS.augment_rc, augment_shifts=augment_shifts, + model.build_feed(job, augment_rc=FLAGS.augment_rc, augment_shifts=augment_shifts, ensemble_rc=FLAGS.ensemble_rc, ensemble_shifts=ensemble_shifts) print('Model building time %f' % (time.time() - t0)) From df296153dc1b90871705ea6f2072bbb0b9d9c70f Mon Sep 17 00:00:00 2001 From: David Kelley Date: Fri, 20 Jul 2018 18:56:12 -0700 Subject: [PATCH 23/71] predict in-graph ensembling --- basenji/seqnn.py | 2 + basenji/seqnn_util.py | 160 ++++++++++++++++++++++++++++++++++++++---- bin/basenji_sad.py | 13 ++-- 3 files changed, 156 insertions(+), 19 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index d58ed141..873b8978 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -341,6 +341,8 @@ def build_loss(self, preds, targets, target_subset=None): # slice buffer tstart = self.hp.batch_buffer // self.hp.target_pool tend = (self.hp.seq_length - self.hp.batch_buffer) // self.hp.target_pool + self.target_length = tend - tstart + targets = tf.identity(targets[:, tstart:tend, :], name='targets_op') if target_subset is not None: diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index b3c9bea9..d4f3ab78 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -631,19 +631,12 @@ def _predict_ensemble(self, return preds_batch, preds_batch_var, preds_all - def predict(self, - sess, - batcher, - rc=False, - shifts=[0], - mc_n=0, - target_indexes=None, - return_var=False, - return_all=False, - down_sample=1, - penultimate=False, - test_batches=None, - dtype='float32'): + def predict_h5_manual(self, sess, batcher, + rc=False, shifts=[0], mc_n=0, + target_indexes=None, + return_var=False, return_all=False, + down_sample=1, penultimate=False, + test_batches=None, dtype='float32'): """ Compute predictions on a test set. In @@ -765,6 +758,146 @@ def predict(self, else: return preds + def predict_h5(self, sess, batcher, + return_var=False, return_all=False, + penultimate=False, test_batches=None): + """ Compute preidctions on an HDF5 test set. + + Args: + sess: TensorFlow session + return_var: Return variance estimates + return_all: Retyrn all predictions. + penultimate: Predict the penultimate layer. + test_batches: Number of test batches to use. + + Returns: + preds: S (sequences) x L (unbuffered length) x T (targets) array + """ + fd = self.set_mode('test') + + # initialize prediction data structures + preds = [] + if return_var: + preds_var = [] + if return_all: + preds_all = [] + + # get first batch + batch_num = 0 + Xb, _, _, Nb = batcher.next() + + while Xb is not None and (test_batches is None or + batch_num < test_batches): + # update feed dict + fd[self.inputs_ph] = Xb + + # make predictions + if return_var or return_all: + preds_batch, preds_ensemble_batch = sess.run([self.preds_eval, self.preds_ensemble], feed_dict=fd) + + # move ensemble to back + preds_ensemble_batch = np.moveaxis(preds_ensemble_batch, 0, -1) + + else: + preds_batch = sess.run(self.preds_eval, feed_dict=fd) + + # accumulate predictions and targets + preds.append(preds_batch[:Nb]) + if return_var: + preds_var_batch = np.var(preds_ensemble_batch, axis=-1) + preds_var.append(preds_var_batch[:Nb]) + if return_all: + preds_all.append(preds_ensemble_batch[:Nb]) + + # next batch + batch_num += 1 + Xb, _, _, Nb = batcher.next() + + # reset batcher + batcher.reset() + + # construct arrays + preds = np.concatenate(preds, axis=0) + if return_var: + preds_var = np.concatenate(preds_var, axis=0) + if return_all: + preds_all = np.concatenate(preds_all, axis=0) + + if return_var: + if return_all: + return preds, preds_var, preds_all + else: + return preds, preds_var + else: + return preds + + def predict_tfr(self, sess, + return_var=False, return_all=False, + penultimate=False, test_batches=None): + """ Compute preidctions on a TFRecord test set. + + Args: + sess: TensorFlow session + return_var: Return variance estimates + return_all: Retyrn all predictions. + penultimate: Predict the penultimate layer. + test_batches: Number of test batches to use. + + Returns: + preds: S (sequences) x L (unbuffered length) x T (targets) array + """ + fd = self.set_mode('test') + + # initialize prediction data structures + preds = [] + if return_var: + preds_var = [] + if return_all: + preds_all = [] + + # sequence index + data_available = True + batch_num = 0 + while data_available and (test_batches is None or batch_num < test_batches): + try: + # make predictions + if return_var or return_all: + preds_batch, preds_ensemble_batch = sess.run([self.preds_eval, self.preds_ensemble], feed_dict=fd) + + # move ensemble to back + preds_ensemble_batch = np.moveaxis(preds_ensemble_batch, 0, -1) + + else: + preds_batch = sess.run(self.preds_eval, feed_dict=fd) + + # accumulate predictions and targets + preds.append(preds_batch) + if return_var: + preds_var_batch = np.var(preds_ensemble_batch, axis=-1) + preds_var.append(preds_var_batch) + if return_all: + preds_all.append(preds_ensemble_batch) + + batch_num += 1 + + except tf.errors.OutOfRangeError: + data_available = False + + # construct arrays + preds = np.concatenate(preds, axis=0) + if return_var: + preds_var = np.concatenate(preds_var, axis=0) + if return_all: + preds_all = np.concatenate(preds_all, axis=0) + + if return_var: + if return_all: + return preds, preds_var, preds_all + else: + return preds, preds_var + else: + return preds + def predict_genes(self, sess, batcher, @@ -901,7 +1034,6 @@ def test_h5(self, sess, batcher, test_batches=None): Args: sess: TensorFlow session batcher: Batcher object to provide data - mc_n: Monte Carlo iterations per rc/shift. test_batches: Number of test batches Returns: diff --git a/bin/basenji_sad.py b/bin/basenji_sad.py index b7250f06..2cd7a088 100755 --- a/bin/basenji_sad.py +++ b/bin/basenji_sad.py @@ -173,7 +173,9 @@ def main(): # build model t0 = time.time() model = basenji.seqnn.SeqNN() - model.build_feed(job, target_subset=target_subset) + # model.build_feed(job, target_subset=target_subset) + model.build_feed(job, target_subset=target_subset, + ensemble_rc=options.rc, ensemble_shifts=options.shifts) print('Model building time %f' % (time.time() - t0), flush=True) if options.penultimate: @@ -246,7 +248,6 @@ def main(): # initialize saver saver = tf.train.Saver() - with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) @@ -263,9 +264,11 @@ def main(): batcher = basenji.batcher.Batcher(batch_1hot, batch_size=model.hp.batch_size) # predict - batch_preds = model.predict(sess, batcher, - rc=options.rc, shifts=options.shifts, - penultimate=options.penultimate) + # batch_preds = model.predict(sess, batcher, + # rc=options.rc, shifts=options.shifts, + # penultimate=options.penultimate) + batch_preds = model.predict_h5(sess, batcher, + penultimate=options.penultimate) # normalize batch_preds /= target_norms From 269a2ed8fef50ad26461ad835ed9ab5a53965da1 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 21 Jul 2018 09:23:07 -0700 Subject: [PATCH 24/71] penultimate draft --- basenji/seqnn.py | 154 ++++++++++++++++++++++-------------------- basenji/seqnn_util.py | 30 ++++---- 2 files changed, 91 insertions(+), 93 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 873b8978..5116971b 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -34,7 +34,8 @@ def __init__(self): self.hparams_set = False def build_feed(self, job, augment_rc=False, augment_shifts=[0], - ensemble_rc=False, ensemble_shifts=[0], target_subset=None): + ensemble_rc=False, ensemble_shifts=[0], + penultimate=False, target_subset=None): """Build training ops that depend on placeholders.""" self.hp = params.make_hparams(job) @@ -46,12 +47,13 @@ def build_feed(self, job, augment_rc=False, augment_shifts=[0], augment_shifts=augment_shifts, ensemble_rc=ensemble_rc, ensemble_shifts=ensemble_shifts, + penultimate=penultimate, target_subset=target_subset) def build_from_data_ops(self, job, data_ops, augment_rc=False, augment_shifts=[0], ensemble_rc=False, ensemble_shifts=[0], - target_subset=None): + penultimate=False, target_subset=None): """Build training ops from input data ops.""" if not self.hparams_set: self.hp = params.make_hparams(job) @@ -69,7 +71,7 @@ def build_from_data_ops(self, job, data_ops, # compute train representation self.preds_train = self.build_predict(data_ops_train['sequence'], - None, target_subset) + None, penultimate, target_subset) # training losses loss_returns = self.build_loss(self.preds_train, data_ops_train['label'], target_subset) @@ -90,7 +92,7 @@ def build_from_data_ops(self, job, data_ops, # compute eval representation map_elems_eval = (data_seq_eval, data_rev_eval) - build_rep = lambda do: self.build_predict(do[0], do[1], target_subset) + build_rep = lambda do: self.build_predict(do[0], do[1], penultimate, target_subset) self.preds_ensemble = tf.map_fn(build_rep, map_elems_eval, dtype=self.preds_train.dtype) # back_prop=False self.preds_eval = tf.reduce_mean(self.preds_ensemble, axis=0) @@ -148,7 +150,7 @@ def _make_conv_block_args(self, layer_index, layer_reprs): 'name': 'conv-%d' % layer_index } - def build_predict(self, inputs, reverse_preds=None, target_subset=None): + def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_subset=None): """Construct per-location real-valued predictions.""" assert inputs is not None print('Targets pooled by %d to length %d' % @@ -158,17 +160,17 @@ def build_predict(self, inputs, reverse_preds=None, target_subset=None): # convolution layers ################################################### filter_weights = [] - layer_reprs = [inputs] + self.layer_reprs = [inputs] seqs_repr = inputs for layer_index in range(self.hp.cnn_layers): with tf.variable_scope('cnn%d' % layer_index, reuse=tf.AUTO_REUSE): # convolution block - args_for_block = self._make_conv_block_args(layer_index, layer_reprs) + args_for_block = self._make_conv_block_args(layer_index, self.layer_reprs) seqs_repr = layers.conv_block(seqs_repr=seqs_repr, **args_for_block) # save representation - layer_reprs.append(seqs_repr) + self.layer_reprs.append(seqs_repr) # final nonlinearity seqs_repr = tf.nn.relu(seqs_repr) @@ -190,44 +192,44 @@ def build_predict(self, inputs, reverse_preds=None, target_subset=None): seqs_repr = seqs_repr[:, batch_buffer_pool: seq_length - batch_buffer_pool, :] - # save penultimate representation - # self.penultimate_op = seqs_repr - ################################################### # final layer ################################################### - with tf.variable_scope('final', reuse=tf.AUTO_REUSE): - final_filters = self.hp.num_targets * self.hp.target_classes - final_repr = tf.layers.dense( - inputs=seqs_repr, - units=final_filters, - activation=None, - kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_in'), - kernel_regularizer=tf.contrib.layers.l1_regularizer(self.hp.final_l1_scale)) - print('Convolution w/ %d %dx1 filters to final targets' % - (final_filters, seqs_repr.shape[2])) - - if target_subset is not None: - # get convolution parameters - filters_full = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'final/dense/kernel')[0] - bias_full = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'final/dense/bias')[0] - - # subset to specific targets - filters_subset = tf.gather(filters_full, target_subset, axis=1) - bias_subset = tf.gather(bias_full, target_subset, axis=0) - - # substitute a new limited convolution - final_repr = tf.tensordot(seqs_repr, filters_subset, 1) - final_repr = tf.nn.bias_add(final_repr, bias_subset) - - # update # targets - self.hp.num_targets = len(target_subset) - - # expand length back out - if self.hp.target_classes > 1: - final_repr = tf.reshape(final_repr, - (self.hp.batch_size, -1, self.hp.num_targets, - self.hp.target_classes)) + if penultimate: + final_repr = seqs_repr + else: + with tf.variable_scope('final', reuse=tf.AUTO_REUSE): + final_filters = self.hp.num_targets * self.hp.target_classes + final_repr = tf.layers.dense( + inputs=seqs_repr, + units=final_filters, + activation=None, + kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_in'), + kernel_regularizer=tf.contrib.layers.l1_regularizer(self.hp.final_l1_scale)) + print('Convolution w/ %d %dx1 filters to final targets' % + (final_filters, seqs_repr.shape[2])) + + if target_subset is not None: + # get convolution parameters + filters_full = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'final/dense/kernel')[0] + bias_full = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'final/dense/bias')[0] + + # subset to specific targets + filters_subset = tf.gather(filters_full, target_subset, axis=1) + bias_subset = tf.gather(bias_full, target_subset, axis=0) + + # substitute a new limited convolution + final_repr = tf.tensordot(seqs_repr, filters_subset, 1) + final_repr = tf.nn.bias_add(final_repr, bias_subset) + + # update # targets + self.hp.num_targets = len(target_subset) + + # expand length back out + if self.hp.target_classes > 1: + final_repr = tf.reshape(final_repr, + (self.hp.batch_size, -1, self.hp.num_targets, + self.hp.target_classes)) # transform for reverse complement if reverse_preds is not None: @@ -238,47 +240,49 @@ def build_predict(self, inputs, reverse_preds=None, target_subset=None): ################################################### # link function ################################################### + if penultimate: + predictions = final_repr + else: + # work-around for specifying my own predictions + # self.preds_adhoc = tf.placeholder( + # tf.float32, shape=final_repr.shape, name='preds-adhoc') - # work-around for specifying my own predictions - # self.preds_adhoc = tf.placeholder( - # tf.float32, shape=final_repr.shape, name='preds-adhoc') - - # float 32 exponential clip max - exp_max = 50 + # float 32 exponential clip max + exp_max = 50 - # choose link - if self.hp.link in ['identity', 'linear']: - predictions = tf.identity(final_repr, name='preds') + # choose link + if self.hp.link in ['identity', 'linear']: + predictions = tf.identity(final_repr, name='preds') - elif self.hp.link == 'relu': - predictions = tf.relu(final_repr, name='preds') + elif self.hp.link == 'relu': + predictions = tf.relu(final_repr, name='preds') - elif self.hp.link == 'exp': - final_repr_clip = tf.clip_by_value(final_repr, -exp_max, exp_max) - predictions = tf.exp(final_repr_clip, name='preds') + elif self.hp.link == 'exp': + final_repr_clip = tf.clip_by_value(final_repr, -exp_max, exp_max) + predictions = tf.exp(final_repr_clip, name='preds') - elif self.hp.link == 'exp_linear': - predictions = tf.where( - final_repr > 0, - final_repr + 1, - tf.exp(tf.clip_by_value(final_repr, -exp_max, exp_max)), - name='preds') + elif self.hp.link == 'exp_linear': + predictions = tf.where( + final_repr > 0, + final_repr + 1, + tf.exp(tf.clip_by_value(final_repr, -exp_max, exp_max)), + name='preds') - elif self.hp.link == 'softplus': - final_repr_clip = tf.clip_by_value(final_repr, -exp_max, 10000) - predictions = tf.nn.softplus(final_repr_clip, name='preds') + elif self.hp.link == 'softplus': + final_repr_clip = tf.clip_by_value(final_repr, -exp_max, 10000) + predictions = tf.nn.softplus(final_repr_clip, name='preds') - else: - print('Unknown link function %s' % self.hp.link, file=sys.stderr) - exit(1) + else: + print('Unknown link function %s' % self.hp.link, file=sys.stderr) + exit(1) - # clip - if self.hp.target_clip is not None: - predictions = tf.clip_by_value(predictions, 0, self.hp.target_clip) + # clip + if self.hp.target_clip is not None: + predictions = tf.clip_by_value(predictions, 0, self.hp.target_clip) - # sqrt - if self.hp.target_sqrt: - predictions = tf.sqrt(predictions) + # sqrt + if self.hp.target_sqrt: + predictions = tf.sqrt(predictions) return predictions diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index d4f3ab78..f9d9dae1 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -596,10 +596,7 @@ def _predict_ensemble(self, # print('ei=%d, mi=%d, fwdrc=%d, shifts=%d' % (ei, mi, ensemble_fwdrc[ei], ensemble_shifts[ei]), flush=True) # predict - if penultimate: - preds_ei = sess.run(self.penultimate_op, feed_dict=fd) - else: - preds_ei = sess.run(self.preds_eval, feed_dict=fd) + preds_ei = sess.run(self.preds_eval, feed_dict=fd) # reverse if ensemble_fwdrc[ei] is False: @@ -641,7 +638,7 @@ def predict_h5_manual(self, sess, batcher, In sess: TensorFlow session - batcher: Batcher class with transcript-covering sequences. + batcher: Batcher class with sequences. rc: Average predictions from the forward and reverse complement sequences. shifts: Average predictions from sequence shifts left/right. @@ -758,17 +755,16 @@ def predict_h5_manual(self, sess, batcher, else: return preds - def predict_h5(self, sess, batcher, - return_var=False, return_all=False, - penultimate=False, test_batches=None): + def predict_h5(self, sess, batcher, test_batches=None, + return_var=False, return_all=False): """ Compute preidctions on an HDF5 test set. Args: sess: TensorFlow session + batcher: Batcher class with sequences. + test_batches: Number of test batches to use. return_var: Return variance estimates return_all: Retyrn all predictions. - penultimate: Predict the penultimate layer. - test_batches: Number of test batches to use. Returns: preds: S (sequences) x L (unbuffered length) x T (targets) array @@ -831,17 +827,15 @@ def predict_h5(self, sess, batcher, else: return preds - def predict_tfr(self, sess, - return_var=False, return_all=False, - penultimate=False, test_batches=None): + def predict_tfr(self, sess, test_batches=None + return_var=False, return_all=False): """ Compute preidctions on a TFRecord test set. Args: sess: TensorFlow session + test_batches: Number of test batches to use. return_var: Return variance estimates return_all: Retyrn all predictions. - penultimate: Predict the penultimate layer. - test_batches: Number of test batches to use. Returns: preds: S (sequences) x L (unbuffered length) x T (targets) array @@ -952,9 +946,9 @@ def predict_genes(self, while not batcher.empty(): # predict gene sequences - gseq_preds = self.predict(sess, batcher, rc=rc, shifts=shifts, mc_n=mc_n, - target_indexes=target_indexes, penultimate=penultimate, - test_batches=test_batches_per) + gseq_preds = self.predict_h5_manual(sess, batcher, rc=rc, shifts=shifts, mc_n=mc_n, + target_indexes=target_indexes, penultimate=penultimate, + test_batches=test_batches_per) # slice TSSs for bsi in range(gseq_preds.shape[0]): for tss in gene_seqs[si].tss_list: From a16a31b79f5c932ea36eb072b0093d72016c31d7 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sun, 22 Jul 2018 09:13:40 -0700 Subject: [PATCH 25/71] missing comma --- basenji/seqnn_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index f9d9dae1..824a7067 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -827,7 +827,7 @@ def predict_h5(self, sess, batcher, test_batches=None, else: return preds - def predict_tfr(self, sess, test_batches=None + def predict_tfr(self, sess, test_batches=None, return_var=False, return_all=False): """ Compute preidctions on a TFRecord test set. From 2e5af1c9052ab01855b9d4680f80438ab7383429 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sun, 22 Jul 2018 09:15:00 -0700 Subject: [PATCH 26/71] penultimate loss fix --- basenji/seqnn.py | 19 ++++++++++--------- bin/basenji_sad.py | 7 +++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 5116971b..5dfd39bf 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -72,14 +72,16 @@ def build_from_data_ops(self, job, data_ops, # compute train representation self.preds_train = self.build_predict(data_ops_train['sequence'], None, penultimate, target_subset) + self.target_length = self.preds_train.shape[1].value # training losses - loss_returns = self.build_loss(self.preds_train, data_ops_train['label'], target_subset) - self.loss_train, self.loss_train_targets, self.targets_train = loss_returns + if not penultimate: + loss_returns = self.build_loss(self.preds_train, data_ops_train['label'], target_subset) + self.loss_train, self.loss_train_targets, self.targets_train = loss_returns - # optimizer - self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - self.build_optimizer(self.loss_train) + # optimizer + self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + self.build_optimizer(self.loss_train) ################################################## # eval @@ -97,8 +99,9 @@ def build_from_data_ops(self, job, data_ops, self.preds_eval = tf.reduce_mean(self.preds_ensemble, axis=0) # eval loss - loss_returns = self.build_loss(self.preds_eval, data_ops['label'], target_subset) - self.loss_eval, self.loss_eval_targets, self.targets_eval = loss_returns + if not penultimate: + loss_returns = self.build_loss(self.preds_eval, data_ops['label'], target_subset) + self.loss_eval, self.loss_eval_targets, self.targets_eval = loss_returns # helper variables self.preds_length = self.preds_train.shape[1] @@ -345,8 +348,6 @@ def build_loss(self, preds, targets, target_subset=None): # slice buffer tstart = self.hp.batch_buffer // self.hp.target_pool tend = (self.hp.seq_length - self.hp.batch_buffer) // self.hp.target_pool - self.target_length = tend - tstart - targets = tf.identity(targets[:, tstart:tend, :], name='targets_op') if target_subset is not None: diff --git a/bin/basenji_sad.py b/bin/basenji_sad.py index 2cd7a088..00537e1d 100755 --- a/bin/basenji_sad.py +++ b/bin/basenji_sad.py @@ -174,8 +174,8 @@ def main(): t0 = time.time() model = basenji.seqnn.SeqNN() # model.build_feed(job, target_subset=target_subset) - model.build_feed(job, target_subset=target_subset, - ensemble_rc=options.rc, ensemble_shifts=options.shifts) + model.build_feed(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts, + target_subset=target_subset, penultimate=options.penultimate) print('Model building time %f' % (time.time() - t0), flush=True) if options.penultimate: @@ -267,8 +267,7 @@ def main(): # batch_preds = model.predict(sess, batcher, # rc=options.rc, shifts=options.shifts, # penultimate=options.penultimate) - batch_preds = model.predict_h5(sess, batcher, - penultimate=options.penultimate) + batch_preds = model.predict_h5(sess, batcher) # normalize batch_preds /= target_norms From f376d01b4e25e3b34ccc1356210f89b4c58e567d Mon Sep 17 00:00:00 2001 From: David Kelley Date: Mon, 23 Jul 2018 15:48:14 -0700 Subject: [PATCH 27/71] hidden and map --- basenji/seqnn.py | 31 ++++++++++-------- basenji/seqnn_util.py | 31 +++++++++++++----- bin/basenji_hidden.py | 76 +++++++++++++++++++++++-------------------- bin/basenji_map.py | 2 +- 4 files changed, 83 insertions(+), 57 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 5dfd39bf..49c2d179 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -71,7 +71,8 @@ def build_from_data_ops(self, job, data_ops, # compute train representation self.preds_train = self.build_predict(data_ops_train['sequence'], - None, penultimate, target_subset) + None, penultimate, target_subset, + save_reprs=True) self.target_length = self.preds_train.shape[1].value # training losses @@ -103,10 +104,13 @@ def build_from_data_ops(self, job, data_ops, loss_returns = self.build_loss(self.preds_eval, data_ops['label'], target_subset) self.loss_eval, self.loss_eval_targets, self.targets_eval = loss_returns + # update # targets + if target_subset is not None: + self.hp.num_targets = len(target_subset) + # helper variables self.preds_length = self.preds_train.shape[1] - def make_placeholders(self): """Allocates placeholders to be used in place of input data ops.""" # batches @@ -153,7 +157,7 @@ def _make_conv_block_args(self, layer_index, layer_reprs): 'name': 'conv-%d' % layer_index } - def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_subset=None): + def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_subset=None, save_reprs=False): """Construct per-location real-valued predictions.""" assert inputs is not None print('Targets pooled by %d to length %d' % @@ -163,17 +167,20 @@ def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_su # convolution layers ################################################### filter_weights = [] - self.layer_reprs = [inputs] + layer_reprs = [inputs] seqs_repr = inputs for layer_index in range(self.hp.cnn_layers): with tf.variable_scope('cnn%d' % layer_index, reuse=tf.AUTO_REUSE): # convolution block - args_for_block = self._make_conv_block_args(layer_index, self.layer_reprs) + args_for_block = self._make_conv_block_args(layer_index, layer_reprs) seqs_repr = layers.conv_block(seqs_repr=seqs_repr, **args_for_block) # save representation - self.layer_reprs.append(seqs_repr) + layer_reprs.append(seqs_repr) + + if save_reprs: + self.layer_reprs = layer_reprs # final nonlinearity seqs_repr = tf.nn.relu(seqs_repr) @@ -225,9 +232,6 @@ def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_su final_repr = tf.tensordot(seqs_repr, filters_subset, 1) final_repr = tf.nn.bias_add(final_repr, bias_subset) - # update # targets - self.hp.num_targets = len(target_subset) - # expand length back out if self.hp.target_classes > 1: final_repr = tf.reshape(final_repr, @@ -381,12 +385,13 @@ def build_loss(self, preds, targets, target_subset=None): # reduce lossses by batch and position loss_op = tf.reduce_mean(loss_op, axis=[0, 1], name='target_loss') loss_op = tf.check_numerics(loss_op, 'Invalid loss', name='loss_check') - - tf.summary.histogram('target_loss', loss_op) - for ti in np.linspace(0, self.hp.num_targets - 1, 10).astype('int'): - tf.summary.scalar('loss_t%d' % ti, loss_op[ti]) target_losses = loss_op + if target_subset is None: + tf.summary.histogram('target_loss', loss_op) + for ti in np.linspace(0, self.hp.num_targets - 1, 10).astype('int'): + tf.summary.scalar('loss_t%d' % ti, loss_op[ti]) + # fully reduce loss_op = tf.reduce_mean(loss_op, name='loss') diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 824a7067..ac5bbb09 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -27,7 +27,7 @@ def build_grads(self, layers=[0]): self.grad_ops = [] for ti in range(self.hp.num_targets): - grad_ti_op = tf.gradients(self.preds_eval[:,:,ti], [self.layer_reprs[li] for li in self.grad_layers]) + grad_ti_op = tf.gradients(self.preds_train[:,:,ti], [self.layer_reprs[li] for li in self.grad_layers]) self.grad_ops.append(grad_ti_op) @@ -224,7 +224,9 @@ def gradients(self, return layer_grads, layer_reprs, preds - def _gradients_ensemble(self, sess, fd, Xb, ensemble_fwdrc, ensemble_shifts, mc_n, return_var=False, return_all=False): + def _gradients_ensemble(self, sess, fd, Xb, + ensemble_fwdrc, ensemble_shifts, mc_n, + return_var=False, return_all=False): """ Compute gradients over an ensemble of input augmentations. In @@ -312,7 +314,7 @@ def _gradients_ensemble(self, sess, fd, Xb, ensemble_fwdrc, ensemble_shifts, mc_ # prediction # predict - preds_ei, layer_reprs_ei = sess.run([self.preds_eval, self.layer_reprs], feed_dict=fd) + preds_ei, layer_reprs_ei = sess.run([self.preds_train, self.layer_reprs], feed_dict=fd) # reverse if ensemble_fwdrc[ei] is False: @@ -451,7 +453,7 @@ def gradients_genes(self, sess, batcher, gene_seqs): fd[self.inputs_ph] = Xb # predict - reprs_batch, _ = sess.run([self.layer_reprs, self.preds_eval], feed_dict=fd) + reprs_batch, _ = sess.run([self.layer_reprs, self.preds_train], feed_dict=fd) # save representations for lii in range(len(self.grad_layers)): @@ -487,8 +489,18 @@ def gradients_genes(self, sess, batcher, gene_seqs): return layer_grads, layer_reprs - def hidden(self, sess, batcher, layers=None): - """ Compute hidden representations for a test set. """ + def hidden(self, sess, batcher, layers=None, test_batches=None): + """ Compute hidden representations for a test set. + + In + sess: TensorFlow session + batcher: Batcher class with sequences. + layers: Layer indexes to return representations. + test_batches: Number of test batches to use. + + Out + preds: S (sequences) x L (unbuffered length) x T (targets) array + """ if layers is None: layers = list(range(self.hp.cnn_layers)) @@ -505,13 +517,15 @@ def hidden(self, sess, batcher, layers=None): # get first batch Xb, _, _, Nb = batcher.next() - while Xb is not None: + batch_num = 0 + while Xb is not None and (test_batches is None or + batch_num < test_batches): # update feed dict fd[self.inputs_ph] = Xb # compute predictions layer_reprs_batch, preds_batch = sess.run( - [self.layer_reprs, self.preds_eval], feed_dict=fd) + [self.layer_reprs, self.preds_train], feed_dict=fd) # accumulate representationsmakes the number of members for self smaller and also for li in layers: @@ -527,6 +541,7 @@ def hidden(self, sess, batcher, layers=None): # next batch Xb, _, _, Nb = batcher.next() + batch_num += 1 # reset batcher batcher.reset() diff --git a/bin/basenji_hidden.py b/bin/basenji_hidden.py index 2fc9bd4e..c7ba14e2 100755 --- a/bin/basenji_hidden.py +++ b/bin/basenji_hidden.py @@ -30,7 +30,10 @@ import statsmodels import tensorflow as tf -import basenji +from basenji import batcher +from basenji import params +from basenji import plots +from basenji import seqnn ################################################################################ # basenji_hidden.py @@ -45,24 +48,15 @@ def main(): usage = 'usage: %prog [options] ' parser = OptionParser(usage) - parser.add_option( - '-l', - dest='layers', - default=None, - help='Comma-separated list of layers to plot') - parser.add_option( - '-n', - dest='num_seqs', - default=None, - type='int', + parser.add_option('-l', dest='layers', + default=None, help='Comma-separated list of layers to plot') + parser.add_option('-n', dest='num_seqs', + default=None, type='int', help='Number of sequences to process') - parser.add_option( - '-o', - dest='out_dir', - default='hidden', - help='Output directory [Default: %default]') - parser.add_option( - '-t', dest='target_indexes', default=None, help='Target indexes to plot') + parser.add_option('-o', dest='out_dir', + default='hidden', help='Output directory [Default: %default]') + parser.add_option('-t', dest='target_indexes', + default=None, help='Paint 2D plots with these target index values.') (options, args) = parser.parse_args() if len(args) != 3: @@ -92,17 +86,16 @@ def main(): ####################################################### # model parameters and placeholders ####################################################### - job = basenji.dna_io.read_job_params(params_file) + job = params.read_job_params(params_file) job['seq_length'] = test_seqs.shape[1] job['seq_depth'] = test_seqs.shape[2] job['num_targets'] = test_targets.shape[2] job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) - job['save_reprs'] = True t0 = time.time() - model = basenji.seqnn.SeqNN() - model.build(job) + model = seqnn.SeqNN() + model.build_feed(job) if options.target_indexes is None: options.target_indexes = range(job['num_targets']) @@ -115,11 +108,11 @@ def main(): # test ####################################################### # initialize batcher - batcher_test = basenji.batcher.Batcher( + batcher_test = batcher.Batcher( test_seqs, test_targets, - batch_size=model.batch_size, - pool_width=model.target_pool) + batch_size=model.hp.batch_size, + pool_width=model.hp.target_pool) # initialize saver saver = tf.train.Saver() @@ -144,12 +137,14 @@ def main(): # sample one nt per sequence ds_indexes = np.arange(0, layer_repr.shape[1], 256) nt_reprs = layer_repr[:, ds_indexes, :].reshape((-1, layer_repr.shape[2])) + print('nt_reprs', nt_reprs.shape) ######################################################## # plot raw sns.set(style='ticks', font_scale=1.2) plt.figure() - g = sns.clustermap(nt_reprs, xticklabels=False, yticklabels=False) + g = sns.clustermap(nt_reprs, cmap='RdBu_r', + xticklabels=False, yticklabels=False) g.ax_heatmap.set_xlabel('Representation') g.ax_heatmap.set_ylabel('Sequences') plt.savefig('%s/l%d_reprs.pdf' % (options.out_dir, li)) @@ -182,7 +177,18 @@ def main(): nt_2d = model2.fit_transform(nt_reprs) for ti in options.target_indexes: - nt_targets = np.log2(test_targets[:, ds_indexes, ti].flatten() + 1) + # slice for target + test_targets_ti = test_targets[:,:,ti] + + # repeat to match layer_repr + target_repeat = layer_repr.shape[1] // test_targets.shape[1] + test_targets_ti = np.repeat(test_targets_ti, target_repeat, axis=1) + + # downsample indexes + nt_targets = test_targets_ti[:,ds_indexes].flatten() + + # log transform + nt_targets = np.log1p(nt_targets) plt.figure() plt.scatter( @@ -193,17 +199,17 @@ def main(): plt.savefig('%s/l%d_nt2d_t%d.pdf' % (options.out_dir, li, ti)) plt.close() + ######################################################## # plot neuron-neuron correlations - # mean-normalize representation - nt_reprs_norm = nt_reprs - nt_reprs.mean(axis=0) - - # compute covariance matrix - hidden_cov = np.dot(nt_reprs_norm.T, nt_reprs_norm) + # compute correlation matrix + hidden_cov = np.corrcoef(nt_reprs.T) + print('hidden_cov', hidden_cov.shape) plt.figure() - g = sns.clustermap(hidden_cov, xticklabels=False, yticklabels=False) + g = sns.clustermap(hidden_cov, cmap='RdBu_r', + xticklabels=False, yticklabels=False) plt.savefig('%s/l%d_cov.pdf' % (options.out_dir, li)) plt.close() @@ -258,8 +264,8 @@ def regplot(vals1, vals2, out_pdf, alpha=0.5, x_label=None, y_label=None): 'alpha': alpha}, line_kws={'color': gold}) - xmin, xmax = basenji.plots.scatter_lims(vals1) - ymin, ymax = basenji.plots.scatter_lims(vals2) + xmin, xmax = plots.scatter_lims(vals1) + ymin, ymax = plots.scatter_lims(vals2) ax.set_xlim(xmin, xmax) if x_label is not None: diff --git a/bin/basenji_map.py b/bin/basenji_map.py index e5034947..30fd904d 100755 --- a/bin/basenji_map.py +++ b/bin/basenji_map.py @@ -128,7 +128,7 @@ def main(): # build model model = seqnn.SeqNN() - model.build(job, target_subset=target_subset) + model.build_feed(job, target_subset=target_subset) # determine latest pre-dilated layer cnn_dilation = np.array([cp.dilation for cp in model.hp.cnn_params]) From 29dd294bf104eb6f38559a6665fc2ff7d233afc9 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 25 Jul 2018 17:01:02 -0700 Subject: [PATCH 28/71] align predict_h5 with predict_h5_manual --- basenji/seqnn_util.py | 53 ++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index ac5bbb09..192d5d82 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -736,7 +736,6 @@ def predict_h5_manual(self, sess, batcher, # while we want more batches while test_batches is None or batch_num < test_batches: - # get batch Xb, _, _, Nb = batcher.next() @@ -793,39 +792,41 @@ def predict_h5(self, sess, batcher, test_batches=None, if return_all: preds_all = [] - # get first batch + # count batches batch_num = 0 - Xb, _, _, Nb = batcher.next() - while Xb is not None and (test_batches is None or - batch_num < test_batches): - # update feed dict - fd[self.inputs_ph] = Xb + # while we want more batches + while test_batches is None or batch_num < test_batches: + # get batch + Xb, _, _, Nb = batcher.next() - # make predictions - if return_var or return_all: - preds_batch, preds_ensemble_batch = sess.run([self.preds_eval, self.preds_ensemble], feed_dict=fd) + # verify fidelity + if Xb is None: + break + else: + # update feed dict + fd[self.inputs_ph] = Xb - # move ensemble to back - preds_ensemble_batch = np.moveaxis(preds_ensemble_batch, 0, -1) + # make predictions + if return_var or return_all: + preds_batch, preds_ensemble_batch = sess.run([self.preds_eval, self.preds_ensemble], feed_dict=fd) - else: - preds_batch = sess.run(self.preds_eval, feed_dict=fd) + # move ensemble to back + preds_ensemble_batch = np.moveaxis(preds_ensemble_batch, 0, -1) - # accumulate predictions and targets - preds.append(preds_batch[:Nb]) - if return_var: - preds_var_batch = np.var(preds_ensemble_batch, axis=-1) - preds_var.append(preds_var_batch[:Nb]) - if return_all: - preds_all.append(preds_ensemble_batch[:Nb]) + else: + preds_batch = sess.run(self.preds_eval, feed_dict=fd) - # next batch - batch_num += 1 - Xb, _, _, Nb = batcher.next() + # accumulate predictions and targets + preds.append(preds_batch[:Nb]) + if return_var: + preds_var_batch = np.var(preds_ensemble_batch, axis=-1) + preds_var.append(preds_var_batch[:Nb]) + if return_all: + preds_all.append(preds_ensemble_batch[:Nb]) - # reset batcher - batcher.reset() + # next batch + batch_num += 1 # construct arrays preds = np.concatenate(preds, axis=0) From 0f0f34a15b798bfaecf7be9743b3189517965b30 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Fri, 10 Aug 2018 15:46:17 -0700 Subject: [PATCH 29/71] sqrt soft clip --- bin/basenji_data.py | 11 ++++++++--- bin/basenji_data_read.py | 11 +++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/bin/basenji_data.py b/bin/basenji_data.py index 5fcf78c1..6ca0bf62 100755 --- a/bin/basenji_data.py +++ b/bin/basenji_data.py @@ -70,6 +70,9 @@ def main(): parser.add_option('-p', dest='processes', default=None, type='int', help='Number parallel processes [Default: %default]') + parser.add_option('-r', dest='seqs_per_tfr', + default=256, type='int', + help='Sequences per TFRecord file [Default: %default]') parser.add_option('--seed', dest='seed', default=44, type='int', help='Random seed [Default: %default]') @@ -82,9 +85,9 @@ def main(): parser.add_option('-s', dest='sum_stat', default='sum', help='Summary statistic to compute in windows [Default: %default]') - parser.add_option('-r', dest='seqs_per_tfr', - default=256, type='int', - help='Sequences per TFRecord file [Default: %default]') + parser.add_option('--soft', dest='soft_clip', + default=False, action='store_true', + help='Soft clip values, applying sqrt to the execess above the threshold [Default: %default]') parser.add_option('-t', dest='test_pct_or_chr', default=0.05, type='str', help='Proportion of the data for testing [Default: %default]') @@ -238,6 +241,8 @@ def main(): cmd += ' -w %d' % options.pool_width cmd += ' -s %s' % options.sum_stat cmd += ' -c %f' % clip_ti + if options.soft_clip: + cmd += ' --soft' if options.blacklist_bed: cmd += ' -b %s' % options.blacklist_bed cmd += ' %s' % genome_cov_file diff --git a/bin/basenji_data_read.py b/bin/basenji_data_read.py index 15107c38..9e91b98a 100755 --- a/bin/basenji_data_read.py +++ b/bin/basenji_data_read.py @@ -41,7 +41,10 @@ def main(): help='Set blacklist nucleotides to a baseline value.') parser.add_option('-c', dest='clip', default=None, type='float', - help='Clip absolute values post-summary to a maximum [Default: %default]') + help='Clip values post-summary to a maximum [Default: %default]') + parser.add_option('--soft', dest='soft_clip', + default=False, action='store_true', + help='Soft clip values, applying sqrt to the execess above the threshold [Default: %default]') parser.add_option('-s', dest='sum_stat', default='sum', help='Summary statistic to compute in windows [Default: %default]') @@ -121,7 +124,11 @@ def main(): # clip if options.clip is not None: - seq_cov = np.clip(seq_cov, -options.clip, options.clip) + if options.soft_clip: + clip_mask = (seq_cov > options.clip) + seq_cov[clip_mask] = options.clip + np.sqrt(seq_cov[clip_mask] - options.clip) + else: + seq_cov = np.clip(seq_cov, 0, options.clip) # write seqs_cov_open['seqs_cov'][si,:] = seq_cov.astype('float16') From 672b434d4397b3d66f80340e3569f18ac0d840c5 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Fri, 10 Aug 2018 18:05:25 -0700 Subject: [PATCH 30/71] debugging TFR --- bin/tfr_qc.py | 130 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100755 bin/tfr_qc.py diff --git a/bin/tfr_qc.py b/bin/tfr_qc.py new file mode 100755 index 00000000..8133db66 --- /dev/null +++ b/bin/tfr_qc.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +from optparse import OptionParser +import multiprocessing +import os + +import h5py +import numpy as np +import pandas as pd +import tensorflow as tf + +from basenji import tfrecord_batcher + +''' +tfr_qc.py + +Print quality control statistics for a TFRecords dataset. +''' + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('-l', dest='target_length', + default=1024, type='int') + parser.add_option('-o', dest='out_dir', default='tfr_qc') + parser.add_option('-p', dest='processes', default=16, type='int', + help='Number of parallel threads to use [Default: %default]') + parser.add_option('-s', dest='split', default='test') + (options,args) = parser.parse_args() + + if len(args) != 1: + parser.error('Must provide TFRecords data directory') + else: + tfr_data_dir = args[0] + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + # read target datasets + targets_file = '%s/targets.txt' % tfr_data_dir + targets_df = pd.read_table(targets_file) + + # read target values + tfr_pattern = '%s/tfrecords/%s-*.tfr' % (tfr_data_dir, options.split) + targets = read_tfr(tfr_pattern, options.target_length) + + # compute stats + target_means = np.mean(targets, axis=(0,1), dtype='float64') + target_max = np.max(targets, axis=(0,1)) + + # print statistics for each target + table_out = open('%s/table.txt' % options.out_dir, 'w') + for ti in range(targets.shape[2]): + cols = (ti, target_means[ti], target_max[ti], targets_df.identifier[ti], targets_df.description[ti]) + print('%-4d %8.3f %7.3f %16s %s' % cols, file=table_out) + table_out.close() + + # plot distributions for each target + distr_dir = '%s/distr' % options.out_dir + if not os.path.isdir(distr_dir): + os.mkdir(distr_dir) + + # initialize multiprocessing pool + pool = multiprocessing.Pool(options.processes) + + plot_distr_args = [] + for ti in range(targets.shape[2]): + targets_ti = np.random.choice(targets[:,:,ti].flatten(), size=10000, replace=False) + plot_distr_args.append((targets_ti, '%s/t%d.pdf' % (distr_dir,ti))) + + pool.starmap(plot_distr, plot_distr_args) + +def plot_distr(targets_ti, out_pdf): + plt.figure() + sns.distplot(targets_ti) + plt.savefig(out_pdf) + plt.close() + + +def read_tfr(tfr_pattern, target_len): + tfr_files = tfrecord_batcher.order_tfrecords(tfr_pattern) + if tfr_files: + dataset = tf.data.Dataset.list_files(tf.constant(tfr_files), shuffle=False) + else: + dataset = tf.data.Dataset.list_files(tfr_pattern) + dataset = dataset.flat_map(file_to_records) + dataset = dataset.batch(1) + dataset = dataset.map(parse_proto) + + iterator = dataset.make_one_shot_iterator() + + next_op = iterator.get_next() + + targets = [] + + with tf.Session() as sess: + next_datum = sess.run(next_op) + while next_datum: + targets1 = next_datum['targets'].reshape(target_len,-1) + targets.append(targets1) + + try: + next_datum = sess.run(next_op) + except tf.errors.OutOfRangeError: + next_datum = False + + return np.array(targets) + + +def file_to_records(filename): + return tf.data.TFRecordDataset(filename, compression_type='ZLIB') + +def parse_proto(example_protos): + features = { + 'sequence': tf.FixedLenFeature([], tf.string), + 'target': tf.FixedLenFeature([], tf.string) + } + parsed_features = tf.parse_example(example_protos, features=features) + seq = tf.decode_raw(parsed_features['sequence'], tf.uint8) + targets = tf.decode_raw(parsed_features['target'], tf.float16) + return {'sequence': seq, 'targets': targets} + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() From e0a04083b5b7bfb6a528ae94c1dca1c0bdf1180d Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 11 Aug 2018 18:06:41 -0700 Subject: [PATCH 31/71] nan baseline --- bin/basenji_data_read.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bin/basenji_data_read.py b/bin/basenji_data_read.py index 9e91b98a..e3389fe4 100755 --- a/bin/basenji_data_read.py +++ b/bin/basenji_data_read.py @@ -95,6 +95,7 @@ def main(): # determine baseline coverage baseline_cov = np.percentile(seq_cov_nt, 10) + baseline_cov = np.nan_to_num(baseline_cov) # set blacklist to baseline if mseq.chr in black_chr_trees: From 777b0e5fb0667ab6e22767c8510d5f23a23a4185 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 11 Aug 2018 18:06:58 -0700 Subject: [PATCH 32/71] seqs_per_tfr --- bin/h5_tfr.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/bin/h5_tfr.py b/bin/h5_tfr.py index cae82e0e..3c181f30 100755 --- a/bin/h5_tfr.py +++ b/bin/h5_tfr.py @@ -5,6 +5,7 @@ import os import h5py +import numpy as np import tensorflow as tf from basenji import tfrecord_util @@ -21,10 +22,12 @@ def main(): usage = 'usage: %prog [options] ' parser = OptionParser(usage) - parser.add_option('-p', dest='processes', default=16, type='int', - help='Number of parallel threads to use [Default: %default]') - parser.add_option('-s', dest='shards', default=16, type='int', - help='Number of sharded files to output per dataset [Default: %default]') + parser.add_option('-p', dest='processes', + default=16, type='int', + help='Number of parallel threads to use [Default: %default]') + parser.add_option('-s', dest='seqs_per_tfr', + default=256, type='int', + help='Sequences per TFRecord file [Default: %default]') (options,args) = parser.parse_args() if len(args) != 2: @@ -43,11 +46,21 @@ def main(): # initialize multiprocessing pool pool = multiprocessing.Pool(options.processes) + h5_open = h5py.File(h5_file) + for dataset in ['train', 'valid', 'test']: - tfr_files = ['%s/%s-%d.tfr' % (tfr_dir,dataset,si) for si in range(options.shards)] + # count sequences + data_in = h5_open['%s_in'%dataset] + num_seqs = data_in.shape[0] + + # shards + shards = int(np.ceil(num_seqs/options.seqs_per_tfr)) - writer_args = [(tfr_files[si], tf_opts, h5_file, dataset, si, options.shards) for si in range(options.shards)] + # args + tfr_files = ['%s/%s-%d.tfr' % (tfr_dir,dataset,si) for si in range(shards)] + writer_args = [(tfr_files[si], tf_opts, h5_file, dataset, si, shards) for si in range(shards)] + # computes pool.starmap(writer_worker, writer_args) From ab0ed1d6cef280307b7a56de115692f58cbbc8b5 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sun, 12 Aug 2018 21:14:27 -0700 Subject: [PATCH 33/71] shuffle bug --- bin/basenji_hdf5_cluster.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bin/basenji_hdf5_cluster.py b/bin/basenji_hdf5_cluster.py index 5b43a453..0ff512ac 100755 --- a/bin/basenji_hdf5_cluster.py +++ b/bin/basenji_hdf5_cluster.py @@ -343,9 +343,9 @@ def main(): set(range(len(seqs_segments))) - set(valid_indexes) - set(test_indexes)) # training may require shuffling - random.shuffle(sorted(train_indexes)) - random.shuffle(sorted(valid_indexes)) - random.shuffle(sorted(test_indexes)) + random.shuffle(train_indexes) + random.shuffle(valid_indexes) + random.shuffle(test_indexes) # write to HDF5 hdf5_out = h5py.File(hdf5_file, 'w') From c86098f2f9e9f760001d7e2526492acffaaa9a25 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Fri, 29 Jun 2018 18:01:47 -0700 Subject: [PATCH 34/71] failing with map_fn --- basenji/seqnn.py | 33 ++++++++++++++++++++++++--------- basenji/tfrecord_batcher.py | 21 ++++++++++++++------- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index c3f782f4..e62c9063 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -21,6 +21,7 @@ import numpy as np import tensorflow as tf +from basenji import augmentation from basenji import layers from basenji import params from basenji import seqnn_util @@ -43,6 +44,7 @@ def build(self, job, target_subset=None): def build_from_data_ops(self, job, data_ops, augment_rc=False, augment_shifts=[], + ensemble_rc=False, ensemble_shifts=[], target_subset=None): """Build training ops from input data ops.""" if not self.hparams_set: @@ -52,19 +54,32 @@ def build_from_data_ops(self, job, data_ops, self.inputs = data_ops['sequence'] self.targets_na = data_ops['na'] + # training data_ops w/ stochastic augmentation + data_ops_train, transform_repr_train = augmentation.augment_stochastic( + data_ops, augment_rc, augment_shifts) + + # eval data ops w/ deterministic augmentation + data_ops_eval, transform_repr_eval = augmentation.augment_deterministic( + data_ops, ensemble_rc, ensemble_shifts) + # training conditional self.is_training = tf.placeholder(tf.bool, name='is_training') - # active only via basenji_train_queues.py for TFRecords - if augment_rc or len(augment_shifts) > 0: - # augment data ops - data_ops_aug, _ = tfrecord_batcher.data_augmentation_from_data_ops( - data_ops, augment_rc, augment_shifts) - - # condition on training - data_ops = tf.cond(self.is_training, lambda: data_ops_aug, lambda: data_ops) + # condition on training + data_ops_list = tf.cond(self.is_training, + lambda: [data_ops_train], + lambda: data_ops_eval) + transform_repr_list = tf.cond(self.is_training, + lambda: [transform_repr_train], + lambda: transform_repr_eval) + + # compute representation for every input + def repr_i(i): + return transform_repr_list[i](self.build_representation(data_ops_list[i])) + seqs_repr_list = tf.map_fn(repr_i, tf.range(len(data_ops_list))) + seqs_repr = tf.reduce_mean(seqs_repr_list) + # seqs_repr = self.build_representation(data_ops, target_subset) - seqs_repr = self.build_representation(data_ops, target_subset) self.loss_op, self.loss_adhoc = self.build_loss(seqs_repr, data_ops, target_subset) self.build_optimizer(self.loss_op) diff --git a/basenji/tfrecord_batcher.py b/basenji/tfrecord_batcher.py index e8bf062b..0826d105 100755 --- a/basenji/tfrecord_batcher.py +++ b/basenji/tfrecord_batcher.py @@ -57,7 +57,7 @@ def _shift_left(_seq): return output # TODO(dbelanger) change inputs to be (features, labels) like for Estimator. -def rc_data_augmentation(dataset): +def rc_data_augmentation(dataset, stochastic=False): """Apply reverse complement to seq and flip label/na along the time axis. Args: @@ -71,13 +71,20 @@ def rc_data_augmentation(dataset): """ seq, label, na = [dataset[k] for k in ['sequence', 'label', 'na']] - do_flip = tf.random_uniform(shape=[]) > 0.5 - seq, label, na = tf.cond(do_flip, lambda: ops.reverse_complement_transform(seq, label, na), - lambda: (seq, label, na)) + if stochastic: + do_flip = tf.random_uniform(shape=[]) > 0.5 + seq, label, na = tf.cond(do_flip, lambda: ops.reverse_complement_transform(seq, label, na), + lambda: (seq, label, na)) - def process_predictions_fn(predictions): - return tf.cond(do_flip, lambda: tf.reverse(predictions, axis=[1]), - lambda: predictions) + def process_predictions_fn(predictions): + return tf.cond(do_flip, lambda: tf.reverse(predictions, axis=[1]), + lambda: predictions) + + else: + seq, label, na = ops.reverse_complement_transform(seq, label, na) + + def process_predictions_fn(predictions): + return lambda: tf.reverse(predictions, axis=[1]) transformed_dataset = {'sequence': seq, 'label': label, 'na': na} return transformed_dataset, process_predictions_fn From 391abd40ff2eb00f185ee6b1185bf90ed4878fbb Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 30 Jun 2018 13:11:56 -0700 Subject: [PATCH 35/71] cond refuses transform_fn --- basenji/seqnn.py | 24 +++++++++++++++--------- bin/basenji_train_queues.py | 4 +++- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index e62c9063..8ec54eab 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -57,9 +57,11 @@ def build_from_data_ops(self, job, data_ops, # training data_ops w/ stochastic augmentation data_ops_train, transform_repr_train = augmentation.augment_stochastic( data_ops, augment_rc, augment_shifts) + data_ops_train = [data_ops_train] + transform_repr_train = [transform_repr_train] # eval data ops w/ deterministic augmentation - data_ops_eval, transform_repr_eval = augmentation.augment_deterministic( + data_ops_eval, transform_repr_eval = augmentation.augment_deterministic_set( data_ops, ensemble_rc, ensemble_shifts) # training conditional @@ -67,16 +69,16 @@ def build_from_data_ops(self, job, data_ops, # condition on training data_ops_list = tf.cond(self.is_training, - lambda: [data_ops_train], - lambda: data_ops_eval) + lambda: data_ops_train, + lambda: data_ops_eval, strict=False) transform_repr_list = tf.cond(self.is_training, - lambda: [transform_repr_train], - lambda: transform_repr_eval) + lambda: transform_repr_train, + lambda: transform_repr_eval, strict=False) # compute representation for every input - def repr_i(i): - return transform_repr_list[i](self.build_representation(data_ops_list[i])) - seqs_repr_list = tf.map_fn(repr_i, tf.range(len(data_ops_list))) + map_elems = (data_ops_list, transform_repr_list) + build_rep = lambda me: self.build_representation(me[0], me[1], target_subset) + seqs_repr_list = tf.map_fn(build_rep, map_elems) # back_prop=False seqs_repr = tf.reduce_mean(seqs_repr_list) # seqs_repr = self.build_representation(data_ops, target_subset) @@ -126,7 +128,7 @@ def _make_conv_block_args(self, layer_index): 'name': 'conv-%d' % layer_index } - def build_representation(self, data_ops, target_subset): + def build_representation(self, data_ops, transform_preds_fn=lambda x: x, target_subset=None): """Construct per-location real-valued predictions.""" inputs = data_ops['sequence'] assert inputs is not None @@ -213,6 +215,10 @@ def build_representation(self, data_ops, target_subset): (self.hp.batch_size, -1, self.hp.num_targets, self.hp.target_classes)) + + # transform for reverse complement + final_repr = transform_preds_fn(final_repr) + return final_repr def build_optimizer(self, loss_op): diff --git a/bin/basenji_train_queues.py b/bin/basenji_train_queues.py index 0a29fb1b..b3ca245d 100755 --- a/bin/basenji_train_queues.py +++ b/bin/basenji_train_queues.py @@ -57,7 +57,9 @@ def run(params_file, train_file, test_file, train_epochs, train_epoch_batches, # initialize model model = seqnn.SeqNN() - model.build_from_data_ops(job, data_ops, FLAGS.augment_rc, augment_shifts) + model.build_from_data_ops(job, data_ops, + FLAGS.augment_rc, augment_shifts, + FLAGS.ensemble_rc, ensemble_shifts) # checkpoints saver = tf.train.Saver() From 77bda7f05ac9d06bdafe2a731abf574864ef201f Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 30 Jun 2018 20:14:50 -0700 Subject: [PATCH 36/71] another failed attempt --- basenji/seqnn.py | 24 +++---- basenji/tfrecord_batcher.py | 137 ------------------------------------ 2 files changed, 10 insertions(+), 151 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 8ec54eab..e4c199e4 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -55,13 +55,11 @@ def build_from_data_ops(self, job, data_ops, self.targets_na = data_ops['na'] # training data_ops w/ stochastic augmentation - data_ops_train, transform_repr_train = augmentation.augment_stochastic( + data_ops_train = augmentation.augment_stochastic( data_ops, augment_rc, augment_shifts) - data_ops_train = [data_ops_train] - transform_repr_train = [transform_repr_train] # eval data ops w/ deterministic augmentation - data_ops_eval, transform_repr_eval = augmentation.augment_deterministic_set( + data_ops_eval = augmentation.augment_deterministic_set( data_ops, ensemble_rc, ensemble_shifts) # training conditional @@ -69,16 +67,12 @@ def build_from_data_ops(self, job, data_ops, # condition on training data_ops_list = tf.cond(self.is_training, - lambda: data_ops_train, - lambda: data_ops_eval, strict=False) - transform_repr_list = tf.cond(self.is_training, - lambda: transform_repr_train, - lambda: transform_repr_eval, strict=False) + lambda: [data_ops_train], + lambda: data_ops_eval, strict=True) # compute representation for every input - map_elems = (data_ops_list, transform_repr_list) - build_rep = lambda me: self.build_representation(me[0], me[1], target_subset) - seqs_repr_list = tf.map_fn(build_rep, map_elems) # back_prop=False + build_rep = lambda do: self.build_representation(do, target_subset) + seqs_repr_list = tf.map_fn(build_rep, data_ops_list) # back_prop=False seqs_repr = tf.reduce_mean(seqs_repr_list) # seqs_repr = self.build_representation(data_ops, target_subset) @@ -128,7 +122,7 @@ def _make_conv_block_args(self, layer_index): 'name': 'conv-%d' % layer_index } - def build_representation(self, data_ops, transform_preds_fn=lambda x: x, target_subset=None): + def build_representation(self, data_ops, target_subset=None): """Construct per-location real-valued predictions.""" inputs = data_ops['sequence'] assert inputs is not None @@ -217,7 +211,9 @@ def build_representation(self, data_ops, transform_preds_fn=lambda x: x, target_ # transform for reverse complement - final_repr = transform_preds_fn(final_repr) + final_repr = tf.cond(data_ops['reverse_preds'], + lambda: tf.reverse(final_repr, axis=1), + lambda: final_repr) return final_repr diff --git a/basenji/tfrecord_batcher.py b/basenji/tfrecord_batcher.py index 0826d105..cd22a14b 100755 --- a/basenji/tfrecord_batcher.py +++ b/basenji/tfrecord_batcher.py @@ -28,143 +28,6 @@ # datasets. NUM_FILES_TO_PARALLEL_INTERLEAVE = 10 -def shift_sequence(seq, shift_amount, pad_value): - """Shift a sequence left or right by shift_amount. - Args: - seq: a [batch_size, sequence_length, sequence_depth] sequence to shift - shift_amount: the signed amount to shift (tf.int32 or int) - pad_value: value to fill the padding (primitive or scalar tf.Tensor) - """ - if seq.shape.ndims != 3: - raise ValueError('input sequence should be rank 3') - input_shape = seq.shape - - pad = pad_value * tf.ones_like(seq[:, 0:tf.abs(shift_amount), :]) - - def _shift_right(_seq): - sliced_seq = _seq[:, :-shift_amount:, :] - return tf.concat([pad, sliced_seq], axis=1) - - def _shift_left(_seq): - sliced_seq = _seq[:, -shift_amount:, :] - return tf.concat([sliced_seq, pad], axis=1) - - output = tf.cond( - tf.greater(shift_amount, 0), lambda: _shift_right(seq), - lambda: _shift_left(seq)) - - output.set_shape(input_shape) - return output - -# TODO(dbelanger) change inputs to be (features, labels) like for Estimator. -def rc_data_augmentation(dataset, stochastic=False): - """Apply reverse complement to seq and flip label/na along the time axis. - - Args: - dataset: dict with keys 'sequence,' 'label,' and 'na.' - Returns - transformed_dataset: augmented data - process_predictions_fn: callable to be applied to predictions - such that they are directly comparable to the input dataset['label'] - rather than transformed_dataset['label']. Here, it flips the prediction - along the time axis. - """ - seq, label, na = [dataset[k] for k in ['sequence', 'label', 'na']] - - if stochastic: - do_flip = tf.random_uniform(shape=[]) > 0.5 - seq, label, na = tf.cond(do_flip, lambda: ops.reverse_complement_transform(seq, label, na), - lambda: (seq, label, na)) - - def process_predictions_fn(predictions): - return tf.cond(do_flip, lambda: tf.reverse(predictions, axis=[1]), - lambda: predictions) - - else: - seq, label, na = ops.reverse_complement_transform(seq, label, na) - - def process_predictions_fn(predictions): - return lambda: tf.reverse(predictions, axis=[1]) - - transformed_dataset = {'sequence': seq, 'label': label, 'na': na} - return transformed_dataset, process_predictions_fn - - -def shift_sequence_augmentation(seq, shift_augment_offsets, pad_value): - """Shift seq by a random amount. Pad to maintain the input size. - - Args: - seq: input sequence of size [batch_size, length, depth] - shift_augment_offsets: list of int offsets to sample from. If `None` or - `[]`, then only "shift" by 0 (the identity). - pad_value: value to fill the padding with. - Returns: - shifted and padded sequence of size [batch_size, length, depth] - """ - # The value of the parameter shift_augment_offsets are the set of things to - # _augment_ the original data with, and we want to, in addition to including - # those augmentations, actually include the original data. - total_set_of_shifts = [] - if shift_augment_offsets: - total_set_of_shifts += shift_augment_offsets - if 0 not in total_set_of_shifts: - total_set_of_shifts.append(0) - - shift_index = tf.random_uniform( - shape=[], minval=0, maxval=len(total_set_of_shifts), dtype=tf.int64) - shift_value = tf.gather(tf.constant(total_set_of_shifts), shift_index) - - seq = tf.cond( - tf.not_equal(shift_value, 0), - lambda: shift_sequence(seq, shift_value, pad_value), lambda: seq) - - return seq - - -def apply_data_augmentation(input_ops, label_ops, augment_with_complement, - shift_augment_offsets): - """Apply data augmentation to input and label ops. - Args: - input_ops: dict containing input Tensors. - label_ops: dict containing label Tensors. - augment_with_complement: whether to do reverse complement augmentation. - shift_augment_offsets: offsets used for doing shift-based augmentation. - Can be `None` or `[]` to indicate no shift-augmentation. - - Returns: - transformed_inputs: inputs with augmentation applied. - transformed_labels: labels transformed in accordance with the augmentation. - process_predictions_fn: callable to be applied to predictions - such that they are directly comparable to the label_ops - rather than transformed_labels. - """ - data_ops = {} - data_ops.update(input_ops) - data_ops.update(label_ops) - - augmented_data_ops, process_predictions_fn = data_augmentation_from_data_ops( - data_ops, augment_with_complement, shift_augment_offsets) - return ({ - 'sequence': augmented_data_ops['sequence'] - }, {name: augmented_data_ops[name] - for name in ['label', 'na']}, process_predictions_fn) - -# TODO(dbelanger) switch to directly calling apply_data_augmentation -def data_augmentation_from_data_ops(data_ops, augment_with_complement, - shift_augment_offsets): - process_predictions_fn = None - - if shift_augment_offsets and len(shift_augment_offsets) > 1: - pad_value = 0.25 - data_ops['sequence'] = shift_sequence_augmentation( - data_ops['sequence'], shift_augment_offsets, pad_value) - - if augment_with_complement: - data_ops, process_predictions_fn = rc_data_augmentation(data_ops) - - return data_ops, process_predictions_fn - - def tfrecord_dataset(tfr_data_files_pattern, batch_size, seq_length, From 57ee2c3453a1a8993cace3d8dbac84b0fddfcfa7 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 30 Jun 2018 20:39:35 -0700 Subject: [PATCH 37/71] moving cond downstream but map_fn still fails --- basenji/seqnn.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index e4c199e4..a1612bb4 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -54,6 +54,9 @@ def build_from_data_ops(self, job, data_ops, self.inputs = data_ops['sequence'] self.targets_na = data_ops['na'] + # training conditional + self.is_training = tf.placeholder(tf.bool, name='is_training') + # training data_ops w/ stochastic augmentation data_ops_train = augmentation.augment_stochastic( data_ops, augment_rc, augment_shifts) @@ -62,19 +65,17 @@ def build_from_data_ops(self, job, data_ops, data_ops_eval = augmentation.augment_deterministic_set( data_ops, ensemble_rc, ensemble_shifts) - # training conditional - self.is_training = tf.placeholder(tf.bool, name='is_training') + # compute train representation + seqs_repr_train = self.build_representation(data_ops_train, target_subset) - # condition on training - data_ops_list = tf.cond(self.is_training, - lambda: [data_ops_train], - lambda: data_ops_eval, strict=True) + pdb.set_trace() - # compute representation for every input + # compute eval representation build_rep = lambda do: self.build_representation(do, target_subset) - seqs_repr_list = tf.map_fn(build_rep, data_ops_list) # back_prop=False - seqs_repr = tf.reduce_mean(seqs_repr_list) - # seqs_repr = self.build_representation(data_ops, target_subset) + seqs_repr_list = tf.map_fn(build_rep, data_ops_eval, dtype=seqs_repr_train.dtype) # back_prop=False + seqs_repr_eval = tf.reduce_mean(seqs_repr_list) + + seqs_repr = tf.cond(self.is_training, lambda: seqs_repr_train, lambda: seqs_repr_eval) self.loss_op, self.loss_adhoc = self.build_loss(seqs_repr, data_ops, target_subset) self.build_optimizer(self.loss_op) @@ -209,10 +210,9 @@ def build_representation(self, data_ops, target_subset=None): (self.hp.batch_size, -1, self.hp.num_targets, self.hp.target_classes)) - # transform for reverse complement final_repr = tf.cond(data_ops['reverse_preds'], - lambda: tf.reverse(final_repr, axis=1), + lambda: tf.reverse(final_repr, axis=[1]), lambda: final_repr) return final_repr From d917be8199f849958edbec31e8d1e63c98201f5e Mon Sep 17 00:00:00 2001 From: David Kelley Date: Tue, 3 Jul 2018 13:18:04 -0700 Subject: [PATCH 38/71] separate train and eval loss --- basenji/seqnn.py | 147 ++++++++++++++++++++++-------------------- basenji/seqnn_util.py | 18 +++--- 2 files changed, 87 insertions(+), 78 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index a1612bb4..3d604712 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -57,28 +57,44 @@ def build_from_data_ops(self, job, data_ops, # training conditional self.is_training = tf.placeholder(tf.bool, name='is_training') + ################################################## + # training + # training data_ops w/ stochastic augmentation data_ops_train = augmentation.augment_stochastic( data_ops, augment_rc, augment_shifts) + # compute train representation + seqs_repr_train = self.build_representation(data_ops_train['sequence'], + None, target_subset) + + # training losses + loss_returns = self.build_loss(seqs_repr_train, data_ops_train, target_subset) + self.loss_train, self.loss_train_targets, self.preds_train, self.targets_train = loss_returns + + # optimizer + self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + self.build_optimizer(self.loss_train) + + ################################################## + # eval + # eval data ops w/ deterministic augmentation data_ops_eval = augmentation.augment_deterministic_set( data_ops, ensemble_rc, ensemble_shifts) - - # compute train representation - seqs_repr_train = self.build_representation(data_ops_train, target_subset) - - pdb.set_trace() + data_seq_eval = tf.stack([do['sequence'] for do in data_ops_eval]) + data_rev_eval = tf.stack([do['reverse_preds'] for do in data_ops_eval]) # compute eval representation - build_rep = lambda do: self.build_representation(do, target_subset) - seqs_repr_list = tf.map_fn(build_rep, data_ops_eval, dtype=seqs_repr_train.dtype) # back_prop=False - seqs_repr_eval = tf.reduce_mean(seqs_repr_list) + map_elems_eval = (data_seq_eval, data_rev_eval) + build_rep = lambda do: self.build_representation(do[0], do[1], target_subset) + seqs_repr_list = tf.map_fn(build_rep, map_elems_eval, dtype=seqs_repr_train.dtype) # back_prop=False + seqs_repr_eval = tf.reduce_mean(seqs_repr_list, axis=0) - seqs_repr = tf.cond(self.is_training, lambda: seqs_repr_train, lambda: seqs_repr_eval) + # eval loss + loss_returns = self.build_loss(seqs_repr_eval, data_ops, target_subset) + self.loss_eval, self.loss_eval_targets, self.preds_eval, self.targets_eval = loss_returns - self.loss_op, self.loss_adhoc = self.build_loss(seqs_repr, data_ops, target_subset) - self.build_optimizer(self.loss_op) def make_placeholders(self): """Allocates placeholders to be used in place of input data ops.""" @@ -109,7 +125,7 @@ def make_placeholders(self): } return data - def _make_conv_block_args(self, layer_index): + def _make_conv_block_args(self, layer_index, layer_reprs): """Packages arguments to be used by layers.conv_block.""" return { 'conv_params': self.hp.cnn_params[layer_index], @@ -119,37 +135,39 @@ def _make_conv_block_args(self, layer_index): 'batch_renorm': self.hp.batch_renorm, 'batch_renorm_momentum': self.hp.batch_renorm_momentum, 'l2_scale': self.hp.cnn_l2_scale, - 'layer_reprs': self.layer_reprs, + 'layer_reprs': layer_reprs, 'name': 'conv-%d' % layer_index } - def build_representation(self, data_ops, target_subset=None): + def build_representation(self, inputs, reverse_preds=None, target_subset=None): """Construct per-location real-valued predictions.""" - inputs = data_ops['sequence'] assert inputs is not None - print('Targets pooled by %d to length %d' % (self.hp.target_pool, self.hp.seq_length // self.hp.target_pool)) ################################################### # convolution layers ################################################### - self.filter_weights = [] - self.layer_reprs = [inputs] + filter_weights = [] + layer_reprs = [inputs] seqs_repr = inputs for layer_index in range(self.hp.cnn_layers): - with tf.variable_scope('cnn%d' % layer_index): + with tf.variable_scope('cnn%d' % layer_index, reuse=tf.AUTO_REUSE): # convolution block - args_for_block = self._make_conv_block_args(layer_index) + args_for_block = self._make_conv_block_args(layer_index, layer_reprs) seqs_repr = layers.conv_block(seqs_repr=seqs_repr, **args_for_block) # save representation - self.layer_reprs.append(seqs_repr) + layer_reprs.append(seqs_repr) # final nonlinearity seqs_repr = tf.nn.relu(seqs_repr) + ################################################### + # slice out side buffer + ################################################### + # update batch buffer to reflect pooling seq_length = seqs_repr.shape[1].value pool_preds = self.hp.seq_length // seq_length @@ -158,26 +176,18 @@ def build_representation(self, data_ops, target_subset=None): ' by the CNN pooling %d') % (self.hp.batch_buffer, pool_preds) batch_buffer_pool = self.hp.batch_buffer // pool_preds - - ################################################### - # slice out side buffer - ################################################### - - # predictions + # slice out buffer seq_length = seqs_repr.shape[1] seqs_repr = seqs_repr[:, batch_buffer_pool: seq_length - batch_buffer_pool, :] - seq_length = seqs_repr.shape[1].value - self.preds_length = seq_length # save penultimate representation - self.penultimate_op = seqs_repr - + # self.penultimate_op = seqs_repr ################################################### # final layer ################################################### - with tf.variable_scope('final'): + with tf.variable_scope('final', reuse=tf.AUTO_REUSE): final_filters = self.hp.num_targets * self.hp.target_classes final_repr = tf.layers.dense( inputs=seqs_repr, @@ -211,9 +221,10 @@ def build_representation(self, data_ops, target_subset=None): self.hp.target_classes)) # transform for reverse complement - final_repr = tf.cond(data_ops['reverse_preds'], - lambda: tf.reverse(final_repr, axis=[1]), - lambda: final_repr) + if reverse_preds is not None: + final_repr = tf.cond(reverse_preds, + lambda: tf.reverse(final_repr, axis=[1]), + lambda: final_repr) return final_repr @@ -266,8 +277,6 @@ def build_optimizer(self, loss_op): self.step_op = self.opt.apply_gradients( self.gvs, global_step=self.global_step) - self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - # summary self.merged_summary = tf.summary.merge_all() @@ -278,7 +287,6 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): # targets tstart = self.hp.batch_buffer // self.hp.target_pool tend = (self.hp.seq_length - self.hp.batch_buffer) // self.hp.target_pool - self.target_length = tend - tstart targets = data_ops['label'] targets = tf.identity(targets[:, tstart:tend, :], name='targets_op') @@ -287,8 +295,8 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): targets = tf.gather(targets, target_subset, axis=2) # work-around for specifying my own predictions - self.preds_adhoc = tf.placeholder( - tf.float32, shape=seqs_repr.shape, name='preds-adhoc') + # self.preds_adhoc = tf.placeholder( + # tf.float32, shape=seqs_repr.shape, name='preds-adhoc') # float 32 exponential clip max # exp_max = np.floor(np.log(0.5*tf.float32.max)) @@ -296,17 +304,17 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): # choose link if self.hp.link in ['identity', 'linear']: - self.preds_op = tf.identity(seqs_repr, name='preds') + preds_op = tf.identity(seqs_repr, name='preds') elif self.hp.link == 'relu': - self.preds_op = tf.relu(seqs_repr, name='preds') + preds_op = tf.relu(seqs_repr, name='preds') elif self.hp.link == 'exp': seqs_repr_clip = tf.clip_by_value(seqs_repr, -exp_max, exp_max) - self.preds_op = tf.exp(seqs_repr_clip, name='preds') + preds_op = tf.exp(seqs_repr_clip, name='preds') elif self.hp.link == 'exp_linear': - self.preds_op = tf.where( + preds_op = tf.where( seqs_repr > 0, seqs_repr + 1, tf.exp(tf.clip_by_value(seqs_repr, -exp_max, exp_max)), @@ -314,7 +322,7 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): elif self.hp.link == 'softplus': seqs_repr_clip = tf.clip_by_value(seqs_repr, -exp_max, 10000) - self.preds_op = tf.nn.softplus(seqs_repr_clip, name='preds') + preds_op = tf.nn.softplus(seqs_repr_clip, name='preds') elif self.hp.link == 'softmax': # performed in the loss function, but saving probabilities @@ -326,33 +334,33 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): # clip if self.hp.target_clip is not None: - self.preds_op = tf.clip_by_value(self.preds_op, 0, self.hp.target_clip) + preds_op = tf.clip_by_value(preds_op, 0, self.hp.target_clip) targets = tf.clip_by_value(targets, 0, self.hp.target_clip) # sqrt if self.hp.target_sqrt: - self.preds_op = tf.sqrt(self.preds_op) + preds_op = tf.sqrt(preds_op) targets = tf.sqrt(targets) loss_op = None - loss_adhoc = None + # loss_adhoc = None # choose loss if self.hp.loss == 'gaussian': - loss_op = tf.squared_difference(self.preds_op, targets) - loss_adhoc = tf.squared_difference(self.preds_adhoc, targets) + loss_op = tf.squared_difference(preds_op, targets) + # loss_adhoc = tf.squared_difference(self.preds_adhoc, targets) elif self.hp.loss == 'poisson': loss_op = tf.nn.log_poisson_loss( - targets, tf.log(self.preds_op), compute_full_loss=True) - loss_adhoc = tf.nn.log_poisson_loss( - targets, tf.log(self.preds_adhoc), compute_full_loss=True) + targets, tf.log(preds_op), compute_full_loss=True) + # loss_adhoc = tf.nn.log_poisson_loss( + # targets, tf.log(self.preds_adhoc), compute_full_loss=True) elif self.hp.loss == 'cross_entropy': loss_op = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=(targets - 1), logits=self.preds_op) - loss_adhoc = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=(targets - 1), logits=self.preds_adhoc) + labels=(targets - 1), logits=preds_op) + # loss_adhoc = tf.nn.sparse_softmax_cross_entropy_with_logits( + # labels=(targets - 1), logits=self.preds_adhoc) else: raise ValueError('Cannot identify loss function %s' % self.hp.loss) @@ -361,29 +369,29 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): loss_op = tf.reduce_mean(loss_op, axis=[0, 1], name='target_loss') loss_op = tf.check_numerics(loss_op, 'Invalid loss', name='loss_check') - loss_adhoc = tf.reduce_mean( - loss_adhoc, axis=[0, 1], name='target_loss_adhoc') + # loss_adhoc = tf.reduce_mean( + # loss_adhoc, axis=[0, 1], name='target_loss_adhoc') tf.summary.histogram('target_loss', loss_op) for ti in np.linspace(0, self.hp.num_targets - 1, 10).astype('int'): tf.summary.scalar('loss_t%d' % ti, loss_op[ti]) - self.target_losses = loss_op - self.target_losses_adhoc = loss_adhoc + target_losses = loss_op + # self.target_losses_adhoc = loss_adhoc # fully reduce loss_op = tf.reduce_mean(loss_op, name='loss') - loss_adhoc = tf.reduce_mean(loss_adhoc, name='loss_adhoc') + # loss_adhoc = tf.reduce_mean(loss_adhoc, name='loss_adhoc') # add regularization terms reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) reg_sum = tf.reduce_sum(reg_losses) tf.summary.scalar('regularizers', reg_sum) loss_op += reg_sum - loss_adhoc += reg_sum + # loss_adhoc += reg_sum # track tf.summary.scalar('loss', loss_op) - self.targets_op = targets - return loss_op, loss_adhoc + + return loss_op, target_losses, preds_op, targets def set_mode(self, mode): @@ -438,12 +446,12 @@ def train_epoch(self, fd[self.targets_na] = NAb if no_steps: - run_returns = sess.run([self.merged_summary, self.loss_op] + \ + run_returns = sess.run([self.merged_summary, self.loss_train] + \ self.update_ops, feed_dict=fd) summary, loss_batch = run_returns[:2] else: run_returns = sess.run( - [self.merged_summary, self.loss_op, self.global_step, self.step_op] + self.update_ops, + [self.merged_summary, self.loss_train, self.global_step, self.step_op] + self.update_ops, feed_dict=fd) summary, loss_batch, global_step = run_returns[:3] @@ -482,10 +490,11 @@ def train_epoch_from_data_ops(self, data_available = True batch_num = 0 while data_available and (epoch_batches is None or batch_num < epoch_batches): + print(batch_num) try: - run_returns = sess.run( - [self.merged_summary, self.loss_op, self.global_step, self.step_op] + self.update_ops, - feed_dict=fd) + # update_ops won't run + run_ops = [self.merged_summary, self.loss_train, self.global_step, self.step_op] + self.update_ops + run_returns = sess.run(run_ops, feed_dict=fd) summary, loss_batch, global_step = run_returns[:3] # add summary diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 3d2c8110..76c06b73 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -27,7 +27,7 @@ def build_grads(self, layers=[0]): self.grad_ops = [] for ti in range(self.hp.num_targets): - grad_ti_op = tf.gradients(self.preds_op[:,:,ti], [self.layer_reprs[li] for li in self.grad_layers]) + grad_ti_op = tf.gradients(self.preds_eval[:,:,ti], [self.layer_reprs[li] for li in self.grad_layers]) self.grad_ops.append(grad_ti_op) @@ -60,7 +60,7 @@ def build_grads_genes(self, gene_seqs, layers=[0]): if pi in tss_pos: # build position-specific, target-specific gradient ops for ti in range(self.hp.num_targets): - grad_piti_op = tf.gradients(self.preds_op[:,pi,ti], + grad_piti_op = tf.gradients(self.preds_eval[:,pi,ti], [self.layer_reprs[li] for li in self.grad_layers]) self.grad_pos_ops[-1].append(grad_piti_op) @@ -312,7 +312,7 @@ def _gradients_ensemble(self, sess, fd, Xb, ensemble_fwdrc, ensemble_shifts, mc_ # prediction # predict - preds_ei, layer_reprs_ei = sess.run([self.preds_op, self.layer_reprs], feed_dict=fd) + preds_ei, layer_reprs_ei = sess.run([self.preds_eval, self.layer_reprs], feed_dict=fd) # reverse if ensemble_fwdrc[ei] is False: @@ -451,7 +451,7 @@ def gradients_genes(self, sess, batcher, gene_seqs): fd[self.inputs] = Xb # predict - reprs_batch, _ = sess.run([self.layer_reprs, self.preds_op], feed_dict=fd) + reprs_batch, _ = sess.run([self.layer_reprs, self.preds_eval], feed_dict=fd) # save representations for lii in range(len(self.grad_layers)): @@ -511,7 +511,7 @@ def hidden(self, sess, batcher, layers=None): # compute predictions layer_reprs_batch, preds_batch = sess.run( - [self.layer_reprs, self.preds_op], feed_dict=fd) + [self.layer_reprs, self.preds_eval], feed_dict=fd) # accumulate representationsmakes the number of members for self smaller and also for li in layers: @@ -599,7 +599,7 @@ def _predict_ensemble(self, if penultimate: preds_ei = sess.run(self.penultimate_op, feed_dict=fd) else: - preds_ei = sess.run(self.preds_op, feed_dict=fd) + preds_ei = sess.run(self.preds_eval, feed_dict=fd) # reverse if ensemble_fwdrc[ei] is False: @@ -866,8 +866,8 @@ def test_from_data_ops(self, sess, test_batches=None): while data_available and (test_batches is None or batch_num < test_batches): try: # make non-ensembled predictions - run_ops = [self.targets_op, self.preds_op, - self.loss_op, self.target_losses] + run_ops = [self.targets_eval, self.preds_eval, + self.loss_eval, self.loss_eval_targets] run_returns = sess.run(run_ops, feed_dict=fd) targets_batch, preds_batch, loss_batch, target_losses_batch = run_returns @@ -974,7 +974,7 @@ def test(self, # recompute loss w/ ensembled prediction fd[self.preds_adhoc] = preds_batch targets_batch, loss_batch, target_losses_batch = sess.run( - [self.targets_op, self.loss_adhoc, self.target_losses_adhoc], + [self.targets_train, self.loss_adhoc, self.target_losses_adhoc], feed_dict=fd) # accumulate predictions and targets From 17fd1ce64ed2beb354c19bfcccc203030de85301 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Tue, 3 Jul 2018 18:33:27 -0700 Subject: [PATCH 39/71] needs preds_length --- basenji/seqnn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 3d604712..e33c9372 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -95,6 +95,9 @@ def build_from_data_ops(self, job, data_ops, loss_returns = self.build_loss(seqs_repr_eval, data_ops, target_subset) self.loss_eval, self.loss_eval_targets, self.preds_eval, self.targets_eval = loss_returns + # helper variables + self.preds_length = self.preds_train.shape[1] + def make_placeholders(self): """Allocates placeholders to be used in place of input data ops.""" @@ -490,7 +493,6 @@ def train_epoch_from_data_ops(self, data_available = True batch_num = 0 while data_available and (epoch_batches is None or batch_num < epoch_batches): - print(batch_num) try: # update_ops won't run run_ops = [self.merged_summary, self.loss_train, self.global_step, self.step_op] + self.update_ops From 96c01d6bab7aa2a18c8cf34c0079022731dc1943 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 7 Jul 2018 10:46:15 -0700 Subject: [PATCH 40/71] augmentation methods --- basenji/augmentation.py | 164 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 basenji/augmentation.py diff --git a/basenji/augmentation.py b/basenji/augmentation.py new file mode 100644 index 00000000..340c313e --- /dev/null +++ b/basenji/augmentation.py @@ -0,0 +1,164 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +import pdb +import tensorflow as tf + +from basenji import ops + +def shift_sequence(seq, shift_amount, pad_value=0.25): + """Shift a sequence left or right by shift_amount. + + Args: + seq: a [batch_size, sequence_length, sequence_depth] sequence to shift + shift_amount: the signed amount to shift (tf.int32 or int) + pad_value: value to fill the padding (primitive or scalar tf.Tensor) + """ + if seq.shape.ndims != 3: + raise ValueError('input sequence should be rank 3') + input_shape = seq.shape + + pad = pad_value * tf.ones_like(seq[:, 0:tf.abs(shift_amount), :]) + + def _shift_right(_seq): + sliced_seq = _seq[:, :-shift_amount:, :] + return tf.concat([pad, sliced_seq], axis=1) + + def _shift_left(_seq): + sliced_seq = _seq[:, -shift_amount:, :] + return tf.concat([sliced_seq, pad], axis=1) + + output = tf.cond( + tf.greater(shift_amount, 0), lambda: _shift_right(seq), + lambda: _shift_left(seq)) + + output.set_shape(input_shape) + return output + +def augment_deterministic_set(data_ops, augment_rc=False, augment_shifts=[0]): + """ + + Args: + data_ops: dict with keys 'sequence,' 'label,' and 'na.' + augment_rc: Boolean + augment_shifts: List of ints. + Returns + data_ops_list: + """ + augment_pairs = [] + for ashift in augment_shifts: + augment_pairs.append((False, ashift)) + if augment_rc: + augment_pairs.append((True, ashift)) + + data_ops_list = [] + for arc, ashift in augment_pairs: + data_ops_aug = augment_deterministic(data_ops, arc, ashift) + data_ops_list.append(data_ops_aug) + + return data_ops_list + + +def augment_deterministic(data_ops, augment_rc=False, augment_shift=0): + """Apply a deterministic augmentation, specified by the parameters. + + Args: + data_ops: dict with keys 'sequence,' 'label,' and 'na.' + augment_rc: Boolean + augment_shifts: Int + Returns + data_ops: augmented data + """ + if augment_shift != 0: + shift_amount = tf.constant(augment_shift, shape=(), dtype=tf.int64) + data_ops['sequence'] = shift_sequence(data_ops['sequence'], shift_amount) + + if augment_rc: + data_ops = augment_deterministic_rc(data_ops) + else: + data_ops['reverse_preds'] = tf.zeros((), dtype=tf.bool) + + return data_ops + + +def augment_deterministic_rc(data_ops): + """Apply a deterministic reverse complement augmentation. + + Args: + data_ops: dict with keys 'sequence,' 'label,' and 'na.' + Returns + data_ops_aug: augmented data ops + """ + seq, label, na = [data_ops[k] for k in ['sequence', 'label', 'na']] + seq, label, na = ops.reverse_complement_transform(seq, label, na) + reverse_preds = tf.ones((), dtype=tf.bool) + data_ops_aug = {'sequence': seq, 'label': label, 'na': na, 'reverse_preds':reverse_preds} + return data_ops_aug + + +def augment_stochastic_rc(data_ops): + """Apply a stochastic reverse complement augmentation. + + Args: + data_ops: dict with keys 'sequence,' 'label,' and 'na.' + Returns + data_ops_aug: augmented data + """ + seq, label, na = [data_ops[k] for k in ['sequence', 'label', 'na']] + reverse_preds = tf.random_uniform(shape=[]) > 0.5 + seq, label, na = tf.cond(reverse_preds, lambda: ops.reverse_complement_transform(seq, label, na), + lambda: (seq, label, na)) + data_ops_aug = {'sequence': seq, 'label': label, 'na': na, 'reverse_preds':reverse_preds} + return data_ops_aug + + +def augment_stochastic_shifts(seq, augment_shifts): + """Apply a stochastic shift augmentation. + + Args: + seq: input sequence of size [batch_size, length, depth] + augment_shifts: list of int offsets to sample from + Returns: + shifted and padded sequence of size [batch_size, length, depth] + """ + shift_index = tf.random_uniform(shape=[], minval=0, + maxval=len(augment_shifts), dtype=tf.int64) + shift_value = tf.gather(tf.constant(augment_shifts), shift_index) + + seq = tf.cond(tf.not_equal(shift_value, 0), + lambda: shift_sequence(seq, shift_value), + lambda: seq) + + return seq + + +def augment_stochastic(data_ops, augment_rc=False, augment_shifts=[]): + """Apply stochastic augmentations, + + Args: + data_ops: dict with keys 'sequence,' 'label,' and 'na.' + augment_rc: Boolean for whether to apply reverse complement augmentation. + augment_shifts: list of int offsets to sample shift augmentations. + Returns: + data_ops_aug: augmented data + """ + if augment_shifts: + data_ops['sequence'] = augment_stochastic_shifts(data_ops['sequence'], + augment_shifts) + + if augment_rc: + data_ops = augment_stochastic_rc(data_ops) + else: + data_ops['reverse_preds'] = tf.zeros((), dtype=tf.bool) + + return data_ops From 1770abdb5ffaa07350e87da9a9ea996fe7f3ad16 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 7 Jul 2018 10:48:57 -0700 Subject: [PATCH 41/71] h5 in-graph augmentation --- basenji/seqnn.py | 97 ++++++++++++---- bin/basenji_train_h5.py | 216 ++++++++++++++++++++++++++++++++++++ bin/basenji_train_queues.py | 2 +- 3 files changed, 291 insertions(+), 24 deletions(-) create mode 100755 bin/basenji_train_h5.py diff --git a/basenji/seqnn.py b/basenji/seqnn.py index e33c9372..a6f8471f 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -33,14 +33,20 @@ def __init__(self): self.global_step = tf.train.get_or_create_global_step() self.hparams_set = False - def build(self, job, target_subset=None): + def build(self, job, augment_rc=False, augment_shifts=[], + ensemble_rc=False, ensemble_shifts=[], target_subset=None): """Build training ops that depend on placeholders.""" self.hp = params.make_hparams(job) self.hparams_set = True data_ops = self.make_placeholders() - self.build_from_data_ops(job, data_ops, target_subset=target_subset) + self.build_from_data_ops(job, data_ops, + augment_rc=augment_rc, + augment_shifts=augment_shifts, + ensemble_rc=ensemble_rc, + ensemble_shifts=ensemble_shifts, + target_subset=target_subset) def build_from_data_ops(self, job, data_ops, augment_rc=False, augment_shifts=[], @@ -50,9 +56,6 @@ def build_from_data_ops(self, job, data_ops, if not self.hparams_set: self.hp = params.make_hparams(job) self.hparams_set = True - self.targets = data_ops['label'] - self.inputs = data_ops['sequence'] - self.targets_na = data_ops['na'] # training conditional self.is_training = tf.placeholder(tf.bool, name='is_training') @@ -102,29 +105,26 @@ def build_from_data_ops(self, job, data_ops, def make_placeholders(self): """Allocates placeholders to be used in place of input data ops.""" # batches - self.inputs = tf.placeholder( + self.inputs_ph = tf.placeholder( tf.float32, shape=(self.hp.batch_size, self.hp.seq_length, self.hp.seq_depth), name='inputs') if self.hp.target_classes == 1: - self.targets = tf.placeholder( + self.targets_ph = tf.placeholder( tf.float32, shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool, self.hp.num_targets), name='targets') else: - self.targets = tf.placeholder( + self.targets_ph = tf.placeholder( tf.int32, shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool, self.hp.num_targets), name='targets') - self.targets_na = tf.placeholder( - tf.bool, shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool)) data = { - 'sequence': self.inputs, - 'label': self.targets, - 'na': self.targets_na + 'sequence': self.inputs_ph, + 'label': self.targets_ph } return data @@ -419,7 +419,7 @@ def set_mode(self, mode): return fd - def train_epoch(self, + def train_epoch_h5_manual(self, sess, batcher, fwdrc=True, @@ -427,7 +427,8 @@ def train_epoch(self, sum_writer=None, epoch_batches=None, no_steps=False): - """Execute one training epoch.""" + """Execute one training epoch, using HDF5 data + and manual augmentation.""" # initialize training loss train_loss = [] @@ -444,9 +445,8 @@ def train_epoch(self, epoch_batches is None or batch_num < epoch_batches): # update feed dict - fd[self.inputs] = Xb - fd[self.targets] = Yb - fd[self.targets_na] = NAb + fd[self.inputs_ph] = Xb + fd[self.targets_ph] = Yb if no_steps: run_returns = sess.run([self.merged_summary, self.loss_train] + \ @@ -477,11 +477,62 @@ def train_epoch(self, return np.mean(train_loss), global_step - def train_epoch_from_data_ops(self, - sess, - sum_writer=None, - epoch_batches=None): - """ Execute one training epoch """ + def train_epoch_h5(self, + sess, + batcher, + sum_writer=None, + epoch_batches=None, + no_steps=False): + """Execute one training epoch using HDF5 data, + and compute-graph augmentation""" + + # initialize training loss + train_loss = [] + global_step = 0 + + # setup feed dict + fd = self.set_mode('train') + + # get first batch + Xb, Yb, NAb, Nb = batcher.next() + + batch_num = 0 + while Xb is not None and Nb == self.hp.batch_size and ( + epoch_batches is None or batch_num < epoch_batches): + + # update feed dict + fd[self.inputs_ph] = Xb + fd[self.targets_ph] = Yb + + if no_steps: + run_returns = sess.run([self.merged_summary, self.loss_train] + \ + self.update_ops, feed_dict=fd) + summary, loss_batch = run_returns[:2] + else: + run_ops = [self.merged_summary, self.loss_train, self.global_step, self.step_op] + run_ops += self.update_ops + summary, loss_batch, global_step = sess.run(run_ops, feed_dict=fd) + + # add summary + if sum_writer is not None: + sum_writer.add_summary(summary, global_step) + + # accumulate loss + train_loss.append(loss_batch) + + # next batch + Xb, Yb, NAb, Nb = batcher.next(fwdrc, shift) + batch_num += 1 + + # reset training batcher if epoch considered all of the data + if epoch_batches is None: + batcher.reset() + + return np.mean(train_loss), global_step + + + def train_epoch_tfr(self, sess, sum_writer=None, epoch_batches=None): + """ Execute one training epoch, using TFRecords data. """ # initialize training loss train_loss = [] diff --git a/bin/basenji_train_h5.py b/bin/basenji_train_h5.py new file mode 100755 index 00000000..6e09f9e3 --- /dev/null +++ b/bin/basenji_train_h5.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function + +import sys +import time + +import h5py +import numpy as np +import tensorflow as tf + +from basenji import batcher +from basenji import params +from basenji import seqnn +from basenji import shared_flags + +FLAGS = tf.app.flags.FLAGS + +################################################################################ +# main +################################################################################ +def main(_): + np.random.seed(FLAGS.seed) + + run(params_file=FLAGS.params, + data_file=FLAGS.data, + train_epochs=FLAGS.train_epochs, + train_epoch_batches=FLAGS.train_epoch_batches, + test_epoch_batches=FLAGS.test_epoch_batches) + + +def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_batches): + + ####################################################### + # load data + ####################################################### + data_open = h5py.File(data_file) + + train_seqs = data_open['train_in'] + train_targets = data_open['train_out'] + train_na = None + if 'train_na' in data_open: + train_na = data_open['train_na'] + + valid_seqs = data_open['valid_in'] + valid_targets = data_open['valid_out'] + valid_na = None + if 'valid_na' in data_open: + valid_na = data_open['valid_na'] + + ####################################################### + # model parameters and placeholders + ####################################################### + job = params.read_job_params(params_file) + + job['seq_length'] = train_seqs.shape[1] + job['seq_depth'] = train_seqs.shape[2] + job['num_targets'] = train_targets.shape[2] + job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) + + augment_shifts = [int(shift) for shift in FLAGS.augment_shifts.split(',')] + ensemble_shifts = [int(shift) for shift in FLAGS.ensemble_shifts.split(',')] + + t0 = time.time() + model = seqnn.SeqNN() + model.build(job, augment_rc=FLAGS.augment_rc, augment_shifts=augment_shifts, + ensemble_rc=FLAGS.ensemble_rc, ensemble_shifts=ensemble_shifts) + print('Model building time %f' % (time.time() - t0)) + + # adjust for fourier + job['fourier'] = 'train_out_imag' in data_open + if job['fourier']: + train_targets_imag = data_open['train_out_imag'] + valid_targets_imag = data_open['valid_out_imag'] + + ####################################################### + # prepare batcher + ####################################################### + if job['fourier']: + batcher_train = batcher.BatcherF( + train_seqs, + train_targets, + train_targets_imag, + train_na, + model.hp.batch_size, + model.hp.target_pool, + shuffle=True) + batcher_valid = batcher.BatcherF(valid_seqs, valid_targets, + valid_targets_imag, valid_na, + model.batch_size, model.target_pool) + else: + batcher_train = batcher.Batcher( + train_seqs, + train_targets, + train_na, + model.hp.batch_size, + model.hp.target_pool, + shuffle=True) + batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na, + model.hp.batch_size, model.hp.target_pool) + print('Batcher initialized') + + ####################################################### + # train + ####################################################### + + # checkpoints + saver = tf.train.Saver() + + config = tf.ConfigProto() + if FLAGS.log_device_placement: + config.log_device_placement = True + with tf.Session(config=config) as sess: + t0 = time.time() + + # set seed + tf.set_random_seed(FLAGS.seed) + + if FLAGS.logdir: + train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train', sess.graph) + else: + train_writer = None + + if FLAGS.restart: + # load variables into session + saver.restore(sess, FLAGS.restart) + else: + # initialize variables + print('Initializing...') + sess.run(tf.global_variables_initializer()) + print('Initialization time %f' % (time.time() - t0)) + + train_loss = None + best_loss = None + early_stop_i = 0 + + epoch = 0 + while (train_epochs is None or epoch < train_epochs) and early_stop_i < FLAGS.early_stop: + t0 = time.time() + + # alternate forward and reverse batches + fwdrc = True + if FLAGS.augment_rc and epoch % 2 == 1: + fwdrc = False + + # cycle shifts + shift_i = epoch % len(augment_shifts) + + # train + train_loss, steps = model.train_epoch(sess, batcher_train, fwdrc=fwdrc, + shift=augment_shifts[shift_i], + sum_writer=train_writer, + epoch_batches=train_epoch_batches, + no_steps=FLAGS.no_steps) + + # validate + valid_acc = model.test(sess, batcher_valid, mc_n=FLAGS.ensemble_mc, + rc=FLAGS.ensemble_rc, shifts=ensemble_shifts, + test_batches=test_epoch_batches) + valid_loss = valid_acc.loss + valid_r2 = valid_acc.r2().mean() + del valid_acc + + best_str = '' + if best_loss is None or valid_loss < best_loss: + best_loss = valid_loss + best_str = ', best!' + early_stop_i = 0 + saver.save(sess, '%s/model_best.tf' % FLAGS.logdir) + else: + early_stop_i += 1 + + # measure time + et = time.time() - t0 + if et < 600: + time_str = '%3ds' % et + elif et < 6000: + time_str = '%3dm' % (et / 60) + else: + time_str = '%3.1fh' % (et / 3600) + + # print update + print( + 'Epoch: %3d, Steps: %7d, Train loss: %7.5f, Valid loss: %7.5f, Valid R2: %7.5f, Time: %s%s' + % (epoch + 1, steps, train_loss, valid_loss, valid_r2, time_str, best_str)) + sys.stdout.flush() + + if FLAGS.check_all: + saver.save(sess, '%s/model_check%d.tf' % (FLAGS.logdir, epoch)) + + # update epoch + epoch += 1 + + + if FLAGS.logdir: + train_writer.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + tf.app.run(main) diff --git a/bin/basenji_train_queues.py b/bin/basenji_train_queues.py index b3ca245d..50f4d5e5 100755 --- a/bin/basenji_train_queues.py +++ b/bin/basenji_train_queues.py @@ -95,7 +95,7 @@ def run(params_file, train_file, test_file, train_epochs, train_epoch_batches, # train epoch sess.run(training_init_op) - train_loss, steps = model.train_epoch_from_data_ops(sess, train_writer, train_epoch_batches) + train_loss, steps = model.train_epoch_tfr(sess, train_writer, train_epoch_batches) # test validation sess.run(test_init_op) From 65c7bb59691b3cf807a7cbe7c5b10003ff6dcd15 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 7 Jul 2018 12:56:28 -0700 Subject: [PATCH 42/71] tuning --- basenji/seqnn.py | 10 ++++++++-- bin/basenji_train_h5.py | 14 +++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index a6f8471f..27f34043 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -109,6 +109,7 @@ def make_placeholders(self): tf.float32, shape=(self.hp.batch_size, self.hp.seq_length, self.hp.seq_depth), name='inputs') + if self.hp.target_classes == 1: self.targets_ph = tf.placeholder( tf.float32, @@ -122,9 +123,14 @@ def make_placeholders(self): self.hp.num_targets), name='targets') + self.targets_na_ph = tf.placeholder(tf.bool, + shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool), + name='targets_na') + data = { 'sequence': self.inputs_ph, - 'label': self.targets_ph + 'label': self.targets_ph, + 'na': self.targets_na_ph } return data @@ -511,7 +517,7 @@ def train_epoch_h5(self, else: run_ops = [self.merged_summary, self.loss_train, self.global_step, self.step_op] run_ops += self.update_ops - summary, loss_batch, global_step = sess.run(run_ops, feed_dict=fd) + summary, loss_batch, global_step = sess.run(run_ops, feed_dict=fd)[:3] # add summary if sum_writer is not None: diff --git a/bin/basenji_train_h5.py b/bin/basenji_train_h5.py index 6e09f9e3..41f86789 100755 --- a/bin/basenji_train_h5.py +++ b/bin/basenji_train_h5.py @@ -77,7 +77,8 @@ def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_ba t0 = time.time() model = seqnn.SeqNN() model.build(job, augment_rc=FLAGS.augment_rc, augment_shifts=augment_shifts, - ensemble_rc=FLAGS.ensemble_rc, ensemble_shifts=ensemble_shifts) + ensemble_rc=FLAGS.ensemble_rc, ensemble_shifts=ensemble_shifts) + print('Model building time %f' % (time.time() - t0)) # adjust for fourier @@ -151,17 +152,8 @@ def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_ba while (train_epochs is None or epoch < train_epochs) and early_stop_i < FLAGS.early_stop: t0 = time.time() - # alternate forward and reverse batches - fwdrc = True - if FLAGS.augment_rc and epoch % 2 == 1: - fwdrc = False - - # cycle shifts - shift_i = epoch % len(augment_shifts) - # train - train_loss, steps = model.train_epoch(sess, batcher_train, fwdrc=fwdrc, - shift=augment_shifts[shift_i], + train_loss, steps = model.train_epoch_h5(sess, batcher_train, sum_writer=train_writer, epoch_batches=train_epoch_batches, no_steps=FLAGS.no_steps) From 68f5f65a8fa2ea09f5ba9b55834b29813a92065c Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 7 Jul 2018 13:19:48 -0700 Subject: [PATCH 43/71] tran epoch bug --- basenji/seqnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 27f34043..7f1d1fc5 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -527,7 +527,7 @@ def train_epoch_h5(self, train_loss.append(loss_batch) # next batch - Xb, Yb, NAb, Nb = batcher.next(fwdrc, shift) + Xb, Yb, NAb, Nb = batcher.next() batch_num += 1 # reset training batcher if epoch considered all of the data From 702ec1e572f4f317bb18bc6db6710392c190a8f9 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 7 Jul 2018 13:22:02 -0700 Subject: [PATCH 44/71] h5 ensembling in-graph --- basenji/seqnn_util.py | 89 +++++++++++++++++++++++++++++++------ bin/basenji_train_h5.py | 10 ++--- bin/basenji_train_queues.py | 2 +- 3 files changed, 80 insertions(+), 21 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 76c06b73..123ab054 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -835,7 +835,7 @@ def predict_genes(self, return tss_preds - def test_from_data_ops(self, sess, test_batches=None): + def test_tfr(self, sess, test_batches=None): """ Compute model accuracy on a test set, where data is loaded from a queue. Args: @@ -845,11 +845,6 @@ def test_from_data_ops(self, sess, test_batches=None): Returns: acc: Accuracy object """ - - # TODO(dbelanger) this ignores rc and shift ensembling for now. - # Accuracy will be slightly lower than if we had used this. - # The rc and shift data augmentation need to be pulled into the graph. - fd = self.set_mode('test') # initialize prediction and target arrays @@ -865,7 +860,7 @@ def test_from_data_ops(self, sess, test_batches=None): batch_num = 0 while data_available and (test_batches is None or batch_num < test_batches): try: - # make non-ensembled predictions + # make predictions run_ops = [self.targets_eval, self.preds_eval, self.loss_eval, self.loss_eval_targets] run_returns = sess.run(run_ops, feed_dict=fd) @@ -900,13 +895,79 @@ def test_from_data_ops(self, sess, test_batches=None): return acc - def test(self, - sess, - batcher, - rc=False, - shifts=[0], - mc_n=0, - test_batches=None): + def test_h5(self, sess, batcher, test_batches=None): + """ Compute model accuracy on a test set. + + Args: + sess: TensorFlow session + batcher: Batcher object to provide data + mc_n: Monte Carlo iterations per rc/shift. + test_batches: Number of test batches + + Returns: + acc: Accuracy object + """ + # setup feed dict + fd = self.set_mode('test') + + # initialize prediction and target arrays + preds = [] + targets = [] + targets_na = [] + + batch_losses = [] + batch_target_losses = [] + + # get first batch + batch_num = 0 + Xb, Yb, NAb, Nb = batcher.next() + + while Xb is not None and (test_batches is None or + batch_num < test_batches): + # make predictions + run_ops = [self.targets_eval, self.preds_eval, + self.loss_eval, self.loss_eval_targets] + run_returns = sess.run(run_ops, feed_dict=fd) + targets_batch, preds_batch, loss_batch, target_losses_batch = run_returns + + # accumulate predictions and targets + preds.append(preds_batch.astype('float16')) + targets.append(targets_batch.astype('float16')) + targets_na.append(np.zeros([preds_batch.shape[0], self.preds_length], dtype='bool')) + + # accumulate loss + batch_losses.append(loss_batch) + batch_target_losses.append(target_losses_batch) + + # next batch + batch_num += 1 + Xb, Yb, NAb, Nb = batcher.next() + + # reset batcher + batcher.reset() + + # construct arrays + targets = np.concatenate(targets, axis=0) + preds = np.concatenate(preds, axis=0) + targets_na = np.concatenate(targets_na, axis=0) + + # mean across batches + batch_losses = np.mean(batch_losses) + batch_target_losses = np.array(batch_target_losses).mean(axis=0) + + # instantiate accuracy object + acc = accuracy.Accuracy(targets, preds, targets_na, + batch_losses, batch_target_losses) + + return acc + + def test_h5_manual(self, + sess, + batcher, + rc=False, + shifts=[0], + mc_n=0, + test_batches=None): """ Compute model accuracy on a test set. Args: diff --git a/bin/basenji_train_h5.py b/bin/basenji_train_h5.py index 41f86789..875ee3af 100755 --- a/bin/basenji_train_h5.py +++ b/bin/basenji_train_h5.py @@ -154,14 +154,12 @@ def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_ba # train train_loss, steps = model.train_epoch_h5(sess, batcher_train, - sum_writer=train_writer, - epoch_batches=train_epoch_batches, - no_steps=FLAGS.no_steps) + sum_writer=train_writer, + epoch_batches=train_epoch_batches, + no_steps=FLAGS.no_steps) # validate - valid_acc = model.test(sess, batcher_valid, mc_n=FLAGS.ensemble_mc, - rc=FLAGS.ensemble_rc, shifts=ensemble_shifts, - test_batches=test_epoch_batches) + valid_acc = model.test_h5(sess, batcher_valid, test_batches=test_epoch_batches) valid_loss = valid_acc.loss valid_r2 = valid_acc.r2().mean() del valid_acc diff --git a/bin/basenji_train_queues.py b/bin/basenji_train_queues.py index 50f4d5e5..8ac51ec5 100755 --- a/bin/basenji_train_queues.py +++ b/bin/basenji_train_queues.py @@ -99,7 +99,7 @@ def run(params_file, train_file, test_file, train_epochs, train_epoch_batches, # test validation sess.run(test_init_op) - valid_acc = model.test_from_data_ops(sess, test_epoch_batches) + valid_acc = model.test_tfr(sess, test_epoch_batches) valid_loss = valid_acc.loss valid_r2 = valid_acc.r2().mean() del valid_acc From 40a960e96c6afb752c13f28a85896500c934d186 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sun, 8 Jul 2018 10:24:23 -0700 Subject: [PATCH 45/71] test feed dict --- basenji/seqnn_util.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 123ab054..071a13a7 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -924,6 +924,10 @@ def test_h5(self, sess, batcher, test_batches=None): while Xb is not None and (test_batches is None or batch_num < test_batches): + # update feed dict + fd[self.inputs_ph] = Xb + fd[self.targets_ph] = Yb + # make predictions run_ops = [self.targets_eval, self.preds_eval, self.loss_eval, self.loss_eval_targets] @@ -931,9 +935,9 @@ def test_h5(self, sess, batcher, test_batches=None): targets_batch, preds_batch, loss_batch, target_losses_batch = run_returns # accumulate predictions and targets - preds.append(preds_batch.astype('float16')) - targets.append(targets_batch.astype('float16')) - targets_na.append(np.zeros([preds_batch.shape[0], self.preds_length], dtype='bool')) + preds.append(preds_batch[:Nb,:,:].astype('float16')) + targets.append(targets_batch[:Nb,:,:].astype('float16')) + targets_na.append(np.zeros([Nb, self.preds_length], dtype='bool')) # accumulate loss batch_losses.append(loss_batch) @@ -1027,8 +1031,8 @@ def test_h5_manual(self, sess, fd, Xb, ensemble_fwdrc, ensemble_shifts, mc_n) # add target info - fd[self.targets] = Yb - fd[self.targets_na] = NAb + fd[self.targets_ph] = Yb + fd[self.targets_na_ph] = NAb targets_na.append(np.zeros([Nb, self.preds_length], dtype='bool')) From e80c4fb6f9734d7557f50acac92078c87d4a780c Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 14:34:00 -0700 Subject: [PATCH 46/71] update placeholder --- basenji/seqnn_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 071a13a7..2df94556 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -300,7 +300,7 @@ def _gradients_ensemble(self, sess, fd, Xb, ensemble_fwdrc, ensemble_shifts, mc_ Xb_ensemble = hot1_augment(Xb, ensemble_fwdrc[ei], ensemble_shifts[ei]) # update feed dict - fd[self.inputs] = Xb_ensemble + fd[self.inputs_ph] = Xb_ensemble # for each monte carlo (or non-mc single) iteration for mi in range(mc_n): @@ -448,7 +448,7 @@ def gradients_genes(self, sess, batcher, gene_seqs): while Xb is not None: # update feed dict - fd[self.inputs] = Xb + fd[self.inputs_ph] = Xb # predict reprs_batch, _ = sess.run([self.layer_reprs, self.preds_eval], feed_dict=fd) @@ -507,7 +507,7 @@ def hidden(self, sess, batcher, layers=None): while Xb is not None: # update feed dict - fd[self.inputs] = Xb + fd[self.inputs_ph] = Xb # compute predictions layer_reprs_batch, preds_batch = sess.run( @@ -589,7 +589,7 @@ def _predict_ensemble(self, Xb_ensemble = hot1_augment(Xb, ensemble_fwdrc[ei], ensemble_shifts[ei]) # update feed dict - fd[self.inputs] = Xb_ensemble + fd[self.inputs_ph] = Xb_ensemble # for each monte carlo (or non-mc single) iteration for mi in range(mc_n): From dfa82247ea3c33e3c221cc7e8256f5049b1fbdef Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 14:34:10 -0700 Subject: [PATCH 47/71] 0 shift defaults --- basenji/seqnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 7f1d1fc5..fbc0c413 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -33,8 +33,8 @@ def __init__(self): self.global_step = tf.train.get_or_create_global_step() self.hparams_set = False - def build(self, job, augment_rc=False, augment_shifts=[], - ensemble_rc=False, ensemble_shifts=[], target_subset=None): + def build(self, job, augment_rc=False, augment_shifts=[0], + ensemble_rc=False, ensemble_shifts=[0], target_subset=None): """Build training ops that depend on placeholders.""" self.hp = params.make_hparams(job) From c253cb9f946cbe5877b215525db5eb4b01171131 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 14:34:38 -0700 Subject: [PATCH 48/71] testing --- bin/basenji_test.py | 4 +- bin/basenji_test_h5.py | 624 +++++++++++++++++++++++++++++++++++++++++ bin/basenji_testq.py | 6 +- 3 files changed, 630 insertions(+), 4 deletions(-) create mode 100755 bin/basenji_test_h5.py diff --git a/bin/basenji_test.py b/bin/basenji_test.py index 6f7ed4ae..1555d719 100755 --- a/bin/basenji_test.py +++ b/bin/basenji_test.py @@ -255,8 +255,8 @@ def main(): # test t0 = time.time() - test_acc = dr.test(sess, batcher_test, rc=options.rc, - shifts=options.shifts, mc_n=options.mc_n) + test_acc = dr.test_h5_manual(sess, batcher_test, rc=options.rc, + shifts=options.shifts, mc_n=options.mc_n) if options.save: np.save('%s/preds.npy' % options.out_dir, test_acc.preds) diff --git a/bin/basenji_test_h5.py b/bin/basenji_test_h5.py new file mode 100755 index 00000000..7dd64484 --- /dev/null +++ b/bin/basenji_test_h5.py @@ -0,0 +1,624 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import print_function +from optparse import OptionParser +import os +import random +import sys +import time + +import h5py +import joblib +import matplotlib +matplotlib.use('PDF') +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pyBigWig +from scipy.stats import spearmanr, poisson +import seaborn as sns +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, average_precision_score +import tensorflow as tf + +from basenji import batcher +from basenji import params +from basenji import plots +from basenji import seqnn + +""" +basenji_test.py + +Test the accuracy of a trained model. + +Notes + -This probably needs work for the pooled large sequence version. I tried to + update the "full" comparison, but it's not tested. The notion of peak calls + will need to completely change; we probably want to predict in each bin. +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option( + '--ai', + dest='accuracy_indexes', + help= + 'Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values' + ) + parser.add_option( + '--clip', + dest='target_clip', + default=None, + type='float', + help='Clip targets and predictions to a maximum value [Default: %default]' + ) + parser.add_option( + '-d', + dest='down_sample', + default=1, + type='int', + help= + 'Down sample test computation by taking uniformly spaced positions [Default: %default]' + ) + parser.add_option( + '-g', + dest='genome_file', + default='%s/tutorials/data/human.hg19.genome' % os.environ['BASENJIDIR'], + help='Chromosome length information [Default: %default]') + parser.add_option( + '--mc', + dest='mc_n', + default=0, + type='int', + help='Monte carlo test iterations [Default: %default]') + parser.add_option( + '--peak','--peaks', + dest='peaks', + default=False, + action='store_true', + help='Compute expensive peak accuracy [Default: %default]') + parser.add_option( + '-o', + dest='out_dir', + default='test_out', + help='Output directory for test statistics [Default: %default]') + parser.add_option( + '--rc', + dest='rc', + default=False, + action='store_true', + help= + 'Average the fwd and rc predictions [Default: %default]') + parser.add_option( + '--sample', + dest='sample_pct', + default=1, + type='float', + help='Sample percentage') + parser.add_option( + '--save', + dest='save', + default=False, + action='store_true') + parser.add_option( + '--shifts', + dest='shifts', + default='0', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option( + '-t', + dest='track_bed', + help='BED file describing regions so we can output BigWig tracks') + parser.add_option( + '--ti', + dest='track_indexes', + help='Comma-separated list of target indexes to output BigWig tracks') + parser.add_option( + '--train', + dest='train', + default=False, + action='store_true', + help='Process the training set [Default: %default]') + parser.add_option( + '-v', + dest='valid', + default=False, + action='store_true', + help='Process the validation set [Default: %default]') + parser.add_option( + '-w', + dest='pool_width', + default=1, + type='int', + help= + 'Max pool width for regressing nt predictions to predict peak calls [Default: %default]' + ) + (options, args) = parser.parse_args() + + if len(args) != 3: + parser.error('Must provide parameters, model, and test data HDF5') + else: + params_file = args[0] + model_file = args[1] + test_hdf5_file = args[2] + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.shifts = [int(shift) for shift in options.shifts.split(',')] + + ####################################################### + # load data + ####################################################### + data_open = h5py.File(test_hdf5_file) + + if options.train: + test_seqs = data_open['train_in'] + test_targets = data_open['train_out'] + if 'train_na' in data_open: + test_na = data_open['train_na'] + + elif options.valid: + test_seqs = data_open['valid_in'] + test_targets = data_open['valid_out'] + test_na = None + if 'valid_na' in data_open: + test_na = data_open['valid_na'] + + else: + test_seqs = data_open['test_in'] + test_targets = data_open['test_out'] + test_na = None + if 'test_na' in data_open: + test_na = data_open['test_na'] + + if options.sample_pct < 1: + sample_n = int(test_seqs.shape[0]*options.sample_pct) + print('Sampling %d sequences' % sample_n) + sample_indexes = sorted(np.random.choice(np.arange(test_seqs.shape[0]), + size=sample_n, replace=False)) + test_seqs = test_seqs[sample_indexes] + test_targets = test_targets[sample_indexes] + if test_na is not None: + test_na = test_na[sample_indexes] + + target_labels = [tl.decode('UTF-8') for tl in data_open['target_labels']] + + ####################################################### + # model parameters and placeholders + + job = params.read_job_params(params_file) + + job['seq_length'] = test_seqs.shape[1] + job['seq_depth'] = test_seqs.shape[2] + job['num_targets'] = test_targets.shape[2] + job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) + + t0 = time.time() + model = seqnn.SeqNN() + model.build(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts) + print('Model building time %ds' % (time.time() - t0)) + + # adjust for fourier + job['fourier'] = 'train_out_imag' in data_open + if job['fourier']: + test_targets_imag = data_open['test_out_imag'] + if options.valid: + test_targets_imag = data_open['valid_out_imag'] + + ####################################################### + # test + + # initialize batcher + if job['fourier']: + batcher_test = batcher.BatcherF(test_seqs, test_targets, + test_targets_imag, test_na, + model.hp.batch_size, model.hp.target_pool) + else: + batcher_test = batcher.Batcher(test_seqs, test_targets, test_na, + model.hp.batch_size, model.hp.target_pool) + + # initialize saver + saver = tf.train.Saver() + + with tf.Session() as sess: + # load variables into session + saver.restore(sess, model_file) + + # test + t0 = time.time() + test_acc = model.test_h5(sess, batcher_test) + + if options.save: + np.save('%s/preds.npy' % options.out_dir, test_acc.preds) + np.save('%s/targets.npy' % options.out_dir, test_acc.targets) + + test_preds = test_acc.preds + print('SeqNN test: %ds' % (time.time() - t0)) + + # compute stats + t0 = time.time() + test_r2 = test_acc.r2(clip=options.target_clip) + # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip) + test_pcor = test_acc.pearsonr(clip=options.target_clip) + test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip) + #test_scor = test_acc.spearmanr() # too slow; mostly driven by low values + print('Compute stats: %ds' % (time.time()-t0)) + + # print + print('Test Loss: %7.5f' % test_acc.loss) + print('Test R2: %7.5f' % test_r2.mean()) + # print('Test log R2: %7.5f' % test_log_r2.mean()) + print('Test PearsonR: %7.5f' % test_pcor.mean()) + print('Test log PearsonR: %7.5f' % test_log_pcor.mean()) + # print('Test SpearmanR: %7.5f' % test_scor.mean()) + + acc_out = open('%s/acc.txt' % options.out_dir, 'w') + for ti in range(len(test_r2)): + print( + '%4d %7.5f %.5f %.5f %.5f %s' % + (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti], + test_log_pcor[ti], target_labels[ti]), file=acc_out) + acc_out.close() + + # print normalization factors + target_means = test_preds.mean(axis=(0,1), dtype='float64') + target_means_median = np.median(target_means) + target_means /= target_means_median + norm_out = open('%s/normalization.txt' % options.out_dir, 'w') + print('\n'.join([str(tu) for tu in target_means]), file=norm_out) + norm_out.close() + + # clean up + del test_acc + + + ####################################################### + # peak call accuracy + + if options.peaks: + # sample every few bins to decrease correlations + ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) + ds_indexes_targets = ds_indexes_preds + (model.hp.batch_buffer // model.hp.target_pool) + + aurocs = [] + auprcs = [] + + peaks_out = open('%s/peaks.txt' % options.out_dir, 'w') + for ti in range(test_targets.shape[2]): + test_targets_ti = test_targets[:, :, ti] + + # subset and flatten + test_targets_ti_flat = test_targets_ti[:, ds_indexes_targets].flatten( + ).astype('float32') + test_preds_ti_flat = test_preds[:, ds_indexes_preds, ti].flatten().astype( + 'float32') + + # call peaks + test_targets_ti_lambda = np.mean(test_targets_ti_flat) + test_targets_pvals = 1 - poisson.cdf( + np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) + test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) + test_targets_peaks = test_targets_qvals < 0.01 + + if test_targets_peaks.sum() == 0: + aurocs.append(0.5) + auprcs.append(0) + + else: + # compute prediction accuracy + aurocs.append(roc_auc_score(test_targets_peaks, test_preds_ti_flat)) + auprcs.append( + average_precision_score(test_targets_peaks, test_preds_ti_flat)) + + print('%4d %6d %.5f %.5f' % (ti, test_targets_peaks.sum(), + aurocs[-1], auprcs[-1]), + file=peaks_out) + + peaks_out.close() + + print('Test AUROC: %7.5f' % np.mean(aurocs)) + print('Test AUPRC: %7.5f' % np.mean(auprcs)) + + ####################################################### + # BigWig tracks + + # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE + + # print bigwig tracks for visualization + if options.track_bed: + if options.genome_file is None: + parser.error('Must provide genome file in order to print valid BigWigs') + + if not os.path.isdir('%s/tracks' % options.out_dir): + os.mkdir('%s/tracks' % options.out_dir) + + track_indexes = range(test_preds.shape[2]) + if options.track_indexes: + track_indexes = [int(ti) for ti in options.track_indexes.split(',')] + + bed_set = 'test' + if options.valid: + bed_set = 'valid' + + for ti in track_indexes: + test_targets_ti = test_targets[:, :, ti] + + # make true targets bigwig + bw_file = '%s/tracks/t%d_true.bw' % (options.out_dir, ti) + bigwig_write( + bw_file, + test_targets_ti, + options.track_bed, + options.genome_file, + bed_set=bed_set) + + # make predictions bigwig + bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti) + bigwig_write( + bw_file, + test_preds[:, :, ti], + options.track_bed, + options.genome_file, + model.hp.batch_buffer, + bed_set=bed_set) + + # make NA bigwig + # bw_file = '%s/tracks/na.bw' % options.out_dir + # bigwig_write( + # bw_file, + # test_na, + # options.track_bed, + # options.genome_file, + # bed_set=bed_set) + + ####################################################### + # accuracy plots + + if options.accuracy_indexes is not None: + accuracy_indexes = [int(ti) for ti in options.accuracy_indexes.split(',')] + + if not os.path.isdir('%s/scatter' % options.out_dir): + os.mkdir('%s/scatter' % options.out_dir) + + if not os.path.isdir('%s/violin' % options.out_dir): + os.mkdir('%s/violin' % options.out_dir) + + if not os.path.isdir('%s/roc' % options.out_dir): + os.mkdir('%s/roc' % options.out_dir) + + if not os.path.isdir('%s/pr' % options.out_dir): + os.mkdir('%s/pr' % options.out_dir) + + for ti in accuracy_indexes: + test_targets_ti = test_targets[:, :, ti] + + ############################################ + # scatter + + # sample every few bins (adjust to plot the # points I want) + ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) + ds_indexes_targets = ds_indexes_preds + ( + model.hp.batch_buffer // model.hp.target_pool) + + # subset and flatten + test_targets_ti_flat = test_targets_ti[:, ds_indexes_targets].flatten( + ).astype('float32') + test_preds_ti_flat = test_preds[:, ds_indexes_preds, ti].flatten().astype( + 'float32') + + # take log2 + test_targets_ti_log = np.log2(test_targets_ti_flat + 1) + test_preds_ti_log = np.log2(test_preds_ti_flat + 1) + + # plot log2 + sns.set(font_scale=1.2, style='ticks') + out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti) + plots.regplot( + test_targets_ti_log, + test_preds_ti_log, + out_pdf, + poly_order=1, + alpha=0.3, + sample=500, + figsize=(6, 6), + x_label='log2 Experiment', + y_label='log2 Prediction', + table=True) + + ############################################ + # violin + + # call peaks + test_targets_ti_lambda = np.mean(test_targets_ti_flat) + test_targets_pvals = 1 - poisson.cdf( + np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) + test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) + test_targets_peaks = test_targets_qvals < 0.01 + test_targets_peaks_str = np.where(test_targets_peaks, 'Peak', + 'Background') + + # violin plot + sns.set(font_scale=1.3, style='ticks') + plt.figure() + df = pd.DataFrame({ + 'log2 Prediction': np.log2(test_preds_ti_flat + 1), + 'Experimental coverage status': test_targets_peaks_str + }) + ax = sns.violinplot( + x='Experimental coverage status', y='log2 Prediction', data=df) + ax.grid(True, linestyle=':') + plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti)) + plt.close() + + # ROC + plt.figure() + fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat) + auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat) + plt.plot( + [0, 1], [0, 1], c='black', linewidth=1, linestyle='--', alpha=0.7) + plt.plot(fpr, tpr, c='black') + ax = plt.gca() + ax.set_xlabel('False positive rate') + ax.set_ylabel('True positive rate') + ax.text( + 0.99, 0.02, 'AUROC %.3f' % auroc, + horizontalalignment='right') # , fontsize=14) + ax.grid(True, linestyle=':') + plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti)) + plt.close() + + # PR + plt.figure() + prec, recall, _ = precision_recall_curve(test_targets_peaks, + test_preds_ti_flat) + auprc = average_precision_score(test_targets_peaks, test_preds_ti_flat) + plt.axhline( + y=test_targets_peaks.mean(), + c='black', + linewidth=1, + linestyle='--', + alpha=0.7) + plt.plot(recall, prec, c='black') + ax = plt.gca() + ax.set_xlabel('Recall') + ax.set_ylabel('Precision') + ax.text( + 0.99, 0.95, 'AUPRC %.3f' % auprc, + horizontalalignment='right') # , fontsize=14) + ax.grid(True, linestyle=':') + plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti)) + plt.close() + + data_open.close() + + +def ben_hoch(p_values): + """ Convert the given p-values to q-values using Benjamini-Hochberg FDR. """ + m = len(p_values) + + # attach original indexes to p-values + p_k = [(p_values[k], k) for k in range(m)] + + # sort by p-value + p_k.sort() + + # compute q-value and attach original index to front + k_q = [(p_k[i][1], p_k[i][0] * m // (i + 1)) for i in range(m)] + + # re-sort by original index + k_q.sort() + + # drop original indexes + q_values = [k_q[k][1] for k in range(m)] + + return q_values + + +def bigwig_open(bw_file, genome_file): + """ Open the bigwig file for writing and write the header. """ + + bw_out = pyBigWig.open(bw_file, 'w') + + chrom_sizes = [] + for line in open(genome_file): + a = line.split() + chrom_sizes.append((a[0], int(a[1]))) + + bw_out.addHeader(chrom_sizes) + + return bw_out + + +def bigwig_write(bw_file, + signal_ti, + track_bed, + genome_file, + buffer=0, + bed_set='test'): + """ Write a signal track to a BigWig file over the regions + specified by track_bed. + + Args + bw_file: BigWig filename + signal_ti: Sequences X Length array for some target + track_bed: BED file specifying sequence coordinates + genome_file: Chromosome lengths file + buffer: Length skipped on each side of the region. + """ + + bw_out = bigwig_open(bw_file, genome_file) + + si = 0 + bw_hash = {} + + # set entries + for line in open(track_bed): + a = line.split() + if a[3] == bed_set: + chrom = a[0] + start = int(a[1]) + end = int(a[2]) + + preds_pool = (end - start - 2 * buffer) // signal_ti.shape[1] + + bw_start = start + buffer + for li in range(signal_ti.shape[1]): + bw_end = bw_start + preds_pool + bw_hash.setdefault((chrom,bw_start,bw_end),[]).append(signal_ti[si,li]) + bw_start = bw_end + + si += 1 + + # average duplicates + bw_entries = [] + for bw_key in bw_hash: + bw_signal = np.mean(bw_hash[bw_key]) + bwe = tuple(list(bw_key)+[bw_signal]) + bw_entries.append(bwe) + + # sort entries + bw_entries.sort() + + # add entries + for line in open(genome_file): + chrom = line.split()[0] + + bw_entries_chroms = [be[0] for be in bw_entries if be[0] == chrom] + bw_entries_starts = [be[1] for be in bw_entries if be[0] == chrom] + bw_entries_ends = [be[2] for be in bw_entries if be[0] == chrom] + bw_entries_values = [float(be[3]) for be in bw_entries if be[0] == chrom] + + if len(bw_entries_chroms) > 0: + bw_out.addEntries( + bw_entries_chroms, + bw_entries_starts, + ends=bw_entries_ends, + values=bw_entries_values) + + bw_out.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/bin/basenji_testq.py b/bin/basenji_testq.py index 5af5614d..578d7a19 100755 --- a/bin/basenji_testq.py +++ b/bin/basenji_testq.py @@ -120,7 +120,9 @@ def main(): # initialize model model = seqnn.SeqNN() - model.build_from_data_ops(job, data_ops) + model.build_from_data_ops(job, data_ops, + ensemble_rc=options.rc, + ensemble_shifts=options.shifts) # initialize saver saver = tf.train.Saver() @@ -136,7 +138,7 @@ def main(): # test t0 = time.time() sess.run(test_init_op) - test_acc = model.test_from_data_ops(sess) + test_acc = model.test_tfr(sess) test_preds = test_acc.preds print('SeqNN test: %ds' % (time.time() - t0)) From f68c6acb3caaafc4638700d27775c22f01d9855f Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 15:07:00 -0700 Subject: [PATCH 49/71] target labels --- bin/basenji_testq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/basenji_testq.py b/bin/basenji_testq.py index 578d7a19..fe29057e 100755 --- a/bin/basenji_testq.py +++ b/bin/basenji_testq.py @@ -169,7 +169,7 @@ def main(): print( '%4d %7.5f %.5f %.5f %.5f %s' % (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti], - test_log_pcor[ti], target_labels[ti]), file=acc_out) + test_log_pcor[ti], targets_df.description.iloc[ti]), file=acc_out) acc_out.close() # print normalization factors From b344161ea2c795c338662b96b4b5f7f487b0637d Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 15:07:34 -0700 Subject: [PATCH 50/71] float64 loss mean --- basenji/seqnn_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 2df94556..b3c9bea9 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -886,8 +886,8 @@ def test_tfr(self, sess, test_batches=None): targets_na = np.concatenate(targets_na, axis=0) # mean across batches - batch_losses = np.mean(batch_losses) - batch_target_losses = np.array(batch_target_losses).mean(axis=0) + batch_losses = np.mean(batch_losses, dtype='float64') + batch_target_losses = np.array(batch_target_losses).mean(axis=0, dtype='float64') # instantiate accuracy object acc = accuracy.Accuracy(targets, preds, targets_na, @@ -956,8 +956,8 @@ def test_h5(self, sess, batcher, test_batches=None): targets_na = np.concatenate(targets_na, axis=0) # mean across batches - batch_losses = np.mean(batch_losses) - batch_target_losses = np.array(batch_target_losses).mean(axis=0) + batch_losses = np.mean(batch_losses, dtype='float64') + batch_target_losses = np.array(batch_target_losses).mean(axis=0, dtype='float64') # instantiate accuracy object acc = accuracy.Accuracy(targets, preds, targets_na, From 74b12c3722f7b233c6d890f95a8d2783ad8d772e Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 15:09:02 -0700 Subject: [PATCH 51/71] default shift 0 --- basenji/seqnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index fbc0c413..d8c2f282 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -49,8 +49,8 @@ def build(self, job, augment_rc=False, augment_shifts=[0], target_subset=target_subset) def build_from_data_ops(self, job, data_ops, - augment_rc=False, augment_shifts=[], - ensemble_rc=False, ensemble_shifts=[], + augment_rc=False, augment_shifts=[0], + ensemble_rc=False, ensemble_shifts=[0], target_subset=None): """Build training ops from input data ops.""" if not self.hparams_set: From 99fdf38d6be8f68d924f8f171ccf590d04bbfd95 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 11 Jul 2018 15:10:18 -0700 Subject: [PATCH 52/71] no data open --- bin/basenji_testq.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bin/basenji_testq.py b/bin/basenji_testq.py index fe29057e..c6c83b99 100755 --- a/bin/basenji_testq.py +++ b/bin/basenji_testq.py @@ -400,8 +400,6 @@ def main(): plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti)) plt.close() - data_open.close() - def ben_hoch(p_values): """ Convert the given p-values to q-values using Benjamini-Hochberg FDR. """ From 53b9a53b1ea610b5948d344ed7964102b55a5c5f Mon Sep 17 00:00:00 2001 From: David Kelley Date: Thu, 12 Jul 2018 06:39:57 -0700 Subject: [PATCH 53/71] create new data_ops dict --- basenji/augmentation.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/basenji/augmentation.py b/basenji/augmentation.py index 340c313e..7b7fb437 100644 --- a/basenji/augmentation.py +++ b/basenji/augmentation.py @@ -79,16 +79,21 @@ def augment_deterministic(data_ops, augment_rc=False, augment_shift=0): Returns data_ops: augmented data """ - if augment_shift != 0: + + data_ops_aug = {'label': data_ops['label'], 'na': data_ops['na']} + + if augment_shift == 0: + data_ops_aug['sequence'] = data_ops['sequence'] + else: shift_amount = tf.constant(augment_shift, shape=(), dtype=tf.int64) - data_ops['sequence'] = shift_sequence(data_ops['sequence'], shift_amount) + data_ops_aug['sequence'] = shift_sequence(data_ops['sequence'], shift_amount) if augment_rc: - data_ops = augment_deterministic_rc(data_ops) + data_ops_aug = augment_deterministic_rc(data_ops_aug) else: - data_ops['reverse_preds'] = tf.zeros((), dtype=tf.bool) + data_ops_aug['reverse_preds'] = tf.zeros((), dtype=tf.bool) - return data_ops + return data_ops_aug def augment_deterministic_rc(data_ops): From 6799d25a182bb8556d8971ad7e81c7d3d530f74a Mon Sep 17 00:00:00 2001 From: David Kelley Date: Thu, 12 Jul 2018 06:40:48 -0700 Subject: [PATCH 54/71] average predictions, not representations --- basenji/seqnn.py | 132 ++++++++++++++++++++++------------------------- 1 file changed, 62 insertions(+), 70 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index d8c2f282..f9db82df 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -68,12 +68,12 @@ def build_from_data_ops(self, job, data_ops, data_ops, augment_rc, augment_shifts) # compute train representation - seqs_repr_train = self.build_representation(data_ops_train['sequence'], - None, target_subset) + self.preds_train = self.build_predict(data_ops_train['sequence'], + None, target_subset) # training losses - loss_returns = self.build_loss(seqs_repr_train, data_ops_train, target_subset) - self.loss_train, self.loss_train_targets, self.preds_train, self.targets_train = loss_returns + loss_returns = self.build_loss(self.preds_train, data_ops_train['label'], target_subset) + self.loss_train, self.loss_train_targets, self.targets_train = loss_returns # optimizer self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) @@ -90,13 +90,13 @@ def build_from_data_ops(self, job, data_ops, # compute eval representation map_elems_eval = (data_seq_eval, data_rev_eval) - build_rep = lambda do: self.build_representation(do[0], do[1], target_subset) - seqs_repr_list = tf.map_fn(build_rep, map_elems_eval, dtype=seqs_repr_train.dtype) # back_prop=False - seqs_repr_eval = tf.reduce_mean(seqs_repr_list, axis=0) + build_rep = lambda do: self.build_predict(do[0], do[1], target_subset) + self.preds_ensemble = tf.map_fn(build_rep, map_elems_eval, dtype=self.preds_train.dtype) # back_prop=False + self.preds_eval = tf.reduce_mean(self.preds_ensemble, axis=0) # eval loss - loss_returns = self.build_loss(seqs_repr_eval, data_ops, target_subset) - self.loss_eval, self.loss_eval_targets, self.preds_eval, self.targets_eval = loss_returns + loss_returns = self.build_loss(self.preds_eval, data_ops['label'], target_subset) + self.loss_eval, self.loss_eval_targets, self.targets_eval = loss_returns # helper variables self.preds_length = self.preds_train.shape[1] @@ -148,7 +148,7 @@ def _make_conv_block_args(self, layer_index, layer_reprs): 'name': 'conv-%d' % layer_index } - def build_representation(self, inputs, reverse_preds=None, target_subset=None): + def build_predict(self, inputs, reverse_preds=None, target_subset=None): """Construct per-location real-valued predictions.""" assert inputs is not None print('Targets pooled by %d to length %d' % @@ -235,7 +235,52 @@ def build_representation(self, inputs, reverse_preds=None, target_subset=None): lambda: tf.reverse(final_repr, axis=[1]), lambda: final_repr) - return final_repr + ################################################### + # link function + ################################################### + + # work-around for specifying my own predictions + # self.preds_adhoc = tf.placeholder( + # tf.float32, shape=final_repr.shape, name='preds-adhoc') + + # float 32 exponential clip max + exp_max = 50 + + # choose link + if self.hp.link in ['identity', 'linear']: + predictions = tf.identity(final_repr, name='preds') + + elif self.hp.link == 'relu': + predictions = tf.relu(final_repr, name='preds') + + elif self.hp.link == 'exp': + final_repr_clip = tf.clip_by_value(final_repr, -exp_max, exp_max) + predictions = tf.exp(final_repr_clip, name='preds') + + elif self.hp.link == 'exp_linear': + predictions = tf.where( + final_repr > 0, + final_repr + 1, + tf.exp(tf.clip_by_value(final_repr, -exp_max, exp_max)), + name='preds') + + elif self.hp.link == 'softplus': + final_repr_clip = tf.clip_by_value(final_repr, -exp_max, 10000) + predictions = tf.nn.softplus(final_repr_clip, name='preds') + + else: + print('Unknown link function %s' % self.hp.link, file=sys.stderr) + exit(1) + + # clip + if self.hp.target_clip is not None: + predictions = tf.clip_by_value(predictions, 0, self.hp.target_clip) + + # sqrt + if self.hp.target_sqrt: + predictions = tf.sqrt(predictions) + + return predictions def build_optimizer(self, loss_op): """Construct optimization op that minimizes loss_op.""" @@ -290,86 +335,38 @@ def build_optimizer(self, loss_op): self.merged_summary = tf.summary.merge_all() - def build_loss(self, seqs_repr, data_ops, target_subset=None): + def build_loss(self, preds, targets, target_subset=None): """Convert per-location real-valued predictions to a loss.""" - # targets + # slice buffer tstart = self.hp.batch_buffer // self.hp.target_pool tend = (self.hp.seq_length - self.hp.batch_buffer) // self.hp.target_pool - - targets = data_ops['label'] targets = tf.identity(targets[:, tstart:tend, :], name='targets_op') if target_subset is not None: targets = tf.gather(targets, target_subset, axis=2) - # work-around for specifying my own predictions - # self.preds_adhoc = tf.placeholder( - # tf.float32, shape=seqs_repr.shape, name='preds-adhoc') - - # float 32 exponential clip max - # exp_max = np.floor(np.log(0.5*tf.float32.max)) - exp_max = 50 - - # choose link - if self.hp.link in ['identity', 'linear']: - preds_op = tf.identity(seqs_repr, name='preds') - - elif self.hp.link == 'relu': - preds_op = tf.relu(seqs_repr, name='preds') - - elif self.hp.link == 'exp': - seqs_repr_clip = tf.clip_by_value(seqs_repr, -exp_max, exp_max) - preds_op = tf.exp(seqs_repr_clip, name='preds') - - elif self.hp.link == 'exp_linear': - preds_op = tf.where( - seqs_repr > 0, - seqs_repr + 1, - tf.exp(tf.clip_by_value(seqs_repr, -exp_max, exp_max)), - name='preds') - - elif self.hp.link == 'softplus': - seqs_repr_clip = tf.clip_by_value(seqs_repr, -exp_max, 10000) - preds_op = tf.nn.softplus(seqs_repr_clip, name='preds') - - elif self.hp.link == 'softmax': - # performed in the loss function, but saving probabilities - self.preds_prob = tf.nn.softmax(seqs_repr, name='preds') - - else: - print('Unknown link function %s' % self.hp.link, file=sys.stderr) - exit(1) - # clip if self.hp.target_clip is not None: - preds_op = tf.clip_by_value(preds_op, 0, self.hp.target_clip) targets = tf.clip_by_value(targets, 0, self.hp.target_clip) # sqrt if self.hp.target_sqrt: - preds_op = tf.sqrt(preds_op) targets = tf.sqrt(targets) loss_op = None - # loss_adhoc = None # choose loss if self.hp.loss == 'gaussian': - loss_op = tf.squared_difference(preds_op, targets) - # loss_adhoc = tf.squared_difference(self.preds_adhoc, targets) + loss_op = tf.squared_difference(preds, targets) elif self.hp.loss == 'poisson': loss_op = tf.nn.log_poisson_loss( - targets, tf.log(preds_op), compute_full_loss=True) - # loss_adhoc = tf.nn.log_poisson_loss( - # targets, tf.log(self.preds_adhoc), compute_full_loss=True) + targets, tf.log(preds), compute_full_loss=True) elif self.hp.loss == 'cross_entropy': loss_op = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=(targets - 1), logits=preds_op) - # loss_adhoc = tf.nn.sparse_softmax_cross_entropy_with_logits( - # labels=(targets - 1), logits=self.preds_adhoc) + labels=(targets - 1), logits=preds) else: raise ValueError('Cannot identify loss function %s' % self.hp.loss) @@ -378,29 +375,24 @@ def build_loss(self, seqs_repr, data_ops, target_subset=None): loss_op = tf.reduce_mean(loss_op, axis=[0, 1], name='target_loss') loss_op = tf.check_numerics(loss_op, 'Invalid loss', name='loss_check') - # loss_adhoc = tf.reduce_mean( - # loss_adhoc, axis=[0, 1], name='target_loss_adhoc') tf.summary.histogram('target_loss', loss_op) for ti in np.linspace(0, self.hp.num_targets - 1, 10).astype('int'): tf.summary.scalar('loss_t%d' % ti, loss_op[ti]) target_losses = loss_op - # self.target_losses_adhoc = loss_adhoc # fully reduce loss_op = tf.reduce_mean(loss_op, name='loss') - # loss_adhoc = tf.reduce_mean(loss_adhoc, name='loss_adhoc') # add regularization terms reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) reg_sum = tf.reduce_sum(reg_losses) tf.summary.scalar('regularizers', reg_sum) loss_op += reg_sum - # loss_adhoc += reg_sum # track tf.summary.scalar('loss', loss_op) - return loss_op, target_losses, preds_op, targets + return loss_op, target_losses, targets def set_mode(self, mode): From 2769f822f91130c5131110a30e4a908fa2c54db4 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Fri, 20 Jul 2018 16:36:02 -0700 Subject: [PATCH 55/71] rename build --- basenji/seqnn.py | 2 +- bin/basenji_sad.py | 2 +- bin/basenji_sat.py | 2 +- bin/basenji_sat_vcf.py | 2 +- bin/basenji_sed.py | 2 +- bin/basenji_test.py | 2 +- bin/basenji_test_genes.py | 2 +- bin/basenji_test_h5.py | 2 +- bin/basenji_test_reps.py | 2 +- bin/basenji_train.py | 2 +- bin/basenji_train_h5.py | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index f9db82df..d58ed141 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -33,7 +33,7 @@ def __init__(self): self.global_step = tf.train.get_or_create_global_step() self.hparams_set = False - def build(self, job, augment_rc=False, augment_shifts=[0], + def build_feed(self, job, augment_rc=False, augment_shifts=[0], ensemble_rc=False, ensemble_shifts=[0], target_subset=None): """Build training ops that depend on placeholders.""" diff --git a/bin/basenji_sad.py b/bin/basenji_sad.py index 366e4693..b7250f06 100755 --- a/bin/basenji_sad.py +++ b/bin/basenji_sad.py @@ -173,7 +173,7 @@ def main(): # build model t0 = time.time() model = basenji.seqnn.SeqNN() - model.build(job, target_subset=target_subset) + model.build_feed(job, target_subset=target_subset) print('Model building time %f' % (time.time() - t0), flush=True) if options.penultimate: diff --git a/bin/basenji_sat.py b/bin/basenji_sat.py index f5b3602f..4b27c0a8 100755 --- a/bin/basenji_sat.py +++ b/bin/basenji_sat.py @@ -189,7 +189,7 @@ def main(): t0 = time.time() dr = seqnn.SeqNN() - dr.build(job, target_subset=target_subset) + dr.build_feed(job, target_subset=target_subset) print('Model building time %f' % (time.time() - t0), flush=True) if options.batch_size is not None: diff --git a/bin/basenji_sat_vcf.py b/bin/basenji_sat_vcf.py index c0236b76..de8e610c 100755 --- a/bin/basenji_sat_vcf.py +++ b/bin/basenji_sat_vcf.py @@ -180,7 +180,7 @@ def main(): # build model dr = basenji.seqnn.SeqNN() - dr.build(job) + dr.build_feed(job) # initialize saver saver = tf.train.Saver() diff --git a/bin/basenji_sed.py b/bin/basenji_sed.py index 8e25e3eb..7956934c 100755 --- a/bin/basenji_sed.py +++ b/bin/basenji_sed.py @@ -230,7 +230,7 @@ def main(): # build model model = seqnn.SeqNN() - model.build(job, target_subset=target_subset) + model.build_feed(job, target_subset=target_subset) if options.penultimate: # labels become inappropriate diff --git a/bin/basenji_test.py b/bin/basenji_test.py index 1555d719..7750ea62 100755 --- a/bin/basenji_test.py +++ b/bin/basenji_test.py @@ -218,7 +218,7 @@ def main(): t0 = time.time() dr = seqnn.SeqNN() - dr.build(job) + dr.build_feed(job) print('Model building time %ds' % (time.time() - t0)) # adjust for fourier diff --git a/bin/basenji_test_genes.py b/bin/basenji_test_genes.py index b78d688a..c72eb0ef 100755 --- a/bin/basenji_test_genes.py +++ b/bin/basenji_test_genes.py @@ -181,7 +181,7 @@ def main(): # build model model = seqnn.SeqNN() - model.build(job) + model.build_feed(job) if options.batch_size is not None: model.hp.batch_size = options.batch_size diff --git a/bin/basenji_test_h5.py b/bin/basenji_test_h5.py index 7dd64484..56810cdb 100755 --- a/bin/basenji_test_h5.py +++ b/bin/basenji_test_h5.py @@ -214,7 +214,7 @@ def main(): t0 = time.time() model = seqnn.SeqNN() - model.build(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts) + model.build_feed(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts) print('Model building time %ds' % (time.time() - t0)) # adjust for fourier diff --git a/bin/basenji_test_reps.py b/bin/basenji_test_reps.py index ea7ce82e..3a180e05 100755 --- a/bin/basenji_test_reps.py +++ b/bin/basenji_test_reps.py @@ -136,7 +136,7 @@ def main(): job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) dr = basenji.seqnn.SeqNN() - dr.build(job) + dr.build_feed(job) # adjust for fourier job['fourier'] = 'train_out_imag' in data_open diff --git a/bin/basenji_train.py b/bin/basenji_train.py index f8154e23..6a540d49 100755 --- a/bin/basenji_train.py +++ b/bin/basenji_train.py @@ -73,7 +73,7 @@ def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_ba t0 = time.time() model = seqnn.SeqNN() - model.build(job) + model.build_feed(job) print('Model building time %f' % (time.time() - t0)) # adjust for fourier diff --git a/bin/basenji_train_h5.py b/bin/basenji_train_h5.py index 875ee3af..b6cbb87f 100755 --- a/bin/basenji_train_h5.py +++ b/bin/basenji_train_h5.py @@ -76,7 +76,7 @@ def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_ba t0 = time.time() model = seqnn.SeqNN() - model.build(job, augment_rc=FLAGS.augment_rc, augment_shifts=augment_shifts, + model.build_feed(job, augment_rc=FLAGS.augment_rc, augment_shifts=augment_shifts, ensemble_rc=FLAGS.ensemble_rc, ensemble_shifts=ensemble_shifts) print('Model building time %f' % (time.time() - t0)) From 44b099d0e609b977ba356643824fde071d3675be Mon Sep 17 00:00:00 2001 From: David Kelley Date: Fri, 20 Jul 2018 18:56:12 -0700 Subject: [PATCH 56/71] predict in-graph ensembling --- basenji/seqnn.py | 2 + basenji/seqnn_util.py | 160 ++++++++++++++++++++++++++++++++++++++---- bin/basenji_sad.py | 13 ++-- 3 files changed, 156 insertions(+), 19 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index d58ed141..873b8978 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -341,6 +341,8 @@ def build_loss(self, preds, targets, target_subset=None): # slice buffer tstart = self.hp.batch_buffer // self.hp.target_pool tend = (self.hp.seq_length - self.hp.batch_buffer) // self.hp.target_pool + self.target_length = tend - tstart + targets = tf.identity(targets[:, tstart:tend, :], name='targets_op') if target_subset is not None: diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index b3c9bea9..d4f3ab78 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -631,19 +631,12 @@ def _predict_ensemble(self, return preds_batch, preds_batch_var, preds_all - def predict(self, - sess, - batcher, - rc=False, - shifts=[0], - mc_n=0, - target_indexes=None, - return_var=False, - return_all=False, - down_sample=1, - penultimate=False, - test_batches=None, - dtype='float32'): + def predict_h5_manual(self, sess, batcher, + rc=False, shifts=[0], mc_n=0, + target_indexes=None, + return_var=False, return_all=False, + down_sample=1, penultimate=False, + test_batches=None, dtype='float32'): """ Compute predictions on a test set. In @@ -765,6 +758,146 @@ def predict(self, else: return preds + def predict_h5(self, sess, batcher, + return_var=False, return_all=False, + penultimate=False, test_batches=None): + """ Compute preidctions on an HDF5 test set. + + Args: + sess: TensorFlow session + return_var: Return variance estimates + return_all: Retyrn all predictions. + penultimate: Predict the penultimate layer. + test_batches: Number of test batches to use. + + Returns: + preds: S (sequences) x L (unbuffered length) x T (targets) array + """ + fd = self.set_mode('test') + + # initialize prediction data structures + preds = [] + if return_var: + preds_var = [] + if return_all: + preds_all = [] + + # get first batch + batch_num = 0 + Xb, _, _, Nb = batcher.next() + + while Xb is not None and (test_batches is None or + batch_num < test_batches): + # update feed dict + fd[self.inputs_ph] = Xb + + # make predictions + if return_var or return_all: + preds_batch, preds_ensemble_batch = sess.run([self.preds_eval, self.preds_ensemble], feed_dict=fd) + + # move ensemble to back + preds_ensemble_batch = np.moveaxis(preds_ensemble_batch, 0, -1) + + else: + preds_batch = sess.run(self.preds_eval, feed_dict=fd) + + # accumulate predictions and targets + preds.append(preds_batch[:Nb]) + if return_var: + preds_var_batch = np.var(preds_ensemble_batch, axis=-1) + preds_var.append(preds_var_batch[:Nb]) + if return_all: + preds_all.append(preds_ensemble_batch[:Nb]) + + # next batch + batch_num += 1 + Xb, _, _, Nb = batcher.next() + + # reset batcher + batcher.reset() + + # construct arrays + preds = np.concatenate(preds, axis=0) + if return_var: + preds_var = np.concatenate(preds_var, axis=0) + if return_all: + preds_all = np.concatenate(preds_all, axis=0) + + if return_var: + if return_all: + return preds, preds_var, preds_all + else: + return preds, preds_var + else: + return preds + + def predict_tfr(self, sess, + return_var=False, return_all=False, + penultimate=False, test_batches=None): + """ Compute preidctions on a TFRecord test set. + + Args: + sess: TensorFlow session + return_var: Return variance estimates + return_all: Retyrn all predictions. + penultimate: Predict the penultimate layer. + test_batches: Number of test batches to use. + + Returns: + preds: S (sequences) x L (unbuffered length) x T (targets) array + """ + fd = self.set_mode('test') + + # initialize prediction data structures + preds = [] + if return_var: + preds_var = [] + if return_all: + preds_all = [] + + # sequence index + data_available = True + batch_num = 0 + while data_available and (test_batches is None or batch_num < test_batches): + try: + # make predictions + if return_var or return_all: + preds_batch, preds_ensemble_batch = sess.run([self.preds_eval, self.preds_ensemble], feed_dict=fd) + + # move ensemble to back + preds_ensemble_batch = np.moveaxis(preds_ensemble_batch, 0, -1) + + else: + preds_batch = sess.run(self.preds_eval, feed_dict=fd) + + # accumulate predictions and targets + preds.append(preds_batch) + if return_var: + preds_var_batch = np.var(preds_ensemble_batch, axis=-1) + preds_var.append(preds_var_batch) + if return_all: + preds_all.append(preds_ensemble_batch) + + batch_num += 1 + + except tf.errors.OutOfRangeError: + data_available = False + + # construct arrays + preds = np.concatenate(preds, axis=0) + if return_var: + preds_var = np.concatenate(preds_var, axis=0) + if return_all: + preds_all = np.concatenate(preds_all, axis=0) + + if return_var: + if return_all: + return preds, preds_var, preds_all + else: + return preds, preds_var + else: + return preds + def predict_genes(self, sess, batcher, @@ -901,7 +1034,6 @@ def test_h5(self, sess, batcher, test_batches=None): Args: sess: TensorFlow session batcher: Batcher object to provide data - mc_n: Monte Carlo iterations per rc/shift. test_batches: Number of test batches Returns: diff --git a/bin/basenji_sad.py b/bin/basenji_sad.py index b7250f06..2cd7a088 100755 --- a/bin/basenji_sad.py +++ b/bin/basenji_sad.py @@ -173,7 +173,9 @@ def main(): # build model t0 = time.time() model = basenji.seqnn.SeqNN() - model.build_feed(job, target_subset=target_subset) + # model.build_feed(job, target_subset=target_subset) + model.build_feed(job, target_subset=target_subset, + ensemble_rc=options.rc, ensemble_shifts=options.shifts) print('Model building time %f' % (time.time() - t0), flush=True) if options.penultimate: @@ -246,7 +248,6 @@ def main(): # initialize saver saver = tf.train.Saver() - with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) @@ -263,9 +264,11 @@ def main(): batcher = basenji.batcher.Batcher(batch_1hot, batch_size=model.hp.batch_size) # predict - batch_preds = model.predict(sess, batcher, - rc=options.rc, shifts=options.shifts, - penultimate=options.penultimate) + # batch_preds = model.predict(sess, batcher, + # rc=options.rc, shifts=options.shifts, + # penultimate=options.penultimate) + batch_preds = model.predict_h5(sess, batcher, + penultimate=options.penultimate) # normalize batch_preds /= target_norms From f6c437932cf16f2103e62e8a762f950de9442400 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 21 Jul 2018 09:23:07 -0700 Subject: [PATCH 57/71] penultimate draft --- basenji/seqnn.py | 154 ++++++++++++++++++++++-------------------- basenji/seqnn_util.py | 30 ++++---- 2 files changed, 91 insertions(+), 93 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 873b8978..5116971b 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -34,7 +34,8 @@ def __init__(self): self.hparams_set = False def build_feed(self, job, augment_rc=False, augment_shifts=[0], - ensemble_rc=False, ensemble_shifts=[0], target_subset=None): + ensemble_rc=False, ensemble_shifts=[0], + penultimate=False, target_subset=None): """Build training ops that depend on placeholders.""" self.hp = params.make_hparams(job) @@ -46,12 +47,13 @@ def build_feed(self, job, augment_rc=False, augment_shifts=[0], augment_shifts=augment_shifts, ensemble_rc=ensemble_rc, ensemble_shifts=ensemble_shifts, + penultimate=penultimate, target_subset=target_subset) def build_from_data_ops(self, job, data_ops, augment_rc=False, augment_shifts=[0], ensemble_rc=False, ensemble_shifts=[0], - target_subset=None): + penultimate=False, target_subset=None): """Build training ops from input data ops.""" if not self.hparams_set: self.hp = params.make_hparams(job) @@ -69,7 +71,7 @@ def build_from_data_ops(self, job, data_ops, # compute train representation self.preds_train = self.build_predict(data_ops_train['sequence'], - None, target_subset) + None, penultimate, target_subset) # training losses loss_returns = self.build_loss(self.preds_train, data_ops_train['label'], target_subset) @@ -90,7 +92,7 @@ def build_from_data_ops(self, job, data_ops, # compute eval representation map_elems_eval = (data_seq_eval, data_rev_eval) - build_rep = lambda do: self.build_predict(do[0], do[1], target_subset) + build_rep = lambda do: self.build_predict(do[0], do[1], penultimate, target_subset) self.preds_ensemble = tf.map_fn(build_rep, map_elems_eval, dtype=self.preds_train.dtype) # back_prop=False self.preds_eval = tf.reduce_mean(self.preds_ensemble, axis=0) @@ -148,7 +150,7 @@ def _make_conv_block_args(self, layer_index, layer_reprs): 'name': 'conv-%d' % layer_index } - def build_predict(self, inputs, reverse_preds=None, target_subset=None): + def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_subset=None): """Construct per-location real-valued predictions.""" assert inputs is not None print('Targets pooled by %d to length %d' % @@ -158,17 +160,17 @@ def build_predict(self, inputs, reverse_preds=None, target_subset=None): # convolution layers ################################################### filter_weights = [] - layer_reprs = [inputs] + self.layer_reprs = [inputs] seqs_repr = inputs for layer_index in range(self.hp.cnn_layers): with tf.variable_scope('cnn%d' % layer_index, reuse=tf.AUTO_REUSE): # convolution block - args_for_block = self._make_conv_block_args(layer_index, layer_reprs) + args_for_block = self._make_conv_block_args(layer_index, self.layer_reprs) seqs_repr = layers.conv_block(seqs_repr=seqs_repr, **args_for_block) # save representation - layer_reprs.append(seqs_repr) + self.layer_reprs.append(seqs_repr) # final nonlinearity seqs_repr = tf.nn.relu(seqs_repr) @@ -190,44 +192,44 @@ def build_predict(self, inputs, reverse_preds=None, target_subset=None): seqs_repr = seqs_repr[:, batch_buffer_pool: seq_length - batch_buffer_pool, :] - # save penultimate representation - # self.penultimate_op = seqs_repr - ################################################### # final layer ################################################### - with tf.variable_scope('final', reuse=tf.AUTO_REUSE): - final_filters = self.hp.num_targets * self.hp.target_classes - final_repr = tf.layers.dense( - inputs=seqs_repr, - units=final_filters, - activation=None, - kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_in'), - kernel_regularizer=tf.contrib.layers.l1_regularizer(self.hp.final_l1_scale)) - print('Convolution w/ %d %dx1 filters to final targets' % - (final_filters, seqs_repr.shape[2])) - - if target_subset is not None: - # get convolution parameters - filters_full = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'final/dense/kernel')[0] - bias_full = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'final/dense/bias')[0] - - # subset to specific targets - filters_subset = tf.gather(filters_full, target_subset, axis=1) - bias_subset = tf.gather(bias_full, target_subset, axis=0) - - # substitute a new limited convolution - final_repr = tf.tensordot(seqs_repr, filters_subset, 1) - final_repr = tf.nn.bias_add(final_repr, bias_subset) - - # update # targets - self.hp.num_targets = len(target_subset) - - # expand length back out - if self.hp.target_classes > 1: - final_repr = tf.reshape(final_repr, - (self.hp.batch_size, -1, self.hp.num_targets, - self.hp.target_classes)) + if penultimate: + final_repr = seqs_repr + else: + with tf.variable_scope('final', reuse=tf.AUTO_REUSE): + final_filters = self.hp.num_targets * self.hp.target_classes + final_repr = tf.layers.dense( + inputs=seqs_repr, + units=final_filters, + activation=None, + kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_in'), + kernel_regularizer=tf.contrib.layers.l1_regularizer(self.hp.final_l1_scale)) + print('Convolution w/ %d %dx1 filters to final targets' % + (final_filters, seqs_repr.shape[2])) + + if target_subset is not None: + # get convolution parameters + filters_full = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'final/dense/kernel')[0] + bias_full = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'final/dense/bias')[0] + + # subset to specific targets + filters_subset = tf.gather(filters_full, target_subset, axis=1) + bias_subset = tf.gather(bias_full, target_subset, axis=0) + + # substitute a new limited convolution + final_repr = tf.tensordot(seqs_repr, filters_subset, 1) + final_repr = tf.nn.bias_add(final_repr, bias_subset) + + # update # targets + self.hp.num_targets = len(target_subset) + + # expand length back out + if self.hp.target_classes > 1: + final_repr = tf.reshape(final_repr, + (self.hp.batch_size, -1, self.hp.num_targets, + self.hp.target_classes)) # transform for reverse complement if reverse_preds is not None: @@ -238,47 +240,49 @@ def build_predict(self, inputs, reverse_preds=None, target_subset=None): ################################################### # link function ################################################### + if penultimate: + predictions = final_repr + else: + # work-around for specifying my own predictions + # self.preds_adhoc = tf.placeholder( + # tf.float32, shape=final_repr.shape, name='preds-adhoc') - # work-around for specifying my own predictions - # self.preds_adhoc = tf.placeholder( - # tf.float32, shape=final_repr.shape, name='preds-adhoc') - - # float 32 exponential clip max - exp_max = 50 + # float 32 exponential clip max + exp_max = 50 - # choose link - if self.hp.link in ['identity', 'linear']: - predictions = tf.identity(final_repr, name='preds') + # choose link + if self.hp.link in ['identity', 'linear']: + predictions = tf.identity(final_repr, name='preds') - elif self.hp.link == 'relu': - predictions = tf.relu(final_repr, name='preds') + elif self.hp.link == 'relu': + predictions = tf.relu(final_repr, name='preds') - elif self.hp.link == 'exp': - final_repr_clip = tf.clip_by_value(final_repr, -exp_max, exp_max) - predictions = tf.exp(final_repr_clip, name='preds') + elif self.hp.link == 'exp': + final_repr_clip = tf.clip_by_value(final_repr, -exp_max, exp_max) + predictions = tf.exp(final_repr_clip, name='preds') - elif self.hp.link == 'exp_linear': - predictions = tf.where( - final_repr > 0, - final_repr + 1, - tf.exp(tf.clip_by_value(final_repr, -exp_max, exp_max)), - name='preds') + elif self.hp.link == 'exp_linear': + predictions = tf.where( + final_repr > 0, + final_repr + 1, + tf.exp(tf.clip_by_value(final_repr, -exp_max, exp_max)), + name='preds') - elif self.hp.link == 'softplus': - final_repr_clip = tf.clip_by_value(final_repr, -exp_max, 10000) - predictions = tf.nn.softplus(final_repr_clip, name='preds') + elif self.hp.link == 'softplus': + final_repr_clip = tf.clip_by_value(final_repr, -exp_max, 10000) + predictions = tf.nn.softplus(final_repr_clip, name='preds') - else: - print('Unknown link function %s' % self.hp.link, file=sys.stderr) - exit(1) + else: + print('Unknown link function %s' % self.hp.link, file=sys.stderr) + exit(1) - # clip - if self.hp.target_clip is not None: - predictions = tf.clip_by_value(predictions, 0, self.hp.target_clip) + # clip + if self.hp.target_clip is not None: + predictions = tf.clip_by_value(predictions, 0, self.hp.target_clip) - # sqrt - if self.hp.target_sqrt: - predictions = tf.sqrt(predictions) + # sqrt + if self.hp.target_sqrt: + predictions = tf.sqrt(predictions) return predictions diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index d4f3ab78..f9d9dae1 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -596,10 +596,7 @@ def _predict_ensemble(self, # print('ei=%d, mi=%d, fwdrc=%d, shifts=%d' % (ei, mi, ensemble_fwdrc[ei], ensemble_shifts[ei]), flush=True) # predict - if penultimate: - preds_ei = sess.run(self.penultimate_op, feed_dict=fd) - else: - preds_ei = sess.run(self.preds_eval, feed_dict=fd) + preds_ei = sess.run(self.preds_eval, feed_dict=fd) # reverse if ensemble_fwdrc[ei] is False: @@ -641,7 +638,7 @@ def predict_h5_manual(self, sess, batcher, In sess: TensorFlow session - batcher: Batcher class with transcript-covering sequences. + batcher: Batcher class with sequences. rc: Average predictions from the forward and reverse complement sequences. shifts: Average predictions from sequence shifts left/right. @@ -758,17 +755,16 @@ def predict_h5_manual(self, sess, batcher, else: return preds - def predict_h5(self, sess, batcher, - return_var=False, return_all=False, - penultimate=False, test_batches=None): + def predict_h5(self, sess, batcher, test_batches=None, + return_var=False, return_all=False): """ Compute preidctions on an HDF5 test set. Args: sess: TensorFlow session + batcher: Batcher class with sequences. + test_batches: Number of test batches to use. return_var: Return variance estimates return_all: Retyrn all predictions. - penultimate: Predict the penultimate layer. - test_batches: Number of test batches to use. Returns: preds: S (sequences) x L (unbuffered length) x T (targets) array @@ -831,17 +827,15 @@ def predict_h5(self, sess, batcher, else: return preds - def predict_tfr(self, sess, - return_var=False, return_all=False, - penultimate=False, test_batches=None): + def predict_tfr(self, sess, test_batches=None + return_var=False, return_all=False): """ Compute preidctions on a TFRecord test set. Args: sess: TensorFlow session + test_batches: Number of test batches to use. return_var: Return variance estimates return_all: Retyrn all predictions. - penultimate: Predict the penultimate layer. - test_batches: Number of test batches to use. Returns: preds: S (sequences) x L (unbuffered length) x T (targets) array @@ -952,9 +946,9 @@ def predict_genes(self, while not batcher.empty(): # predict gene sequences - gseq_preds = self.predict(sess, batcher, rc=rc, shifts=shifts, mc_n=mc_n, - target_indexes=target_indexes, penultimate=penultimate, - test_batches=test_batches_per) + gseq_preds = self.predict_h5_manual(sess, batcher, rc=rc, shifts=shifts, mc_n=mc_n, + target_indexes=target_indexes, penultimate=penultimate, + test_batches=test_batches_per) # slice TSSs for bsi in range(gseq_preds.shape[0]): for tss in gene_seqs[si].tss_list: From ad01d86899df440271a75176b84319dd1020a3b1 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sun, 22 Jul 2018 09:13:40 -0700 Subject: [PATCH 58/71] missing comma --- basenji/seqnn_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index f9d9dae1..824a7067 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -827,7 +827,7 @@ def predict_h5(self, sess, batcher, test_batches=None, else: return preds - def predict_tfr(self, sess, test_batches=None + def predict_tfr(self, sess, test_batches=None, return_var=False, return_all=False): """ Compute preidctions on a TFRecord test set. From 0ec3447973ba4d36dbd4e0bbba9756ae164c99c4 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sun, 22 Jul 2018 09:15:00 -0700 Subject: [PATCH 59/71] penultimate loss fix --- basenji/seqnn.py | 19 ++++++++++--------- bin/basenji_sad.py | 7 +++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 5116971b..5dfd39bf 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -72,14 +72,16 @@ def build_from_data_ops(self, job, data_ops, # compute train representation self.preds_train = self.build_predict(data_ops_train['sequence'], None, penultimate, target_subset) + self.target_length = self.preds_train.shape[1].value # training losses - loss_returns = self.build_loss(self.preds_train, data_ops_train['label'], target_subset) - self.loss_train, self.loss_train_targets, self.targets_train = loss_returns + if not penultimate: + loss_returns = self.build_loss(self.preds_train, data_ops_train['label'], target_subset) + self.loss_train, self.loss_train_targets, self.targets_train = loss_returns - # optimizer - self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - self.build_optimizer(self.loss_train) + # optimizer + self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + self.build_optimizer(self.loss_train) ################################################## # eval @@ -97,8 +99,9 @@ def build_from_data_ops(self, job, data_ops, self.preds_eval = tf.reduce_mean(self.preds_ensemble, axis=0) # eval loss - loss_returns = self.build_loss(self.preds_eval, data_ops['label'], target_subset) - self.loss_eval, self.loss_eval_targets, self.targets_eval = loss_returns + if not penultimate: + loss_returns = self.build_loss(self.preds_eval, data_ops['label'], target_subset) + self.loss_eval, self.loss_eval_targets, self.targets_eval = loss_returns # helper variables self.preds_length = self.preds_train.shape[1] @@ -345,8 +348,6 @@ def build_loss(self, preds, targets, target_subset=None): # slice buffer tstart = self.hp.batch_buffer // self.hp.target_pool tend = (self.hp.seq_length - self.hp.batch_buffer) // self.hp.target_pool - self.target_length = tend - tstart - targets = tf.identity(targets[:, tstart:tend, :], name='targets_op') if target_subset is not None: diff --git a/bin/basenji_sad.py b/bin/basenji_sad.py index 2cd7a088..00537e1d 100755 --- a/bin/basenji_sad.py +++ b/bin/basenji_sad.py @@ -174,8 +174,8 @@ def main(): t0 = time.time() model = basenji.seqnn.SeqNN() # model.build_feed(job, target_subset=target_subset) - model.build_feed(job, target_subset=target_subset, - ensemble_rc=options.rc, ensemble_shifts=options.shifts) + model.build_feed(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts, + target_subset=target_subset, penultimate=options.penultimate) print('Model building time %f' % (time.time() - t0), flush=True) if options.penultimate: @@ -267,8 +267,7 @@ def main(): # batch_preds = model.predict(sess, batcher, # rc=options.rc, shifts=options.shifts, # penultimate=options.penultimate) - batch_preds = model.predict_h5(sess, batcher, - penultimate=options.penultimate) + batch_preds = model.predict_h5(sess, batcher) # normalize batch_preds /= target_norms From 3f35246b6e9c4470e429feb51e79b43bcdd96199 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Mon, 23 Jul 2018 15:48:14 -0700 Subject: [PATCH 60/71] hidden and map --- basenji/seqnn.py | 31 ++++++++++-------- basenji/seqnn_util.py | 31 +++++++++++++----- bin/basenji_hidden.py | 76 +++++++++++++++++++++++-------------------- bin/basenji_map.py | 2 +- 4 files changed, 83 insertions(+), 57 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 5dfd39bf..49c2d179 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -71,7 +71,8 @@ def build_from_data_ops(self, job, data_ops, # compute train representation self.preds_train = self.build_predict(data_ops_train['sequence'], - None, penultimate, target_subset) + None, penultimate, target_subset, + save_reprs=True) self.target_length = self.preds_train.shape[1].value # training losses @@ -103,10 +104,13 @@ def build_from_data_ops(self, job, data_ops, loss_returns = self.build_loss(self.preds_eval, data_ops['label'], target_subset) self.loss_eval, self.loss_eval_targets, self.targets_eval = loss_returns + # update # targets + if target_subset is not None: + self.hp.num_targets = len(target_subset) + # helper variables self.preds_length = self.preds_train.shape[1] - def make_placeholders(self): """Allocates placeholders to be used in place of input data ops.""" # batches @@ -153,7 +157,7 @@ def _make_conv_block_args(self, layer_index, layer_reprs): 'name': 'conv-%d' % layer_index } - def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_subset=None): + def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_subset=None, save_reprs=False): """Construct per-location real-valued predictions.""" assert inputs is not None print('Targets pooled by %d to length %d' % @@ -163,17 +167,20 @@ def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_su # convolution layers ################################################### filter_weights = [] - self.layer_reprs = [inputs] + layer_reprs = [inputs] seqs_repr = inputs for layer_index in range(self.hp.cnn_layers): with tf.variable_scope('cnn%d' % layer_index, reuse=tf.AUTO_REUSE): # convolution block - args_for_block = self._make_conv_block_args(layer_index, self.layer_reprs) + args_for_block = self._make_conv_block_args(layer_index, layer_reprs) seqs_repr = layers.conv_block(seqs_repr=seqs_repr, **args_for_block) # save representation - self.layer_reprs.append(seqs_repr) + layer_reprs.append(seqs_repr) + + if save_reprs: + self.layer_reprs = layer_reprs # final nonlinearity seqs_repr = tf.nn.relu(seqs_repr) @@ -225,9 +232,6 @@ def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_su final_repr = tf.tensordot(seqs_repr, filters_subset, 1) final_repr = tf.nn.bias_add(final_repr, bias_subset) - # update # targets - self.hp.num_targets = len(target_subset) - # expand length back out if self.hp.target_classes > 1: final_repr = tf.reshape(final_repr, @@ -381,12 +385,13 @@ def build_loss(self, preds, targets, target_subset=None): # reduce lossses by batch and position loss_op = tf.reduce_mean(loss_op, axis=[0, 1], name='target_loss') loss_op = tf.check_numerics(loss_op, 'Invalid loss', name='loss_check') - - tf.summary.histogram('target_loss', loss_op) - for ti in np.linspace(0, self.hp.num_targets - 1, 10).astype('int'): - tf.summary.scalar('loss_t%d' % ti, loss_op[ti]) target_losses = loss_op + if target_subset is None: + tf.summary.histogram('target_loss', loss_op) + for ti in np.linspace(0, self.hp.num_targets - 1, 10).astype('int'): + tf.summary.scalar('loss_t%d' % ti, loss_op[ti]) + # fully reduce loss_op = tf.reduce_mean(loss_op, name='loss') diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 824a7067..ac5bbb09 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -27,7 +27,7 @@ def build_grads(self, layers=[0]): self.grad_ops = [] for ti in range(self.hp.num_targets): - grad_ti_op = tf.gradients(self.preds_eval[:,:,ti], [self.layer_reprs[li] for li in self.grad_layers]) + grad_ti_op = tf.gradients(self.preds_train[:,:,ti], [self.layer_reprs[li] for li in self.grad_layers]) self.grad_ops.append(grad_ti_op) @@ -224,7 +224,9 @@ def gradients(self, return layer_grads, layer_reprs, preds - def _gradients_ensemble(self, sess, fd, Xb, ensemble_fwdrc, ensemble_shifts, mc_n, return_var=False, return_all=False): + def _gradients_ensemble(self, sess, fd, Xb, + ensemble_fwdrc, ensemble_shifts, mc_n, + return_var=False, return_all=False): """ Compute gradients over an ensemble of input augmentations. In @@ -312,7 +314,7 @@ def _gradients_ensemble(self, sess, fd, Xb, ensemble_fwdrc, ensemble_shifts, mc_ # prediction # predict - preds_ei, layer_reprs_ei = sess.run([self.preds_eval, self.layer_reprs], feed_dict=fd) + preds_ei, layer_reprs_ei = sess.run([self.preds_train, self.layer_reprs], feed_dict=fd) # reverse if ensemble_fwdrc[ei] is False: @@ -451,7 +453,7 @@ def gradients_genes(self, sess, batcher, gene_seqs): fd[self.inputs_ph] = Xb # predict - reprs_batch, _ = sess.run([self.layer_reprs, self.preds_eval], feed_dict=fd) + reprs_batch, _ = sess.run([self.layer_reprs, self.preds_train], feed_dict=fd) # save representations for lii in range(len(self.grad_layers)): @@ -487,8 +489,18 @@ def gradients_genes(self, sess, batcher, gene_seqs): return layer_grads, layer_reprs - def hidden(self, sess, batcher, layers=None): - """ Compute hidden representations for a test set. """ + def hidden(self, sess, batcher, layers=None, test_batches=None): + """ Compute hidden representations for a test set. + + In + sess: TensorFlow session + batcher: Batcher class with sequences. + layers: Layer indexes to return representations. + test_batches: Number of test batches to use. + + Out + preds: S (sequences) x L (unbuffered length) x T (targets) array + """ if layers is None: layers = list(range(self.hp.cnn_layers)) @@ -505,13 +517,15 @@ def hidden(self, sess, batcher, layers=None): # get first batch Xb, _, _, Nb = batcher.next() - while Xb is not None: + batch_num = 0 + while Xb is not None and (test_batches is None or + batch_num < test_batches): # update feed dict fd[self.inputs_ph] = Xb # compute predictions layer_reprs_batch, preds_batch = sess.run( - [self.layer_reprs, self.preds_eval], feed_dict=fd) + [self.layer_reprs, self.preds_train], feed_dict=fd) # accumulate representationsmakes the number of members for self smaller and also for li in layers: @@ -527,6 +541,7 @@ def hidden(self, sess, batcher, layers=None): # next batch Xb, _, _, Nb = batcher.next() + batch_num += 1 # reset batcher batcher.reset() diff --git a/bin/basenji_hidden.py b/bin/basenji_hidden.py index 2fc9bd4e..c7ba14e2 100755 --- a/bin/basenji_hidden.py +++ b/bin/basenji_hidden.py @@ -30,7 +30,10 @@ import statsmodels import tensorflow as tf -import basenji +from basenji import batcher +from basenji import params +from basenji import plots +from basenji import seqnn ################################################################################ # basenji_hidden.py @@ -45,24 +48,15 @@ def main(): usage = 'usage: %prog [options] ' parser = OptionParser(usage) - parser.add_option( - '-l', - dest='layers', - default=None, - help='Comma-separated list of layers to plot') - parser.add_option( - '-n', - dest='num_seqs', - default=None, - type='int', + parser.add_option('-l', dest='layers', + default=None, help='Comma-separated list of layers to plot') + parser.add_option('-n', dest='num_seqs', + default=None, type='int', help='Number of sequences to process') - parser.add_option( - '-o', - dest='out_dir', - default='hidden', - help='Output directory [Default: %default]') - parser.add_option( - '-t', dest='target_indexes', default=None, help='Target indexes to plot') + parser.add_option('-o', dest='out_dir', + default='hidden', help='Output directory [Default: %default]') + parser.add_option('-t', dest='target_indexes', + default=None, help='Paint 2D plots with these target index values.') (options, args) = parser.parse_args() if len(args) != 3: @@ -92,17 +86,16 @@ def main(): ####################################################### # model parameters and placeholders ####################################################### - job = basenji.dna_io.read_job_params(params_file) + job = params.read_job_params(params_file) job['seq_length'] = test_seqs.shape[1] job['seq_depth'] = test_seqs.shape[2] job['num_targets'] = test_targets.shape[2] job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) - job['save_reprs'] = True t0 = time.time() - model = basenji.seqnn.SeqNN() - model.build(job) + model = seqnn.SeqNN() + model.build_feed(job) if options.target_indexes is None: options.target_indexes = range(job['num_targets']) @@ -115,11 +108,11 @@ def main(): # test ####################################################### # initialize batcher - batcher_test = basenji.batcher.Batcher( + batcher_test = batcher.Batcher( test_seqs, test_targets, - batch_size=model.batch_size, - pool_width=model.target_pool) + batch_size=model.hp.batch_size, + pool_width=model.hp.target_pool) # initialize saver saver = tf.train.Saver() @@ -144,12 +137,14 @@ def main(): # sample one nt per sequence ds_indexes = np.arange(0, layer_repr.shape[1], 256) nt_reprs = layer_repr[:, ds_indexes, :].reshape((-1, layer_repr.shape[2])) + print('nt_reprs', nt_reprs.shape) ######################################################## # plot raw sns.set(style='ticks', font_scale=1.2) plt.figure() - g = sns.clustermap(nt_reprs, xticklabels=False, yticklabels=False) + g = sns.clustermap(nt_reprs, cmap='RdBu_r', + xticklabels=False, yticklabels=False) g.ax_heatmap.set_xlabel('Representation') g.ax_heatmap.set_ylabel('Sequences') plt.savefig('%s/l%d_reprs.pdf' % (options.out_dir, li)) @@ -182,7 +177,18 @@ def main(): nt_2d = model2.fit_transform(nt_reprs) for ti in options.target_indexes: - nt_targets = np.log2(test_targets[:, ds_indexes, ti].flatten() + 1) + # slice for target + test_targets_ti = test_targets[:,:,ti] + + # repeat to match layer_repr + target_repeat = layer_repr.shape[1] // test_targets.shape[1] + test_targets_ti = np.repeat(test_targets_ti, target_repeat, axis=1) + + # downsample indexes + nt_targets = test_targets_ti[:,ds_indexes].flatten() + + # log transform + nt_targets = np.log1p(nt_targets) plt.figure() plt.scatter( @@ -193,17 +199,17 @@ def main(): plt.savefig('%s/l%d_nt2d_t%d.pdf' % (options.out_dir, li, ti)) plt.close() + ######################################################## # plot neuron-neuron correlations - # mean-normalize representation - nt_reprs_norm = nt_reprs - nt_reprs.mean(axis=0) - - # compute covariance matrix - hidden_cov = np.dot(nt_reprs_norm.T, nt_reprs_norm) + # compute correlation matrix + hidden_cov = np.corrcoef(nt_reprs.T) + print('hidden_cov', hidden_cov.shape) plt.figure() - g = sns.clustermap(hidden_cov, xticklabels=False, yticklabels=False) + g = sns.clustermap(hidden_cov, cmap='RdBu_r', + xticklabels=False, yticklabels=False) plt.savefig('%s/l%d_cov.pdf' % (options.out_dir, li)) plt.close() @@ -258,8 +264,8 @@ def regplot(vals1, vals2, out_pdf, alpha=0.5, x_label=None, y_label=None): 'alpha': alpha}, line_kws={'color': gold}) - xmin, xmax = basenji.plots.scatter_lims(vals1) - ymin, ymax = basenji.plots.scatter_lims(vals2) + xmin, xmax = plots.scatter_lims(vals1) + ymin, ymax = plots.scatter_lims(vals2) ax.set_xlim(xmin, xmax) if x_label is not None: diff --git a/bin/basenji_map.py b/bin/basenji_map.py index e5034947..30fd904d 100755 --- a/bin/basenji_map.py +++ b/bin/basenji_map.py @@ -128,7 +128,7 @@ def main(): # build model model = seqnn.SeqNN() - model.build(job, target_subset=target_subset) + model.build_feed(job, target_subset=target_subset) # determine latest pre-dilated layer cnn_dilation = np.array([cp.dilation for cp in model.hp.cnn_params]) From 2f84d6786bf47039ff2eb1c9776a7548b4a9cc01 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 25 Jul 2018 17:01:02 -0700 Subject: [PATCH 61/71] align predict_h5 with predict_h5_manual --- basenji/seqnn_util.py | 53 ++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index ac5bbb09..192d5d82 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -736,7 +736,6 @@ def predict_h5_manual(self, sess, batcher, # while we want more batches while test_batches is None or batch_num < test_batches: - # get batch Xb, _, _, Nb = batcher.next() @@ -793,39 +792,41 @@ def predict_h5(self, sess, batcher, test_batches=None, if return_all: preds_all = [] - # get first batch + # count batches batch_num = 0 - Xb, _, _, Nb = batcher.next() - while Xb is not None and (test_batches is None or - batch_num < test_batches): - # update feed dict - fd[self.inputs_ph] = Xb + # while we want more batches + while test_batches is None or batch_num < test_batches: + # get batch + Xb, _, _, Nb = batcher.next() - # make predictions - if return_var or return_all: - preds_batch, preds_ensemble_batch = sess.run([self.preds_eval, self.preds_ensemble], feed_dict=fd) + # verify fidelity + if Xb is None: + break + else: + # update feed dict + fd[self.inputs_ph] = Xb - # move ensemble to back - preds_ensemble_batch = np.moveaxis(preds_ensemble_batch, 0, -1) + # make predictions + if return_var or return_all: + preds_batch, preds_ensemble_batch = sess.run([self.preds_eval, self.preds_ensemble], feed_dict=fd) - else: - preds_batch = sess.run(self.preds_eval, feed_dict=fd) + # move ensemble to back + preds_ensemble_batch = np.moveaxis(preds_ensemble_batch, 0, -1) - # accumulate predictions and targets - preds.append(preds_batch[:Nb]) - if return_var: - preds_var_batch = np.var(preds_ensemble_batch, axis=-1) - preds_var.append(preds_var_batch[:Nb]) - if return_all: - preds_all.append(preds_ensemble_batch[:Nb]) + else: + preds_batch = sess.run(self.preds_eval, feed_dict=fd) - # next batch - batch_num += 1 - Xb, _, _, Nb = batcher.next() + # accumulate predictions and targets + preds.append(preds_batch[:Nb]) + if return_var: + preds_var_batch = np.var(preds_ensemble_batch, axis=-1) + preds_var.append(preds_var_batch[:Nb]) + if return_all: + preds_all.append(preds_ensemble_batch[:Nb]) - # reset batcher - batcher.reset() + # next batch + batch_num += 1 # construct arrays preds = np.concatenate(preds, axis=0) From 7ff7ed2cf2d3d9cab8b944d4194f6a3b180deaab Mon Sep 17 00:00:00 2001 From: David Kelley Date: Mon, 6 Aug 2018 10:05:39 -0700 Subject: [PATCH 62/71] sad tf.data --- basenji/augmentation.py | 12 +- basenji/ops.py | 27 ++- basenji/seqnn.py | 32 ++- bin/basenji_sadq.py | 499 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 560 insertions(+), 10 deletions(-) create mode 100755 bin/basenji_sadq.py diff --git a/basenji/augmentation.py b/basenji/augmentation.py index 7b7fb437..34dc276b 100644 --- a/basenji/augmentation.py +++ b/basenji/augmentation.py @@ -80,7 +80,11 @@ def augment_deterministic(data_ops, augment_rc=False, augment_shift=0): data_ops: augmented data """ - data_ops_aug = {'label': data_ops['label'], 'na': data_ops['na']} + data_ops_aug = {} + if 'label' in data_ops: + data_ops_aug['label'] = data_ops['label'] + if 'na' in data_ops: + data_ops_aug['na'] = data_ops['na'] if augment_shift == 0: data_ops_aug['sequence'] = data_ops['sequence'] @@ -104,10 +108,8 @@ def augment_deterministic_rc(data_ops): Returns data_ops_aug: augmented data ops """ - seq, label, na = [data_ops[k] for k in ['sequence', 'label', 'na']] - seq, label, na = ops.reverse_complement_transform(seq, label, na) - reverse_preds = tf.ones((), dtype=tf.bool) - data_ops_aug = {'sequence': seq, 'label': label, 'na': na, 'reverse_preds':reverse_preds} + data_ops_aug = ops.reverse_complement_transform(data_ops) + data_ops_aug['reverse_preds'] = tf.ones((), dtype=tf.bool) return data_ops_aug diff --git a/basenji/ops.py b/basenji/ops.py index 8603ac62..9cb808d7 100755 --- a/basenji/ops.py +++ b/basenji/ops.py @@ -41,15 +41,34 @@ def adjust_max(start, stop, start_value, stop_value, name=None): else: return None -def reverse_complement_transform(seq, label, na): +def reverse_complement_transform(data_ops): """Reverse complement of batched onehot seq and corresponding label and na.""" + + # initialize reverse complemented data_ops + data_ops_rc = {} + + # extract sequence from dict + seq = data_ops['sequence'] + + # check rank rank = seq.shape.ndims if rank != 3: raise ValueError("input seq must be rank 3.") - complement = tf.gather(seq, [3, 2, 1, 0], axis=-1) - return (tf.reverse(complement, axis=[1]), tf.reverse(label, axis=[1]), - tf.reverse(na, axis=[1])) + # reverse complement sequence + seq_rc = tf.gather(seq, [3, 2, 1, 0], axis=-1) + seq_rc = tf.reverse(seq_rc, axis=[1]) + data_ops_rc['sequence'] = seq_rc + + # reverse labels + if 'label' in data_ops: + data_ops_rc['label'] = tf.reverse(data_ops['label'], axis=[1]) + + # reverse NA + if 'na' in data_ops: + data_ops_rc['na'] = tf.reverse(na_rc, axis=[1]) + + return data_ops_rc def reverse_complement(input_seq, lengths=None): diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 49c2d179..4bdc2eea 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -96,7 +96,7 @@ def build_from_data_ops(self, job, data_ops, # compute eval representation map_elems_eval = (data_seq_eval, data_rev_eval) build_rep = lambda do: self.build_predict(do[0], do[1], penultimate, target_subset) - self.preds_ensemble = tf.map_fn(build_rep, map_elems_eval, dtype=self.preds_train.dtype) # back_prop=False + self.preds_ensemble = tf.map_fn(build_rep, map_elems_eval, dtype=tf.float32) # back_prop=False self.preds_eval = tf.reduce_mean(self.preds_ensemble, axis=0) # eval loss @@ -111,6 +111,36 @@ def build_from_data_ops(self, job, data_ops, # helper variables self.preds_length = self.preds_train.shape[1] + def build_sad(self, job, data_ops, + ensemble_rc=False, ensemble_shifts=[0], + penultimate=False, target_subset=None): + """Build SAD predict ops.""" + if not self.hparams_set: + self.hp = params.make_hparams(job) + self.hparams_set = True + + # training conditional + self.is_training = tf.placeholder(tf.bool, name='is_training') + + # eval data ops w/ deterministic augmentation + data_ops_eval = augmentation.augment_deterministic_set( + data_ops, ensemble_rc, ensemble_shifts) + data_seq_eval = tf.stack([do['sequence'] for do in data_ops_eval]) + data_rev_eval = tf.stack([do['reverse_preds'] for do in data_ops_eval]) + + # compute eval representation + map_elems_eval = (data_seq_eval, data_rev_eval) + build_rep = lambda do: self.build_predict(do[0], do[1], penultimate, target_subset) + self.preds_ensemble = tf.map_fn(build_rep, map_elems_eval, dtype=tf.float32) # back_prop=False + self.preds_eval = tf.reduce_mean(self.preds_ensemble, axis=0) + + # update # targets + if target_subset is not None: + self.hp.num_targets = len(target_subset) + + # helper variables + self.preds_length = self.preds_eval.shape[1] + def make_placeholders(self): """Allocates placeholders to be used in place of input data ops.""" # batches diff --git a/bin/basenji_sadq.py b/bin/basenji_sadq.py new file mode 100755 index 00000000..f284d7da --- /dev/null +++ b/bin/basenji_sadq.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function + +from optparse import OptionParser +import pickle +import os +import sys +import threading +import time + +import h5py +import numpy as np +import pandas as pd +import pysam +import tensorflow as tf +import zarr + +import basenji.dna_io +import basenji.vcf as bvcf + +from basenji_test import bigwig_open + +''' +basenji_sadq.py + +Compute SNP Activity Difference (SAD) scores for SNPs in a VCF file. +''' + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('-b',dest='batch_size', + default=256, type='int', + help='Batch size [Default: %default]') + parser.add_option('-c', dest='csv', + default=False, action='store_true', + help='Print table as CSV [Default: %default]') + parser.add_option('-f', dest='genome_fasta', + default='%s/assembly/hg19.fa' % os.environ['HG19'], + help='Genome FASTA for sequences [Default: %default]') + parser.add_option('-g', dest='genome_file', + default='%s/assembly/human.hg19.genome' % os.environ['HG19'], + help='Chromosome lengths file [Default: %default]') + parser.add_option('--h5', dest='out_h5', + default=False, action='store_true', + help='Output stats to sad.h5 [Default: %default]') + parser.add_option('-l', dest='seq_len', + default=131072, type='int', + help='Sequence length provided to the model [Default: %default]') + parser.add_option('--local',dest='local', + default=1024, type='int', + help='Local SAD score [Default: %default]') + parser.add_option('-n', dest='norm_file', + default=None, + help='Normalize SAD scores') + parser.add_option('-o',dest='out_dir', + default='sad', + help='Output directory for tables and plots [Default: %default]') + parser.add_option('-p', dest='processes', + default=None, type='int', + help='Number of processes, passed by multi script') + parser.add_option('--pseudo', dest='log_pseudo', + default=1, type='float', + help='Log2 pseudocount [Default: %default]') + parser.add_option('--rc', dest='rc', + default=False, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--stats', dest='sad_stats', + default='SAD,xSAR', + help='Comma-separated list of stats to save. [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + parser.add_option('--ti', dest='track_indexes', + default=None, type='str', + help='Comma-separated list of target indexes to output BigWig tracks') + parser.add_option('-u', dest='penultimate', + default=False, action='store_true', + help='Compute SED in the penultimate layer [Default: %default]') + parser.add_option('-z', dest='out_zarr', + default=False, action='store_true', + help='Output stats to sad.zarr [Default: %default]') + (options, args) = parser.parse_args() + + if len(args) == 3: + # single worker + params_file = args[0] + model_file = args[1] + vcf_file = args[2] + + elif len(args) == 5: + # multi worker + options_pkl_file = args[0] + params_file = args[1] + model_file = args[2] + vcf_file = args[3] + worker_index = int(args[4]) + + # load options + options_pkl = open(options_pkl_file, 'rb') + options = pickle.load(options_pkl) + options_pkl.close() + + # update output directory + options.out_dir = '%s/job%d' % (options.out_dir, worker_index) + + else: + parser.error('Must provide parameters and model files and QTL VCF file') + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + if options.track_indexes is None: + options.track_indexes = [] + else: + options.track_indexes = [int(ti) for ti in options.track_indexes.split(',')] + if not os.path.isdir('%s/tracks' % options.out_dir): + os.mkdir('%s/tracks' % options.out_dir) + + options.shifts = [int(shift) for shift in options.shifts.split(',')] + options.sad_stats = options.sad_stats.split(',') + + + ################################################################# + # read parameters + + job = basenji.params.read_job_params(params_file) + job['seq_length'] = options.seq_len + + if 'num_targets' not in job: + print( + "Must specify number of targets (num_targets) in the parameters file.", + file=sys.stderr) + exit(1) + + if 'target_pool' not in job: + print( + "Must specify target pooling (target_pool) in the parameters file.", + file=sys.stderr) + exit(1) + + if options.targets_file is None: + target_ids = ['t%d' % ti for ti in range(job['num_targets'])] + target_labels = ['']*len(target_ids) + target_subset = None + + else: + targets_df = pd.read_table(options.targets_file) + target_ids = targets_df.identifier + target_labels = targets_df.description + target_subset = targets_df.index + if len(target_subset) == job['num_targets']: + target_subset = None + + + ################################################################# + # load SNPs + + snps = bvcf.vcf_snps(vcf_file) + + # filter for worker SNPs + if options.processes is not None: + worker_bounds = np.linspace(0, len(snps), options.processes+1, dtype='int') + snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index+1]] + + num_snps = len(snps) + + # open genome FASTA + genome_open = pysam.Fastafile(options.genome_fasta) + + def snp_gen(): + for snp in snps: + # get SNP sequences + snp_1hot_list = bvcf.snp_seq1(snp, options.seq_len, genome_open) + + for snp_1hot in snp_1hot_list: + yield {'sequnece':snp_1hot} + + snp_types = {'sequence': tf.float32} + snp_shapes = {'sequence': tf.TensorShape([tf.Dimension(options.seq_len), + tf.Dimension(4)])} + + dataset = tf.data.Dataset().from_generator(snp_gen, + output_types=snp_types, + output_shapes=snp_shapes) + dataset = dataset.batch(job['batch_size']) + + iterator = dataset.make_one_shot_iterator() + data_ops = iterator.get_next() + + + ################################################################# + # setup model + + # build model + t0 = time.time() + model = basenji.seqnn.SeqNN() + # model.build_feed(job, target_subset=target_subset) + # model.build_feed(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts, + # target_subset=target_subset, penultimate=options.penultimate) + model.build_sad(job, data_ops, + ensemble_rc=options.rc, ensemble_shifts=options.shifts, + penultimate=options.penultimate, target_subset=target_subset) + print('Model building time %f' % (time.time() - t0), flush=True) + + if options.penultimate: + # labels become inappropriate + target_ids = ['']*model.hp.cnn_filters[-1] + target_labels = target_ids + + # read target normalization factors + target_norms = np.ones(len(target_labels)) + if options.norm_file is not None: + ti = 0 + for line in open(options.norm_file): + target_norms[ti] = float(line.strip()) + ti += 1 + + num_targets = len(target_ids) + + ################################################################# + # setup output + + header_cols = ('rsid', 'ref', 'alt', + 'ref_pred', 'alt_pred', 'sad', 'sar', 'geo_sad', + 'ref_lpred', 'alt_lpred', 'lsad', 'lsar', + 'ref_xpred', 'alt_xpred', 'xsad', 'xsar', + 'target_index', 'target_id', 'target_label') + + if options.out_h5: + sad_out = initialize_output_h5(options.out_dir, options.sad_stats, + snps, target_ids, target_labels) + + elif options.out_zarr: + sad_out = initialize_output_zarr(options.out_dir, options.sad_stats, + snps, target_ids, target_labels) + + else: + if options.csv: + sad_out = open('%s/sad_table.csv' % options.out_dir, 'w') + print(','.join(header_cols), file=sad_out) + else: + sad_out = open('%s/sad_table.txt' % options.out_dir, 'w') + print(' '.join(header_cols), file=sad_out) + + + ################################################################# + # process + + # determine local start and end + loc_mid = model.preds_length // 2 + loc_start = loc_mid - (options.local//2) // model.hp.target_pool + loc_end = loc_start + options.local // model.hp.target_pool + + snp_i = 0 + szi = 0 + + sum_write_thread = None + + # initialize saver + saver = tf.train.Saver() + with tf.Session() as sess: + # load variables into session + saver.restore(sess, model_file) + + # construct first batch + batch_1hot, batch_snps, snp_i = snps_next_batch( + snps, snp_i, options.batch_size, options.seq_len, genome_open) + + while len(batch_snps) > 0: + ################################################### + # predict + + # initialize batcher + batcher = basenji.batcher.Batcher(batch_1hot, batch_size=model.hp.batch_size) + + # predict + batch_preds = model.predict_tfr(sess, test_batches=256) + + # normalize + batch_preds /= target_norms + + ################################################### + # collect and print SADs + + # block for last thread + if sum_write_thread is not None: + sum_write_thread.join() + + sum_write_thread = threading.Thread(target=summarize_write, + args=(batch_snps, batch_preds, sad_out, szi, loc_start, loc_end, options.log_pseudo)) + sum_write_thread.start() + szi += len(batch_snps) + + ################################################### + # construct next batch + + batch_1hot, batch_snps, snp_i = snps_next_batch( + snps, snp_i, options.batch_size, options.seq_len, genome_open) + + sum_write_thread.join() + + ################################################### + # compute SAD distributions across variants + + if options.out_h5 or options.out_zarr: + # define percentiles + d_fine = 0.001 + d_coarse = 0.01 + percentiles_neg = np.arange(d_fine, 0.1, d_fine) + percentiles_base = np.arange(0.1, 0.9, d_coarse) + percentiles_pos = np.arange(0.9, 1, d_fine) + + percentiles = np.concatenate([percentiles_neg, percentiles_base, percentiles_pos]) + sad_out.create_dataset('percentiles', data=percentiles) + pct_len = len(percentiles) + + for sad_stat in options.sad_stats: + sad_stat_pct = '%s_pct' % sad_stat + + # compute + sad_pct = np.percentile(sad_out[sad_stat], 100*percentiles, axis=0).T + sad_pct = sad_pct.astype('float16') + + # save + sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16') + + if not options.out_zarr: + sad_out.close() + + +def summarize_write(batch_snps, batch_preds, sad_out, szi, loc_start, loc_end, log_pseudo): + num_targets = batch_preds.shape[-1] + pi = 0 + for snp in batch_snps: + # get reference prediction (LxT) + ref_preds = batch_preds[pi] + pi += 1 + + # sum across length + ref_preds_sum = ref_preds.sum(axis=0, dtype='float64') + + for alt_al in snp.alt_alleles: + # get alternate prediction (LxT) + alt_preds = batch_preds[pi] + pi += 1 + + # sum across length + alt_preds_sum = alt_preds.sum(axis=0, dtype='float64') + + # compare reference to alternative via mean subtraction + sad_vec = alt_preds - ref_preds + sad = alt_preds_sum - ref_preds_sum + + # compare reference to alternative via mean log division + sar = np.log2(alt_preds_sum + log_pseudo) \ + - np.log2(ref_preds_sum + log_pseudo) + + # compare geometric means + sar_vec = np.log2(alt_preds.astype('float64') + log_pseudo) \ + - np.log2(ref_preds.astype('float64') + log_pseudo) + geo_sad = sar_vec.sum(axis=0) + + # sum locally + # ref_preds_loc = ref_preds[loc_start:loc_end,:].sum(axis=0, dtype='float64') + # alt_preds_loc = alt_preds[loc_start:loc_end,:].sum(axis=0, dtype='float64') + + # # compute SAD locally + # sad_loc = alt_preds_loc - ref_preds_loc + # sar_loc = np.log2(alt_preds_loc + log_pseudo) \ + # - np.log2(ref_preds_loc + log_pseudo) + + # compute max difference position + max_li = np.argmax(np.abs(sar_vec), axis=0) + + sad_out['SAD'][szi,:] = sad.astype('float16') + sad_out['xSAR'][szi,:] = np.array([sar_vec[max_li[ti],ti] for ti in range(num_targets)], dtype='float16') + szi += 1 + + +def bigwig_write(snp, seq_len, preds, model, bw_file, genome_file): + bw_open = bigwig_open(bw_file, genome_file) + + seq_chrom = snp.chrom + seq_start = snp.pos - seq_len // 2 + + bw_chroms = [seq_chrom] * len(preds) + bw_starts = [ + int(seq_start + model.hp.batch_buffer + bi * model.hp.target_pool) + for bi in range(len(preds)) + ] + bw_ends = [int(bws + model.hp.target_pool) for bws in bw_starts] + + preds_list = [float(p) for p in preds] + bw_open.addEntries(bw_chroms, bw_starts, ends=bw_ends, values=preds_list) + + bw_open.close() + + +def initialize_output_h5(out_dir, sad_stats, snps, target_ids, target_labels): + """Initialize an output HDF5 file for SAD stats.""" + + num_targets = len(target_ids) + num_snps = len(snps) + + sad_out = h5py.File('%s/sad.h5' % out_dir, 'w') + + # write SNPs + snp_ids = np.array([snp.rsid for snp in snps], 'S') + sad_out.create_dataset('snp', data=snp_ids) + + # write targets + sad_out.create_dataset('target_ids', data=np.array(target_ids, 'S')) + sad_out.create_dataset('target_labels', data=np.array(target_labels, 'S')) + + # initialize SAD stats + for sad_stat in sad_stats: + sad_out.create_dataset(sad_stat, + shape=(num_snps, num_targets), + dtype='float16', + compression=None) + + return sad_out + + +def initialize_output_zarr(out_dir, sad_stats, snps, target_ids, target_labels): + """Initialize an output Zarr file for SAD stats.""" + + num_targets = len(target_ids) + num_snps = len(snps) + + sad_out = zarr.open_group('%s/sad.zarr' % out_dir, 'w') + + # write SNPs + sad_out.create_dataset('snp', data=[snp.rsid for snp in snps], chunks=(32768,)) + + # write targets + sad_out.create_dataset('target_ids', data=target_ids, compressor=None) + sad_out.create_dataset('target_labels', data=target_labels, compressor=None) + + # initialize SAD stats + for sad_stat in sad_stats: + sad_out.create_dataset(sad_stat, + shape=(num_snps, num_targets), + chunks=(128, num_targets), + dtype='float16') + + return sad_out + + +def snps_next_batch(snps, snp_i, batch_size, seq_len, genome_open): + """ Load the next batch of SNP sequence 1-hot. """ + + batch_1hot = [] + batch_snps = [] + + while len(batch_1hot) < batch_size and snp_i < len(snps): + # get SNP sequences + snp_1hot = bvcf.snp_seq1(snps[snp_i], seq_len, genome_open) + + # if it was valid + if len(snp_1hot) > 0: + # accumulate + batch_1hot += snp_1hot + batch_snps.append(snps[snp_i]) + + # advance SNP index + snp_i += 1 + + # convert to array + batch_1hot = np.array(batch_1hot) + + return batch_1hot, batch_snps, snp_i + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() From 6b5e5027e4ebc7bb09df3520704f44edb929c86d Mon Sep 17 00:00:00 2001 From: David Kelley Date: Mon, 6 Aug 2018 15:47:25 -0700 Subject: [PATCH 63/71] reverse_complement bug fix --- basenji/augmentation.py | 7 +++---- basenji/ops.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/basenji/augmentation.py b/basenji/augmentation.py index 34dc276b..f6f931cb 100644 --- a/basenji/augmentation.py +++ b/basenji/augmentation.py @@ -121,11 +121,10 @@ def augment_stochastic_rc(data_ops): Returns data_ops_aug: augmented data """ - seq, label, na = [data_ops[k] for k in ['sequence', 'label', 'na']] reverse_preds = tf.random_uniform(shape=[]) > 0.5 - seq, label, na = tf.cond(reverse_preds, lambda: ops.reverse_complement_transform(seq, label, na), - lambda: (seq, label, na)) - data_ops_aug = {'sequence': seq, 'label': label, 'na': na, 'reverse_preds':reverse_preds} + data_ops_aug = tf.cond(reverse_preds, lambda: ops.reverse_complement_transform(data_ops), + lambda: data_ops.copy()) + data_ops_aug['reverse_preds'] = reverse_preds return data_ops_aug diff --git a/basenji/ops.py b/basenji/ops.py index 9cb808d7..021cb19b 100755 --- a/basenji/ops.py +++ b/basenji/ops.py @@ -66,7 +66,7 @@ def reverse_complement_transform(data_ops): # reverse NA if 'na' in data_ops: - data_ops_rc['na'] = tf.reverse(na_rc, axis=[1]) + data_ops_rc['na'] = tf.reverse(data_ops['na'], axis=[1]) return data_ops_rc From 51a36711abdbb9c23a804a61d4be6563935f959a Mon Sep 17 00:00:00 2001 From: David Kelley Date: Mon, 6 Aug 2018 15:51:13 -0700 Subject: [PATCH 64/71] optimizing --- basenji/seqnn_util.py | 25 ++++++---- bin/basenji_sadq.py | 107 ++++++++++++++++-------------------------- 2 files changed, 56 insertions(+), 76 deletions(-) diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 192d5d82..476ad895 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -860,10 +860,8 @@ def predict_tfr(self, sess, test_batches=None, # initialize prediction data structures preds = [] - if return_var: - preds_var = [] - if return_all: - preds_all = [] + preds_var = [] + preds_all = [] # sequence index data_available = True @@ -893,12 +891,19 @@ def predict_tfr(self, sess, test_batches=None, except tf.errors.OutOfRangeError: data_available = False - # construct arrays - preds = np.concatenate(preds, axis=0) - if return_var: - preds_var = np.concatenate(preds_var, axis=0) - if return_all: - preds_all = np.concatenate(preds_all, axis=0) + if preds: + # concatenate into arrays + preds = np.concatenate(preds, axis=0) + if return_var and preds_var: + preds_var = np.concatenate(preds_var, axis=0) + if return_all and preds_all: + preds_all = np.concatenate(preds_all, axis=0) + + else: + # return empty array objects + preds = np.array(preds) + preds_var = np.array(preds_var) + preds_all = np.array(preds_all) if return_var: if return_all: diff --git a/bin/basenji_sadq.py b/bin/basenji_sadq.py index f284d7da..25989b3c 100755 --- a/bin/basenji_sadq.py +++ b/bin/basenji_sadq.py @@ -194,7 +194,7 @@ def snp_gen(): snp_1hot_list = bvcf.snp_seq1(snp, options.seq_len, genome_open) for snp_1hot in snp_1hot_list: - yield {'sequnece':snp_1hot} + yield {'sequence':snp_1hot} snp_types = {'sequence': tf.float32} snp_shapes = {'sequence': tf.TensorShape([tf.Dimension(options.seq_len), @@ -204,6 +204,8 @@ def snp_gen(): output_types=snp_types, output_shapes=snp_shapes) dataset = dataset.batch(job['batch_size']) + dataset = dataset.prefetch(2*job['batch_size']) + dataset = dataset.apply(tf.contrib.data.prefetch_to_device('/device:GPU:0')) iterator = dataset.make_one_shot_iterator() data_ops = iterator.get_next() @@ -215,9 +217,6 @@ def snp_gen(): # build model t0 = time.time() model = basenji.seqnn.SeqNN() - # model.build_feed(job, target_subset=target_subset) - # model.build_feed(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts, - # target_subset=target_subset, penultimate=options.penultimate) model.build_sad(job, data_ops, ensemble_rc=options.rc, ensemble_shifts=options.shifts, penultimate=options.penultimate, target_subset=target_subset) @@ -267,56 +266,44 @@ def snp_gen(): ################################################################# # process - # determine local start and end - loc_mid = model.preds_length // 2 - loc_start = loc_mid - (options.local//2) // model.hp.target_pool - loc_end = loc_start + options.local // model.hp.target_pool - - snp_i = 0 szi = 0 - sum_write_thread = None + sw_batch_size = 32 // job['batch_size'] # initialize saver saver = tf.train.Saver() with tf.Session() as sess: + # coordinator + coord = tf.train.Coordinator() + tf.train.start_queue_runners(coord=coord) + # load variables into session saver.restore(sess, model_file) - # construct first batch - batch_1hot, batch_snps, snp_i = snps_next_batch( - snps, snp_i, options.batch_size, options.seq_len, genome_open) + # predict first + batch_preds = model.predict_tfr(sess, test_batches=sw_batch_size) - while len(batch_snps) > 0: - ################################################### - # predict - - # initialize batcher - batcher = basenji.batcher.Batcher(batch_1hot, batch_size=model.hp.batch_size) - - # predict - batch_preds = model.predict_tfr(sess, test_batches=256) + while batch_preds.shape[0] > 0: + # count predicted SNPs + num_snps = batch_preds.shape[0] // 2 # normalize batch_preds /= target_norms - ################################################### - # collect and print SADs - # block for last thread if sum_write_thread is not None: sum_write_thread.join() + # summarize and write sum_write_thread = threading.Thread(target=summarize_write, - args=(batch_snps, batch_preds, sad_out, szi, loc_start, loc_end, options.log_pseudo)) + args=(batch_preds, sad_out, szi, options.log_pseudo)) sum_write_thread.start() - szi += len(batch_snps) - ################################################### - # construct next batch + # update SNP index + szi += num_snps - batch_1hot, batch_snps, snp_i = snps_next_batch( - snps, snp_i, options.batch_size, options.seq_len, genome_open) + # predict next + batch_preds = model.predict_tfr(sess, test_batches=sw_batch_size) sum_write_thread.join() @@ -349,53 +336,41 @@ def snp_gen(): sad_out.close() -def summarize_write(batch_snps, batch_preds, sad_out, szi, loc_start, loc_end, log_pseudo): +def summarize_write(batch_preds, sad_out, szi, log_pseudo): num_targets = batch_preds.shape[-1] pi = 0 - for snp in batch_snps: + while pi < batch_preds.shape[0]: # get reference prediction (LxT) ref_preds = batch_preds[pi] pi += 1 + # get alternate prediction (LxT) + alt_preds = batch_preds[pi] + pi += 1 + # sum across length ref_preds_sum = ref_preds.sum(axis=0, dtype='float64') + alt_preds_sum = alt_preds.sum(axis=0, dtype='float64') - for alt_al in snp.alt_alleles: - # get alternate prediction (LxT) - alt_preds = batch_preds[pi] - pi += 1 - - # sum across length - alt_preds_sum = alt_preds.sum(axis=0, dtype='float64') - - # compare reference to alternative via mean subtraction - sad_vec = alt_preds - ref_preds - sad = alt_preds_sum - ref_preds_sum - - # compare reference to alternative via mean log division - sar = np.log2(alt_preds_sum + log_pseudo) \ - - np.log2(ref_preds_sum + log_pseudo) - - # compare geometric means - sar_vec = np.log2(alt_preds.astype('float64') + log_pseudo) \ - - np.log2(ref_preds.astype('float64') + log_pseudo) - geo_sad = sar_vec.sum(axis=0) + # compare reference to alternative via mean subtraction + # sad_vec = alt_preds - ref_preds + sad = alt_preds_sum - ref_preds_sum - # sum locally - # ref_preds_loc = ref_preds[loc_start:loc_end,:].sum(axis=0, dtype='float64') - # alt_preds_loc = alt_preds[loc_start:loc_end,:].sum(axis=0, dtype='float64') + # compare reference to alternative via mean log division + # sar = np.log2(alt_preds_sum + log_pseudo) \ + # - np.log2(ref_preds_sum + log_pseudo) - # # compute SAD locally - # sad_loc = alt_preds_loc - ref_preds_loc - # sar_loc = np.log2(alt_preds_loc + log_pseudo) \ - # - np.log2(ref_preds_loc + log_pseudo) + # compare geometric means + # sar_vec = np.log2(alt_preds.astype('float64') + log_pseudo) \ + # - np.log2(ref_preds.astype('float64') + log_pseudo) + # geo_sad = sar_vec.sum(axis=0) - # compute max difference position - max_li = np.argmax(np.abs(sar_vec), axis=0) + # compute max difference position + # max_li = np.argmax(np.abs(sar_vec), axis=0) - sad_out['SAD'][szi,:] = sad.astype('float16') - sad_out['xSAR'][szi,:] = np.array([sar_vec[max_li[ti],ti] for ti in range(num_targets)], dtype='float16') - szi += 1 + sad_out['SAD'][szi,:] = sad.astype('float16') + # sad_out['xSAR'][szi,:] = np.array([sar_vec[max_li[ti],ti] for ti in range(num_targets)], dtype='float16') + szi += 1 def bigwig_write(snp, seq_len, preds, model, bw_file, genome_file): From 5e19e7efc4025fdb3bc8e56c0cf3dc247e1bcfb0 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Tue, 7 Aug 2018 15:17:11 -0700 Subject: [PATCH 65/71] sadq multi --- bin/basenji_sadq.py | 59 ++------ bin/basenji_sadq_multi.py | 287 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 295 insertions(+), 51 deletions(-) create mode 100755 bin/basenji_sadq_multi.py diff --git a/bin/basenji_sadq.py b/bin/basenji_sadq.py index 25989b3c..55a23aa1 100755 --- a/bin/basenji_sadq.py +++ b/bin/basenji_sadq.py @@ -64,7 +64,7 @@ def main(): parser.add_option('-l', dest='seq_len', default=131072, type='int', help='Sequence length provided to the model [Default: %default]') - parser.add_option('--local',dest='local', + parser.add_option('--local', dest='local', default=1024, type='int', help='Local SAD score [Default: %default]') parser.add_option('-n', dest='norm_file', @@ -86,7 +86,7 @@ def main(): default='0', type='str', help='Ensemble prediction shifts [Default: %default]') parser.add_option('--stats', dest='sad_stats', - default='SAD,xSAR', + default='SAD', help='Comma-separated list of stats to save. [Default: %default]') parser.add_option('-t', dest='targets_file', default=None, type='str', @@ -240,12 +240,6 @@ def snp_gen(): ################################################################# # setup output - header_cols = ('rsid', 'ref', 'alt', - 'ref_pred', 'alt_pred', 'sad', 'sar', 'geo_sad', - 'ref_lpred', 'alt_lpred', 'lsad', 'lsar', - 'ref_xpred', 'alt_xpred', 'xsad', 'xsar', - 'target_index', 'target_id', 'target_label') - if options.out_h5: sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps, target_ids, target_labels) @@ -255,6 +249,12 @@ def snp_gen(): snps, target_ids, target_labels) else: + header_cols = ('rsid', 'ref', 'alt', + 'ref_pred', 'alt_pred', 'sad', 'sar', 'geo_sad', + 'ref_lpred', 'alt_lpred', 'lsad', 'lsar', + 'ref_xpred', 'alt_xpred', 'xsad', 'xsar', + 'target_index', 'target_id', 'target_label') + if options.csv: sad_out = open('%s/sad_table.csv' % options.out_dir, 'w') print(','.join(header_cols), file=sad_out) @@ -373,25 +373,6 @@ def summarize_write(batch_preds, sad_out, szi, log_pseudo): szi += 1 -def bigwig_write(snp, seq_len, preds, model, bw_file, genome_file): - bw_open = bigwig_open(bw_file, genome_file) - - seq_chrom = snp.chrom - seq_start = snp.pos - seq_len // 2 - - bw_chroms = [seq_chrom] * len(preds) - bw_starts = [ - int(seq_start + model.hp.batch_buffer + bi * model.hp.target_pool) - for bi in range(len(preds)) - ] - bw_ends = [int(bws + model.hp.target_pool) for bws in bw_starts] - - preds_list = [float(p) for p in preds] - bw_open.addEntries(bw_chroms, bw_starts, ends=bw_ends, values=preds_list) - - bw_open.close() - - def initialize_output_h5(out_dir, sad_stats, snps, target_ids, target_labels): """Initialize an output HDF5 file for SAD stats.""" @@ -443,30 +424,6 @@ def initialize_output_zarr(out_dir, sad_stats, snps, target_ids, target_labels): return sad_out -def snps_next_batch(snps, snp_i, batch_size, seq_len, genome_open): - """ Load the next batch of SNP sequence 1-hot. """ - - batch_1hot = [] - batch_snps = [] - - while len(batch_1hot) < batch_size and snp_i < len(snps): - # get SNP sequences - snp_1hot = bvcf.snp_seq1(snps[snp_i], seq_len, genome_open) - - # if it was valid - if len(snp_1hot) > 0: - # accumulate - batch_1hot += snp_1hot - batch_snps.append(snps[snp_i]) - - # advance SNP index - snp_i += 1 - - # convert to array - batch_1hot = np.array(batch_1hot) - - return batch_1hot, batch_snps, snp_i - ################################################################################ # __main__ ################################################################################ diff --git a/bin/basenji_sadq_multi.py b/bin/basenji_sadq_multi.py new file mode 100755 index 00000000..6a160cd2 --- /dev/null +++ b/bin/basenji_sadq_multi.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from optparse import OptionParser +import gc +import glob +import os +import pickle +import shutil +import subprocess +import sys + +import h5py +import numpy as np +import zarr + +import slurm + +""" +basenji_sadq_multi.py + +Compute SNP expression difference scores for variants in a VCF file, +using multiple processes. +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('-b',dest='batch_size', + default=256, type='int', + help='Batch size [Default: %default]') + parser.add_option('-c', dest='csv', + default=False, action='store_true', + help='Print table as CSV [Default: %default]') + parser.add_option('-f', dest='genome_fasta', + default='%s/assembly/hg19.fa' % os.environ['HG19'], + help='Genome FASTA for sequences [Default: %default]') + parser.add_option('-g', dest='genome_file', + default='%s/assembly/human.hg19.genome' % os.environ['HG19'], + help='Chromosome lengths file [Default: %default]') + parser.add_option('--h5', dest='out_h5', + default=False, action='store_true', + help='Output stats to sad.h5 [Default: %default]') + parser.add_option('-l', dest='seq_len', + default=131072, type='int', + help='Sequence length provided to the model [Default: %default]') + parser.add_option('--local',dest='local', + default=1024, type='int', + help='Local SAD score [Default: %default]') + parser.add_option('-n', dest='norm_file', + default=None, + help='Normalize SAD scores') + parser.add_option('-o',dest='out_dir', + default='sad', + help='Output directory for tables and plots [Default: %default]') + parser.add_option('-p', dest='processes', + default=None, type='int', + help='Number of processes, passed by multi script') + parser.add_option('--pseudo', dest='log_pseudo', + default=1, type='float', + help='Log2 pseudocount [Default: %default]') + parser.add_option('-q', dest='queue', + default='k80', + help='SLURM queue on which to run the jobs [Default: %default]') + parser.add_option('-r', dest='restart', + default=False, action='store_true', + help='Restart a partially completed job [Default: %default]') + parser.add_option('--rc', dest='rc', + default=False, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--stats', dest='sad_stats', + default='SAD', + help='Comma-separated list of stats to save. [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + parser.add_option('--ti', dest='track_indexes', + default=None, type='str', + help='Comma-separated list of target indexes to output BigWig tracks') + parser.add_option('-u', dest='penultimate', + default=False, action='store_true', + help='Compute SED in the penultimate layer [Default: %default]') + parser.add_option('-z', dest='out_zarr', + default=False, action='store_true', + help='Output stats to sad.zarr [Default: %default]') + (options, args) = parser.parse_args() + + if len(args) != 3: + parser.error('Must provide parameters and model files and VCF file') + else: + params_file = args[0] + model_file = args[1] + vcf_file = args[2] + + ####################################################### + # prep work + + # output directory + if not options.restart: + if os.path.isdir(options.out_dir): + print('Please remove %s' % options.out_dir, file=sys.stderr) + exit(1) + os.mkdir(options.out_dir) + + # pickle options + options_pkl_file = '%s/options.pkl' % options.out_dir + options_pkl = open(options_pkl_file, 'wb') + pickle.dump(options, options_pkl) + options_pkl.close() + + ####################################################### + # launch worker threads + jobs = [] + for pi in range(options.processes): + if not options.restart or not job_completed(options, pi): + cmd = 'source activate py3_gpu; basenji_sadq.py %s %s %d' % ( + options_pkl_file, ' '.join(args), pi) + name = 'sad_p%d' % pi + outf = '%s/job%d.out' % (options.out_dir, pi) + errf = '%s/job%d.err' % (options.out_dir, pi) + j = slurm.Job(cmd, name, + outf, errf, + queue=options.queue, gpu=1, + mem=15000, time='7-0:0:0') + jobs.append(j) + + slurm.multi_run(jobs, max_proc=options.processes, verbose=True, + launch_sleep=10, update_sleep=60) + + ####################################################### + # collect output + + if options.out_h5: + collect_h5('sad.h5', options.out_dir, options.processes) + + elif options.out_zarr: + collect_zarr('sad.zarr', options.out_dir, options.processes) + + else: + collect_table('sad_table.txt', options.out_dir, options.processes) + + # for pi in range(options.processes): + # shutil.rmtree('%s/job%d' % (options.out_dir,pi)) + + +def collect_table(file_name, out_dir, num_procs): + os.rename('%s/job0/%s' % (out_dir, file_name), '%s/%s' % (out_dir, file_name)) + for pi in range(1, num_procs): + subprocess.call( + 'tail -n +2 %s/job%d/%s >> %s/%s' % (out_dir, pi, file_name, out_dir, + file_name), + shell=True) + + +def collect_h5(file_name, out_dir, num_procs): + # count variants + num_variants = 0 + for pi in range(num_procs): + # open job + job_h5_file = '%s/job%d/%s' % (out_dir, pi, file_name) + job_h5_open = h5py.File(job_h5_file, 'r') + num_variants += len(job_h5_open['snp']) + job_h5_open.close() + + # initialize final h5 + final_h5_file = '%s/%s' % (out_dir, file_name) + final_h5_open = h5py.File(final_h5_file, 'w') + + job0_h5_file = '%s/job0/%s' % (out_dir, file_name) + job0_h5_open = h5py.File(job0_h5_file, 'r') + for key in job0_h5_open.keys(): + if key in ['percentiles', 'target_ids', 'target_labels']: + # copy + final_h5_open.create_dataset(key, data=job0_h5_open[key]) + + elif key[-4:] == '_pct': + values = np.zeros(job0_h5_open[key].shape) + final_h5_open.create_dataset(key, data=values) + + elif key == 'snp': + final_h5_open.create_dataset(key, shape=(num_variants,), dtype=job0_h5_open[key].dtype) + + else: + num_targets = job0_h5_open[key].shape[1] + final_h5_open.create_dataset(key, shape=(num_variants, num_targets), dtype=job0_h5_open[key].dtype) + + job0_h5_open.close() + + # set values + vi = 0 + for pi in range(num_procs): + # open job + job_h5_file = '%s/job%d/%s' % (out_dir, pi, file_name) + job_h5_open = h5py.File(job_h5_file, 'r') + + # append to final + for key in job_h5_open.keys(): + if key in ['percentiles', 'target_ids', 'target_labels']: + # once is enough + pass + + elif key[-4:] == '_pct': + # average + u_k1 = np.array(final_h5_open[key]) + x_k = np.array(job_h5_open[key]) + final_h5_open[key][:] = u_k1 + (x_k - u_k1) / (pi+1) + + else: + job_variants = job_h5_open[key].shape[0] + final_h5_open[key][vi:vi+job_variants] = job_h5_open[key] + + vi += job_variants + job_h5_open.close() + + final_h5_open.close() + + +def collect_zarr(file_name, out_dir, num_procs): + final_zarr_file = '%s/%s' % (out_dir, file_name) + + # seed w/ job0 + job_zarr_file = '%s/job0/%s' % (out_dir, file_name) + shutil.copytree(job_zarr_file, final_zarr_file) + + # open final + final_zarr_open = zarr.open_group(final_zarr_file) + + for pi in range(1, num_procs): + # open job + job_zarr_file = '%s/job%d/%s' % (out_dir, pi, file_name) + job_zarr_open = zarr.open_group(job_zarr_file, 'r') + + # append to final + for key in final_zarr_open.keys(): + if key in ['percentiles', 'target_ids', 'target_labels']: + # once is enough + pass + + elif key[-4:] == '_pct': + # average + u_k1 = np.array(final_zarr_open[key]) + x_k = np.array(job_zarr_open[key]) + final_zarr_open[key] = u_k1 + (x_k - u_k1) / (pi+1) + + else: + # append + final_zarr_open[key].append(job_zarr_open[key]) + + +def job_completed(options, pi): + """Check whether a specific job has generated its + output file.""" + if options.out_h5: + out_file = '%s/job%d/sad.h5' % (options.out_dir, pi) + elif options.out_zarr: + out_file = '%s/job%d/sad.zarr' % (options.out_dir, pi) + elif options.csv: + out_file = '%s/job%d/sad_table.csv' % (options.out_dir, pi) + else: + out_file = '%s/job%d/sad_table.txt' % (options.out_dir, pi) + return os.path.isfile(out_file) or os.path.isdir(out_file) + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() From 1d6bb21d10bf03449d7751725e0a5ce4c3977eb0 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Thu, 9 Aug 2018 17:01:18 -0700 Subject: [PATCH 66/71] dynamic batch size --- basenji/seqnn.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 4bdc2eea..3e7acdb1 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -146,24 +146,24 @@ def make_placeholders(self): # batches self.inputs_ph = tf.placeholder( tf.float32, - shape=(self.hp.batch_size, self.hp.seq_length, self.hp.seq_depth), + shape=(None, self.hp.seq_length, self.hp.seq_depth), name='inputs') if self.hp.target_classes == 1: self.targets_ph = tf.placeholder( tf.float32, - shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool, + shape=(None, self.hp.seq_length // self.hp.target_pool, self.hp.num_targets), name='targets') else: self.targets_ph = tf.placeholder( tf.int32, - shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool, + shape=(None, self.hp.seq_length // self.hp.target_pool, self.hp.num_targets), name='targets') self.targets_na_ph = tf.placeholder(tf.bool, - shape=(self.hp.batch_size, self.hp.seq_length // self.hp.target_pool), + shape=(None, self.hp.seq_length // self.hp.target_pool), name='targets_na') data = { @@ -231,6 +231,7 @@ def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_su seq_length = seqs_repr.shape[1] seqs_repr = seqs_repr[:, batch_buffer_pool: seq_length - batch_buffer_pool, :] + seq_length = seqs_repr.shape[1] ################################################### # final layer @@ -265,7 +266,7 @@ def build_predict(self, inputs, reverse_preds=None, penultimate=False, target_su # expand length back out if self.hp.target_classes > 1: final_repr = tf.reshape(final_repr, - (self.hp.batch_size, -1, self.hp.num_targets, + (-1, seq_length, self.hp.num_targets, self.hp.target_classes)) # transform for reverse complement From 74dec8b33a0544448e441f5746c9b9906e11a8a9 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Thu, 9 Aug 2018 20:07:29 -0700 Subject: [PATCH 67/71] full batches unnecessary --- basenji/batcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/basenji/batcher.py b/basenji/batcher.py index 6072ab46..782f3956 100755 --- a/basenji/batcher.py +++ b/basenji/batcher.py @@ -84,7 +84,7 @@ def next(self, fwdrc=True, shift=0): # initialize Xb = np.zeros( - (self.batch_size, self.seq_len, self.seq_depth), dtype='float32') + (Nb, self.seq_len, self.seq_depth), dtype='float32') if self.Yf is not None: if self.Yf.dtype == np.uint8: ytype = 'int32' @@ -92,11 +92,11 @@ def next(self, fwdrc=True, shift=0): ytype = 'float32' Yb = np.zeros( - (self.batch_size, self.seq_len // self.pool_width, + (Nb, self.seq_len // self.pool_width, self.num_targets), dtype=ytype) NAb = np.zeros( - (self.batch_size, self.seq_len // self.pool_width), dtype='bool') + (Nb, self.seq_len // self.pool_width), dtype='bool') # copy data for i in range(Nb): From 8264f081334ce3cea4f5f984317a9f3edfd23a40 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Fri, 10 Aug 2018 09:39:42 -0700 Subject: [PATCH 68/71] weighted average losses --- basenji/seqnn.py | 32 +++++++++++++++++++------------- basenji/seqnn_util.py | 30 ++++++++++++++++++------------ 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/basenji/seqnn.py b/basenji/seqnn.py index 3e7acdb1..fcdb4529 100755 --- a/basenji/seqnn.py +++ b/basenji/seqnn.py @@ -473,6 +473,7 @@ def train_epoch_h5_manual(self, # initialize training loss train_loss = [] + batch_sizes = [] global_step = 0 # setup feed dict @@ -482,9 +483,7 @@ def train_epoch_h5_manual(self, Xb, Yb, NAb, Nb = batcher.next(fwdrc, shift) batch_num = 0 - while Xb is not None and Nb == self.hp.batch_size and ( - epoch_batches is None or batch_num < epoch_batches): - + while Xb is not None and (epoch_batches is None or batch_num < epoch_batches): # update feed dict fd[self.inputs_ph] = Xb fd[self.targets_ph] = Yb @@ -504,9 +503,8 @@ def train_epoch_h5_manual(self, sum_writer.add_summary(summary, global_step) # accumulate loss - # avail_sum = np.logical_not(NAb[:Nb,:]).sum() - # train_loss.append(loss_batch / avail_sum) train_loss.append(loss_batch) + batch_sizes.append(Xb.shape[0]) # next batch Xb, Yb, NAb, Nb = batcher.next(fwdrc, shift) @@ -516,7 +514,9 @@ def train_epoch_h5_manual(self, if epoch_batches is None: batcher.reset() - return np.mean(train_loss), global_step + avg_loss = np.average(train_loss, weights=batch_sizes) + + return avg_loss, global_step def train_epoch_h5(self, sess, @@ -529,6 +529,7 @@ def train_epoch_h5(self, # initialize training loss train_loss = [] + batch_sizes = [] global_step = 0 # setup feed dict @@ -538,9 +539,7 @@ def train_epoch_h5(self, Xb, Yb, NAb, Nb = batcher.next() batch_num = 0 - while Xb is not None and Nb == self.hp.batch_size and ( - epoch_batches is None or batch_num < epoch_batches): - + while Xb is not None and (epoch_batches is None or batch_num < epoch_batches): # update feed dict fd[self.inputs_ph] = Xb fd[self.targets_ph] = Yb @@ -560,6 +559,7 @@ def train_epoch_h5(self, # accumulate loss train_loss.append(loss_batch) + batch_sizes.append(Nb) # next batch Xb, Yb, NAb, Nb = batcher.next() @@ -569,7 +569,9 @@ def train_epoch_h5(self, if epoch_batches is None: batcher.reset() - return np.mean(train_loss), global_step + avg_loss = np.average(train_loss, weights=batch_sizes) + + return avg_loss, global_step def train_epoch_tfr(self, sess, sum_writer=None, epoch_batches=None): @@ -577,6 +579,7 @@ def train_epoch_tfr(self, sess, sum_writer=None, epoch_batches=None): # initialize training loss train_loss = [] + batch_sizes = [] global_step = 0 # setup feed dict @@ -587,9 +590,9 @@ def train_epoch_tfr(self, sess, sum_writer=None, epoch_batches=None): while data_available and (epoch_batches is None or batch_num < epoch_batches): try: # update_ops won't run - run_ops = [self.merged_summary, self.loss_train, self.global_step, self.step_op] + self.update_ops + run_ops = [self.merged_summary, self.loss_train, self.preds_train, self.global_step, self.step_op] + self.update_ops run_returns = sess.run(run_ops, feed_dict=fd) - summary, loss_batch, global_step = run_returns[:3] + summary, loss_batch, preds, global_step = run_returns[:4] # add summary if sum_writer is not None: @@ -597,6 +600,7 @@ def train_epoch_tfr(self, sess, sum_writer=None, epoch_batches=None): # accumulate loss train_loss.append(loss_batch) + batch_sizes.append(preds.shape[0]) # next batch batch_num += 1 @@ -604,4 +608,6 @@ def train_epoch_tfr(self, sess, sum_writer=None, epoch_batches=None): except tf.errors.OutOfRangeError: data_available = False - return np.mean(train_loss), global_step + avg_loss = np.average(train_loss, weights=batch_sizes) + + return avg_loss, global_step diff --git a/basenji/seqnn_util.py b/basenji/seqnn_util.py index 476ad895..22a3409a 100644 --- a/basenji/seqnn_util.py +++ b/basenji/seqnn_util.py @@ -1002,6 +1002,7 @@ def test_tfr(self, sess, test_batches=None): batch_losses = [] batch_target_losses = [] + batch_sizes = [] # sequence index data_available = True @@ -1022,6 +1023,7 @@ def test_tfr(self, sess, test_batches=None): # accumulate loss batch_losses.append(loss_batch) batch_target_losses.append(target_losses_batch) + batch_sizes.append(preds_batch.shape[0]) batch_num += 1 @@ -1034,8 +1036,10 @@ def test_tfr(self, sess, test_batches=None): targets_na = np.concatenate(targets_na, axis=0) # mean across batches - batch_losses = np.mean(batch_losses, dtype='float64') - batch_target_losses = np.array(batch_target_losses).mean(axis=0, dtype='float64') + batch_losses = np.array(batch_losses, dtype='float64') + batch_losses = np.average(batch_losses, weights=batch_sizes) + batch_target_losses = np.array(batch_target_losses, dtype='float64') + batch_target_losses = np.average(batch_target_losses, axis=0, weights=batch_sizes) # instantiate accuracy object acc = accuracy.Accuracy(targets, preds, targets_na, @@ -1064,6 +1068,7 @@ def test_h5(self, sess, batcher, test_batches=None): batch_losses = [] batch_target_losses = [] + batch_sizes = [] # get first batch batch_num = 0 @@ -1089,6 +1094,7 @@ def test_h5(self, sess, batcher, test_batches=None): # accumulate loss batch_losses.append(loss_batch) batch_target_losses.append(target_losses_batch) + batch_sizes.append(Nb) # next batch batch_num += 1 @@ -1103,8 +1109,10 @@ def test_h5(self, sess, batcher, test_batches=None): targets_na = np.concatenate(targets_na, axis=0) # mean across batches - batch_losses = np.mean(batch_losses, dtype='float64') - batch_target_losses = np.array(batch_target_losses).mean(axis=0, dtype='float64') + batch_losses = np.array(batch_losses, dtype='float64') + batch_losses = np.average(batch_losses, weights=batch_sizes) + batch_target_losses = np.array(batch_target_losses, dtype='float64') + batch_target_losses = np.average(batch_target_losses, axis=0, weights=batch_sizes) # instantiate accuracy object acc = accuracy.Accuracy(targets, preds, targets_na, @@ -1163,9 +1171,7 @@ def test_h5_manual(self, batch_losses = [] batch_target_losses = [] - - # sequence index - si = 0 + batch_size = [] # get first batch Xb, Yb, NAb, Nb = batcher.next() @@ -1214,9 +1220,7 @@ def test_h5_manual(self, # accumulate loss batch_losses.append(loss_batch) batch_target_losses.append(target_losses_batch) - - # update sequence index - si += Nb + batch_sizes.append(Nb) # next batch Xb, Yb, NAb, Nb = batcher.next() @@ -1230,8 +1234,10 @@ def test_h5_manual(self, batcher.reset() # mean across batches - batch_losses = np.mean(batch_losses) - batch_target_losses = np.array(batch_target_losses).mean(axis=0) + batch_losses = np.array(batch_losses, dtype='float64') + batch_losses = np.average(batch_losses, weights=batch_sizes) + batch_target_losses = np.array(batch_target_losses, dtype='float64') + batch_target_losses = np.average(batch_target_losses, axis=0, weights=batch_sizes) # instantiate accuracy object acc = accuracy.Accuracy(targets, preds, targets_na, batch_losses, From eb5c2242c9915ab36c58c9d881529c25ee14c5d1 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Tue, 14 Aug 2018 20:41:02 -0700 Subject: [PATCH 69/71] sad optimizations --- basenji/stream.py | 50 +- basenji/vcf.py | 104 ++-- bin/basenji_sadq.py | 9 - bin/basenji_sadq_ref.py | 487 +++++++++++++++++++ tests/data/regulatory_validated_misorder.vcf | 18 + tests/data/regulatory_validated_misref.vcf | 18 + tests/test_sad.py | 134 +++++ 7 files changed, 755 insertions(+), 65 deletions(-) create mode 100755 bin/basenji_sadq_ref.py create mode 100644 tests/data/regulatory_validated_misorder.vcf create mode 100644 tests/data/regulatory_validated_misref.vcf diff --git a/basenji/stream.py b/basenji/stream.py index 1f6f851b..8c4d4f35 100755 --- a/basenji/stream.py +++ b/basenji/stream.py @@ -17,11 +17,49 @@ import basenji - class PredStream: """ Interface to acquire predictions via a buffered stream mechanism rather than getting them all at once and using excessive memory. """ + def __init__(self, sess, model, stream_length, verbose=False): + self.sess = sess + self.model = model + self.verbose = verbose + + self.stream_start = 0 + self.stream_end = 0 + + if stream_length % self.model.hp.batch_size != 0: + print( + 'Make the stream length a multiple of the batch size', + file=sys.stderr) + exit(1) + else: + self.stream_batches = stream_length // self.model.hp.batch_size + + def __getitem__(self, i): + # acquire predictions, if needed + if i >= self.stream_end: + # update start + self.stream_start = self.stream_end + + if self.verbose: + print('Predicting from %d' % self.stream_start, flush=True) + + # predict + self.stream_preds = self.model.predict_tfr(self.sess, + test_batches=self.stream_batches) + + # update end + self.stream_end = self.stream_start + self.stream_preds.shape[0] + + return self.stream_preds[i - self.stream_start] + + +class PredStreamFeed: + """ Interface to acquire predictions via a buffered stream mechanism + rather than getting them all at once and using excessive memory. """ + def __init__(self, sess, model, seqs_1hot, stream_length): self.sess = sess self.model = model @@ -32,7 +70,7 @@ def __init__(self, sess, model, seqs_1hot, stream_length): self.stream_start = 0 self.stream_end = 0 - if self.stream_length % self.model.batch_size != 0: + if self.stream_length % self.model.hp.batch_size != 0: print( 'Make the stream length a multiple of the batch size', file=sys.stderr) @@ -50,7 +88,7 @@ def __getitem__(self, i): # initialize batcher batcher = basenji.batcher.Batcher( - stream_seqs_1hot, batch_size=self.model.batch_size) + stream_seqs_1hot, batch_size=self.model.hp.batch_size) # predict self.stream_preds = self.model.predict(self.sess, batcher, rc_avg=False) @@ -62,7 +100,7 @@ class PredGradStream: """ Interface to acquire predictions and gradients via a buffered stream mechanism rather than getting them all at once and using excessive - memory. + memory. """ def __init__(self, sess, model, seqs_1hot, stream_length): @@ -75,7 +113,7 @@ def __init__(self, sess, model, seqs_1hot, stream_length): self.stream_start = 0 self.stream_end = 0 - if self.stream_length % self.model.batch_size != 0: + if self.stream_length % self.model.hp.batch_size != 0: print( 'Make the stream length a multiple of the batch size', file=sys.stderr) @@ -93,7 +131,7 @@ def __getitem__(self, i): # initialize batcher batcher = basenji.batcher.Batcher( - stream_seqs_1hot, batch_size=self.model.batch_size) + stream_seqs_1hot, batch_size=self.model.hp.batch_size) # predict self.stream_grads, self.stream_preds = self.model.gradients( diff --git a/basenji/vcf.py b/basenji/vcf.py index 1bc25216..4ee97223 100755 --- a/basenji/vcf.py +++ b/basenji/vcf.py @@ -58,9 +58,9 @@ def intersect_seqs_snps(vcf_file, gene_seqs, vision_p=1): seq_indexes = {} for si in range(len(gene_seqs)): gs = gene_seqs[si] - gene_seq_key = (gs.chrom, gs.start) + gene_seq_key = (gs.chr, gs.start) seq_indexes[gene_seq_key] = si - print('%s\t%d\t%d' % (gs.chrom, gs.start, gs.end), file=seq_bed_out) + print('%s\t%d\t%d' % (gs.chr, gs.start, gs.end), file=seq_bed_out) seq_bed_out.close() # hash SNPs to indexes @@ -206,9 +206,9 @@ def snp_seq1(snp, seq_len, genome_open): # extract sequence as BED style if seq_start < 0: - seq = 'N'*(1-seq_start) + genome_open.fetch(snp.chrom, 0, seq_end).upper() + seq = 'N'*(1-seq_start) + genome_open.fetch(snp.chr, 0, seq_end).upper() else: - seq = genome_open.fetch(snp.chrom, seq_start - 1, seq_end).upper() + seq = genome_open.fetch(snp.chr, seq_start - 1, seq_end).upper() # extend to full length if len(seq) < seq_end - seq_start: @@ -299,10 +299,10 @@ def snps_seq1(snps, seq_len, genome_fasta, return_seqs=False): # extract sequence as BED style if seq_start < 0: - seq = 'N' * (-seq_start) + genome_open.fetch(snp.chrom, 0, + seq = 'N' * (-seq_start) + genome_open.fetch(snp.chr, 0, seq_end).upper() else: - seq = genome_open.fetch(snp.chrom, seq_start - 1, seq_end).upper() + seq = genome_open.fetch(snp.chr, seq_start - 1, seq_end).upper() # extend to full length if len(seq) < seq_end - seq_start: @@ -420,10 +420,10 @@ def snps2_seq1(snps, seq_len, genome1_fasta, genome2_fasta, return_seqs=False): # extract sequence as BED style if seq_start < 0: - seq_ref = 'N' * (-seq_start) + genome1.fetch(snp.chrom, 0, + seq_ref = 'N' * (-seq_start) + genome1.fetch(snp.chr, 0, seq_end).upper() else: - seq_ref = genome1.fetch(snp.chrom, seq_start - 1, seq_end).upper() + seq_ref = genome1.fetch(snp.chr, seq_start - 1, seq_end).upper() # extend to full length if len(seq_ref) < seq_end - seq_start: @@ -442,10 +442,10 @@ def snps2_seq1(snps, seq_len, genome1_fasta, genome2_fasta, return_seqs=False): # extract sequence as BED style if seq_start < 0: - seq_alt = 'N' * (-seq_start) + genome2.fetch(snp.chrom, 0, + seq_alt = 'N' * (-seq_start) + genome2.fetch(snp.chr, 0, seq_end).upper() else: - seq_alt = genome2.fetch(snp.chrom, seq_start - 1, seq_end).upper() + seq_alt = genome2.fetch(snp.chr, seq_start - 1, seq_end).upper() # extend to full length if len(seq_alt) < seq_end - seq_start: @@ -507,32 +507,7 @@ def dna_length_1hot(seq, length): return seq_1hot, seq -def filter_positive(pos_vcf, uneg_vcf, neg_vcf, dist_t=100): - """ Remove SNPs in uneg_vcf within dist_t from SNPs in pos_vcf """ - - neg_out = open(neg_vcf, 'w') - print('##fileformat=VCFv4.0', file=neg_out) - - printed_snps = set() - - p = subprocess.Popen( - 'bedtools closest -d -a %s -b %s' % (uneg_vcf, pos_vcf), - shell=True, - stdout=subprocess.PIPE) - for line in p.stdout: - line = line.decode('UTF-8') - a = line.split() - snp_id = a[2] - dist = int(a[-1]) - if dist == -1 or dist > dist_t: - if snp_id not in printed_snps: - print('\t'.join(a[:7]), file=neg_out) - printed_snps.add(snp_id) - - neg_out.close() - - -def vcf_snps(vcf_file, index_snp=False, score=False, pos2=False): +def vcf_snps(vcf_file, require_sorted=False, validate_ref_fasta=None, pos2=False): """ Load SNPs from a VCF file """ if vcf_file[-3:] == '.gz': vcf_in = gzip.open(vcf_file, 'rt') @@ -544,10 +519,46 @@ def vcf_snps(vcf_file, index_snp=False, score=False, pos2=False): while line[0] == '#': line = vcf_in.readline() + # to check sorted + if require_sorted: + seen_chrs = set() + prev_chr = None + prev_pos = -1 + + # to check reference + if validate_ref_fasta is not None: + genome_open = pysam.Fastafile(validate_ref_fasta) + # read in SNPs snps = [] while line: - snps.append(SNP(line, index_snp, score, pos2)) + snps.append(SNP(line, pos2)) + + if require_sorted: + if prev_chr is not None: + # same chromosome + if prev_chr == snps[-1].chr: + if snps[-1].pos < prev_pos: + print('Sorted VCF required. Mis-ordered position: %s' % line.rstrip(), + file=sys.stderr) + exit(1) + elif snps[-1].chr in seen_chrs: + print('Sorted VCF required. Mis-ordered chromosome: %s' % line.rstrip(), + file=sys.stderr) + exit(1) + + seen_chrs.add(snps[-1].chr) + prev_chr = snps[-1].chr + prev_pos = snps[-1].pos + + if validate_ref_fasta is not None: + ref_n = len(snps[-1].ref_allele) + snp_pos = snps[-1].pos-1 + ref_snp = genome_open.fetch(snps[-1].chr, snp_pos, snp_pos+ref_n) + if snps[-1].ref_allele != ref_snp: + print('ERROR: %s does not match reference %s' % (snps[-1], ref_snp), file=sys.stderr) + exit(1) + line = vcf_in.readline() vcf_in.close() @@ -581,27 +592,20 @@ class SNP: vcf_line (str) """ - def __init__(self, vcf_line, index_snp=False, score=False, pos2=False): + def __init__(self, vcf_line, pos2=False): a = vcf_line.split() if a[0].startswith('chr'): - self.chrom = a[0] + self.chr = a[0] else: - self.chrom = 'chr%s' % a[0] + self.chr = 'chr%s' % a[0] self.pos = int(a[1]) self.rsid = a[2] self.ref_allele = a[3] self.alt_alleles = a[4].split(',') + self.alt_allele = self.alt_alleles[0] if self.rsid == '.': - self.rsid = '%s:%d' % (self.chrom, self.pos) - - self.index_snp = '.' - if index_snp: - self.index_snp = a[5] - - self.score = None - if score: - self.score = float(a[6]) + self.rsid = '%s:%d' % (self.chr, self.pos) self.pos2 = None if pos2: @@ -617,6 +621,6 @@ def longest_alt(self): return max([len(al) for al in self.alt_alleles]) def __str__(self): - return 'SNP(%s, %s:%d, %s/%s)' % (self.rsid, self.chrom, self.pos, + return 'SNP(%s, %s:%d, %s/%s)' % (self.rsid, self.chr, self.pos, self.ref_allele, ','.join(self.alt_alleles)) diff --git a/bin/basenji_sadq.py b/bin/basenji_sadq.py index 55a23aa1..29074216 100755 --- a/bin/basenji_sadq.py +++ b/bin/basenji_sadq.py @@ -46,9 +46,6 @@ def main(): usage = 'usage: %prog [options] ' parser = OptionParser(usage) - parser.add_option('-b',dest='batch_size', - default=256, type='int', - help='Batch size [Default: %default]') parser.add_option('-c', dest='csv', default=False, action='store_true', help='Print table as CSV [Default: %default]') @@ -153,12 +150,6 @@ def main(): file=sys.stderr) exit(1) - if 'target_pool' not in job: - print( - "Must specify target pooling (target_pool) in the parameters file.", - file=sys.stderr) - exit(1) - if options.targets_file is None: target_ids = ['t%d' % ti for ti in range(job['num_targets'])] target_labels = ['']*len(target_ids) diff --git a/bin/basenji_sadq_ref.py b/bin/basenji_sadq_ref.py new file mode 100755 index 00000000..8b38ad46 --- /dev/null +++ b/bin/basenji_sadq_ref.py @@ -0,0 +1,487 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function + +from optparse import OptionParser +import pdb +import pickle +import os +from queue import Queue +import sys +from threading import Thread +import time + +import h5py +import numpy as np +import pandas as pd +import pysam +import tensorflow as tf + +import basenji.dna_io as dna_io +import basenji.params as params +import basenji.seqnn as seqnn +import basenji.vcf as bvcf +from basenji.stream import PredStream + +''' +basenji_sadq_ref.py + +Compute SNP Activity Difference (SAD) scores for SNPs in a VCF file. +This versions saves computation by clustering nearby SNPs in order to +make a single reference prediction for several SNPs. +''' + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option('-c', dest='center_pct', + default=0.25, type='float', + help='Require clustered SNPs lie in center region [Default: %default]') + parser.add_option('-f', dest='genome_fasta', + default='%s/assembly/hg19.fa' % os.environ['HG19'], + help='Genome FASTA for sequences [Default: %default]') + parser.add_option('-g', dest='genome_file', + default='%s/assembly/human.hg19.genome' % os.environ['HG19'], + help='Chromosome lengths file [Default: %default]') + parser.add_option('--h5', dest='out_h5', + default=False, action='store_true', + help='Output stats to sad.h5 [Default: %default]') + parser.add_option('-l', dest='seq_len', + default=131072, type='int', + help='Sequence length provided to the model [Default: %default]') + parser.add_option('--local', dest='local', + default=1024, type='int', + help='Local SAD score [Default: %default]') + parser.add_option('-n', dest='norm_file', + default=None, + help='Normalize SAD scores') + parser.add_option('-o',dest='out_dir', + default='sad', + help='Output directory for tables and plots [Default: %default]') + parser.add_option('-p', dest='processes', + default=None, type='int', + help='Number of processes, passed by multi script') + parser.add_option('--pseudo', dest='log_pseudo', + default=1, type='float', + help='Log2 pseudocount [Default: %default]') + parser.add_option('--rc', dest='rc', + default=False, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + parser.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--stats', dest='sad_stats', + default='SAD', + help='Comma-separated list of stats to save. [Default: %default]') + parser.add_option('-t', dest='targets_file', + default=None, type='str', + help='File specifying target indexes and labels in table format') + parser.add_option('--ti', dest='track_indexes', + default=None, type='str', + help='Comma-separated list of target indexes to output BigWig tracks') + parser.add_option('-u', dest='penultimate', + default=False, action='store_true', + help='Compute SED in the penultimate layer [Default: %default]') + parser.add_option('-z', dest='out_zarr', + default=False, action='store_true', + help='Output stats to sad.zarr [Default: %default]') + (options, args) = parser.parse_args() + + if len(args) == 3: + # single worker + params_file = args[0] + model_file = args[1] + vcf_file = args[2] + + elif len(args) == 5: + # multi worker + options_pkl_file = args[0] + params_file = args[1] + model_file = args[2] + vcf_file = args[3] + worker_index = int(args[4]) + + # load options + options_pkl = open(options_pkl_file, 'rb') + options = pickle.load(options_pkl) + options_pkl.close() + + # update output directory + options.out_dir = '%s/job%d' % (options.out_dir, worker_index) + + else: + parser.error('Must provide parameters and model files and QTL VCF file') + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + if options.track_indexes is None: + options.track_indexes = [] + else: + options.track_indexes = [int(ti) for ti in options.track_indexes.split(',')] + if not os.path.isdir('%s/tracks' % options.out_dir): + os.mkdir('%s/tracks' % options.out_dir) + + options.shifts = [int(shift) for shift in options.shifts.split(',')] + options.sad_stats = options.sad_stats.split(',') + + + ################################################################# + # read parameters and collet target information + + job = params.read_job_params(params_file) + job['seq_length'] = options.seq_len + + if 'num_targets' not in job: + print("Must specify num_targets in the parameters file.", file=sys.stderr) + exit(1) + + if options.targets_file is None: + target_ids = ['t%d' % ti for ti in range(job['num_targets'])] + target_labels = ['']*len(target_ids) + target_subset = None + + else: + targets_df = pd.read_table(options.targets_file) + target_ids = targets_df.identifier + target_labels = targets_df.description + target_subset = targets_df.index + if len(target_subset) == job['num_targets']: + target_subset = None + + + ################################################################# + # load SNPs + + # read sorted SNPs from VCF + snps = bvcf.vcf_snps(vcf_file, require_sorted=True, validate_ref_fasta=options.genome_fasta) + + # filter for worker SNPs + if options.processes is not None: + worker_bounds = np.linspace(0, len(snps), options.processes+1, dtype='int') + snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index+1]] + + num_snps = len(snps) + + # cluster SNPs by position + snp_clusters = cluster_snps(snps, options.seq_len, options.center_pct) + + # delimit sequence boundaries + [sc.delimit(options.seq_len) for sc in snp_clusters] + + # open genome FASTA + genome_open = pysam.Fastafile(options.genome_fasta) + + # make SNP sequence generator + def snp_gen(): + for sc in snp_clusters: + snp_1hot_list = sc.get_1hots(genome_open) + for snp_1hot in snp_1hot_list: + yield {'sequence':snp_1hot} + + snp_types = {'sequence': tf.float32} + snp_shapes = {'sequence': tf.TensorShape([tf.Dimension(options.seq_len), + tf.Dimension(4)])} + + dataset = tf.data.Dataset().from_generator(snp_gen, + output_types=snp_types, + output_shapes=snp_shapes) + dataset = dataset.batch(job['batch_size']) + dataset = dataset.prefetch(2*job['batch_size']) + # dataset = dataset.apply(tf.contrib.data.prefetch_to_device('/device:GPU:0')) + + iterator = dataset.make_one_shot_iterator() + data_ops = iterator.get_next() + + + ################################################################# + # setup model + + # build model + t0 = time.time() + model = seqnn.SeqNN() + model.build_sad(job, data_ops, + ensemble_rc=options.rc, ensemble_shifts=options.shifts, + penultimate=options.penultimate, target_subset=target_subset) + print('Model building time %f' % (time.time() - t0), flush=True) + + if options.penultimate: + # labels become inappropriate + target_ids = ['']*model.hp.cnn_filters[-1] + target_labels = target_ids + + # read target normalization factors + target_norms = np.ones(len(target_labels)) + if options.norm_file is not None: + ti = 0 + for line in open(options.norm_file): + target_norms[ti] = float(line.strip()) + ti += 1 + + num_targets = len(target_ids) + + ################################################################# + # setup output + + sad_out = initialize_output_h5(options.out_dir, options.sad_stats, + snps, target_ids, target_labels) + + snp_threads = [] + + snp_queue = Queue() + for i in range(1): + sw = SNPWorker(snp_queue, sad_out) + sw.start() + snp_threads.append(sw) + + ################################################################# + # predict SNP scores, write output + + # initialize saver + saver = tf.train.Saver() + with tf.Session() as sess: + # coordinator + coord = tf.train.Coordinator() + tf.train.start_queue_runners(coord=coord) + + # load variables into session + saver.restore(sess, model_file) + + # initialize predictions stream + preds_stream = PredStream(sess, model, 32) + + # predictions index + pi = 0 + + # SNP index + si = 0 + + for snp_cluster in snp_clusters: + ref_preds = preds_stream[pi] + pi += 1 + + for snp in snp_cluster.snps: + # print(snp, flush=True) + + alt_preds = preds_stream[pi] + pi += 1 + + # queue SNP + snp_queue.put((ref_preds, alt_preds, si)) + + # update SNP index + si += 1 + + # finish queue + print('Waiting for threads to finish.', flush=True) + snp_queue.join() + + # close genome + genome_open.close() + + ################################################### + # compute SAD distributions across variants + + # define percentiles + d_fine = 0.001 + d_coarse = 0.01 + percentiles_neg = np.arange(d_fine, 0.1, d_fine) + percentiles_base = np.arange(0.1, 0.9, d_coarse) + percentiles_pos = np.arange(0.9, 1, d_fine) + + percentiles = np.concatenate([percentiles_neg, percentiles_base, percentiles_pos]) + sad_out.create_dataset('percentiles', data=percentiles) + pct_len = len(percentiles) + + for sad_stat in options.sad_stats: + sad_stat_pct = '%s_pct' % sad_stat + + # compute + sad_pct = np.percentile(sad_out[sad_stat], 100*percentiles, axis=0).T + sad_pct = sad_pct.astype('float16') + + # save + sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16') + + sad_out.close() + + +def cluster_snps(snps, seq_len, center_pct): + """Cluster a sorted list of SNPs into regions that will satisfy + the required center_pct.""" + valid_snp_distance = int(seq_len*center_pct) + + snp_clusters = [] + cluster_chr = None + + for snp in snps: + if snp.chr == cluster_chr and snp.pos < cluster_pos0 + valid_snp_distance: + # append to latest cluster + snp_clusters[-1].add_snp(snp) + else: + # initialize new cluster + snp_clusters.append(SNPCluster()) + snp_clusters[-1].add_snp(snp) + cluster_chr = snp.chr + cluster_pos0 = snp.pos + + return snp_clusters + + +def initialize_output_h5(out_dir, sad_stats, snps, target_ids, target_labels): + """Initialize an output HDF5 file for SAD stats.""" + + num_targets = len(target_ids) + num_snps = len(snps) + + sad_out = h5py.File('%s/sad.h5' % out_dir, 'w') + + # write SNPs + snp_ids = np.array([snp.rsid for snp in snps], 'S') + sad_out.create_dataset('snp', data=snp_ids) + + # write targets + sad_out.create_dataset('target_ids', data=np.array(target_ids, 'S')) + sad_out.create_dataset('target_labels', data=np.array(target_labels, 'S')) + + # initialize SAD stats + for sad_stat in sad_stats: + sad_out.create_dataset(sad_stat, + shape=(num_snps, num_targets), + dtype='float16', + compression=None) + + return sad_out + + +class SNPCluster: + def __init__(self): + self.snps = [] + self.chr = None + self.start = None + self.end = None + + def add_snp(self, snp): + self.snps.append(snp) + + def delimit(self, seq_len): + positions = [snp.pos for snp in self.snps] + pos_min = np.min(positions) + pos_max = np.max(positions) + pos_mid = (pos_min + pos_max) // 2 + + self.chr = self.snps[0].chr + self.start = pos_mid - seq_len//2 + self.end = self.start + seq_len + + for snp in self.snps: + snp.seq_pos = snp.pos - 1 - self.start + + def get_1hots(self, genome_open): + seqs1_list = [] + + # extract reference + if self.start < 0: + ref_seq = 'N'*(1-self.start) + genome_open.fetch(self.chr, 0, self.end).upper() + else: + ref_seq = genome_open.fetch(self.chr, self.start, self.end).upper() + + # extend to full length + if len(ref_seq) < self.end - self.start: + ref_seq += 'N'*(self.end-self.start-len(ref_seq)) + + # verify reference alleles + for snp in self.snps: + ref_n = len(snp.ref_allele) + ref_snp = ref_seq[snp.seq_pos:snp.seq_pos+ref_n] + if snp.ref_allele != ref_snp: + print('ERROR: %s does not match reference %s' % (snp, ref_snp), file=sys.stderr) + exit(1) + + # 1 hot code reference sequence + ref_1hot = dna_io.dna_1hot(ref_seq) + seqs1_list = [ref_1hot] + + # make alternative 1 hot coded sequences + # (assuming SNP is 1-based indexed) + for snp in self.snps: + alt_1hot = make_alt_1hot(ref_1hot, snp.seq_pos, snp.ref_allele, snp.alt_allele) + seqs1_list.append(alt_1hot) + + return seqs1_list + + +class SNPWorker(Thread): + """Compute summary statistics and write to HDF.""" + def __init__(self, snp_queue, sad_out): + Thread.__init__(self) + self.queue = snp_queue + self.daemon = True + self.sad_out = sad_out + + def run(self): + while True: + # unload predictions + ref_preds, alt_preds, si = self.queue.get() + + # sum across length + ref_preds_sum = ref_preds.sum(axis=0, dtype='float64') + alt_preds_sum = alt_preds.sum(axis=0, dtype='float64') + + # compare reference to alternative via mean subtraction + sad = alt_preds_sum - ref_preds_sum + + # write to HDF5 + self.sad_out['SAD'][si,:] = sad.astype('float16') + + # communicate finished task + self.queue.task_done() + + +def make_alt_1hot(ref_1hot, snp_seq_pos, ref_allele, alt_allele): + """Return alternative allele one hot coding.""" + ref_n = len(ref_allele) + alt_n = len(alt_allele) + + # copy reference + alt_1hot = np.copy(ref_1hot) + + if alt_n == ref_n: + # SNP + dna_io.hot1_set(alt_1hot, snp_seq_pos, alt_allele) + + elif ref_n > alt_n: + # deletion + delete_len = ref_n - alt_n + assert (ref_allele[0] == alt_allele[0]) + dna_io.hot1_delete(alt_1hot, snp_seq_pos+1, delete_len) + + else: + # insertion + assert (ref_allele[0] == alt_allele[0]) + dna_io.hot1_insert(alt_1hot, snp_seq_pos+1, alt_allele[1:]) + + return alt_1hot + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/tests/data/regulatory_validated_misorder.vcf b/tests/data/regulatory_validated_misorder.vcf new file mode 100644 index 00000000..459ffd1e --- /dev/null +++ b/tests/data/regulatory_validated_misorder.vcf @@ -0,0 +1,18 @@ +##fileformat=VCFv4.2 +#CHROM POS ID REF ALT QUAL FILTER INFO +chr1 109817590 rs12740374 G T 100 . +chr1 56972353 rs72664324 G A 100 . +chr1 226595403 rs144361550 A AGGGCCC 100 . +chr2 60718043 rs1427407 T G 100 . +chr4 90646886 rs356165 G A 100 . +chr6 117210052 rs339331 T C 100 . +chr6 151953765 rs9383590 T C 100 . +chr6 151954834 rs140068132 A G 100 . +chr7 12284008 rs1990620 A G 100 . +chr8 128413305 rs6983267 G T 100 . +chr9 22124477 rs10757278 A G 100 . +chr10 64807993 rs7903145 T C 100 . +chr11 320836 rs34481144 C T 100 . +chr16 52599188 rs4784227 C T 100 . +chr16 53800954 rs1421085 T C 100 . +chr22 21921686 rs140490 G T 100 . diff --git a/tests/data/regulatory_validated_misref.vcf b/tests/data/regulatory_validated_misref.vcf new file mode 100644 index 00000000..2799cb0e --- /dev/null +++ b/tests/data/regulatory_validated_misref.vcf @@ -0,0 +1,18 @@ +##fileformat=VCFv4.2 +#CHROM POS ID REF ALT QUAL FILTER INFO +chr1 56972353 rs72664324 A G 100 . +chr1 109817590 rs12740374 G T 100 . +chr1 226595403 rs144361550 A AGGGCCC 100 . +chr2 60718043 rs1427407 T G 100 . +chr4 90646886 rs356165 G A 100 . +chr6 117210052 rs339331 T C 100 . +chr6 151953765 rs9383590 T C 100 . +chr6 151954834 rs140068132 A G 100 . +chr7 12284008 rs1990620 A G 100 . +chr8 128413305 rs6983267 G T 100 . +chr9 22124477 rs10757278 A G 100 . +chr10 64807993 rs7903145 T C 100 . +chr11 320836 rs34481144 C T 100 . +chr16 52599188 rs4784227 C T 100 . +chr16 53800954 rs1421085 T C 100 . +chr22 21921686 rs140490 G T 100 . diff --git a/tests/test_sad.py b/tests/test_sad.py index c3426d70..00821490 100755 --- a/tests/test_sad.py +++ b/tests/test_sad.py @@ -3,13 +3,20 @@ import os import pdb +import shutil import subprocess import unittest import h5py +import pysam import numpy as np +import basenji.dna_io as dna_io +import basenji.vcf as bvcf +import basenji_sadq_ref + + class TestSAD(unittest.TestCase): @classmethod def setUpClass(cls): @@ -45,6 +52,133 @@ def test_usad(self): np.testing.assert_allclose(this_sad, saved_sad, atol=1e-3, rtol=1e-3) +class TestSadQRef(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.genome_fasta = '%s/assembly/hg19.fa' % os.environ['HG19'] + cls.params_file = 'data/params.txt' + cls.model_file = 'data/model_best.tf' + cls.vcf_file = 'data/regulatory_validated.vcf' + cls.seq_length = 131072 + + def test_run(self): + sad_opts = '-l %d -o sadq_ref' % self.seq_length + + cmd = 'basenji_sadq_ref.py %s %s %s %s' % \ + (sad_opts, self.params_file, self.model_file, self.vcf_file) + return_code = subprocess.call(cmd, shell=True) + self.assertEqual(return_code, 0) + + if os.path.isdir('sadq_ref'): + shutil.rmtree('sadq_ref') + + def test_misref(self): + sad_opts = '-l %d -o sadq_ref' % self.seq_length + + vcf_misref_file = 'data/regulatory_validated_misref.vcf' + cmd = 'basenji_sadq_ref.py %s %s %s %s' % \ + (sad_opts, self.params_file, self.model_file, vcf_misref_file) + return_code = subprocess.call(cmd, shell=True) + self.assertEqual(return_code, 1) + + if os.path.isdir('sadq_ref'): + shutil.rmtree('sadq_ref') + + def test_misorder(self): + sad_opts = '-l %d -o sadq_ref' % self.seq_length + + vcf_misorder_file = 'data/regulatory_validated_misorder.vcf' + cmd = 'basenji_sadq_ref.py %s %s %s %s' % \ + (sad_opts, self.params_file, self.model_file, vcf_misorder_file) + return_code = subprocess.call(cmd, shell=True) + self.assertEqual(return_code, 1) + + if os.path.isdir('sadq_ref'): + shutil.rmtree('sadq_ref') + + def test_cluster(self): + # read sorted SNPs from VCF + snps = bvcf.vcf_snps(self.vcf_file, require_sorted=True) + + # cluster SNPs by position + snp_clusters = basenji_sadq_ref.cluster_snps(snps, self.seq_length, 0.25) + + # two SNPs should be clustered together + self.assertEqual(len(snps)-1, len(snp_clusters)) + self.assertEqual(len(snp_clusters[0].snps), 1) + self.assertEqual(len(snp_clusters[6].snps), 2) + + def test_get_1hots(self): + # read sorted SNPs from VCF + snps = bvcf.vcf_snps(self.vcf_file, require_sorted=True) + + # cluster SNPs by position + snp_clusters = basenji_sadq_ref.cluster_snps(snps, self.seq_length, 0.25) + + # delimit sequence boundaries + [sc.delimit(self.seq_length) for sc in snp_clusters] + + # open genome FASTA + genome_open = pysam.Fastafile(self.genome_fasta) + + ######################################## + # verify single SNP + + # get 1 hot coded sequences + snp_1hot_list = snp_clusters[0].get_1hots(genome_open) + + self.assertEqual(len(snp_1hot_list), 2) + self.assertEqual(snp_1hot_list[1].shape, (self.seq_length, 4)) + + mid_i = self.seq_length // 2 - 1 + self.assertEqual(mid_i, snps[0].seq_pos) + + ref_nt = dna_io.hot1_get(snp_1hot_list[0], mid_i) + self.assertEqual(ref_nt, snps[0].ref_allele) + + alt_nt = dna_io.hot1_get(snp_1hot_list[1], mid_i) + self.assertEqual(alt_nt, snps[0].alt_allele) + + + ######################################## + # verify multiple SNPs + + # get 1 hot coded sequences + snp_1hot_list = snp_clusters[6].get_1hots(genome_open) + + self.assertEqual(len(snp_1hot_list), 3) + + snp1, snp2 = snps[6:8] + + # verify position 1 changes between 0 and 1 + nt = dna_io.hot1_get(snp_1hot_list[0], snp1.seq_pos) + self.assertEqual(nt, snp1.ref_allele) + + nt = dna_io.hot1_get(snp_1hot_list[1], snp1.seq_pos) + self.assertEqual(nt, snp1.alt_allele) + + # verify position 2 is unchanged between 0 and 1 + nt = dna_io.hot1_get(snp_1hot_list[0], snp2.seq_pos) + self.assertEqual(nt, snp2.ref_allele) + + nt = dna_io.hot1_get(snp_1hot_list[1], snp2.seq_pos) + self.assertEqual(nt, snp2.ref_allele) + + # verify position 2 is unchanged between 0 and 2 + nt = dna_io.hot1_get(snp_1hot_list[0], snp1.seq_pos) + self.assertEqual(nt, snp1.ref_allele) + + nt = dna_io.hot1_get(snp_1hot_list[2], snp1.seq_pos) + self.assertEqual(nt, snp1.ref_allele) + + # verify position 2 changes between 0 and 2 + nt = dna_io.hot1_get(snp_1hot_list[0], snp2.seq_pos) + self.assertEqual(nt, snp2.ref_allele) + + nt = dna_io.hot1_get(snp_1hot_list[2], snp2.seq_pos) + self.assertEqual(nt, snp2.alt_allele) + + ################################################################################ # __main__ ################################################################################ From 33fa21e01897c8c83ab0b21db9d2b51bf74b89da Mon Sep 17 00:00:00 2001 From: David Kelley Date: Tue, 14 Aug 2018 20:48:15 -0700 Subject: [PATCH 70/71] increase shuffle buffer --- basenji/tfrecord_batcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/basenji/tfrecord_batcher.py b/basenji/tfrecord_batcher.py index cd22a14b..94c527e3 100755 --- a/basenji/tfrecord_batcher.py +++ b/basenji/tfrecord_batcher.py @@ -23,10 +23,10 @@ # Multiplier for how many items to have in the shuffle buffer, invariant of how # many files we're parallel-interleaving for our input datasets. -SHUFFLE_BUFFER_DEPTH_PER_FILE = 2 +SHUFFLE_BUFFER_DEPTH_PER_FILE = 128 # Number of files to concurrently read from, and interleave, for our input # datasets. -NUM_FILES_TO_PARALLEL_INTERLEAVE = 10 +NUM_FILES_TO_PARALLEL_INTERLEAVE = 16 def tfrecord_dataset(tfr_data_files_pattern, batch_size, From a4c05a68cc5e7442fe00cf7c1c11ed9d606f1701 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Fri, 17 Aug 2018 17:23:12 -0700 Subject: [PATCH 71/71] typo and descriptions --- basenji/augmentation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/basenji/augmentation.py b/basenji/augmentation.py index f6f931cb..4f99e974 100644 --- a/basenji/augmentation.py +++ b/basenji/augmentation.py @@ -75,9 +75,10 @@ def augment_deterministic(data_ops, augment_rc=False, augment_shift=0): Args: data_ops: dict with keys 'sequence,' 'label,' and 'na.' augment_rc: Boolean - augment_shifts: Int + augment_shift: Int Returns - data_ops: augmented data + data_ops: augmented data, with all existing keys transformed + and 'reverse_preds' bool added. """ data_ops_aug = {}