Skip to content

Commit

Permalink
New GPU-friendly gradient function in seqnn.py; Initial commit of var…
Browse files Browse the repository at this point in the history
…ious Borzoi benchmarking- and large-scale attribution scripts.
  • Loading branch information
Johannes Linder committed Jul 21, 2023
1 parent 5871257 commit 91f0d3f
Show file tree
Hide file tree
Showing 13 changed files with 6,086 additions and 4 deletions.
264 changes: 260 additions & 4 deletions basenji/seqnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,265 @@ def get_conv_weights(self, conv_layer_i=0):
weights = np.transpose(weights, [2,1,0])
return weights

def gradients(self, seq_1hot, head_i=None, target_slice=None, pos_slice=None, pos_mask=None, pos_slice_denom=None, pos_mask_denom=None, chunk_size=None, batch_size=1, track_scale=1., track_transform=1., clip_soft=None, pseudo_count=0., no_transform=False, use_mean=False, use_ratio=False, use_logodds=False, subtract_avg=True, input_gate=True, smooth_grad=False, n_samples=5, sample_prob=0.875, dtype='float16'):
""" Compute input gradients for sequences (GPU-friendly). """

# start time
t0 = time.time()

# choose model
if self.ensemble is not None:
model = self.ensemble
elif head_i is not None:
model = self.models[head_i]
else:
model = self.model

# verify tensor shape(s)
seq_1hot = seq_1hot.astype('float32')
target_slice = np.array(target_slice).astype('int32')
pos_slice = np.array(pos_slice).astype('int32')

# convert constants to tf tensors
track_scale = tf.constant(track_scale, dtype=tf.float32)
track_transform = tf.constant(track_transform, dtype=tf.float32)
if clip_soft is not None :
clip_soft = tf.constant(clip_soft, dtype=tf.float32)
pseudo_count = tf.constant(pseudo_count, dtype=tf.float32)

if pos_mask is not None :
pos_mask = np.array(pos_mask).astype('float32')

if use_ratio and pos_slice_denom is not None :
pos_slice_denom = np.array(pos_slice_denom).astype('int32')

if pos_mask_denom is not None :
pos_mask_denom = np.array(pos_mask_denom).astype('float32')

if len(seq_1hot.shape) < 3:
seq_1hot = seq_1hot[None, ...]

if len(target_slice.shape) < 2:
target_slice = target_slice[None, ...]

if len(pos_slice.shape) < 2:
pos_slice = pos_slice[None, ...]

if pos_mask is not None and len(pos_mask.shape) < 2:
pos_mask = pos_mask[None, ...]

if use_ratio and pos_slice_denom is not None and len(pos_slice_denom.shape) < 2:
pos_slice_denom = pos_slice_denom[None, ...]

if pos_mask_denom is not None and len(pos_mask_denom.shape) < 2:
pos_mask_denom = pos_mask_denom[None, ...]

# chunk parameters
num_chunks = 1
if chunk_size is None :
chunk_size = seq_1hot.shape[0]
else :
num_chunks = int(np.ceil(seq_1hot.shape[0] / chunk_size))

# loop over chunks
grad_chunks = []
for ci in range(num_chunks) :

# collect chunk
seq_1hot_chunk = seq_1hot[ci * chunk_size:(ci+1) * chunk_size, ...]
target_slice_chunk = target_slice[ci * chunk_size:(ci+1) * chunk_size, ...]
pos_slice_chunk = pos_slice[ci * chunk_size:(ci+1) * chunk_size, ...]

pos_mask_chunk = None
if pos_mask is not None :
pos_mask_chunk = pos_mask[ci * chunk_size:(ci+1) * chunk_size, ...]

pos_slice_denom_chunk = None
pos_mask_denom_chunk = None
if use_ratio and pos_slice_denom is not None :
pos_slice_denom_chunk = pos_slice_denom[ci * chunk_size:(ci+1) * chunk_size, ...]

if pos_mask_denom is not None :
pos_mask_denom_chunk = pos_mask_denom[ci * chunk_size:(ci+1) * chunk_size, ...]

actual_chunk_size = seq_1hot_chunk.shape[0]

# sample noisy (discrete) perturbations of the input pattern chunk
if smooth_grad :
seq_1hot_chunk_corrupted = np.repeat(np.copy(seq_1hot_chunk), n_samples, axis=0)

for example_ix in range(seq_1hot_chunk.shape[0]) :
for sample_ix in range(n_samples) :
corrupt_index = np.nonzero(np.random.rand(seq_1hot_chunk.shape[1]) >= sample_prob)[0]

rand_nt_index = np.random.choice([0, 1, 2, 3], size=(corrupt_index.shape[0],))

seq_1hot_chunk_corrupted[example_ix * n_samples + sample_ix, corrupt_index, :] = 0.
seq_1hot_chunk_corrupted[example_ix * n_samples + sample_ix, corrupt_index, rand_nt_index] = 1.

seq_1hot_chunk = seq_1hot_chunk_corrupted
target_slice_chunk = np.repeat(np.copy(target_slice_chunk), n_samples, axis=0)
pos_slice_chunk = np.repeat(np.copy(pos_slice_chunk), n_samples, axis=0)

if pos_mask is not None :
pos_mask_chunk = np.repeat(np.copy(pos_mask_chunk), n_samples, axis=0)

if use_ratio and pos_slice_denom is not None :
pos_slice_denom_chunk = np.repeat(np.copy(pos_slice_denom_chunk), n_samples, axis=0)

if pos_mask_denom is not None :
pos_mask_denom_chunk = np.repeat(np.copy(pos_mask_denom_chunk), n_samples, axis=0)

# convert to tf tensors
seq_1hot_chunk = tf.convert_to_tensor(seq_1hot_chunk, dtype=tf.float32)
target_slice_chunk = tf.convert_to_tensor(target_slice_chunk, dtype=tf.int32)
pos_slice_chunk = tf.convert_to_tensor(pos_slice_chunk, dtype=tf.int32)

if pos_mask is not None :
pos_mask_chunk = tf.convert_to_tensor(pos_mask_chunk, dtype=tf.float32)

if use_ratio and pos_slice_denom is not None :
pos_slice_denom_chunk = tf.convert_to_tensor(pos_slice_denom_chunk, dtype=tf.int32)

if pos_mask_denom is not None :
pos_mask_denom_chunk = tf.convert_to_tensor(pos_mask_denom_chunk, dtype=tf.float32)

# batching parameters
num_batches = int(np.ceil(actual_chunk_size * (n_samples if smooth_grad else 1) / batch_size))

# loop over batches
grad_batches = []
for bi in range(num_batches) :

# collect batch
seq_1hot_batch = seq_1hot_chunk[bi * batch_size:(bi+1) * batch_size, ...]
target_slice_batch = target_slice_chunk[bi * batch_size:(bi+1) * batch_size, ...]
pos_slice_batch = pos_slice_chunk[bi * batch_size:(bi+1) * batch_size, ...]

pos_mask_batch = None
if pos_mask is not None :
pos_mask_batch = pos_mask_chunk[bi * batch_size:(bi+1) * batch_size, ...]

pos_slice_denom_batch = None
pos_mask_denom_batch = None
if use_ratio and pos_slice_denom is not None :
pos_slice_denom_batch = pos_slice_denom_chunk[bi * batch_size:(bi+1) * batch_size, ...]

if pos_mask_denom is not None :
pos_mask_denom_batch = pos_mask_denom_chunk[bi * batch_size:(bi+1) * batch_size, ...]

grad_batch = self.gradients_func(model, seq_1hot_batch, target_slice_batch, pos_slice_batch, pos_mask_batch, pos_slice_denom_batch, pos_mask_denom_batch, track_scale, track_transform, clip_soft, pseudo_count, no_transform, use_mean, use_ratio, use_logodds, subtract_avg, input_gate).numpy().astype(dtype)

grad_batches.append(grad_batch)

# concat gradient batches
grads = np.concatenate(grad_batches, axis=0)

# aggregate noisy gradient perturbations
if smooth_grad :
grads_smoothed = np.zeros((grads.shape[0] // n_samples, grads.shape[1], grads.shape[2]), dtype='float32')

for example_ix in range(grads_smoothed.shape[0]) :
for sample_ix in range(n_samples) :
grads_smoothed[example_ix, ...] += grads[example_ix * n_samples + sample_ix, ...]

grads = grads_smoothed / float(n_samples)
grads = grads.astype(dtype)

grad_chunks.append(grads)

# collect garbage
gc.collect()

# concat gradient chunks
grads = np.concatenate(grad_chunks, axis=0)

# aggregate and broadcast to original input pattern
if input_gate :
grads = np.sum(grads, axis=-1, keepdims=True) * seq_1hot

print('Completed gradient computation in %ds' % (time.time()-t0))

return grads

@tf.function
def gradients_func(self, model, seq_1hot, target_slice, pos_slice, pos_mask=None, pos_slice_denom=None, pos_mask_denom=True, track_scale=1., track_transform=1., clip_soft=None, pseudo_count=0., no_transform=False, use_mean=False, use_ratio=False, use_logodds=False, subtract_avg=True, input_gate=True):

with tf.GradientTape() as tape:
tape.watch(seq_1hot)

# predict
preds = tf.gather(model(seq_1hot, training=False), target_slice, axis=-1, batch_dims=1)

if not no_transform :

# undo scale
preds = preds / track_scale

# undo soft_clip
if clip_soft is not None :
preds = tf.where(preds > clip_soft, (preds - clip_soft)**2 + clip_soft, preds)

# undo sqrt
preds = preds**(1. / track_transform)

# aggregate over tracks (average)
preds = tf.reduce_mean(preds, axis=-1)

# slice specified positions
preds_slice = tf.gather(preds, pos_slice, axis=-1, batch_dims=1)
if pos_mask is not None :
preds_slice = preds_slice * pos_mask

# slice denominator positions
if use_ratio and pos_slice_denom is not None:
preds_slice_denom = tf.gather(preds, pos_slice_denom, axis=-1, batch_dims=1)
if pos_mask_denom is not None :
preds_slice_denom = preds_slice_denom * pos_mask_denom

# aggregate over positions
if not use_mean :
preds_agg = tf.reduce_sum(preds_slice, axis=-1)
if use_ratio and pos_slice_denom is not None:
preds_agg_denom = tf.reduce_sum(preds_slice_denom, axis=-1)
else :
if pos_mask is not None :
preds_agg = tf.reduce_sum(preds_slice, axis=-1) / tf.reduce_sum(pos_mask, axis=-1)
else :
preds_agg = tf.reduce_mean(preds_slice, axis=-1)

if use_ratio and pos_slice_denom is not None:
if pos_mask_denom is not None :
preds_agg_denom = tf.reduce_sum(preds_slice_denom, axis=-1) / tf.reduce_sum(pos_mask_denom, axis=-1)
else :
preds_agg_denom = tf.reduce_mean(preds_slice_denom, axis=-1)

# compute final statistic to take gradient of
if no_transform :
score_ratios = preds_agg
elif not use_ratio :
score_ratios = tf.math.log(preds_agg + pseudo_count + 1e-6)
else :
if not use_logodds :
score_ratios = tf.math.log((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count) + 1e-6)
else :
score_ratios = tf.math.log(((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count)) / (1. - ((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count))) + 1e-6)

# compute gradient
grads = tape.gradient(score_ratios, seq_1hot)

# zero mean each position
if subtract_avg :
grads = grads - tf.reduce_mean(grads, axis=-1, keepdims=True)

# multiply by input
if input_gate :
grads = grads * seq_1hot

return grads

def gradients(self, seq_1hot, head_i=None, pos_slice=None, batch_size=8, dtype='float16'):
""" Compute input gradients sequence. """
def gradients_orig(self, seq_1hot, head_i=None, pos_slice=None, batch_size=2, dtype='float16'):
""" Compute input gradients sequence (original version of code). """
# choose model
if self.ensemble is not None:
model = self.ensemble
Expand Down Expand Up @@ -443,7 +699,7 @@ def gradients(self, seq_1hot, head_i=None, pos_slice=None, batch_size=8, dtype='
# grads_batch = grads_batch - tf.reduce_mean(grads_batch, axis=-2, keepdims=True)


grads_batch = self.gradients_func(model_batch, seq_1hot, pos_slice)
grads_batch = self.gradients_func_orig(model_batch, seq_1hot, pos_slice)
print('Batch gradient computation in %ds' % (time.time()-t0))

# convert numpy dtype
Expand All @@ -459,7 +715,7 @@ def gradients(self, seq_1hot, head_i=None, pos_slice=None, batch_size=8, dtype='
return grads

@tf.function
def gradients_func(self, model, seq_1hot, pos_slice):
def gradients_func_orig(self, model, seq_1hot, pos_slice):
with tf.GradientTape() as tape:
tape.watch(seq_1hot)

Expand Down
Loading

0 comments on commit 91f0d3f

Please sign in to comment.