Skip to content

Commit

Permalink
add the gradient cal
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Oct 24, 2023
1 parent d717f2f commit 8974ac3
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 87 deletions.
2 changes: 1 addition & 1 deletion src/baskerville/helpers/h5_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,4 @@ def collect_h5_borzoi(out_dir, num_procs, sad_stat) -> None:
for key in final_strings:
final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S"))

final_h5_open.close()
final_h5_open.close()
2 changes: 1 addition & 1 deletion src/baskerville/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def load_extra_options(options_pkl_file, options):
for attr_name, attr_value in new_option_attrs.items():
setattr(options, attr_name, attr_value)
options_pkl.close()
return options
return options
208 changes: 123 additions & 85 deletions src/baskerville/seqnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 Calico LLC
# Copyright 2023 Calico LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
from __future__ import print_function

import pdb
import sys
import time
import gc

from natsort import natsorted
import numpy as np
Expand All @@ -29,26 +26,38 @@


class SeqNN:
def __init__(self, params):
"""Sequence neural network model.
Args:
params (dict): Model specification and parameters.
"""

def __init__(self, params: dict):
self.set_defaults()
for key, value in params.items():
self.__setattr__(key, value)
self.build_model()
self.ensemble = None

def set_defaults(self):
# only necessary for my bespoke parameters
# others are best defaulted closer to the source
"""Set default parameters.
Only necessary for my bespoke parameters.
Others are best defaulted closer to the source.
"""
self.augment_rc = False
self.augment_shift = [0]
self.strand_pair = []
self.verbose = True

def build_block(self, current, block_params):
"""Construct a SeqNN block.
Args:
current: Current Tensor.
block_params (dict): Block parameters.
Returns:
current
current: New current Tensor.
"""
block_args = {}

Expand Down Expand Up @@ -117,10 +126,11 @@ def build_block(self, current, block_params):

return current

def build_model(self, save_reprs=True):
def build_model(self, save_reprs: bool = True):
"""Build the model."""

###################################################
# inputs
###################################################
sequence = tf.keras.Input(shape=(self.seq_length, 4), name="sequence")
current = sequence

Expand All @@ -133,7 +143,6 @@ def build_model(self, save_reprs=True):

###################################################
# build convolution blocks
###################################################
self.reprs = []
for bi, block_params in enumerate(self.trunk):
current = self.build_block(current, block_params)
Expand All @@ -149,7 +158,6 @@ def build_model(self, save_reprs=True):

###################################################
# heads
###################################################
head_keys = natsorted([v for v in vars(self) if v.startswith("head")])
self.heads = [getattr(self, hk) for hk in head_keys]

Expand Down Expand Up @@ -184,49 +192,18 @@ def build_model(self, save_reprs=True):

###################################################
# compile model(s)
###################################################
self.models = []
for ho in self.head_output:
self.models.append(tf.keras.Model(inputs=sequence, outputs=ho))
self.model = self.models[0]
if self.verbose:
print(self.model.summary())

###################################################
# track pooling/striding and cropping
###################################################
self.model_strides = []
self.target_lengths = []
self.target_crops = []
for model in self.models:
# determine model stride
self.model_strides.append(1)
for layer in self.model.layers:
if hasattr(layer, "strides") or hasattr(layer, "size"):
stride_factor = layer.input_shape[1] / layer.output_shape[1]
self.model_strides[-1] *= stride_factor
self.model_strides[-1] = int(self.model_strides[-1])
self.track_sequence(sequence)

# determine predictions length before cropping
if type(sequence.shape[1]) == tf.compat.v1.Dimension:
target_full_length = sequence.shape[1].value // self.model_strides[-1]
else:
target_full_length = sequence.shape[1] // self.model_strides[-1]

# determine predictions length after cropping
self.target_lengths.append(model.outputs[0].shape[1])
if type(self.target_lengths[-1]) == tf.compat.v1.Dimension:
self.target_lengths[-1] = self.target_lengths[-1].value
self.target_crops.append(
(target_full_length - self.target_lengths[-1]) // 2
)

if self.verbose:
print("model_strides", self.model_strides)
print("target_lengths", self.target_lengths)
print("target_crops", self.target_crops)

def build_embed(self, conv_layer_i, batch_norm=True):
def build_embed(self, conv_layer_i: int, batch_norm: bool = True):
"""Build model to embed sequences into specific layer."""
if conv_layer_i == -1:
self.model = self.model_trunk

Expand All @@ -240,14 +217,15 @@ def build_embed(self, conv_layer_i, batch_norm=True):
inputs=self.model.inputs, outputs=conv_layer.output
)

def build_ensemble(self, ensemble_rc=False, ensemble_shifts=[0]):
def build_ensemble(self, ensemble_rc: bool = False, ensemble_shifts=[0]):
"""Build ensemble of models computing on augmented input sequences."""
if ensemble_rc or len(ensemble_shifts) > 1:
shift_bool = len(ensemble_shifts) > 1 or ensemble_shifts[0] != 0
if ensemble_rc or shift_bool:
# sequence input
sequence = tf.keras.Input(shape=(self.seq_length, 4), name="sequence")
sequences = [sequence]

if len(ensemble_shifts) > 1:
if shift_bool:
# generate shifted sequences
sequences = layers.EnsembleShift(ensemble_shifts)(sequences)

Expand Down Expand Up @@ -283,6 +261,7 @@ def build_ensemble(self, ensemble_rc=False, ensemble_shifts=[0]):
self.ensemble = tf.keras.Model(inputs=sequence, outputs=preds_avg)

def build_sad(self):
"""Sum across length axis, in graph."""
# sequence input
sequence = tf.keras.Input(shape=(self.seq_length, 4), name="sequence")

Expand All @@ -296,7 +275,7 @@ def build_sad(self):
# replace model
self.model = tf.keras.Model(inputs=sequence, outputs=sad)

def build_slice(self, target_slice=None, target_sum=False):
def build_slice(self, target_slice=None, target_sum: bool = False):
"""Slice and/or sum across tasks, in graph."""
if target_slice is not None or target_sum:
# sequence input
Expand Down Expand Up @@ -348,7 +327,9 @@ def downcast(self, dtype=tf.float16, head_i=None):
else:
self.model = model_down

def evaluate(self, seq_data, head_i=None, loss_label="poisson", loss_fn=None):
def evaluate(
self, seq_data, head_i=None, loss_label: str = "poisson", loss_fn=None
):
"""Evaluate model on SeqDataset."""
# choose model
if self.ensemble is not None:
Expand Down Expand Up @@ -403,7 +384,7 @@ def get_conv_layer(self, conv_layer_i=0):
return conv_layers[conv_layer_i]

def get_dense_layer(self, layer_i=0):
"""Return specified convolution layer."""
"""Return specified dense layer."""
dense_layers = [
layer for layer in self.model.layers if layer.name.startswith("dense")
]
Expand Down Expand Up @@ -812,9 +793,19 @@ def gradients_func(
return grads

def gradients_orig(
self, seq_1hot, head_i=None, pos_slice=None, batch_size=2, dtype="float16"
self, seq_1hot, head_i=None, pos_slice=None, batch_size=8, dtype="float16"
):
"""Compute input gradients sequence (original version of code)."""
"""Compute input gradients for each task.
Args:
seq_1hot (np.array): 1-hot encoded sequence.
head_i (int): Model head index.
pos_slice ([int]): Sequence positions to consider.
batch_size (int): number of tasks to compute gradients for at once.
dtype: Returned data type.
Returns:
Gradients for each task.
"""
# choose model
if self.ensemble is not None:
model = self.ensemble
Expand Down Expand Up @@ -850,30 +841,9 @@ def gradients_orig(
# replace model
model_batch = tf.keras.Model(inputs=sequence, outputs=predictions_slice)

# compute gradients
t0 = time.time()

# with tf.GradientTape() as tape:
# tape.watch(seq_1hot)

# # predict
# preds = model_batch(seq_1hot, training=False)

# if pos_slice is not None:
# # slice specified positions
# preds = tf.gather(preds, pos_slice, axis=-2)

# # sum across positions
# preds = tf.reduce_sum(preds, axis=-2)

# # compute jacboian
# grads_batch = tape.jacobian(preds, seq_1hot)
# grads_batch = tf.squeeze(grads_batch)
# grads_batch = tf.transpose(grads_batch, [1,2,0])

# # zero mean each position
# grads_batch = grads_batch - tf.reduce_mean(grads_batch, axis=-2, keepdims=True)

grads_batch = self.gradients_func_orig(model_batch, seq_1hot, pos_slice)
grads_batch = self.gradients_func(model_batch, seq_1hot, pos_slice)
print("Batch gradient computation in %ds" % (time.time() - t0))

# convert numpy dtype
Expand All @@ -890,6 +860,16 @@ def gradients_orig(

@tf.function
def gradients_func_orig(self, model, seq_1hot, pos_slice):
"""Compute input gradients for each task.
Args:
model (tf.keras.Model): Model to compute gradients for.
seq_1hot (tf.Tensor): 1-hot encoded sequence.
pos_slice ([int]): Sequence positions to consider.
Returns:
grads (tf.Tensor): Gradients for each task.
"""
with tf.GradientTape() as tape:
tape.watch(seq_1hot)

Expand All @@ -914,13 +894,14 @@ def gradients_func_orig(self, model, seq_1hot, pos_slice):
return grads

def num_targets(self, head_i=None):
"""Return number of targets."""
if head_i is None:
return self.model.output_shape[-1]
else:
return self.models[head_i].output_shape[-1]

def __call__(self, x, head_i=None, dtype="float32"):
"""Predict targets for SeqDataset."""
"""Predict targets for single batch."""
# choose model
if self.ensemble is not None:
model = self.ensemble
Expand All @@ -934,14 +915,23 @@ def __call__(self, x, head_i=None, dtype="float32"):
def predict(
self,
seq_data,
head_i=None,
generator=False,
stream=False,
step=1,
dtype="float32",
head_i: int = None,
generator: bool = False,
stream: bool = False,
step: int = 1,
dtype: str = "float32",
**kwargs,
):
"""Predict targets for SeqDataset."""
"""Predict targets for SeqDataset, with more options.
Args:
seq_data (SeqDataset): Dataset to predict on.
head_i (int): Model head index.
generator (bool): Use generator to predict on dataset.
stream (bool): Stream predictions from dataset.
step (int): Step size.
dtype (str): Data type to return.
"""
# choose model
if self.ensemble is not None:
model = self.ensemble
Expand Down Expand Up @@ -986,13 +976,24 @@ def restore(self, model_file, head_i=0, trunk=False):
self.model = self.models[head_i]

def save(self, model_file, trunk=False):
"""Save model weights to file.
Args:
model_file (str): Path to save model weights.
trunk (bool): Save trunk weights only.
"""
if trunk:
self.model_trunk.save(model_file, include_optimizer=False)
else:
self.model.save(model_file, include_optimizer=False)

def step(self, step=2, head_i=None):
"""Step positions across sequence."""
"""Create new model to step positions across sequence.
Args:
step (int): Step size.
head_i (int): Model head index.
"""
# choose model
if self.ensemble is not None:
model = self.ensemble
Expand All @@ -1017,3 +1018,40 @@ def step(self, step=2, head_i=None):
self.models[head_i] = model_step
else:
self.model = model_step

def track_sequence(self, sequence):
"""Track pooling, striding, and cropping of sequence.
Args:
sequence (tf.Tensor): Sequence input.
"""
self.model_strides = []
self.target_lengths = []
self.target_crops = []
for model in self.models:
# determine model stride
self.model_strides.append(1)
for layer in self.model.layers:
if hasattr(layer, "strides") or hasattr(layer, "size"):
stride_factor = layer.input_shape[1] / layer.output_shape[1]
self.model_strides[-1] *= stride_factor
self.model_strides[-1] = int(self.model_strides[-1])

# determine predictions length before cropping
if type(sequence.shape[1]) == tf.compat.v1.Dimension:
target_full_length = sequence.shape[1].value // self.model_strides[-1]
else:
target_full_length = sequence.shape[1] // self.model_strides[-1]

# determine predictions length after cropping
self.target_lengths.append(model.outputs[0].shape[1])
if type(self.target_lengths[-1]) == tf.compat.v1.Dimension:
self.target_lengths[-1] = self.target_lengths[-1].value
self.target_crops.append(
(target_full_length - self.target_lengths[-1]) // 2
)

if self.verbose:
print("model_strides", self.model_strides)
print("target_lengths", self.target_lengths)
print("target_crops", self.target_crops)

0 comments on commit 8974ac3

Please sign in to comment.