Skip to content

Commit

Permalink
transfer_learn
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed Sep 19, 2023
1 parent cf08b90 commit 337b053
Show file tree
Hide file tree
Showing 6 changed files with 853 additions and 40 deletions.
44 changes: 22 additions & 22 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,31 @@ package_dir =
packages = find:
python_requires = >=3.8, <3.11
install_requires =
h5py~=3.7.0
intervaltree~=3.1.0
joblib~=1.1.1
matplotlib~=3.7.1
google-cloud-storage~=2.0.0
natsort~=7.1.1
networkx~=2.8.4
numpy~=1.24.3
pandas~=1.5.3
pybigwig~=0.3.18
pysam~=0.21.0
pybedtools~=0.9.0
qnorm~=0.8.1
seaborn~=0.12.2
scikit-learn~=1.2.2
scipy~=1.9.1
statsmodels~=0.13.5
tabulate~=0.8.10
tensorflow~=2.12.0
tqdm~=4.65.0
h5py>=3.7.0
intervaltree>=3.1.0
joblib>=1.1.1
matplotlib>=3.7.1
google-cloud-storage>=2.0.0
natsort>=7.1.1
networkx>=2.8.4
numpy>=1.24.3
pandas>=1.5.3
pybigwig>=0.3.18
pysam>=0.21.0
pybedtools>=0.9.0
qnorm>=0.8.1
seaborn>=0.12.2
scikit-learn>=1.2.2
scipy>=1.9.1
statsmodels>=0.13.5
tabulate>=0.8.10
tensorflow>=2.12.0
tqdm>=4.65.0

[options.extras_require]
dev =
black==22.3.0
pytest==7.1.2
black>=22.3.0
pytest>=7.1.2

[options.packages.find]
where = src
74 changes: 74 additions & 0 deletions src/baskerville/HY_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
from basenji import dna_io
import pysam
import pyBigWig

def make_seq_1hot(genome_open, chrm, start, end, seq_len):
if start < 0:
seq_dna = 'N'*(-start) + genome_open.fetch(chrm, 0, end)
else:
seq_dna = genome_open.fetch(chrm, start, end)

#Extend to full length
if len(seq_dna) < seq_len:
seq_dna += 'N'*(seq_len-len(seq_dna))

seq_1hot = dna_io.dna_1hot(seq_dna)
return seq_1hot

#Helper function to get (padded) one-hot
def process_sequence(fasta_file, chrom, start, end, seq_len=524288) :

fasta_open = pysam.Fastafile(fasta_file)
seq_len_actual = end - start

#Pad sequence to input window size
start -= (seq_len - seq_len_actual) // 2
end += (seq_len - seq_len_actual) // 2

#Get one-hot
sequence_one_hot = make_seq_1hot(fasta_open, chrom, start, end, seq_len)

return sequence_one_hot.astype('float32')

def compute_cov(seqnn_model, chr, start, end):
seq_len = seqnn_model.model.layers[0].input.shape[1]
seq1hot = process_sequence('/home/yuanh/programs/genomes/hg38/hg38.fa', chr, start, end, seq_len=seq_len)
out = seqnn_model.model(seq1hot[None, ])
return out.numpy()

def write_bw(bw_file, chr, start, end, values, span=32):
bw_out = pyBigWig.open(bw_file, 'w')
header = []
header.append((chr, end+1))
bw_out.addHeader(header)
bw_out.addEntries(chr, start, values=values, span=span, step=span)
bw_out.close()

def transform(seq_cov, clip=384, clip_soft=320, scale=0.3):
seq_cov = scale * seq_cov # scale
seq_cov = -1 + np.sqrt(1+seq_cov) # variant stabilize
clip_mask = (seq_cov > clip_soft) # soft clip
seq_cov[clip_mask] = clip_soft-1 + np.sqrt(seq_cov[clip_mask] - clip_soft+1)
seq_cov = np.clip(seq_cov, -clip, clip) # hard clip
return seq_cov

def untransform(cov, scale=0.3, clip_soft=320, pool_width=32):

# undo clip_soft
cov_unclipped = (cov - clip_soft + 1)**2 + clip_soft - 1
unclip_mask = (cov > clip_soft)
cov[unclip_mask] = cov_unclipped[unclip_mask]

# undo sqrt
cov = (cov +1)**2 - 1

# undo scale
cov = cov / scale

# undo sum
cov = cov / pool_width

return cov


49 changes: 49 additions & 0 deletions src/baskerville/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,55 @@
for device in gpu_devices:
tf.config.experimental.set_memory_growth(device, True)

#####################
# transfer learning #
#####################
class AdapterHoulsby(tf.keras.layers.Layer):
### Houlsby et al. 2019 implementation

def __init__(
self,
latent_size,
activation=tf.keras.layers.ReLU(),
**kwargs):
super(AdapterHoulsby, self).__init__(**kwargs)
self.latent_size = latent_size
self.activation = activation

def build(self, input_shape):
self.down_project = tf.keras.layers.Dense(
units=self.latent_size,
activation="linear",
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3),
bias_initializer="zeros",
name='adapter_down'
)

self.up_project = tf.keras.layers.Dense(
units=input_shape[-1],
activation="linear",
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3),
bias_initializer="zeros",
name='adapter_up'
)

def call(self, inputs):
projected_down = self.down_project(inputs)
activated = self.activation(projected_down)
projected_up = self.up_project(activated)
output = projected_up + inputs
return output

def get_config(self):
config = super().get_config().copy()
config.update(
{
"latent_size": self.latent_size,
"activation": self.activation
}
)
return config

############################################################
# Basic
############################################################
Expand Down
138 changes: 125 additions & 13 deletions src/baskerville/scripts/hound_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import os
import shutil
import re

import numpy as np
import pandas as pd
Expand All @@ -26,14 +27,14 @@
from baskerville import dataset
from baskerville import seqnn
from baskerville import trainer
from baskerville import layers

"""
hound_train.py
Train Hound model using given parameters and data.
"""


def main():
parser = argparse.ArgumentParser(description="Train a model.")
parser.add_argument(
Expand Down Expand Up @@ -67,6 +68,17 @@ def main():
default=False,
help="Restore only model trunk [Default: %(default)s]",
)
parser.add_argument(
"--transfer_mode",
default="full",
help="transfer method. [full, linear, adapter]",
)
parser.add_argument(
"--latent",
type=int,
default=16,
help="adapter latent size.",
)
parser.add_argument(
"--tfr_train",
default=None,
Expand Down Expand Up @@ -131,38 +143,73 @@ def main():
tfr_pattern=args.tfr_eval,
)
)

params_model["strand_pair"] = strand_pairs

if args.mixed_precision:
mixed_precision.set_global_policy("mixed_float16")

policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

if params_train.get("num_gpu", 1) == 1:
########################################
# one GPU

# initialize model
seqnn_model = seqnn.SeqNN(params_model)

# restore
if args.restore:
seqnn_model.restore(args.restore, trunk=args.trunk)

# transfer learning strategies
if args.transfer_mode=='full':
seqnn_model.model.trainable=True

elif args.transfer_mode=='batch_norm':
seqnn_model.model_trunk.trainable=False
for l in seqnn_model.model.layers:
if l.name.startswith("batch_normalization"):
l.trainable=True
seqnn_model.model.summary()

elif args.transfer_mode=='linear':
seqnn_model.model_trunk.trainable=False
seqnn_model.model.summary()

elif args.transfer_mode=='adapterHoulsby':
seqnn_model.model_trunk.trainable=False
strand_pair = strand_pairs[0]
adapter_model = make_adapter_model(seqnn_model.model, strand_pair, args.latent)
seqnn_model.model = adapter_model
seqnn_model.models[0] = seqnn_model.model
seqnn_model.model_trunk = None
seqnn_model.model.summary()

# initialize trainer
seqnn_trainer = trainer.Trainer(
params_train, train_data, eval_data, args.out_dir
)

# compile model
seqnn_trainer.compile(seqnn_model)

# train model
if args.keras_fit:
seqnn_trainer.fit_keras(seqnn_model)
else:
if len(args.data_dirs) == 1:
seqnn_trainer.fit_tape(seqnn_model)
else:
seqnn_trainer.fit2(seqnn_model)

else:
########################################
# multi GPU

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():

if not args.keras_fit:
# distribute data
for di in range(len(args.data_dirs)):
Expand Down Expand Up @@ -190,16 +237,81 @@ def main():
# compile model
seqnn_trainer.compile(seqnn_model)

# train model
if args.keras_fit:
seqnn_trainer.fit_keras(seqnn_model)
else:
if len(args.data_dirs) == 1:
seqnn_trainer.fit_tape(seqnn_model)
# train model
if args.keras_fit:
seqnn_trainer.fit_keras(seqnn_model)
else:
seqnn_trainer.fit2(seqnn_model)
if len(args.data_dirs) == 1:
seqnn_trainer.fit_tape(seqnn_model)
else:
seqnn_trainer.fit2(seqnn_model)

def make_adapter_model(input_model, strand_pair, latent_size=16):
# take seqnn_model as input
# output a new seqnn_model object
# only the adapter, and layer_norm are trainable

model = tf.keras.Model(inputs=input_model.input,
outputs=input_model.layers[-2].output) # remove the switch_reverse layer

# save current graph
layer_parent_dict_old = {} # the parent layers of each layer in the old graph
for layer in model.layers:
for node in layer._outbound_nodes:
layer_name = node.outbound_layer.name
if layer_name not in layer_parent_dict_old:
layer_parent_dict_old.update({layer_name: [layer.name]})
else:
if layer.name not in layer_parent_dict_old[layer_name]:
layer_parent_dict_old[layer_name].append(layer.name)

layer_output_dict_new = {} # the output tensor of each layer in the new graph
layer_output_dict_new.update({model.layers[0].name: model.input})

# remove switch_reverse
to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)]
for i in to_fix:
del layer_parent_dict_old[i]

# Iterate over all layers after the input
model_outputs = []
reverse_bool = None

for layer in model.layers[1:]:

# parent layers
parent_layers = layer_parent_dict_old[layer.name]

# layer inputs
layer_input = [layer_output_dict_new[parent] for parent in parent_layers]
if len(layer_input) == 1: layer_input = layer_input[0]

if re.match('stochastic_reverse_complement', layer.name):
x, reverse_bool = layer(layer_input)

# insert adapter:
elif re.match('add', layer.name):
if any([re.match('dropout', i) for i in parent_layers]):
print('adapter added before:%s'%layer.name)
x = layers.AdapterHoulsby(latent_size=latent_size)(layer_input[1])
x = layer([layer_input[0], x])
else:
x = layer(layer_input)

else:
x = layer(layer_input)

# save the output tensor of every layer
layer_output_dict_new.update({layer.name: x})

final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[model.layers[-1].name], reverse_bool])
model_adapter = tf.keras.Model(inputs=model.inputs, outputs=final)

# set layer_norm layers to trainable
for l in model_adapter.layers:
if re.match('layer_normalization', l.name): l.trainable = True

return model_adapter
################################################################################
# __main__
################################################################################
Expand Down
Loading

0 comments on commit 337b053

Please sign in to comment.