diff --git a/src/baskerville/adapters.py b/src/baskerville/adapters.py
new file mode 100644
index 0000000..f05a307
--- /dev/null
+++ b/src/baskerville/adapters.py
@@ -0,0 +1,301 @@
+# Copyright 2023 Calico LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+import pdb
+import sys
+from typing import Optional, List
+
+import numpy as np
+import tensorflow as tf
+
+gpu_devices = tf.config.experimental.list_physical_devices("GPU")
+for device in gpu_devices:
+    tf.config.experimental.set_memory_growth(device, True)
+
+#####################
+# transfer learning #
+#####################
+class IA3(tf.keras.layers.Layer):
+    # https://arxiv.org/pdf/2205.05638.pdf
+    # ia3 module for attention layer, scale output.
+    
+    def __init__(self, 
+                 original_layer, 
+                 trainable=False, 
+                 **kwargs):
+        
+        # keep the name of this layer the same as the original dense layer.
+        original_layer_config = original_layer.get_config()
+        name = original_layer_config["name"]
+        kwargs.pop("name", None)
+        super().__init__(name=name, trainable=trainable, **kwargs)
+
+        self.output_dim = original_layer_config["units"]
+        
+        self.original_layer = original_layer
+        self.original_layer.trainable = False
+
+        # IA3 weights. Make it a dense layer to control trainable
+        self._ia3_layer = tf.keras.layers.Dense(
+            units=self.output_dim,
+            use_bias=False,
+            kernel_initializer=tf.keras.initializers.Ones(),
+            trainable=True,
+            name="ia3"
+        )
+
+    def call(self, inputs):
+        original_output = self.original_layer(inputs)
+        scaler = self._ia3_layer(tf.constant([[1]], dtype='float64'))[0]
+        return original_output * scaler
+
+    def get_config(self):
+        config = super().get_config().copy()
+        config.update(
+            {
+                "size": self.output_dim,
+            }
+        )
+        return config
+
+class IA3_ff(tf.keras.layers.Layer):
+    # https://arxiv.org/pdf/2205.05638.pdf
+    # ia3 module for down-projection ff layer, scale input.
+    
+    def __init__(self, 
+                 original_layer, 
+                 trainable=False, 
+                 **kwargs):
+        
+        # keep the name of this layer the same as the original dense layer.
+        original_layer_config = original_layer.get_config()
+        name = original_layer_config["name"]
+        kwargs.pop("name", None)
+        super().__init__(name=name, trainable=trainable, **kwargs)
+
+        self.input_dim = original_layer.input_shape[-1]
+        
+        self.original_layer = original_layer
+        self.original_layer.trainable = False
+
+        # IA3 weights. Make it a dense layer to control trainable
+        self._ia3_layer = tf.keras.layers.Dense(
+            units=self.input_dim,
+            use_bias=False,
+            kernel_initializer=tf.keras.initializers.Ones(),
+            trainable=True,
+            name="ia3_ff"
+        )
+
+    def call(self, inputs):
+        scaler = self._ia3_layer(tf.constant([[1]], dtype='float64'))[0]
+        return self.original_layer(inputs * scaler)
+
+    def get_config(self):
+        config = super().get_config().copy()
+        config.update(
+            {
+                "size": self.input_dim
+            }
+        )
+        return config
+        
+class Lora(tf.keras.layers.Layer):
+    # adapted from: 
+    # https://arxiv.org/abs/2106.09685
+    # https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/
+    # https://github.com/Elvenson/stable-diffusion-keras-ft/blob/main/layers.py
+    
+    def __init__(self, 
+                 original_layer, 
+                 rank=8, 
+                 alpha=16, 
+                 trainable=False, 
+                 **kwargs):
+        
+        # keep the name of this layer the same as the original dense layer.
+        original_layer_config = original_layer.get_config()
+        name = original_layer_config["name"]
+        kwargs.pop("name", None)
+        super().__init__(name=name, trainable=trainable, **kwargs)
+
+        self.output_dim = original_layer_config["units"]
+        
+        if rank > self.output_dim:
+            raise ValueError(f"LoRA rank {rank} must be less or equal than {self.output_dim}")
+
+        self.rank = rank
+        self.alpha = alpha
+        self.scale = alpha / rank
+        self.original_layer = original_layer
+        self.original_layer.trainable = False
+
+        # Note: the original paper mentions that normal distribution was
+        # used for initialization. However, the official LoRA implementation
+        # uses "Kaiming/He Initialization".
+        self.down_layer = tf.keras.layers.Dense(
+            units=rank,
+            use_bias=False,
+            kernel_initializer=tf.keras.initializers.HeUniform(),
+            #kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1 / self.rank),
+            trainable=True,
+            name="lora_a"
+        )
+
+        self.up_layer = tf.keras.layers.Dense(
+            units=self.output_dim,
+            use_bias=False,
+            kernel_initializer=tf.keras.initializers.Zeros(),
+            trainable=True,
+            name="lora_b"
+        )
+
+    def call(self, inputs):
+        original_output = self.original_layer(inputs)
+        lora_output = self.up_layer(self.down_layer(inputs)) * self.scale
+        return original_output + lora_output
+
+    def get_config(self):
+        config = super().get_config().copy()
+        config.update(
+            {
+                "rank": self.rank,
+                "alpha": self.alpha
+            }
+        )
+        return config
+
+class Locon(tf.keras.layers.Layer):
+    # LoRA for conv-layer, adapted from: 
+    # https://arxiv.org/pdf/2309.14859#page=23.84
+    # https://github.com/KohakuBlueleaf/LyCORIS/blob/main/lycoris/modules/locon.py
+    # use default alpha and rank for locon
+    
+    def __init__(self, 
+                 original_layer, 
+                 rank=4, 
+                 alpha=1, 
+                 trainable=False,
+                 **kwargs):
+        
+        # keep the name of this layer the same as the original conv layer.
+        original_layer_config = original_layer.get_config()
+        name = original_layer_config["name"]
+        kwargs.pop("name", None)
+        super().__init__(name=name, trainable=trainable, **kwargs)
+
+        self.input_dim = original_layer.input_shape[-1]
+        self.output_dim = original_layer_config["filters"]
+        
+        if rank > self.output_dim:
+            raise ValueError(f"LoRA rank {rank} must be less or equal than {self.output_dim}")
+
+        self.rank = rank
+        self.alpha = alpha
+        self.scale = alpha / rank
+        self.original_layer = original_layer
+        self.original_layer.trainable = False
+
+        input_dim = original_layer.input_shape[-1]
+        output_dim = original_layer_config["filters"]
+        kernel_size = original_layer_config['kernel_size'][0]
+        stride = original_layer_config['strides'][0]
+        dilation_rate = original_layer_config["dilation_rate"][0]
+
+        # Note: the original paper mentions that normal distribution was
+        # used for initialization. However, the official LoRA implementation
+        # uses "Kaiming/He Initialization".
+        
+        self.down_layer = tf.keras.layers.Conv1D(
+            filters=rank,
+            kernel_size=kernel_size,
+            strides=stride,
+            padding="same",
+            use_bias=False,
+            dilation_rate=dilation_rate,
+            kernel_initializer=tf.keras.initializers.HeUniform(),
+            name='locon_down'
+        )
+        
+        self.up_layer = tf.keras.layers.Conv1D(
+            filters=output_dim,
+            kernel_size=1,
+            strides=stride,
+            padding="same",
+            use_bias=False,
+            kernel_initializer=tf.keras.initializers.Zeros(),
+            name='locon_up'
+        )
+
+    def call(self, inputs):
+        original_output = self.original_layer(inputs)
+        lora_output = self.up_layer(self.down_layer(inputs)) * self.scale
+        return original_output + lora_output
+
+    def get_config(self):
+        config = super().get_config().copy()
+        config.update(
+            {
+                "rank": self.rank,
+                "alpha": self.alpha
+            }
+        )
+        return config
+
+class AdapterHoulsby(tf.keras.layers.Layer):
+    # https://arxiv.org/abs/1902.00751
+    # adapted from: https://github.com/jain-harshil/Adapter-BERT
+    
+    def __init__(
+        self, 
+        latent_size, 
+        activation=tf.keras.layers.ReLU(),
+        **kwargs):
+        super(AdapterHoulsby, self).__init__(**kwargs)
+        self.latent_size = latent_size
+        self.activation = activation
+
+    def build(self, input_shape):
+        self.down_project = tf.keras.layers.Dense(
+            units=self.latent_size, 
+            activation="linear", 
+            kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3),
+            bias_initializer="zeros",
+            name='adapter_down'
+        )
+        
+        self.up_project = tf.keras.layers.Dense(
+            units=input_shape[-1], 
+            activation="linear",
+            kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3),
+            bias_initializer="zeros",
+            name='adapter_up'
+        )
+
+    def call(self, inputs):
+        projected_down = self.down_project(inputs)
+        activated = self.activation(projected_down)
+        projected_up = self.up_project(activated)
+        output = projected_up + inputs
+        return output
+
+    def get_config(self):
+        config = super().get_config().copy()
+        config.update(
+            {
+                "latent_size": self.latent_size,
+                "activation": self.activation
+            }
+        )
+        return config
diff --git a/src/baskerville/helpers/transfer.py b/src/baskerville/helpers/transfer.py
new file mode 100644
index 0000000..25b80f5
--- /dev/null
+++ b/src/baskerville/helpers/transfer.py
@@ -0,0 +1,636 @@
+import argparse
+import json
+import os
+import shutil
+import re
+import h5py
+
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from tensorflow.keras import mixed_precision
+
+from baskerville import dataset
+from baskerville import seqnn
+from baskerville import trainer
+from baskerville import layers
+from baskerville import adapters
+
+def param_count(layer, type='all'):
+    if type not in ['all','trainable','non_trainable']:
+        raise ValueError("TYPE must be one of all, trainable, non_trainable")
+    output = 0
+    if type=='all':
+        output = int(sum(tf.keras.backend.count_params(w) for w in layer.weights))
+    elif type=='trainable':
+        output = int(sum(tf.keras.backend.count_params(w) for w in layer.trainable_weights))
+    else:
+        output = int(sum(tf.keras.backend.count_params(w) for w in layer.non_trainable_weights))
+    return output
+
+def param_summary(model):
+    trainable = param_count(model, type='trainable')
+    non_trainable = param_count(model, type='non_trainable')
+    print('total params:%d' %(trainable + non_trainable))
+    print('trainable params:%d' %trainable)
+    print('non-trainable params:%d' %non_trainable)
+
+def keras2dict(model):
+    layer_parent_dict = {} # the parent layers of each layer in the old graph 
+    for layer in model.layers:
+        for node in layer._outbound_nodes:
+            layer_name = node.outbound_layer.name
+            if layer_name not in layer_parent_dict:
+                layer_parent_dict.update({layer_name: [layer.name]})
+            else:
+                if layer.name not in layer_parent_dict[layer_name]:
+                    layer_parent_dict[layer_name].append(layer.name)
+    return layer_parent_dict
+
+# lora requires change model.h5 weight order.
+# locon and ia3 don't modify model in place.
+def var_reorder(weight_h5):
+    # assumes weight_h5 model saved with seqnn_model.save()
+    # [i.name for i in model.layers[30].weights] to check for multihead_attention layer weights order.
+    # model.load_weights() load weights sequencially, assuming h5 weights are in the right order.
+    # When inserting lora, multihead_attention layer weights order changed.
+    # multihead_attention layer weights order is saved inside f['model_weights']['multihead_attention'].attrs
+    # After saving the weight_merged model, we need to go into the weights.h5, and change the attrs in multihead attention.
+    var_init_order = ['r_w_bias:0:0',
+                      'r_r_bias:0:0', 
+                      'q_layer/kernel:0', 
+                      'k_layer/kernel:0',
+                      'v_layer/kernel:0',
+                      'embedding_layer/kernel:0',
+                      'embedding_layer/bias:0',
+                      'r_k_layer/kernel:0']
+
+    f = h5py.File(weight_h5, 'r+')
+    layers = [i for i in list(f['model_weights'].keys()) if 'multihead_attention' in i]
+    for l_name in layers:
+        new_name_order = [l_name+'/'+i for i in var_init_order]
+        f['model_weights'][l_name].attrs.modify(name='weight_names', value=new_name_order)
+    f.close()
+
+
+# houlsby requires architecture change.
+# thus we need to modify json.
+def modify_json(input_json, output_json, adapter, latent=8, se_rank=None, conv_select=None):
+
+    with open(input_json) as params_open:
+        params = json.load(params_open)
+
+    # houlsby
+    if adapter=='adapterHoulsby':
+        params["model"]['adapter']= 'houlsby'
+        params["model"]['adapter_latent']= latent
+
+    # houlsby_se
+    elif adapter=='houlsby_se':
+        params["model"]['adapter']= 'houlsby_se'
+        params["model"]['adapter_latent']= latent
+        params["model"]['se_rank']= se_rank
+        params["model"]['conv_select']= conv_select
+
+    else:
+        raise ValueError("adapter must be adapterHoulsby or houlsby_se")
+        
+    ### output
+    with open(output_json, 'w') as params_open:
+        json.dump(params, params_open, indent=4)
+    
+######################
+# add houlsby layers #
+######################
+def add_houlsby(input_model, strand_pair, latent_size=8):
+    # take seqnn_model as input
+    # output a new seqnn_model object
+    # only the adapter, and layer_norm are trainable
+
+    ##################
+    # houlsby layers #
+    ##################
+    houlsby_layers = []
+    for i in range(len(input_model.layers)-1):
+        layer = input_model.layers[i]
+        next_layer = input_model.layers[i+1]
+        if re.match('dropout', layer.name) and re.match('add', next_layer.name):
+            houlsby_layers += [next_layer.name]
+
+    ###################
+    # construct model #
+    ################### 
+    layer_parent_dict_old = keras2dict(input_model)
+    # remove switch_reverse_layer
+    to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)]
+    for i in to_fix:
+        del layer_parent_dict_old[i]
+    # create new graph
+    layer_output_dict_new = {} # the output tensor of each layer in the new graph
+    layer_output_dict_new.update({input_model.layers[0].name: input_model.input})    
+    # Iterate over all layers after the input
+    model_outputs = []
+    reverse_bool = None
+
+    for layer in input_model.layers[1:-1]:
+    
+        # parent layers
+        parent_layers = layer_parent_dict_old[layer.name]
+    
+        # layer inputs
+        layer_input = [layer_output_dict_new[parent] for parent in parent_layers]
+        if len(layer_input) == 1: layer_input = layer_input[0]
+    
+        if re.match('stochastic_reverse_complement', layer.name):
+            x, reverse_bool  = layer(layer_input)
+        
+        # insert houlsby layer:
+        elif layer.name in houlsby_layers:
+            print('adapter added before:%s'%layer.name)
+            x = adapters.AdapterHoulsby(latent_size=latent_size)(layer_input[1])
+            x = layer([layer_input[0], x])
+        
+        else:
+            x = layer(layer_input)
+    
+        # save the output tensor of every layer
+        layer_output_dict_new.update({layer.name: x})
+    
+    final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool])
+    model_adapter = tf.keras.Model(inputs=input_model.inputs, outputs=final)
+
+    # set trainable #
+    for l in model_adapter.layers[:-2]: # trunk
+        if re.match('layer_normalization|adapter_houlsby', l.name): 
+            l.trainable = True
+        else:
+            l.trainable = False
+
+    # expected number of trainable params added/unfrozen:
+    params_added = 0
+    for l in model_adapter.layers:
+        if l.name.startswith("adapter_houlsby"): 
+            params_added += param_count(l)
+        elif l.name.startswith("layer_normalization"): 
+            params_added += param_count(l, type='trainable')
+    print('params added/unfrozen by adapter_houlsby: %d'%params_added)
+
+    return model_adapter
+
+###############
+# lora layers #
+###############
+def add_lora(input_model, rank=8, alpha=16, mode='default', report_param=True):
+    # take seqnn.model as input
+    # replace _q_layer, _v_layer in multihead_attention
+    # optionally replace _k_layer, _embedding_layer
+    if mode not in ['default','full']:
+        raise ValueError("mode must be default or full")
+    
+    for layer in input_model.layers:
+        if re.match('multihead_attention', layer.name):
+            # default loRA
+            layer._q_layer = adapters.Lora(layer._q_layer, rank=rank, alpha=alpha, trainable=True)
+            layer._v_layer = adapters.Lora(layer._v_layer, rank=rank, alpha=alpha, trainable=True)
+            # full loRA
+            if mode=='full':
+                layer._k_layer = adapters.Lora(layer._k_layer, rank=rank, alpha=alpha, trainable=True)
+                layer._embedding_layer = adapters.Lora(layer._embedding_layer, rank=rank, alpha=alpha, trainable=True)
+    
+    input_model(input_model.input) # initialize new variables 
+
+    # freeze all params but lora
+    for layer in input_model._flatten_layers():
+        lst_of_sublayers = list(layer._flatten_layers())
+        if len(lst_of_sublayers) == 1: 
+            if layer.name in ["lora_a", "lora_b"]:
+                layer.trainable = True
+            else:
+                layer.trainable = False
+
+    ### bias terms need to be frozen separately 
+    for layer in input_model.layers:
+        if re.match('multihead_attention', layer.name):
+            layer._r_w_bias = tf.Variable(layer._r_w_bias, trainable=False, name=layer._r_w_bias.name)
+            layer._r_r_bias = tf.Variable(layer._r_r_bias, trainable=False, name=layer._r_r_bias.name)
+
+    # set final head to be trainable
+    input_model.layers[-2].trainable=True
+
+    # expected number of trainable params added/unfrozen:
+    params_added = 0
+    for l in input_model.layers:
+        if re.match('multihead_attention', l.name):
+            params_added += param_count(l._q_layer.down_layer)
+            params_added += param_count(l._q_layer.up_layer)
+            params_added += param_count(l._v_layer.down_layer)
+            params_added += param_count(l._v_layer.up_layer)
+            if mode=='full':
+                params_added += param_count(l._k_layer.down_layer)
+                params_added += param_count(l._k_layer.up_layer)
+                params_added += param_count(l._embedding_layer.down_layer)
+                params_added += param_count(l._embedding_layer.up_layer)
+
+    if report_param:
+        print('params added/unfrozen by lora: %d'%params_added)
+
+###############
+# lora layers #
+###############
+def add_lora_conv(input_model, conv_select=None):
+
+    # add lora layers
+    add_lora(input_model, rank=8, alpha=16, mode='default', report_param=False)
+
+    # list all conv layers
+    conv_layers = []
+    for layer in input_model.layers:
+        if re.match('conv1d', layer.name):
+            conv_layers += [layer.name]
+    if conv_select is None: 
+        conv_select = len(conv_layers)
+    if conv_select > len(conv_layers):
+        raise ValueError("conv_select must be less than number of conv layers %d."%len(conv_layers))
+
+    # set conv layers trainable
+    trainable_conv = conv_layers[-conv_select:]    
+    for layer in input_model.layers:
+        if layer.name in trainable_conv:
+            layer.trainable=True
+    
+    # expected number of trainable params added/unfrozen:
+    params_added = 0
+    for l in input_model.layers:
+        if re.match('multihead_attention', l.name):
+            params_added += param_count(l._q_layer.down_layer)
+            params_added += param_count(l._q_layer.up_layer)
+            params_added += param_count(l._v_layer.down_layer)
+            params_added += param_count(l._v_layer.up_layer)
+        elif l.name in trainable_conv:
+            params_added += param_count(l)
+
+    print('params added/unfrozen by lora_conv: %d'%params_added)
+
+# merge lora weights #
+def merge_lora_layer(lora_layer):
+    down_weights = lora_layer.down_layer.kernel
+    up_weights = lora_layer.up_layer.kernel
+    increment_weights = tf.einsum("ab,bc->ac", down_weights, up_weights) * lora_layer.scale
+    lora_layer.original_layer.kernel.assign_add(increment_weights)
+    return lora_layer.original_layer
+
+def merge_lora(input_model):
+    for layer in input_model.layers:
+        if 'multihead_attention' in layer.name:
+            if isinstance(layer._q_layer, adapters.Lora):
+                layer._q_layer = merge_lora_layer(layer._q_layer)
+            if isinstance(layer._v_layer, adapters.Lora):                
+                layer._v_layer = merge_lora_layer(layer._v_layer)
+            if isinstance(layer._k_layer, adapters.Lora):                
+                layer._k_layer = merge_lora_layer(layer._k_layer)
+            if isinstance(layer._embedding_layer, adapters.Lora):                
+                layer._embedding_layer = merge_lora_layer(layer._embedding_layer)
+    input_model(input_model.input)
+
+
+##############
+# IA3 layers #
+##############
+def add_ia3(input_model, strand_pair):
+    
+    # add to kv layers #
+    for layer in input_model.layers:
+        if re.match('multihead_attention', layer.name):
+            layer._k_layer = adapters.IA3(layer._k_layer, trainable=True)
+            layer._v_layer = adapters.IA3(layer._v_layer, trainable=True)
+    
+    # add to ff layer #
+    # save old graph to dictionary
+    layer_parent_dict_old = keras2dict(input_model)
+    
+    # remove switch_reverse_layer
+    to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)]
+    for i in to_fix:
+        del layer_parent_dict_old[i]
+
+    # create new graph
+    layer_output_dict_new = {} # the output tensor of each layer in the new graph
+    layer_output_dict_new.update({input_model.layers[0].name: input_model.input})
+    
+    # Iterate over all layers after the input
+    model_outputs = []
+    reverse_bool = None
+    for layer in input_model.layers[1:-1]:
+    
+        # get layer inputs
+        parent_layers = layer_parent_dict_old[layer.name]
+        layer_input = [layer_output_dict_new[parent] for parent in parent_layers]
+        if len(layer_input) == 1: layer_input = layer_input[0]
+
+        # construct
+        if re.match('stochastic_reverse_complement', layer.name):
+            x, reverse_bool  = layer(layer_input)
+        # transformer ff down-project layer (1536 -> 768):
+        elif re.match('dense', layer.name) and layer.input_shape[-1]==1536:
+            x = adapters.IA3_ff(layer, trainable=True)(layer_input)
+        else:
+            x = layer(layer_input)
+    
+        # save layers to dictionary
+        layer_output_dict_new.update({layer.name: x})
+    
+    final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool])
+    model_adapter = tf.keras.Model(inputs=input_model.inputs, outputs=final)
+
+    # set trainable #
+    for layer in model_adapter._flatten_layers():
+        lst_of_sublayers = list(layer._flatten_layers())
+        if len(lst_of_sublayers) == 1: 
+            if layer.name in ['ia3', 'ia3_ff']:
+                layer.trainable = True
+            else:
+                layer.trainable = False
+            
+    ### bias terms need to be frozen separately 
+    for layer in model_adapter.layers:
+        if re.match('multihead_attention', layer.name):
+            layer._r_w_bias = tf.Variable(layer._r_w_bias, trainable=False, name=layer._r_w_bias.name)
+            layer._r_r_bias = tf.Variable(layer._r_r_bias, trainable=False, name=layer._r_r_bias.name)
+
+    # set final head to be trainable
+    model_adapter.layers[-2].trainable=True
+
+    # expected number of trainable params added/unfrozen:
+    params_added = 0
+    for l in model_adapter.layers:
+        if re.match('multihead_attention', l.name): # kv layers
+            params_added += param_count(l._k_layer._ia3_layer)
+            params_added += param_count(l._v_layer._ia3_layer)
+        elif re.match('dense', l.name) and l.input_shape[-1]==1536: # ff layers
+            params_added += param_count(l._ia3_layer)
+    
+    print('params added/unfrozen by ia3: %d'%params_added)
+    
+    return model_adapter
+
+def merge_ia3(original_model, ia3_model):
+    # original model contains pre-trained weights
+    # ia3 model is the fine-tuned ia3 model
+    for i, layer in enumerate(original_model.layers):
+        # attention layers
+        if re.match('multihead_attention', layer.name):
+            # scale k
+            k_scaler = ia3_model.layers[i]._k_layer._ia3_layer.kernel[0]
+            layer._k_layer.kernel.assign(layer._k_layer.kernel * k_scaler)
+            # scale v
+            v_scaler = ia3_model.layers[i]._v_layer._ia3_layer.kernel[0]
+            layer._v_layer.kernel.assign(layer._v_layer.kernel * v_scaler)
+        # ff layers
+        elif re.match('dense', layer.name) and layer.input_shape[-1]==1536:
+            ff_scaler = tf.expand_dims(ia3_model.layers[i]._ia3_layer.kernel[0], 1)
+            layer.kernel.assign(layer.kernel * ff_scaler)
+        # other layers
+        else:
+            layer.set_weights(ia3_model.layers[i].get_weights())
+
+#############
+# add locon #
+#############
+def add_locon(input_model, strand_pair, conv_select=None, rank=4, alpha=1):
+
+    # first add lora to attention
+    add_lora(input_model, report_param=False)
+    
+    # decide:
+    # 1. whether conv1 is trainable
+    # 2. which conv layers to add loRA
+    
+    # all conv layers
+    conv_layers = []
+    for layer in input_model.layers:
+        if re.match('conv1d', layer.name):
+            conv_layers += [layer.name]
+
+    if conv_select is None: 
+        conv_select = len(conv_layers)
+        
+    if conv_select > len(conv_layers):
+        raise ValueError("conv_select must be less than number of conv layers %d."%len(conv_layers))
+
+    locon_layers = []
+    conv1_tune = False
+    if conv_select == len(conv_layers):
+        locon_layers = conv_layers[1:]
+        conv1_tune = True
+    else:
+        locon_layers = conv_layers[-conv_select:]
+        
+    layer_parent_dict_old = keras2dict(input_model)
+    
+    # remove switch_reverse_layer
+    to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)]
+    for i in to_fix:
+        del layer_parent_dict_old[i]
+
+    # create new graph
+    layer_output_dict_new = {} # the output tensor of each layer in the new graph
+    layer_output_dict_new.update({input_model.layers[0].name: input_model.input})
+    
+    # Iterate over all layers after the input
+    model_outputs = []
+    reverse_bool = None
+    for layer in input_model.layers[1:-1]:
+    
+        # get layer inputs
+        parent_layers = layer_parent_dict_old[layer.name]
+        layer_input = [layer_output_dict_new[parent] for parent in parent_layers]
+        if len(layer_input) == 1: layer_input = layer_input[0]
+
+        # construct
+        if re.match('stochastic_reverse_complement', layer.name):
+            x, reverse_bool  = layer(layer_input)
+        elif layer.name in locon_layers:
+            x = adapters.Locon(layer, trainable=True, rank=rank, alpha=alpha)(layer_input)
+        else:
+            x = layer(layer_input)
+    
+        # save layers to dictionary
+        layer_output_dict_new.update({layer.name: x})
+    
+    final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool])
+    model_adapter = tf.keras.Model(inputs=input_model.inputs, outputs=final)
+
+    if conv1_tune:
+        model_adapter.get_layer(name=conv_layers[0]).trainable = True
+
+    # expected number of trainable params added/unfrozen:
+    params_added = 0
+    if conv1_tune:
+        params_added += param_count(model_adapter.get_layer(name=conv_layers[0]))
+    for l in model_adapter.layers:
+        if re.match('multihead_attention', l.name):
+            params_added += param_count(l._q_layer.down_layer)
+            params_added += param_count(l._q_layer.up_layer)
+            params_added += param_count(l._v_layer.down_layer)
+            params_added += param_count(l._v_layer.up_layer)
+        if l.name in locon_layers:
+            params_added += param_count(l.down_layer)
+            params_added += param_count(l.up_layer)
+
+    print('params added/unfrozen by lora: %d'%params_added)
+
+    return model_adapter
+
+#### functions to merge locon
+def lora_increment(layer):
+    down_weights = layer.down_layer.kernel
+    up_weights = layer.up_layer.kernel
+    increment_weights = tf.einsum("ab,bc->ac", down_weights, up_weights) * layer.scale
+    return increment_weights
+
+def locon_increment(layer):
+    down_weights = layer.down_layer.kernel
+    up_weights = layer.up_layer.kernel[0]
+    increment_weights = tf.einsum("abc,cd->abd", down_weights, up_weights) * layer.scale
+    return increment_weights
+
+def merge_locon(original_model, locon_model):
+    # original model contains pre-trained weights
+    for i, layer in enumerate(original_model.layers):
+        
+        # lora layers
+        if re.match('multihead_attention', layer.name):
+            q = locon_model.layers[i]._q_layer
+            k = locon_model.layers[i]._k_layer
+            v = locon_model.layers[i]._v_layer
+            e = locon_model.layers[i]._embedding_layer                
+            if isinstance(q, adapters.Lora):
+                increment_weights = lora_increment(q)
+                layer._q_layer.kernel.assign_add(increment_weights)
+            if isinstance(v, adapters.Lora):
+                increment_weights = lora_increment(v)
+                layer._v_layer.kernel.assign_add(increment_weights)
+            if isinstance(k, adapters.Lora):
+                increment_weights = lora_increment(k)
+                layer._k_layer.kernel.assign_add(increment_weights)
+            if isinstance(e, adapters.Lora):
+                increment_weights = lora_increment(e)
+                layer._embedding_layer.kernel.assign_add(increment_weights)
+        
+        # locon layers
+        elif isinstance(locon_model.layers[i], adapters.Locon):
+                increment_weights = locon_increment(locon_model.layers[i])                
+                layer.kernel.assign_add(increment_weights)
+            
+        else:
+            layer.set_weights(locon_model.layers[i].get_weights())
+
+            
+##############
+# houlsby_se #
+##############
+def add_houlsby_se(input_model, strand_pair, houlsby_latent=8, conv_select=None, se_rank=16):
+    # add squeeze-excitation blocks after conv
+    # input_model should be properly frozen
+    # pre_att: add se_block to pre-attention conv1d
+    # all: add se_block to pre-attention conv1d and post-attention separable_conv1d
+
+    ##################
+    # houlsby layers #
+    ##################
+    houlsby_layers = []
+    for i in range(len(input_model.layers)-1):
+        layer = input_model.layers[i]
+        next_layer = input_model.layers[i+1]
+        if re.match('dropout', layer.name) and re.match('add', next_layer.name):
+            houlsby_layers += [next_layer.name]
+
+    #############
+    # SE layers #
+    #############
+    conv_layers = []
+    for layer in input_model.layers:
+        if re.match('conv1d', layer.name):
+            conv_layers += [layer.name]
+    if conv_select is None: 
+        se_layers = conv_layers[1:]
+    if conv_select >= len(conv_layers):
+        raise ValueError("conv_select must be less than number of conv layers %d."%len(conv_layers))
+    se_layers = conv_layers[-conv_select:]
+
+    ###################
+    # construct model #
+    ###################
+    layer_parent_dict_old = keras2dict(input_model)
+    # remove switch_reverse_layer
+    to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)]
+    for i in to_fix:
+        del layer_parent_dict_old[i]
+    # create new graph
+    layer_output_dict_new = {} # the output tensor of each layer in the new graph
+    layer_output_dict_new.update({input_model.layers[0].name: input_model.input})    
+    # Iterate over all layers after the input
+    model_outputs = []
+    reverse_bool = None
+    
+    for layer in input_model.layers[1:-1]:
+    
+        # parent layers
+        parent_layers = layer_parent_dict_old[layer.name]
+    
+        # layer inputs
+        layer_input = [layer_output_dict_new[parent] for parent in parent_layers]
+        if len(layer_input) == 1: layer_input = layer_input[0]
+
+        if layer.name.startswith("stochastic_reverse_complement"):
+            x, reverse_bool  = layer(layer_input)
+        
+        # insert houlsby layer:
+        elif layer.name in houlsby_layers:
+            print('adapter added before:%s'%layer.name)
+            x = adapters.AdapterHoulsby(latent_size=houlsby_latent)(layer_input[1])
+            x = layer([layer_input[0], x])
+
+        # insert squeeze-excite layer:
+        elif layer.name in se_layers:
+            se_layer = layers.SqueezeExcite(
+                activation=None, # no activation before squeezing
+                additive=False, # use sigmoid multiplicative scaling
+                rank=se_rank, # bottleneck ratio
+                use_bias=False, # ignore bias
+                kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), # near-zero weight initialization
+                scale_fun='tanh'
+            )
+            x = layer(layer_input)
+            x = x + se_layer(x)
+                
+        else:
+            x = layer(layer_input)
+    
+        # save the output tensor of every layer
+        layer_output_dict_new.update({layer.name: x})
+    
+    final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool])
+    model_final = tf.keras.Model(inputs=input_model.inputs, outputs=final)
+
+    # set trainable
+    for l in model_final.layers[:-2]: # trunk
+        if re.match('layer_normalization|adapter_houlsby', l.name): 
+            l.trainable = True
+        else:
+            l.trainable = False
+
+    for l in model_final.layers: # set trunk
+        if l.name.startswith("squeeze_excite"): l.trainable = True
+
+    # expected number of trainable params added/unfrozen:
+    params_added = 0
+    for l in model_final.layers:
+        if  re.match('squeeze_excite|adapter_houlsby', l.name):
+            params_added += param_count(l)
+        elif l.name.startswith("layer_normalization"): 
+            params_added += param_count(l, type='trainable')
+    print('params added/unfrozen by houlsby_se: %d'%params_added)
+    
+    return model_final
+
diff --git a/src/baskerville/pygene.py b/src/baskerville/pygene.py
deleted file mode 100755
index 86cae4f..0000000
--- a/src/baskerville/pygene.py
+++ /dev/null
@@ -1,324 +0,0 @@
-#!/usr/bin/env python
-from optparse import OptionParser
-
-import gzip
-import pdb
-
-'''
-pygene
-
-Classes and methods to manage genes in GTF format.
-'''
-
-################################################################################
-# Classes
-################################################################################
-class GenomicInterval:
-  def __init__(self, start, end, chrom=None, strand=None):
-    self.start = start
-    self.end = end
-    self.chrom = chrom
-    self.strand = strand
-
-  def __eq__(self, other):
-    return self.start == other.start
-
-  def __lt__(self, other):
-    return self.start < other.start
-
-  def __cmp__(self, x):
-    if self.start < x.start:
-      return -1
-    elif self.start > x.start:
-      return 1
-    else:
-      return 0
-
-  def __str__(self):
-    if self.chrom is None:
-      label = '[%d-%d]' % (self.start, self.end)
-    else:
-      label =  '%s:%d-%d' % (self.chrom, self.start, self.end)
-    return label
-
-
-class Transcript:
-  def __init__(self, chrom, strand, kv):
-    self.chrom = chrom
-    self.strand = strand
-    self.kv = kv
-    self.exons = []
-    self.cds = []
-    self.utrs3 = []
-    self.utrs5 = []
-    self.sorted = False
-    self.utrs_defined = False
-
-  def add_cds(self, start, end):
-    self.cds.append(GenomicInterval(start,end))
-
-  def add_exon(self, start, end):
-    self.exons.append(GenomicInterval(start,end))
-
-  def define_utrs(self):
-    self.utrs_defined = True
-
-    if len(self.cds) == 0:
-      self.utrs3 = self.exons
-
-    else:
-      assert(self.sorted)
-
-      # reset UTR lists
-      self.utrs5 = []
-      self.utrs3 = []
-
-      # match up exons and CDS
-      ci = 0
-      for ei in range(len(self.exons)):
-        # left initial
-        if self.exons[ei].end < self.cds[ci].start:
-          utr = GenomicInterval(self.exons[ei].start, self.exons[ei].end)
-          if self.strand == '+':
-            self.utrs5.append(utr)
-          else:
-            self.utrs3.append(utr)
-
-        # right initial
-        elif self.cds[ci].end < self.exons[ei].start:
-          utr = GenomicInterval(self.exons[ei].start, self.exons[ei].end)
-          if self.strand == '+':
-            self.utrs3.append(utr)
-          else:
-            self.utrs5.append(utr)
-
-        # overlap
-        else:
-          # left overlap
-          if self.exons[ei].start < self.cds[ci].start:
-            utr = GenomicInterval(self.exons[ei].start, self.cds[ci].start-1)
-            if self.strand == '+':
-              self.utrs5.append(utr)
-            else:
-              self.utrs3.append(utr)
-
-          # right overlap
-          if self.cds[ci].end < self.exons[ei].end:
-            utr = GenomicInterval(self.cds[ci].end+1, self.exons[ei].end)
-            if self.strand == '+':
-              self.utrs3.append(utr)
-            else:
-              self.utrs5.append(utr)
-
-          # increment up to last
-          ci = min(ci+1, len(self.cds)-1)
-
-  def fasta_cds(self, fasta_open, stranded=False):
-    assert(self.sorted)
-    gene_seq = ''
-    for exon in self.cds:
-      exon_seq = fasta_open.fetch(self.chrom, exon.start-1, exon.end)
-      gene_seq += exon_seq
-    if stranded and self.strand == '-':
-      gene_seq = rc(gene_seq)
-    return gene_seq
-
-  def fasta_exons(self, fasta_open, stranded=False):
-    assert(self.sorted)
-    gene_seq = ''
-    for exon in self.exons:
-      exon_seq = fasta_open.fetch(self.chrom, exon.start-1, exon.end)
-      gene_seq += exon_seq
-    if stranded and self.strand == '-':
-      gene_seq = rc(gene_seq)
-    return gene_seq
-
-  def sort_exons(self):
-    self.sorted = True
-    if len(self.exons) > 1:
-      self.exons.sort()
-    if len(self.cds) > 1:
-      self.cds.sort()
-
-  def span(self):
-    exon_starts = [exon.start for exon in self.exons]
-    exon_ends = [exon.end for exon in self.exons]
-    return min(exon_starts), max(exon_ends)
-
-  def tss(self):
-    if self.strand == '-':
-      return self.exons[-1].end
-    else:
-      return self.exons[0].start
-
-  def write_gtf(self, gtf_out, write_cds=False, write_utrs=False):
-    for ex in self.exons:
-      cols = [self.chrom, 'pygene', 'exon', str(ex.start), str(ex.end)]
-      cols += ['.', self.strand, '.', kv_gtf(self.kv)]
-      print('\t'.join(cols), file=gtf_out)
-    if write_cds:
-      for cds in self.cds:
-        cols = [self.chrom, 'pygene', 'CDS', str(cds.start), str(cds.end)]
-        cols += ['.', self.strand, '.', kv_gtf(self.kv)]
-        print('\t'.join(cols), file=gtf_out)
-    if write_utrs:
-      assert(self.utrs_defined)
-      for utr in self.utrs5:
-        cols = [self.chrom, 'pygene', '5\'UTR', str(utr.start), str(utr.end)]
-        cols += ['.', self.strand, '.', kv_gtf(self.kv)]
-        print('\t'.join(cols), file=gtf_out)
-      for utr in self.utrs3:
-        cols = [self.chrom, 'pygene', '3\'UTR', str(utr.start), str(utr.end)]
-        cols += ['.', self.strand, '.', kv_gtf(self.kv)]
-        print('\t'.join(cols), file=gtf_out)
-
-  def __str__(self):
-    return '%s %s %s %s' % (self.chrom, self.strand, kv_gtf(self.kv), ','.join([ex.__str__() for ex in self.exons]))
-
-
-class Gene:
-  def __init__(self):
-    self.transcripts = {}
-    self.chrom = None
-    self.strand = None
-    self.start = None
-    self.end = None
-
-  def add_transcript(self, tx_id, tx):
-    self.transcripts[tx_id] = tx
-    self.chrom = tx.chrom
-    self.strand = tx.strand
-    self.kv = tx.kv
-
-  def span(self):
-    tx_spans = [tx.span() for tx in self.transcripts.values()]
-    tx_starts, tx_ends = zip(*tx_spans)
-    self.start = min(tx_starts)
-    self.end = max(tx_ends)
-    return self.start, self.end
-
-
-class GTF:
-  def __init__(self, gtf_file, trim_dot=False):
-    self.gtf_file = gtf_file
-    self.genes = {}
-    self.transcripts = {}
-    self.utrs_defined = False
-    self.trim_dot = trim_dot
-
-    self.read_gtf()
-
-  def define_utrs(self):
-    self.utrs_defined = True
-    for tx in self.transcripts.values():
-      tx.define_utrs()
-
-  def read_gtf(self):
-    if self.gtf_file[-3:] == '.gz':
-      gtf_in = gzip.open(self.gtf_file, 'rt')
-    else:   
-      gtf_in = open(self.gtf_file)
-
-    # ignore header
-    line = gtf_in.readline()
-    while line[0] == '#':
-        line = gtf_in.readline()
-
-    while line:
-      a = line.split('\t')
-      if a[2] in ['exon','CDS']:
-        chrom = a[0]
-        interval_type = a[2]
-        start = int(a[3])
-        end = int(a[4])
-        strand = a[6]
-        kv = gtf_kv(a[8])
-
-        # add/get transcript
-        tx_id = kv['transcript_id']
-        if self.trim_dot:
-          tx_id = trim_dot(tx_id)
-        if not tx_id in self.transcripts:
-            self.transcripts[tx_id] = Transcript(chrom, strand, kv)
-        tx = self.transcripts[tx_id]
-
-        # add/get gene
-        gene_id = kv['gene_id']
-        if self.trim_dot:
-          gene_id = trim_dot(gene_id)
-        if not gene_id in self.genes:
-          self.genes[gene_id] = Gene()
-        self.genes[gene_id].add_transcript(tx_id, tx)
-
-        # add exons
-        if interval_type == 'exon':
-          tx.add_exon(start, end)
-        elif interval_type == 'CDS':
-          tx.add_cds(start, end)
-
-      line = gtf_in.readline()
-
-    gtf_in.close()
-
-    # sort transcript exons
-    for tx in self.transcripts.values():
-      tx.sort_exons()
-
-  def write_gtf(self, out_gtf_file, write_cds=False, write_utrs=False):
-    if write_utrs and not self.utrs_defined:
-      self.define_utrs()
-
-    gtf_out = open(out_gtf_file, 'w')
-    for tx in self.transcripts.values():
-      tx.write_gtf(gtf_out, write_cds, write_utrs)
-    gtf_out.close()
-
-
-################################################################################
-# Methods
-################################################################################
-def gtf_kv(s):
-  """Convert the last gtf section of key/value pairs into a dict."""
-  d = {}
-
-  a = s.split(';')
-  for key_val in a:
-    if key_val.strip():
-      eq_i = key_val.find('=')
-      if eq_i != -1 and key_val[eq_i-1] != '"':
-        kvs = key_val.split('=')
-      else:
-        kvs = key_val.split()
-
-      key = kvs[0]
-      if kvs[1][0] == '"' and kvs[-1][-1] == '"':
-        val = (' '.join(kvs[1:]))[1:-1].strip()
-      else:
-        val = (' '.join(kvs[1:])).strip()
-
-      d[key] = val
-
-  return d
-
-def kv_gtf(d):
-  """Convert a kv hash to str gtf representation."""
-  s = ''
-
-  if 'gene_id' in d.keys():
-    s += '%s "%s"; ' % ('gene_id',d['gene_id'])
-
-  if 'transcript_id' in d.keys():
-    s += '%s "%s"; ' % ('transcript_id',d['transcript_id'])
-
-  for key in sorted(d.keys()):
-    if key not in ['gene_id','transcript_id']:
-      s += '%s "%s"; ' % (key,d[key])
-
-  return s
-
-def trim_dot(gene_id):
-  """Trim the final dot suffix off a gene_id."""
-  dot_i = gene_id.rfind('.')
-  if dot_i != -1:
-    gene_id = gene_id[:dot_i]
-  return gene_id
\ No newline at end of file
diff --git a/src/baskerville/scripts/borzoi_test_genes.py b/src/baskerville/scripts/borzoi_test_genes.py
deleted file mode 100755
index 83f1dec..0000000
--- a/src/baskerville/scripts/borzoi_test_genes.py
+++ /dev/null
@@ -1,569 +0,0 @@
-#!/usr/bin/env python
-# Copyright 2021 Calico LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =========================================================================
-from optparse import OptionParser
-import gc
-import json
-import os
-import time
-
-from intervaltree import IntervalTree
-import numpy as np
-import pandas as pd
-import pybedtools
-import pyranges as pr
-from qnorm import quantile_normalize
-from scipy.stats import pearsonr
-from sklearn.metrics import explained_variance_score
-from tensorflow.keras import mixed_precision
-
-from baskerville import pygene
-from baskerville import dataset
-from baskerville import seqnn
-
-"""
-borzoi_test_genes.py
-
-Measure accuracy at gene-level.
-"""
-
-################################################################################
-# main
-################################################################################
-def main():
-    usage = "usage: %prog [options] <params_file> <model_file> <data_dir> <genes_gtf>"
-    parser = OptionParser(usage)
-    parser.add_option(
-        "--head",
-        dest="head_i",
-        default=0,
-        type="int",
-        help="Parameters head [Default: %default]",
-    )
-    parser.add_option(
-        "-o",
-        dest="out_dir",
-        default="testg_out",
-        help="Output directory for predictions [Default: %default]",
-    )
-    parser.add_option(
-        "--rc",
-        dest="rc",
-        default=False,
-        action="store_true",
-        help="Average the fwd and rc predictions [Default: %default]",
-    )
-    parser.add_option(
-        "--shifts",
-        dest="shifts",
-        default="0",
-        help="Ensemble prediction shifts [Default: %default]",
-    )
-    parser.add_option(
-        "--span",
-        dest="span",
-        default=False,
-        action="store_true",
-        help="Aggregate entire gene span [Default: %default]",
-    )
-    parser.add_option(
-        "--f16",
-        dest="f16",        
-        default=False,
-        action="store_true",
-        help="use mixed precision for inference",
-    )
-    parser.add_option(
-        "-t",
-        dest="targets_file",
-        default=None,
-        type="str",
-        help="File specifying target indexes and labels in table format",
-    )
-    parser.add_option(
-        "--split",
-        dest="split_label",
-        default="test",
-        help="Dataset split label for eg TFR pattern [Default: %default]",
-    )
-    parser.add_option(
-        "--tfr",
-        dest="tfr_pattern",
-        default=None,
-        help="TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]",
-    )
-    parser.add_option(
-        "-u",
-        dest="untransform_old",
-        default=False,
-        action="store_true",
-        help="Untransform old models [Default: %default]",
-    )
-    (options, args) = parser.parse_args()
-
-    if len(args) != 4:
-        parser.error("Must provide parameters, model, data directory, and genes GTF")
-    else:
-        params_file = args[0]
-        model_file = args[1]
-        data_dir = args[2]
-        genes_gtf_file = args[3]
-
-    if not os.path.isdir(options.out_dir):
-        os.mkdir(options.out_dir)
-
-    # parse shifts to integers
-    options.shifts = [int(shift) for shift in options.shifts.split(",")]
-
-    #######################################################
-    # inputs
-
-    # read targets
-    if options.targets_file is None:
-        options.targets_file = "%s/targets.txt" % data_dir
-    targets_df = pd.read_csv(options.targets_file, index_col=0, sep="\t")
-
-    # prep strand
-    targets_strand_df = dataset.targets_prep_strand(targets_df)
-    num_targets = targets_df.shape[0]
-    num_targets_strand = targets_strand_df.shape[0]
-
-    # read model parameters
-    with open(params_file) as params_open:
-        params = json.load(params_open)
-    params_model = params["model"]
-    params_train = params["train"]
-
-    # set strand pairs (using new indexing)
-    orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0])))
-    targets_strand_pair = np.array(
-        [orig_new_index[ti] for ti in targets_df.strand_pair]
-    )
-    params_model["strand_pair"] = [targets_strand_pair]
-
-    # construct eval data
-    eval_data = dataset.SeqDataset(
-        data_dir,
-        split_label=options.split_label,
-        batch_size=params_train["batch_size"],
-        mode="eval",
-        tfr_pattern=options.tfr_pattern,
-    )
-
-    # initialize model
-    ###################
-    # mixed precision #
-    ###################
-    if options.f16:
-        mixed_precision.set_global_policy('mixed_float16') # first set global policy
-        seqnn_model = seqnn.SeqNN(params_model) # then create model
-        seqnn_model.restore(model_file, options.head_i)
-        seqnn_model.append_activation() # add additional activation to cast float16 output to float32
-    else:
-        # initialize model
-        seqnn_model = seqnn.SeqNN(params_model)
-        seqnn_model.restore(model_file, options.head_i)
-    
-    seqnn_model.build_slice(targets_df.index)
-    seqnn_model.build_ensemble(options.rc, options.shifts)
-
-    #######################################################
-    # sequence intervals
-
-    # read data parameters
-    with open("%s/statistics.json" % data_dir) as data_open:
-        data_stats = json.load(data_open)
-        crop_bp = data_stats["crop_bp"]
-        pool_width = data_stats["pool_width"]
-
-    # read sequence positions
-    seqs_df = pd.read_csv(
-        "%s/sequences.bed" % data_dir,
-        sep="\t",
-        names=["Chromosome", "Start", "End", "Name"],
-    )
-    seqs_df = seqs_df[seqs_df.Name == options.split_label]
-    seqs_pr = pr.PyRanges(seqs_df)
-
-    #######################################################
-    # make gene BED
-
-    t0 = time.time()
-    print("Making gene BED...", end="")
-    genes_bed_file = "%s/genes.bed" % options.out_dir
-    if options.span:
-        make_genes_span(genes_bed_file, genes_gtf_file, options.out_dir)
-    else:
-        make_genes_exon(genes_bed_file, genes_gtf_file, options.out_dir)
-
-    genes_pr = pr.read_bed(genes_bed_file)
-    print("DONE in %ds" % (time.time() - t0))
-
-    # count gene normalization lengths
-    gene_lengths = {}
-    gene_strand = {}
-    for line in open(genes_bed_file):
-        a = line.rstrip().split("\t")
-        gene_id = a[3]
-        gene_seg_len = int(a[2]) - int(a[1])
-        gene_lengths[gene_id] = gene_lengths.get(gene_id, 0) + gene_seg_len
-        gene_strand[gene_id] = a[5]
-
-    #######################################################
-    # intersect genes w/ preds, targets
-
-    # intersect seqs, genes
-    t0 = time.time()
-    print("Intersecting sequences w/ genes...", end="")
-    seqs_genes_pr = seqs_pr.join(genes_pr)
-    print("DONE in %ds" % (time.time() - t0), flush=True)
-
-    # hash preds/targets by gene_id
-    gene_preds_dict = {}
-    gene_targets_dict = {}
-
-    si = 0
-    for x, y in eval_data.dataset:
-        # predict only if gene overlaps
-        yh = None
-        y = y.numpy()[..., targets_df.index]
-
-        t0 = time.time()
-        print("Sequence %d..." % si, end="")
-        for bsi in range(x.shape[0]):
-            seq = seqs_df.iloc[si + bsi]
-
-            cseqs_genes_df = seqs_genes_pr[seq.Chromosome].df
-            if cseqs_genes_df.shape[0] == 0:
-                # empty. no genes on this chromosome
-                seq_genes_df = cseqs_genes_df
-            else:
-                seq_genes_df = cseqs_genes_df[cseqs_genes_df.Start == seq.Start]
-
-            for _, seq_gene in seq_genes_df.iterrows():
-                gene_id = seq_gene.Name_b
-                gene_start = seq_gene.Start_b
-                gene_end = seq_gene.End_b
-                seq_start = seq_gene.Start
-
-                # clip boundaries
-                gene_seq_start = max(0, gene_start - seq_start)
-                gene_seq_end = max(0, gene_end - seq_start)
-
-                # requires >50% overlap
-                bin_start = int(np.round(gene_seq_start / pool_width))
-                bin_end = int(np.round(gene_seq_end / pool_width))
-
-                # predict
-                if yh is None:
-                    yh = seqnn_model(x)
-
-                # slice gene region
-                yhb = yh[bsi, bin_start:bin_end].astype("float16")
-                yb = y[bsi, bin_start:bin_end].astype("float16")
-
-                if len(yb) > 0:
-                    gene_preds_dict.setdefault(gene_id, []).append(yhb)
-                    gene_targets_dict.setdefault(gene_id, []).append(yb)
-
-        # advance sequence table index
-        si += x.shape[0]
-        print("DONE in %ds" % (time.time() - t0), flush=True)
-        if si % 128 == 0:
-            gc.collect()
-
-    # aggregate gene bin values into arrays
-    gene_targets = []
-    gene_preds = []
-    gene_ids = sorted(gene_targets_dict.keys())
-    gene_within = []
-    gene_wvar = []
-
-    for gene_id in gene_ids:
-        gene_preds_gi = np.concatenate(gene_preds_dict[gene_id], axis=0).astype(
-            "float32"
-        )
-        gene_targets_gi = np.concatenate(gene_targets_dict[gene_id], axis=0).astype(
-            "float32"
-        )
-
-        # slice strand
-        if gene_strand[gene_id] == "+":
-            gene_strand_mask = (targets_df.strand != "-").to_numpy()
-        else:
-            gene_strand_mask = (targets_df.strand != "+").to_numpy()
-        gene_preds_gi = gene_preds_gi[:, gene_strand_mask]
-        gene_targets_gi = gene_targets_gi[:, gene_strand_mask]
-
-        if gene_targets_gi.shape[0] == 0:
-            print(gene_id, gene_targets_gi.shape, gene_preds_gi.shape)
-
-        # untransform
-        if options.untransform_old:
-            gene_preds_gi = dataset.untransform_preds1(gene_preds_gi, targets_strand_df)
-            gene_targets_gi = dataset.untransform_preds1(gene_targets_gi, targets_strand_df)
-        else:
-            gene_preds_gi = dataset.untransform_preds(gene_preds_gi, targets_strand_df)
-            gene_targets_gi = dataset.untransform_preds(gene_targets_gi, targets_strand_df)
-
-        # compute within gene correlation before dropping length axis
-        gene_corr_gi = np.zeros(num_targets_strand)
-        for ti in range(num_targets_strand):
-            if (
-                gene_preds_gi[:, ti].var() > 1e-6
-                and gene_targets_gi[:, ti].var() > 1e-6
-            ):
-                preds_log = np.log2(gene_preds_gi[:, ti] + 1)
-                targets_log = np.log2(gene_targets_gi[:, ti] + 1)
-                gene_corr_gi[ti] = pearsonr(preds_log, targets_log)[0]
-                # gene_corr_gi[ti] = pearsonr(gene_preds_gi[:,ti], gene_targets_gi[:,ti])[0]
-            else:
-                gene_corr_gi[ti] = np.nan
-        gene_within.append(gene_corr_gi)
-        gene_wvar.append(gene_targets_gi.var(axis=0))
-
-        # TEMP: save gene preds/targets
-        # os.makedirs('%s/gene_within' % options.out_dir, exist_ok=True)
-        # np.save('%s/gene_within/%s_preds.npy' % (options.out_dir, gene_id), gene_preds_gi.astype('float16'))
-        # np.save('%s/gene_within/%s_targets.npy' % (options.out_dir, gene_id), gene_targets_gi.astype('float16'))
-
-        # mean coverage
-        gene_preds_gi = gene_preds_gi.mean(axis=0)
-        gene_targets_gi = gene_targets_gi.mean(axis=0)
-
-        # scale by gene length
-        gene_preds_gi *= gene_lengths[gene_id]
-        gene_targets_gi *= gene_lengths[gene_id]
-
-        gene_preds.append(gene_preds_gi)
-        gene_targets.append(gene_targets_gi)
-
-    gene_targets = np.array(gene_targets)
-    gene_preds = np.array(gene_preds)
-    gene_within = np.array(gene_within)
-    gene_wvar = np.array(gene_wvar)
-
-    # log2 transform
-    gene_targets = np.log2(gene_targets + 1)
-    gene_preds = np.log2(gene_preds + 1)
-
-    # save values
-    genes_targets_df = pd.DataFrame(
-        gene_targets, index=gene_ids, columns=targets_strand_df.identifier
-    )
-    genes_targets_df.to_csv("%s/gene_targets.tsv" % options.out_dir, sep="\t")
-    genes_preds_df = pd.DataFrame(
-        gene_preds, index=gene_ids, columns=targets_strand_df.identifier
-    )
-    genes_preds_df.to_csv("%s/gene_preds.tsv" % options.out_dir, sep="\t")
-    genes_within_df = pd.DataFrame(
-        gene_within, index=gene_ids, columns=targets_strand_df.identifier
-    )
-    genes_within_df.to_csv("%s/gene_within.tsv" % options.out_dir, sep="\t")
-    genes_var_df = pd.DataFrame(
-        gene_wvar, index=gene_ids, columns=targets_strand_df.identifier
-    )
-    genes_var_df.to_csv("%s/gene_var.tsv" % options.out_dir, sep="\t")
-
-    # quantile and mean normalize
-    gene_targets_norm = quantile_normalize(gene_targets, ncpus=2)
-    gene_targets_norm = gene_targets_norm - gene_targets_norm.mean(
-        axis=-1, keepdims=True
-    )
-    gene_preds_norm = quantile_normalize(gene_preds, ncpus=2)
-    gene_preds_norm = gene_preds_norm - gene_preds_norm.mean(axis=-1, keepdims=True)
-
-    #######################################################
-    # accuracy stats
-
-    wvar_t = np.percentile(gene_wvar, 80, axis=0)
-
-    acc_pearsonr = []
-    acc_r2 = []
-    acc_npearsonr = []
-    acc_nr2 = []
-    acc_wpearsonr = []
-    for ti in range(num_targets_strand):
-        r_ti = pearsonr(gene_targets[:, ti], gene_preds[:, ti])[0]
-        acc_pearsonr.append(r_ti)
-        r2_ti = explained_variance_score(gene_targets[:, ti], gene_preds[:, ti])
-        acc_r2.append(r2_ti)
-        nr_ti = pearsonr(gene_targets_norm[:, ti], gene_preds_norm[:, ti])[0]
-        acc_npearsonr.append(nr_ti)
-        nr2_ti = explained_variance_score(
-            gene_targets_norm[:, ti], gene_preds_norm[:, ti]
-        )
-        acc_nr2.append(nr2_ti)
-        var_mask = gene_wvar[:, ti] > wvar_t[ti]
-        wr_ti = gene_within[var_mask].mean()
-        acc_wpearsonr.append(wr_ti)
-
-    acc_df = pd.DataFrame(
-        {
-            "identifier": targets_strand_df.identifier,
-            "pearsonr": acc_pearsonr,
-            "r2": acc_r2,
-            "pearsonr_norm": acc_npearsonr,
-            "r2_norm": acc_nr2,
-            "pearsonr_gene": acc_wpearsonr,
-            "description": targets_strand_df.description,
-        }
-    )
-    acc_df.to_csv("%s/acc.txt" % options.out_dir, sep="\t")
-
-    print("%d genes" % gene_targets.shape[0])
-    print("Overall PearsonR:     %.4f" % np.mean(acc_df.pearsonr))
-    print("Overall R2:           %.4f" % np.mean(acc_df.r2))
-    print("Normalized PearsonR:  %.4f" % np.mean(acc_df.pearsonr_norm))
-    print("Normalized R2:        %.4f" % np.mean(acc_df.r2_norm))
-    print("Within-gene PearsonR: %.4f" % np.mean(acc_df.pearsonr_gene))
-
-
-def genes_aggregate(genes_bed_file, values_bedgraph):
-    """Aggregate values across genes.
-
-    Args:
-      genes_bed_file (str): BED file of genes.
-      values_bedgraph (str): BedGraph file of values.
-
-    Returns:
-      gene_values (dict): Dictionary of gene values.
-    """
-    values_bt = pybedtools.BedTool(values_bedgraph)
-    genes_bt = pybedtools.BedTool(genes_bed_file)
-
-    gene_values = {}
-
-    for overlap in genes_bt.intersect(values_bt, wo=True):
-        gene_id = overlap[3]
-        value = overlap[7]
-        gene_values[gene_id] = gene_values.get(gene_id, 0) + value
-
-    return gene_values
-
-
-def make_genes_exon(genes_bed_file: str, genes_gtf_file: str, out_dir: str):
-    """Make a BED file with each genes' exons, excluding exons overlapping
-      across genes.
-
-    Args:
-      genes_bed_file (str): Output BED file of genes.
-      genes_gtf_file (str): Input GTF file of genes.
-      out_dir (str): Output directory for temporary files.
-    """
-    # read genes
-    genes_gtf = pygene.GTF(genes_gtf_file)
-
-    # write gene exons
-    agenes_bed_file = "%s/genes_all.bed" % out_dir
-    agenes_bed_out = open(agenes_bed_file, "w")
-    for gene_id, gene in genes_gtf.genes.items():
-        # collect exons
-        gene_intervals = IntervalTree()
-        for tx_id, tx in gene.transcripts.items():
-            for exon in tx.exons:
-                gene_intervals[exon.start - 1 : exon.end] = True
-
-        # union
-        gene_intervals.merge_overlaps()
-
-        # write
-        for interval in sorted(gene_intervals):
-            cols = [
-                gene.chrom,
-                str(interval.begin),
-                str(interval.end),
-                gene_id,
-                ".",
-                gene.strand,
-            ]
-            print("\t".join(cols), file=agenes_bed_out)
-    agenes_bed_out.close()
-
-    # find overlapping exons
-    genes1_bt = pybedtools.BedTool(agenes_bed_file)
-    genes2_bt = pybedtools.BedTool(agenes_bed_file)
-    overlapping_exons = set()
-    for overlap in genes1_bt.intersect(genes2_bt, s=True, wo=True):
-        gene1_id = overlap[3]
-        gene1_start = int(overlap[1])
-        gene1_end = int(overlap[2])
-        overlapping_exons.add((gene1_id, gene1_start, gene1_end))
-
-        gene2_id = overlap[9]
-        gene2_start = int(overlap[7])
-        gene2_end = int(overlap[8])
-        overlapping_exons.add((gene2_id, gene2_start, gene2_end))
-
-    # filter for nonoverlapping exons
-    genes_bed_out = open(genes_bed_file, "w")
-    for line in open(agenes_bed_file):
-        a = line.split()
-        start = int(a[1])
-        end = int(a[2])
-        gene_id = a[-1]
-        if (gene_id, start, end) not in overlapping_exons:
-            print(line, end="", file=genes_bed_out)
-    genes_bed_out.close()
-
-
-def make_genes_span(
-    genes_bed_file: str, genes_gtf_file: str, out_dir: str, stranded: bool = True
-):
-    """Make a BED file with the span of each gene.
-
-    Args:
-      genes_bed_file (str): Output BED file of genes.
-      genes_gtf_file (str): Input GTF file of genes.
-      out_dir (str): Output directory for temporary files.
-      stranded (bool): Perform stranded intersection.
-    """
-    # read genes
-    genes_gtf = pygene.GTF(genes_gtf_file)
-
-    # write all gene spans
-    agenes_bed_file = "%s/genes_all.bed" % out_dir
-    agenes_bed_out = open(agenes_bed_file, "w")
-    for gene_id, gene in genes_gtf.genes.items():
-        start, end = gene.span()
-        cols = [gene.chrom, str(start - 1), str(end), gene_id, ".", gene.strand]
-        print("\t".join(cols), file=agenes_bed_out)
-    agenes_bed_out.close()
-
-    # find overlapping genes
-    genes1_bt = pybedtools.BedTool(agenes_bed_file)
-    genes2_bt = pybedtools.BedTool(agenes_bed_file)
-    overlapping_genes = set()
-    for overlap in genes1_bt.intersect(genes2_bt, s=stranded, wo=True):
-        gene1_id = overlap[3]
-        gene2_id = overlap[7]
-        if gene1_id != gene2_id:
-            overlapping_genes.add(gene1_id)
-            overlapping_genes.add(gene2_id)
-
-    # filter for nonoverlapping genes
-    genes_bed_out = open(genes_bed_file, "w")
-    for line in open(agenes_bed_file):
-        gene_id = line.split()[-1]
-        if gene_id not in overlapping_genes:
-            print(line, end="", file=genes_bed_out)
-    genes_bed_out.close()
-
-
-################################################################################
-# __main__
-################################################################################
-if __name__ == "__main__":
-    main()