Skip to content

Commit

Permalink
update lora ia3
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed Oct 3, 2023
1 parent ae31d2f commit 0ba9122
Show file tree
Hide file tree
Showing 3 changed files with 478 additions and 157 deletions.
48 changes: 46 additions & 2 deletions src/baskerville/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,60 @@
#####################
# transfer learning #
#####################
class IA3(tf.keras.layers.Layer):
# activation-rescale adapter:
# https://arxiv.org/pdf/2205.05638.pdf

def __init__(self,
original_layer,
trainable=False,
**kwargs):

# keep the name of this layer the same as the original dense layer.
original_layer_config = original_layer.get_config()
name = original_layer_config["name"]
kwargs.pop("name", None)
super().__init__(name=name, trainable=trainable, **kwargs)

self.output_dim = original_layer_config["units"]

self.original_layer = original_layer
self.original_layer.trainable = False

# IA3 weights. Make it a dense layer to control trainable
self._ia3_layer = tf.keras.layers.Dense(
units=self.output_dim,
use_bias=False,
kernel_initializer=tf.keras.initializers.Ones(),
trainable=True,
name="ia3"
)

def call(self, inputs):
original_output = self.original_layer(inputs)
scaler = self._ia3_layer(tf.constant([[1]], dtype='float64'))[0]
return original_output * scaler

def get_config(self):
config = super().get_config().copy()
config.update(
{
"size": self.output_dim,
}
)
return config

class Lora(tf.keras.layers.Layer):
# https://arxiv.org/abs/2106.09685
# adapted from:
# https://arxiv.org/abs/2106.09685
# https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/
# https://github.com/Elvenson/stable-diffusion-keras-ft/blob/main/layers.py

def __init__(self,
original_layer,
rank=8,
alpha=16,
trainable=True,
trainable=False,
**kwargs):

# keep the name of this layer the same as the original dense layer.
Expand Down Expand Up @@ -554,12 +596,14 @@ def __init__(
shape=[1, self._num_heads, 1, self._key_size],
initializer=self._initializer,
dtype=tf.float32,
trainable=True,
)
self._r_r_bias = self.add_weight(
"%s/r_r_bias" % self.name,
shape=[1, self._num_heads, 1, self._key_size],
initializer=self._initializer,
dtype=tf.float32,
trainable=True,
)

def _multihead_output(self, linear_layer, inputs):
Expand Down
169 changes: 14 additions & 155 deletions src/baskerville/scripts/hound_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import json
import os
import shutil
import re

import numpy as np
import pandas as pd
Expand All @@ -27,14 +26,14 @@
from baskerville import dataset
from baskerville import seqnn
from baskerville import trainer
from baskerville import layers

"""
hound_train.py
Train Hound model using given parameters and data.
"""


def main():
parser = argparse.ArgumentParser(description="Train a model.")
parser.add_argument(
Expand Down Expand Up @@ -68,17 +67,6 @@ def main():
default=False,
help="Restore only model trunk [Default: %(default)s]",
)
parser.add_argument(
"--transfer_mode",
default="full",
help="transfer method. [full, linear, adapterHoulsby, lora, lora_full]",
)
parser.add_argument(
"--latent",
type=int,
default=16,
help="adapter latent size.",
)
parser.add_argument(
"--tfr_train",
default=None,
Expand Down Expand Up @@ -143,83 +131,38 @@ def main():
tfr_pattern=args.tfr_eval,
)
)

params_model["strand_pair"] = strand_pairs

if args.mixed_precision:
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

mixed_precision.set_global_policy("mixed_float16")

if params_train.get("num_gpu", 1) == 1:
########################################
# one GPU

# initialize model
seqnn_model = seqnn.SeqNN(params_model)

# restore
if args.restore:
seqnn_model.restore(args.restore, trunk=args.trunk)

# transfer learning strategies
if args.transfer_mode=='full':
seqnn_model.model.trainable=True

elif args.transfer_mode=='batch_norm':
seqnn_model.model_trunk.trainable=False
for l in seqnn_model.model.layers:
if l.name.startswith("batch_normalization"):
l.trainable=True
seqnn_model.model.summary()

elif args.transfer_mode=='linear':
seqnn_model.model_trunk.trainable=False
seqnn_model.model.summary()

elif args.transfer_mode=='adapterHoulsby':
seqnn_model.model_trunk.trainable=False
strand_pair = strand_pairs[0]
adapter_model = make_adapter_model(seqnn_model.model, strand_pair, args.latent)
seqnn_model.model = adapter_model
seqnn_model.models[0] = seqnn_model.model
seqnn_model.model_trunk = None
seqnn_model.model.summary()

elif args.transfer_mode=='lora':
seqnn_model.model_trunk.trainable=False
add_lora(seqnn_model.model, rank=args.latent, mode='default')
seqnn_model.model.summary()

elif args.transfer_mode=='lora_full':
seqnn_model.model_trunk.trainable=False
add_lora(seqnn_model.model, rank=args.latent, mode='full')
seqnn_model.model.summary()

# initialize trainer
seqnn_trainer = trainer.Trainer(
params_train, train_data, eval_data, args.out_dir
)

# compile model
seqnn_trainer.compile(seqnn_model)

# train model
if args.keras_fit:
seqnn_trainer.fit_keras(seqnn_model)
else:
if len(args.data_dirs) == 1:
seqnn_trainer.fit_tape(seqnn_model)
else:
seqnn_trainer.fit2(seqnn_model)

else:
########################################
# multi GPU

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():

if not args.keras_fit:
# distribute data
for di in range(len(args.data_dirs)):
Expand Down Expand Up @@ -247,102 +190,18 @@ def main():
# compile model
seqnn_trainer.compile(seqnn_model)

# train model
if args.keras_fit:
seqnn_trainer.fit_keras(seqnn_model)
else:
if len(args.data_dirs) == 1:
seqnn_trainer.fit_tape(seqnn_model)
else:
seqnn_trainer.fit2(seqnn_model)

def make_adapter_model(input_model, strand_pair, latent_size=16):
# take seqnn_model as input
# output a new seqnn_model object
# only the adapter, and layer_norm are trainable

model = tf.keras.Model(inputs=input_model.input,
outputs=input_model.layers[-2].output) # remove the switch_reverse layer

# save current graph
layer_parent_dict_old = {} # the parent layers of each layer in the old graph
for layer in model.layers:
for node in layer._outbound_nodes:
layer_name = node.outbound_layer.name
if layer_name not in layer_parent_dict_old:
layer_parent_dict_old.update({layer_name: [layer.name]})
else:
if layer.name not in layer_parent_dict_old[layer_name]:
layer_parent_dict_old[layer_name].append(layer.name)

layer_output_dict_new = {} # the output tensor of each layer in the new graph
layer_output_dict_new.update({model.layers[0].name: model.input})

# remove switch_reverse
to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)]
for i in to_fix:
del layer_parent_dict_old[i]

# Iterate over all layers after the input
model_outputs = []
reverse_bool = None

for layer in model.layers[1:]:

# parent layers
parent_layers = layer_parent_dict_old[layer.name]

# layer inputs
layer_input = [layer_output_dict_new[parent] for parent in parent_layers]
if len(layer_input) == 1: layer_input = layer_input[0]

if re.match('stochastic_reverse_complement', layer.name):
x, reverse_bool = layer(layer_input)

# insert adapter:
elif re.match('add', layer.name):
if any([re.match('dropout', i) for i in parent_layers]):
print('adapter added before:%s'%layer.name)
x = layers.AdapterHoulsby(latent_size=latent_size)(layer_input[1])
x = layer([layer_input[0], x])
else:
x = layer(layer_input)

# train model
if args.keras_fit:
seqnn_trainer.fit_keras(seqnn_model)
else:
if len(args.data_dirs) == 1:
seqnn_trainer.fit_tape(seqnn_model)
else:
x = layer(layer_input)

# save the output tensor of every layer
layer_output_dict_new.update({layer.name: x})

final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[model.layers[-1].name], reverse_bool])
model_adapter = tf.keras.Model(inputs=model.inputs, outputs=final)

# set layer_norm layers to trainable
for l in model_adapter.layers:
if re.match('layer_normalization', l.name): l.trainable = True

return model_adapter

def add_lora(input_model, rank=8, alpha=16, mode='default'):
# take seqnn.model as input
# replace _q_layer, _v_layer in multihead_attention
# optionally replace _k_layer, _embedding_layer
if mode not in ['default','full']:
raise ValueError("mode must be default or full")

for layer in input_model.layers:
if re.match('multihead_attention', layer.name):
# default loRA
layer._q_layer = layers.Lora(layer._q_layer, rank=rank, alpha=alpha)
layer._v_layer = layers.Lora(layer._v_layer, rank=rank, alpha=alpha)
# full loRA
if mode=='full':
layer._k_layer = layers.Lora(layer._k_layer, rank=rank, alpha=alpha)
layer._embedding_layer = layers.Lora(layer._embedding_layer, rank=rank, alpha=alpha)
seqnn_trainer.fit2(seqnn_model)


################################################################################
# __main__
################################################################################
if __name__ == "__main__":
main()
main()
Loading

0 comments on commit 0ba9122

Please sign in to comment.