diff --git a/src/baskerville/adapters.py b/src/baskerville/adapters.py new file mode 100644 index 0000000..f05a307 --- /dev/null +++ b/src/baskerville/adapters.py @@ -0,0 +1,301 @@ +# Copyright 2023 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +import pdb +import sys +from typing import Optional, List + +import numpy as np +import tensorflow as tf + +gpu_devices = tf.config.experimental.list_physical_devices("GPU") +for device in gpu_devices: + tf.config.experimental.set_memory_growth(device, True) + +##################### +# transfer learning # +##################### +class IA3(tf.keras.layers.Layer): + # https://arxiv.org/pdf/2205.05638.pdf + # ia3 module for attention layer, scale output. + + def __init__(self, + original_layer, + trainable=False, + **kwargs): + + # keep the name of this layer the same as the original dense layer. + original_layer_config = original_layer.get_config() + name = original_layer_config["name"] + kwargs.pop("name", None) + super().__init__(name=name, trainable=trainable, **kwargs) + + self.output_dim = original_layer_config["units"] + + self.original_layer = original_layer + self.original_layer.trainable = False + + # IA3 weights. Make it a dense layer to control trainable + self._ia3_layer = tf.keras.layers.Dense( + units=self.output_dim, + use_bias=False, + kernel_initializer=tf.keras.initializers.Ones(), + trainable=True, + name="ia3" + ) + + def call(self, inputs): + original_output = self.original_layer(inputs) + scaler = self._ia3_layer(tf.constant([[1]], dtype='float64'))[0] + return original_output * scaler + + def get_config(self): + config = super().get_config().copy() + config.update( + { + "size": self.output_dim, + } + ) + return config + +class IA3_ff(tf.keras.layers.Layer): + # https://arxiv.org/pdf/2205.05638.pdf + # ia3 module for down-projection ff layer, scale input. + + def __init__(self, + original_layer, + trainable=False, + **kwargs): + + # keep the name of this layer the same as the original dense layer. + original_layer_config = original_layer.get_config() + name = original_layer_config["name"] + kwargs.pop("name", None) + super().__init__(name=name, trainable=trainable, **kwargs) + + self.input_dim = original_layer.input_shape[-1] + + self.original_layer = original_layer + self.original_layer.trainable = False + + # IA3 weights. Make it a dense layer to control trainable + self._ia3_layer = tf.keras.layers.Dense( + units=self.input_dim, + use_bias=False, + kernel_initializer=tf.keras.initializers.Ones(), + trainable=True, + name="ia3_ff" + ) + + def call(self, inputs): + scaler = self._ia3_layer(tf.constant([[1]], dtype='float64'))[0] + return self.original_layer(inputs * scaler) + + def get_config(self): + config = super().get_config().copy() + config.update( + { + "size": self.input_dim + } + ) + return config + +class Lora(tf.keras.layers.Layer): + # adapted from: + # https://arxiv.org/abs/2106.09685 + # https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/ + # https://github.com/Elvenson/stable-diffusion-keras-ft/blob/main/layers.py + + def __init__(self, + original_layer, + rank=8, + alpha=16, + trainable=False, + **kwargs): + + # keep the name of this layer the same as the original dense layer. + original_layer_config = original_layer.get_config() + name = original_layer_config["name"] + kwargs.pop("name", None) + super().__init__(name=name, trainable=trainable, **kwargs) + + self.output_dim = original_layer_config["units"] + + if rank > self.output_dim: + raise ValueError(f"LoRA rank {rank} must be less or equal than {self.output_dim}") + + self.rank = rank + self.alpha = alpha + self.scale = alpha / rank + self.original_layer = original_layer + self.original_layer.trainable = False + + # Note: the original paper mentions that normal distribution was + # used for initialization. However, the official LoRA implementation + # uses "Kaiming/He Initialization". + self.down_layer = tf.keras.layers.Dense( + units=rank, + use_bias=False, + kernel_initializer=tf.keras.initializers.HeUniform(), + #kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1 / self.rank), + trainable=True, + name="lora_a" + ) + + self.up_layer = tf.keras.layers.Dense( + units=self.output_dim, + use_bias=False, + kernel_initializer=tf.keras.initializers.Zeros(), + trainable=True, + name="lora_b" + ) + + def call(self, inputs): + original_output = self.original_layer(inputs) + lora_output = self.up_layer(self.down_layer(inputs)) * self.scale + return original_output + lora_output + + def get_config(self): + config = super().get_config().copy() + config.update( + { + "rank": self.rank, + "alpha": self.alpha + } + ) + return config + +class Locon(tf.keras.layers.Layer): + # LoRA for conv-layer, adapted from: + # https://arxiv.org/pdf/2309.14859#page=23.84 + # https://github.com/KohakuBlueleaf/LyCORIS/blob/main/lycoris/modules/locon.py + # use default alpha and rank for locon + + def __init__(self, + original_layer, + rank=4, + alpha=1, + trainable=False, + **kwargs): + + # keep the name of this layer the same as the original conv layer. + original_layer_config = original_layer.get_config() + name = original_layer_config["name"] + kwargs.pop("name", None) + super().__init__(name=name, trainable=trainable, **kwargs) + + self.input_dim = original_layer.input_shape[-1] + self.output_dim = original_layer_config["filters"] + + if rank > self.output_dim: + raise ValueError(f"LoRA rank {rank} must be less or equal than {self.output_dim}") + + self.rank = rank + self.alpha = alpha + self.scale = alpha / rank + self.original_layer = original_layer + self.original_layer.trainable = False + + input_dim = original_layer.input_shape[-1] + output_dim = original_layer_config["filters"] + kernel_size = original_layer_config['kernel_size'][0] + stride = original_layer_config['strides'][0] + dilation_rate = original_layer_config["dilation_rate"][0] + + # Note: the original paper mentions that normal distribution was + # used for initialization. However, the official LoRA implementation + # uses "Kaiming/He Initialization". + + self.down_layer = tf.keras.layers.Conv1D( + filters=rank, + kernel_size=kernel_size, + strides=stride, + padding="same", + use_bias=False, + dilation_rate=dilation_rate, + kernel_initializer=tf.keras.initializers.HeUniform(), + name='locon_down' + ) + + self.up_layer = tf.keras.layers.Conv1D( + filters=output_dim, + kernel_size=1, + strides=stride, + padding="same", + use_bias=False, + kernel_initializer=tf.keras.initializers.Zeros(), + name='locon_up' + ) + + def call(self, inputs): + original_output = self.original_layer(inputs) + lora_output = self.up_layer(self.down_layer(inputs)) * self.scale + return original_output + lora_output + + def get_config(self): + config = super().get_config().copy() + config.update( + { + "rank": self.rank, + "alpha": self.alpha + } + ) + return config + +class AdapterHoulsby(tf.keras.layers.Layer): + # https://arxiv.org/abs/1902.00751 + # adapted from: https://github.com/jain-harshil/Adapter-BERT + + def __init__( + self, + latent_size, + activation=tf.keras.layers.ReLU(), + **kwargs): + super(AdapterHoulsby, self).__init__(**kwargs) + self.latent_size = latent_size + self.activation = activation + + def build(self, input_shape): + self.down_project = tf.keras.layers.Dense( + units=self.latent_size, + activation="linear", + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), + bias_initializer="zeros", + name='adapter_down' + ) + + self.up_project = tf.keras.layers.Dense( + units=input_shape[-1], + activation="linear", + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), + bias_initializer="zeros", + name='adapter_up' + ) + + def call(self, inputs): + projected_down = self.down_project(inputs) + activated = self.activation(projected_down) + projected_up = self.up_project(activated) + output = projected_up + inputs + return output + + def get_config(self): + config = super().get_config().copy() + config.update( + { + "latent_size": self.latent_size, + "activation": self.activation + } + ) + return config diff --git a/src/baskerville/helpers/transfer.py b/src/baskerville/helpers/transfer.py new file mode 100644 index 0000000..25b80f5 --- /dev/null +++ b/src/baskerville/helpers/transfer.py @@ -0,0 +1,636 @@ +import argparse +import json +import os +import shutil +import re +import h5py + +import numpy as np +import pandas as pd +import tensorflow as tf +from tensorflow.keras import mixed_precision + +from baskerville import dataset +from baskerville import seqnn +from baskerville import trainer +from baskerville import layers +from baskerville import adapters + +def param_count(layer, type='all'): + if type not in ['all','trainable','non_trainable']: + raise ValueError("TYPE must be one of all, trainable, non_trainable") + output = 0 + if type=='all': + output = int(sum(tf.keras.backend.count_params(w) for w in layer.weights)) + elif type=='trainable': + output = int(sum(tf.keras.backend.count_params(w) for w in layer.trainable_weights)) + else: + output = int(sum(tf.keras.backend.count_params(w) for w in layer.non_trainable_weights)) + return output + +def param_summary(model): + trainable = param_count(model, type='trainable') + non_trainable = param_count(model, type='non_trainable') + print('total params:%d' %(trainable + non_trainable)) + print('trainable params:%d' %trainable) + print('non-trainable params:%d' %non_trainable) + +def keras2dict(model): + layer_parent_dict = {} # the parent layers of each layer in the old graph + for layer in model.layers: + for node in layer._outbound_nodes: + layer_name = node.outbound_layer.name + if layer_name not in layer_parent_dict: + layer_parent_dict.update({layer_name: [layer.name]}) + else: + if layer.name not in layer_parent_dict[layer_name]: + layer_parent_dict[layer_name].append(layer.name) + return layer_parent_dict + +# lora requires change model.h5 weight order. +# locon and ia3 don't modify model in place. +def var_reorder(weight_h5): + # assumes weight_h5 model saved with seqnn_model.save() + # [i.name for i in model.layers[30].weights] to check for multihead_attention layer weights order. + # model.load_weights() load weights sequencially, assuming h5 weights are in the right order. + # When inserting lora, multihead_attention layer weights order changed. + # multihead_attention layer weights order is saved inside f['model_weights']['multihead_attention'].attrs + # After saving the weight_merged model, we need to go into the weights.h5, and change the attrs in multihead attention. + var_init_order = ['r_w_bias:0:0', + 'r_r_bias:0:0', + 'q_layer/kernel:0', + 'k_layer/kernel:0', + 'v_layer/kernel:0', + 'embedding_layer/kernel:0', + 'embedding_layer/bias:0', + 'r_k_layer/kernel:0'] + + f = h5py.File(weight_h5, 'r+') + layers = [i for i in list(f['model_weights'].keys()) if 'multihead_attention' in i] + for l_name in layers: + new_name_order = [l_name+'/'+i for i in var_init_order] + f['model_weights'][l_name].attrs.modify(name='weight_names', value=new_name_order) + f.close() + + +# houlsby requires architecture change. +# thus we need to modify json. +def modify_json(input_json, output_json, adapter, latent=8, se_rank=None, conv_select=None): + + with open(input_json) as params_open: + params = json.load(params_open) + + # houlsby + if adapter=='adapterHoulsby': + params["model"]['adapter']= 'houlsby' + params["model"]['adapter_latent']= latent + + # houlsby_se + elif adapter=='houlsby_se': + params["model"]['adapter']= 'houlsby_se' + params["model"]['adapter_latent']= latent + params["model"]['se_rank']= se_rank + params["model"]['conv_select']= conv_select + + else: + raise ValueError("adapter must be adapterHoulsby or houlsby_se") + + ### output + with open(output_json, 'w') as params_open: + json.dump(params, params_open, indent=4) + +###################### +# add houlsby layers # +###################### +def add_houlsby(input_model, strand_pair, latent_size=8): + # take seqnn_model as input + # output a new seqnn_model object + # only the adapter, and layer_norm are trainable + + ################## + # houlsby layers # + ################## + houlsby_layers = [] + for i in range(len(input_model.layers)-1): + layer = input_model.layers[i] + next_layer = input_model.layers[i+1] + if re.match('dropout', layer.name) and re.match('add', next_layer.name): + houlsby_layers += [next_layer.name] + + ################### + # construct model # + ################### + layer_parent_dict_old = keras2dict(input_model) + # remove switch_reverse_layer + to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)] + for i in to_fix: + del layer_parent_dict_old[i] + # create new graph + layer_output_dict_new = {} # the output tensor of each layer in the new graph + layer_output_dict_new.update({input_model.layers[0].name: input_model.input}) + # Iterate over all layers after the input + model_outputs = [] + reverse_bool = None + + for layer in input_model.layers[1:-1]: + + # parent layers + parent_layers = layer_parent_dict_old[layer.name] + + # layer inputs + layer_input = [layer_output_dict_new[parent] for parent in parent_layers] + if len(layer_input) == 1: layer_input = layer_input[0] + + if re.match('stochastic_reverse_complement', layer.name): + x, reverse_bool = layer(layer_input) + + # insert houlsby layer: + elif layer.name in houlsby_layers: + print('adapter added before:%s'%layer.name) + x = adapters.AdapterHoulsby(latent_size=latent_size)(layer_input[1]) + x = layer([layer_input[0], x]) + + else: + x = layer(layer_input) + + # save the output tensor of every layer + layer_output_dict_new.update({layer.name: x}) + + final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool]) + model_adapter = tf.keras.Model(inputs=input_model.inputs, outputs=final) + + # set trainable # + for l in model_adapter.layers[:-2]: # trunk + if re.match('layer_normalization|adapter_houlsby', l.name): + l.trainable = True + else: + l.trainable = False + + # expected number of trainable params added/unfrozen: + params_added = 0 + for l in model_adapter.layers: + if l.name.startswith("adapter_houlsby"): + params_added += param_count(l) + elif l.name.startswith("layer_normalization"): + params_added += param_count(l, type='trainable') + print('params added/unfrozen by adapter_houlsby: %d'%params_added) + + return model_adapter + +############### +# lora layers # +############### +def add_lora(input_model, rank=8, alpha=16, mode='default', report_param=True): + # take seqnn.model as input + # replace _q_layer, _v_layer in multihead_attention + # optionally replace _k_layer, _embedding_layer + if mode not in ['default','full']: + raise ValueError("mode must be default or full") + + for layer in input_model.layers: + if re.match('multihead_attention', layer.name): + # default loRA + layer._q_layer = adapters.Lora(layer._q_layer, rank=rank, alpha=alpha, trainable=True) + layer._v_layer = adapters.Lora(layer._v_layer, rank=rank, alpha=alpha, trainable=True) + # full loRA + if mode=='full': + layer._k_layer = adapters.Lora(layer._k_layer, rank=rank, alpha=alpha, trainable=True) + layer._embedding_layer = adapters.Lora(layer._embedding_layer, rank=rank, alpha=alpha, trainable=True) + + input_model(input_model.input) # initialize new variables + + # freeze all params but lora + for layer in input_model._flatten_layers(): + lst_of_sublayers = list(layer._flatten_layers()) + if len(lst_of_sublayers) == 1: + if layer.name in ["lora_a", "lora_b"]: + layer.trainable = True + else: + layer.trainable = False + + ### bias terms need to be frozen separately + for layer in input_model.layers: + if re.match('multihead_attention', layer.name): + layer._r_w_bias = tf.Variable(layer._r_w_bias, trainable=False, name=layer._r_w_bias.name) + layer._r_r_bias = tf.Variable(layer._r_r_bias, trainable=False, name=layer._r_r_bias.name) + + # set final head to be trainable + input_model.layers[-2].trainable=True + + # expected number of trainable params added/unfrozen: + params_added = 0 + for l in input_model.layers: + if re.match('multihead_attention', l.name): + params_added += param_count(l._q_layer.down_layer) + params_added += param_count(l._q_layer.up_layer) + params_added += param_count(l._v_layer.down_layer) + params_added += param_count(l._v_layer.up_layer) + if mode=='full': + params_added += param_count(l._k_layer.down_layer) + params_added += param_count(l._k_layer.up_layer) + params_added += param_count(l._embedding_layer.down_layer) + params_added += param_count(l._embedding_layer.up_layer) + + if report_param: + print('params added/unfrozen by lora: %d'%params_added) + +############### +# lora layers # +############### +def add_lora_conv(input_model, conv_select=None): + + # add lora layers + add_lora(input_model, rank=8, alpha=16, mode='default', report_param=False) + + # list all conv layers + conv_layers = [] + for layer in input_model.layers: + if re.match('conv1d', layer.name): + conv_layers += [layer.name] + if conv_select is None: + conv_select = len(conv_layers) + if conv_select > len(conv_layers): + raise ValueError("conv_select must be less than number of conv layers %d."%len(conv_layers)) + + # set conv layers trainable + trainable_conv = conv_layers[-conv_select:] + for layer in input_model.layers: + if layer.name in trainable_conv: + layer.trainable=True + + # expected number of trainable params added/unfrozen: + params_added = 0 + for l in input_model.layers: + if re.match('multihead_attention', l.name): + params_added += param_count(l._q_layer.down_layer) + params_added += param_count(l._q_layer.up_layer) + params_added += param_count(l._v_layer.down_layer) + params_added += param_count(l._v_layer.up_layer) + elif l.name in trainable_conv: + params_added += param_count(l) + + print('params added/unfrozen by lora_conv: %d'%params_added) + +# merge lora weights # +def merge_lora_layer(lora_layer): + down_weights = lora_layer.down_layer.kernel + up_weights = lora_layer.up_layer.kernel + increment_weights = tf.einsum("ab,bc->ac", down_weights, up_weights) * lora_layer.scale + lora_layer.original_layer.kernel.assign_add(increment_weights) + return lora_layer.original_layer + +def merge_lora(input_model): + for layer in input_model.layers: + if 'multihead_attention' in layer.name: + if isinstance(layer._q_layer, adapters.Lora): + layer._q_layer = merge_lora_layer(layer._q_layer) + if isinstance(layer._v_layer, adapters.Lora): + layer._v_layer = merge_lora_layer(layer._v_layer) + if isinstance(layer._k_layer, adapters.Lora): + layer._k_layer = merge_lora_layer(layer._k_layer) + if isinstance(layer._embedding_layer, adapters.Lora): + layer._embedding_layer = merge_lora_layer(layer._embedding_layer) + input_model(input_model.input) + + +############## +# IA3 layers # +############## +def add_ia3(input_model, strand_pair): + + # add to kv layers # + for layer in input_model.layers: + if re.match('multihead_attention', layer.name): + layer._k_layer = adapters.IA3(layer._k_layer, trainable=True) + layer._v_layer = adapters.IA3(layer._v_layer, trainable=True) + + # add to ff layer # + # save old graph to dictionary + layer_parent_dict_old = keras2dict(input_model) + + # remove switch_reverse_layer + to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)] + for i in to_fix: + del layer_parent_dict_old[i] + + # create new graph + layer_output_dict_new = {} # the output tensor of each layer in the new graph + layer_output_dict_new.update({input_model.layers[0].name: input_model.input}) + + # Iterate over all layers after the input + model_outputs = [] + reverse_bool = None + for layer in input_model.layers[1:-1]: + + # get layer inputs + parent_layers = layer_parent_dict_old[layer.name] + layer_input = [layer_output_dict_new[parent] for parent in parent_layers] + if len(layer_input) == 1: layer_input = layer_input[0] + + # construct + if re.match('stochastic_reverse_complement', layer.name): + x, reverse_bool = layer(layer_input) + # transformer ff down-project layer (1536 -> 768): + elif re.match('dense', layer.name) and layer.input_shape[-1]==1536: + x = adapters.IA3_ff(layer, trainable=True)(layer_input) + else: + x = layer(layer_input) + + # save layers to dictionary + layer_output_dict_new.update({layer.name: x}) + + final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool]) + model_adapter = tf.keras.Model(inputs=input_model.inputs, outputs=final) + + # set trainable # + for layer in model_adapter._flatten_layers(): + lst_of_sublayers = list(layer._flatten_layers()) + if len(lst_of_sublayers) == 1: + if layer.name in ['ia3', 'ia3_ff']: + layer.trainable = True + else: + layer.trainable = False + + ### bias terms need to be frozen separately + for layer in model_adapter.layers: + if re.match('multihead_attention', layer.name): + layer._r_w_bias = tf.Variable(layer._r_w_bias, trainable=False, name=layer._r_w_bias.name) + layer._r_r_bias = tf.Variable(layer._r_r_bias, trainable=False, name=layer._r_r_bias.name) + + # set final head to be trainable + model_adapter.layers[-2].trainable=True + + # expected number of trainable params added/unfrozen: + params_added = 0 + for l in model_adapter.layers: + if re.match('multihead_attention', l.name): # kv layers + params_added += param_count(l._k_layer._ia3_layer) + params_added += param_count(l._v_layer._ia3_layer) + elif re.match('dense', l.name) and l.input_shape[-1]==1536: # ff layers + params_added += param_count(l._ia3_layer) + + print('params added/unfrozen by ia3: %d'%params_added) + + return model_adapter + +def merge_ia3(original_model, ia3_model): + # original model contains pre-trained weights + # ia3 model is the fine-tuned ia3 model + for i, layer in enumerate(original_model.layers): + # attention layers + if re.match('multihead_attention', layer.name): + # scale k + k_scaler = ia3_model.layers[i]._k_layer._ia3_layer.kernel[0] + layer._k_layer.kernel.assign(layer._k_layer.kernel * k_scaler) + # scale v + v_scaler = ia3_model.layers[i]._v_layer._ia3_layer.kernel[0] + layer._v_layer.kernel.assign(layer._v_layer.kernel * v_scaler) + # ff layers + elif re.match('dense', layer.name) and layer.input_shape[-1]==1536: + ff_scaler = tf.expand_dims(ia3_model.layers[i]._ia3_layer.kernel[0], 1) + layer.kernel.assign(layer.kernel * ff_scaler) + # other layers + else: + layer.set_weights(ia3_model.layers[i].get_weights()) + +############# +# add locon # +############# +def add_locon(input_model, strand_pair, conv_select=None, rank=4, alpha=1): + + # first add lora to attention + add_lora(input_model, report_param=False) + + # decide: + # 1. whether conv1 is trainable + # 2. which conv layers to add loRA + + # all conv layers + conv_layers = [] + for layer in input_model.layers: + if re.match('conv1d', layer.name): + conv_layers += [layer.name] + + if conv_select is None: + conv_select = len(conv_layers) + + if conv_select > len(conv_layers): + raise ValueError("conv_select must be less than number of conv layers %d."%len(conv_layers)) + + locon_layers = [] + conv1_tune = False + if conv_select == len(conv_layers): + locon_layers = conv_layers[1:] + conv1_tune = True + else: + locon_layers = conv_layers[-conv_select:] + + layer_parent_dict_old = keras2dict(input_model) + + # remove switch_reverse_layer + to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)] + for i in to_fix: + del layer_parent_dict_old[i] + + # create new graph + layer_output_dict_new = {} # the output tensor of each layer in the new graph + layer_output_dict_new.update({input_model.layers[0].name: input_model.input}) + + # Iterate over all layers after the input + model_outputs = [] + reverse_bool = None + for layer in input_model.layers[1:-1]: + + # get layer inputs + parent_layers = layer_parent_dict_old[layer.name] + layer_input = [layer_output_dict_new[parent] for parent in parent_layers] + if len(layer_input) == 1: layer_input = layer_input[0] + + # construct + if re.match('stochastic_reverse_complement', layer.name): + x, reverse_bool = layer(layer_input) + elif layer.name in locon_layers: + x = adapters.Locon(layer, trainable=True, rank=rank, alpha=alpha)(layer_input) + else: + x = layer(layer_input) + + # save layers to dictionary + layer_output_dict_new.update({layer.name: x}) + + final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool]) + model_adapter = tf.keras.Model(inputs=input_model.inputs, outputs=final) + + if conv1_tune: + model_adapter.get_layer(name=conv_layers[0]).trainable = True + + # expected number of trainable params added/unfrozen: + params_added = 0 + if conv1_tune: + params_added += param_count(model_adapter.get_layer(name=conv_layers[0])) + for l in model_adapter.layers: + if re.match('multihead_attention', l.name): + params_added += param_count(l._q_layer.down_layer) + params_added += param_count(l._q_layer.up_layer) + params_added += param_count(l._v_layer.down_layer) + params_added += param_count(l._v_layer.up_layer) + if l.name in locon_layers: + params_added += param_count(l.down_layer) + params_added += param_count(l.up_layer) + + print('params added/unfrozen by lora: %d'%params_added) + + return model_adapter + +#### functions to merge locon +def lora_increment(layer): + down_weights = layer.down_layer.kernel + up_weights = layer.up_layer.kernel + increment_weights = tf.einsum("ab,bc->ac", down_weights, up_weights) * layer.scale + return increment_weights + +def locon_increment(layer): + down_weights = layer.down_layer.kernel + up_weights = layer.up_layer.kernel[0] + increment_weights = tf.einsum("abc,cd->abd", down_weights, up_weights) * layer.scale + return increment_weights + +def merge_locon(original_model, locon_model): + # original model contains pre-trained weights + for i, layer in enumerate(original_model.layers): + + # lora layers + if re.match('multihead_attention', layer.name): + q = locon_model.layers[i]._q_layer + k = locon_model.layers[i]._k_layer + v = locon_model.layers[i]._v_layer + e = locon_model.layers[i]._embedding_layer + if isinstance(q, adapters.Lora): + increment_weights = lora_increment(q) + layer._q_layer.kernel.assign_add(increment_weights) + if isinstance(v, adapters.Lora): + increment_weights = lora_increment(v) + layer._v_layer.kernel.assign_add(increment_weights) + if isinstance(k, adapters.Lora): + increment_weights = lora_increment(k) + layer._k_layer.kernel.assign_add(increment_weights) + if isinstance(e, adapters.Lora): + increment_weights = lora_increment(e) + layer._embedding_layer.kernel.assign_add(increment_weights) + + # locon layers + elif isinstance(locon_model.layers[i], adapters.Locon): + increment_weights = locon_increment(locon_model.layers[i]) + layer.kernel.assign_add(increment_weights) + + else: + layer.set_weights(locon_model.layers[i].get_weights()) + + +############## +# houlsby_se # +############## +def add_houlsby_se(input_model, strand_pair, houlsby_latent=8, conv_select=None, se_rank=16): + # add squeeze-excitation blocks after conv + # input_model should be properly frozen + # pre_att: add se_block to pre-attention conv1d + # all: add se_block to pre-attention conv1d and post-attention separable_conv1d + + ################## + # houlsby layers # + ################## + houlsby_layers = [] + for i in range(len(input_model.layers)-1): + layer = input_model.layers[i] + next_layer = input_model.layers[i+1] + if re.match('dropout', layer.name) and re.match('add', next_layer.name): + houlsby_layers += [next_layer.name] + + ############# + # SE layers # + ############# + conv_layers = [] + for layer in input_model.layers: + if re.match('conv1d', layer.name): + conv_layers += [layer.name] + if conv_select is None: + se_layers = conv_layers[1:] + if conv_select >= len(conv_layers): + raise ValueError("conv_select must be less than number of conv layers %d."%len(conv_layers)) + se_layers = conv_layers[-conv_select:] + + ################### + # construct model # + ################### + layer_parent_dict_old = keras2dict(input_model) + # remove switch_reverse_layer + to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)] + for i in to_fix: + del layer_parent_dict_old[i] + # create new graph + layer_output_dict_new = {} # the output tensor of each layer in the new graph + layer_output_dict_new.update({input_model.layers[0].name: input_model.input}) + # Iterate over all layers after the input + model_outputs = [] + reverse_bool = None + + for layer in input_model.layers[1:-1]: + + # parent layers + parent_layers = layer_parent_dict_old[layer.name] + + # layer inputs + layer_input = [layer_output_dict_new[parent] for parent in parent_layers] + if len(layer_input) == 1: layer_input = layer_input[0] + + if layer.name.startswith("stochastic_reverse_complement"): + x, reverse_bool = layer(layer_input) + + # insert houlsby layer: + elif layer.name in houlsby_layers: + print('adapter added before:%s'%layer.name) + x = adapters.AdapterHoulsby(latent_size=houlsby_latent)(layer_input[1]) + x = layer([layer_input[0], x]) + + # insert squeeze-excite layer: + elif layer.name in se_layers: + se_layer = layers.SqueezeExcite( + activation=None, # no activation before squeezing + additive=False, # use sigmoid multiplicative scaling + rank=se_rank, # bottleneck ratio + use_bias=False, # ignore bias + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), # near-zero weight initialization + scale_fun='tanh' + ) + x = layer(layer_input) + x = x + se_layer(x) + + else: + x = layer(layer_input) + + # save the output tensor of every layer + layer_output_dict_new.update({layer.name: x}) + + final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool]) + model_final = tf.keras.Model(inputs=input_model.inputs, outputs=final) + + # set trainable + for l in model_final.layers[:-2]: # trunk + if re.match('layer_normalization|adapter_houlsby', l.name): + l.trainable = True + else: + l.trainable = False + + for l in model_final.layers: # set trunk + if l.name.startswith("squeeze_excite"): l.trainable = True + + # expected number of trainable params added/unfrozen: + params_added = 0 + for l in model_final.layers: + if re.match('squeeze_excite|adapter_houlsby', l.name): + params_added += param_count(l) + elif l.name.startswith("layer_normalization"): + params_added += param_count(l, type='trainable') + print('params added/unfrozen by houlsby_se: %d'%params_added) + + return model_final + diff --git a/src/baskerville/pygene.py b/src/baskerville/pygene.py deleted file mode 100755 index 86cae4f..0000000 --- a/src/baskerville/pygene.py +++ /dev/null @@ -1,324 +0,0 @@ -#!/usr/bin/env python -from optparse import OptionParser - -import gzip -import pdb - -''' -pygene - -Classes and methods to manage genes in GTF format. -''' - -################################################################################ -# Classes -################################################################################ -class GenomicInterval: - def __init__(self, start, end, chrom=None, strand=None): - self.start = start - self.end = end - self.chrom = chrom - self.strand = strand - - def __eq__(self, other): - return self.start == other.start - - def __lt__(self, other): - return self.start < other.start - - def __cmp__(self, x): - if self.start < x.start: - return -1 - elif self.start > x.start: - return 1 - else: - return 0 - - def __str__(self): - if self.chrom is None: - label = '[%d-%d]' % (self.start, self.end) - else: - label = '%s:%d-%d' % (self.chrom, self.start, self.end) - return label - - -class Transcript: - def __init__(self, chrom, strand, kv): - self.chrom = chrom - self.strand = strand - self.kv = kv - self.exons = [] - self.cds = [] - self.utrs3 = [] - self.utrs5 = [] - self.sorted = False - self.utrs_defined = False - - def add_cds(self, start, end): - self.cds.append(GenomicInterval(start,end)) - - def add_exon(self, start, end): - self.exons.append(GenomicInterval(start,end)) - - def define_utrs(self): - self.utrs_defined = True - - if len(self.cds) == 0: - self.utrs3 = self.exons - - else: - assert(self.sorted) - - # reset UTR lists - self.utrs5 = [] - self.utrs3 = [] - - # match up exons and CDS - ci = 0 - for ei in range(len(self.exons)): - # left initial - if self.exons[ei].end < self.cds[ci].start: - utr = GenomicInterval(self.exons[ei].start, self.exons[ei].end) - if self.strand == '+': - self.utrs5.append(utr) - else: - self.utrs3.append(utr) - - # right initial - elif self.cds[ci].end < self.exons[ei].start: - utr = GenomicInterval(self.exons[ei].start, self.exons[ei].end) - if self.strand == '+': - self.utrs3.append(utr) - else: - self.utrs5.append(utr) - - # overlap - else: - # left overlap - if self.exons[ei].start < self.cds[ci].start: - utr = GenomicInterval(self.exons[ei].start, self.cds[ci].start-1) - if self.strand == '+': - self.utrs5.append(utr) - else: - self.utrs3.append(utr) - - # right overlap - if self.cds[ci].end < self.exons[ei].end: - utr = GenomicInterval(self.cds[ci].end+1, self.exons[ei].end) - if self.strand == '+': - self.utrs3.append(utr) - else: - self.utrs5.append(utr) - - # increment up to last - ci = min(ci+1, len(self.cds)-1) - - def fasta_cds(self, fasta_open, stranded=False): - assert(self.sorted) - gene_seq = '' - for exon in self.cds: - exon_seq = fasta_open.fetch(self.chrom, exon.start-1, exon.end) - gene_seq += exon_seq - if stranded and self.strand == '-': - gene_seq = rc(gene_seq) - return gene_seq - - def fasta_exons(self, fasta_open, stranded=False): - assert(self.sorted) - gene_seq = '' - for exon in self.exons: - exon_seq = fasta_open.fetch(self.chrom, exon.start-1, exon.end) - gene_seq += exon_seq - if stranded and self.strand == '-': - gene_seq = rc(gene_seq) - return gene_seq - - def sort_exons(self): - self.sorted = True - if len(self.exons) > 1: - self.exons.sort() - if len(self.cds) > 1: - self.cds.sort() - - def span(self): - exon_starts = [exon.start for exon in self.exons] - exon_ends = [exon.end for exon in self.exons] - return min(exon_starts), max(exon_ends) - - def tss(self): - if self.strand == '-': - return self.exons[-1].end - else: - return self.exons[0].start - - def write_gtf(self, gtf_out, write_cds=False, write_utrs=False): - for ex in self.exons: - cols = [self.chrom, 'pygene', 'exon', str(ex.start), str(ex.end)] - cols += ['.', self.strand, '.', kv_gtf(self.kv)] - print('\t'.join(cols), file=gtf_out) - if write_cds: - for cds in self.cds: - cols = [self.chrom, 'pygene', 'CDS', str(cds.start), str(cds.end)] - cols += ['.', self.strand, '.', kv_gtf(self.kv)] - print('\t'.join(cols), file=gtf_out) - if write_utrs: - assert(self.utrs_defined) - for utr in self.utrs5: - cols = [self.chrom, 'pygene', '5\'UTR', str(utr.start), str(utr.end)] - cols += ['.', self.strand, '.', kv_gtf(self.kv)] - print('\t'.join(cols), file=gtf_out) - for utr in self.utrs3: - cols = [self.chrom, 'pygene', '3\'UTR', str(utr.start), str(utr.end)] - cols += ['.', self.strand, '.', kv_gtf(self.kv)] - print('\t'.join(cols), file=gtf_out) - - def __str__(self): - return '%s %s %s %s' % (self.chrom, self.strand, kv_gtf(self.kv), ','.join([ex.__str__() for ex in self.exons])) - - -class Gene: - def __init__(self): - self.transcripts = {} - self.chrom = None - self.strand = None - self.start = None - self.end = None - - def add_transcript(self, tx_id, tx): - self.transcripts[tx_id] = tx - self.chrom = tx.chrom - self.strand = tx.strand - self.kv = tx.kv - - def span(self): - tx_spans = [tx.span() for tx in self.transcripts.values()] - tx_starts, tx_ends = zip(*tx_spans) - self.start = min(tx_starts) - self.end = max(tx_ends) - return self.start, self.end - - -class GTF: - def __init__(self, gtf_file, trim_dot=False): - self.gtf_file = gtf_file - self.genes = {} - self.transcripts = {} - self.utrs_defined = False - self.trim_dot = trim_dot - - self.read_gtf() - - def define_utrs(self): - self.utrs_defined = True - for tx in self.transcripts.values(): - tx.define_utrs() - - def read_gtf(self): - if self.gtf_file[-3:] == '.gz': - gtf_in = gzip.open(self.gtf_file, 'rt') - else: - gtf_in = open(self.gtf_file) - - # ignore header - line = gtf_in.readline() - while line[0] == '#': - line = gtf_in.readline() - - while line: - a = line.split('\t') - if a[2] in ['exon','CDS']: - chrom = a[0] - interval_type = a[2] - start = int(a[3]) - end = int(a[4]) - strand = a[6] - kv = gtf_kv(a[8]) - - # add/get transcript - tx_id = kv['transcript_id'] - if self.trim_dot: - tx_id = trim_dot(tx_id) - if not tx_id in self.transcripts: - self.transcripts[tx_id] = Transcript(chrom, strand, kv) - tx = self.transcripts[tx_id] - - # add/get gene - gene_id = kv['gene_id'] - if self.trim_dot: - gene_id = trim_dot(gene_id) - if not gene_id in self.genes: - self.genes[gene_id] = Gene() - self.genes[gene_id].add_transcript(tx_id, tx) - - # add exons - if interval_type == 'exon': - tx.add_exon(start, end) - elif interval_type == 'CDS': - tx.add_cds(start, end) - - line = gtf_in.readline() - - gtf_in.close() - - # sort transcript exons - for tx in self.transcripts.values(): - tx.sort_exons() - - def write_gtf(self, out_gtf_file, write_cds=False, write_utrs=False): - if write_utrs and not self.utrs_defined: - self.define_utrs() - - gtf_out = open(out_gtf_file, 'w') - for tx in self.transcripts.values(): - tx.write_gtf(gtf_out, write_cds, write_utrs) - gtf_out.close() - - -################################################################################ -# Methods -################################################################################ -def gtf_kv(s): - """Convert the last gtf section of key/value pairs into a dict.""" - d = {} - - a = s.split(';') - for key_val in a: - if key_val.strip(): - eq_i = key_val.find('=') - if eq_i != -1 and key_val[eq_i-1] != '"': - kvs = key_val.split('=') - else: - kvs = key_val.split() - - key = kvs[0] - if kvs[1][0] == '"' and kvs[-1][-1] == '"': - val = (' '.join(kvs[1:]))[1:-1].strip() - else: - val = (' '.join(kvs[1:])).strip() - - d[key] = val - - return d - -def kv_gtf(d): - """Convert a kv hash to str gtf representation.""" - s = '' - - if 'gene_id' in d.keys(): - s += '%s "%s"; ' % ('gene_id',d['gene_id']) - - if 'transcript_id' in d.keys(): - s += '%s "%s"; ' % ('transcript_id',d['transcript_id']) - - for key in sorted(d.keys()): - if key not in ['gene_id','transcript_id']: - s += '%s "%s"; ' % (key,d[key]) - - return s - -def trim_dot(gene_id): - """Trim the final dot suffix off a gene_id.""" - dot_i = gene_id.rfind('.') - if dot_i != -1: - gene_id = gene_id[:dot_i] - return gene_id \ No newline at end of file diff --git a/src/baskerville/scripts/borzoi_test_genes.py b/src/baskerville/scripts/borzoi_test_genes.py deleted file mode 100755 index 83f1dec..0000000 --- a/src/baskerville/scripts/borzoi_test_genes.py +++ /dev/null @@ -1,569 +0,0 @@ -#!/usr/bin/env python -# Copyright 2021 Calico LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= -from optparse import OptionParser -import gc -import json -import os -import time - -from intervaltree import IntervalTree -import numpy as np -import pandas as pd -import pybedtools -import pyranges as pr -from qnorm import quantile_normalize -from scipy.stats import pearsonr -from sklearn.metrics import explained_variance_score -from tensorflow.keras import mixed_precision - -from baskerville import pygene -from baskerville import dataset -from baskerville import seqnn - -""" -borzoi_test_genes.py - -Measure accuracy at gene-level. -""" - -################################################################################ -# main -################################################################################ -def main(): - usage = "usage: %prog [options] <params_file> <model_file> <data_dir> <genes_gtf>" - parser = OptionParser(usage) - parser.add_option( - "--head", - dest="head_i", - default=0, - type="int", - help="Parameters head [Default: %default]", - ) - parser.add_option( - "-o", - dest="out_dir", - default="testg_out", - help="Output directory for predictions [Default: %default]", - ) - parser.add_option( - "--rc", - dest="rc", - default=False, - action="store_true", - help="Average the fwd and rc predictions [Default: %default]", - ) - parser.add_option( - "--shifts", - dest="shifts", - default="0", - help="Ensemble prediction shifts [Default: %default]", - ) - parser.add_option( - "--span", - dest="span", - default=False, - action="store_true", - help="Aggregate entire gene span [Default: %default]", - ) - parser.add_option( - "--f16", - dest="f16", - default=False, - action="store_true", - help="use mixed precision for inference", - ) - parser.add_option( - "-t", - dest="targets_file", - default=None, - type="str", - help="File specifying target indexes and labels in table format", - ) - parser.add_option( - "--split", - dest="split_label", - default="test", - help="Dataset split label for eg TFR pattern [Default: %default]", - ) - parser.add_option( - "--tfr", - dest="tfr_pattern", - default=None, - help="TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]", - ) - parser.add_option( - "-u", - dest="untransform_old", - default=False, - action="store_true", - help="Untransform old models [Default: %default]", - ) - (options, args) = parser.parse_args() - - if len(args) != 4: - parser.error("Must provide parameters, model, data directory, and genes GTF") - else: - params_file = args[0] - model_file = args[1] - data_dir = args[2] - genes_gtf_file = args[3] - - if not os.path.isdir(options.out_dir): - os.mkdir(options.out_dir) - - # parse shifts to integers - options.shifts = [int(shift) for shift in options.shifts.split(",")] - - ####################################################### - # inputs - - # read targets - if options.targets_file is None: - options.targets_file = "%s/targets.txt" % data_dir - targets_df = pd.read_csv(options.targets_file, index_col=0, sep="\t") - - # prep strand - targets_strand_df = dataset.targets_prep_strand(targets_df) - num_targets = targets_df.shape[0] - num_targets_strand = targets_strand_df.shape[0] - - # read model parameters - with open(params_file) as params_open: - params = json.load(params_open) - params_model = params["model"] - params_train = params["train"] - - # set strand pairs (using new indexing) - orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) - targets_strand_pair = np.array( - [orig_new_index[ti] for ti in targets_df.strand_pair] - ) - params_model["strand_pair"] = [targets_strand_pair] - - # construct eval data - eval_data = dataset.SeqDataset( - data_dir, - split_label=options.split_label, - batch_size=params_train["batch_size"], - mode="eval", - tfr_pattern=options.tfr_pattern, - ) - - # initialize model - ################### - # mixed precision # - ################### - if options.f16: - mixed_precision.set_global_policy('mixed_float16') # first set global policy - seqnn_model = seqnn.SeqNN(params_model) # then create model - seqnn_model.restore(model_file, options.head_i) - seqnn_model.append_activation() # add additional activation to cast float16 output to float32 - else: - # initialize model - seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(model_file, options.head_i) - - seqnn_model.build_slice(targets_df.index) - seqnn_model.build_ensemble(options.rc, options.shifts) - - ####################################################### - # sequence intervals - - # read data parameters - with open("%s/statistics.json" % data_dir) as data_open: - data_stats = json.load(data_open) - crop_bp = data_stats["crop_bp"] - pool_width = data_stats["pool_width"] - - # read sequence positions - seqs_df = pd.read_csv( - "%s/sequences.bed" % data_dir, - sep="\t", - names=["Chromosome", "Start", "End", "Name"], - ) - seqs_df = seqs_df[seqs_df.Name == options.split_label] - seqs_pr = pr.PyRanges(seqs_df) - - ####################################################### - # make gene BED - - t0 = time.time() - print("Making gene BED...", end="") - genes_bed_file = "%s/genes.bed" % options.out_dir - if options.span: - make_genes_span(genes_bed_file, genes_gtf_file, options.out_dir) - else: - make_genes_exon(genes_bed_file, genes_gtf_file, options.out_dir) - - genes_pr = pr.read_bed(genes_bed_file) - print("DONE in %ds" % (time.time() - t0)) - - # count gene normalization lengths - gene_lengths = {} - gene_strand = {} - for line in open(genes_bed_file): - a = line.rstrip().split("\t") - gene_id = a[3] - gene_seg_len = int(a[2]) - int(a[1]) - gene_lengths[gene_id] = gene_lengths.get(gene_id, 0) + gene_seg_len - gene_strand[gene_id] = a[5] - - ####################################################### - # intersect genes w/ preds, targets - - # intersect seqs, genes - t0 = time.time() - print("Intersecting sequences w/ genes...", end="") - seqs_genes_pr = seqs_pr.join(genes_pr) - print("DONE in %ds" % (time.time() - t0), flush=True) - - # hash preds/targets by gene_id - gene_preds_dict = {} - gene_targets_dict = {} - - si = 0 - for x, y in eval_data.dataset: - # predict only if gene overlaps - yh = None - y = y.numpy()[..., targets_df.index] - - t0 = time.time() - print("Sequence %d..." % si, end="") - for bsi in range(x.shape[0]): - seq = seqs_df.iloc[si + bsi] - - cseqs_genes_df = seqs_genes_pr[seq.Chromosome].df - if cseqs_genes_df.shape[0] == 0: - # empty. no genes on this chromosome - seq_genes_df = cseqs_genes_df - else: - seq_genes_df = cseqs_genes_df[cseqs_genes_df.Start == seq.Start] - - for _, seq_gene in seq_genes_df.iterrows(): - gene_id = seq_gene.Name_b - gene_start = seq_gene.Start_b - gene_end = seq_gene.End_b - seq_start = seq_gene.Start - - # clip boundaries - gene_seq_start = max(0, gene_start - seq_start) - gene_seq_end = max(0, gene_end - seq_start) - - # requires >50% overlap - bin_start = int(np.round(gene_seq_start / pool_width)) - bin_end = int(np.round(gene_seq_end / pool_width)) - - # predict - if yh is None: - yh = seqnn_model(x) - - # slice gene region - yhb = yh[bsi, bin_start:bin_end].astype("float16") - yb = y[bsi, bin_start:bin_end].astype("float16") - - if len(yb) > 0: - gene_preds_dict.setdefault(gene_id, []).append(yhb) - gene_targets_dict.setdefault(gene_id, []).append(yb) - - # advance sequence table index - si += x.shape[0] - print("DONE in %ds" % (time.time() - t0), flush=True) - if si % 128 == 0: - gc.collect() - - # aggregate gene bin values into arrays - gene_targets = [] - gene_preds = [] - gene_ids = sorted(gene_targets_dict.keys()) - gene_within = [] - gene_wvar = [] - - for gene_id in gene_ids: - gene_preds_gi = np.concatenate(gene_preds_dict[gene_id], axis=0).astype( - "float32" - ) - gene_targets_gi = np.concatenate(gene_targets_dict[gene_id], axis=0).astype( - "float32" - ) - - # slice strand - if gene_strand[gene_id] == "+": - gene_strand_mask = (targets_df.strand != "-").to_numpy() - else: - gene_strand_mask = (targets_df.strand != "+").to_numpy() - gene_preds_gi = gene_preds_gi[:, gene_strand_mask] - gene_targets_gi = gene_targets_gi[:, gene_strand_mask] - - if gene_targets_gi.shape[0] == 0: - print(gene_id, gene_targets_gi.shape, gene_preds_gi.shape) - - # untransform - if options.untransform_old: - gene_preds_gi = dataset.untransform_preds1(gene_preds_gi, targets_strand_df) - gene_targets_gi = dataset.untransform_preds1(gene_targets_gi, targets_strand_df) - else: - gene_preds_gi = dataset.untransform_preds(gene_preds_gi, targets_strand_df) - gene_targets_gi = dataset.untransform_preds(gene_targets_gi, targets_strand_df) - - # compute within gene correlation before dropping length axis - gene_corr_gi = np.zeros(num_targets_strand) - for ti in range(num_targets_strand): - if ( - gene_preds_gi[:, ti].var() > 1e-6 - and gene_targets_gi[:, ti].var() > 1e-6 - ): - preds_log = np.log2(gene_preds_gi[:, ti] + 1) - targets_log = np.log2(gene_targets_gi[:, ti] + 1) - gene_corr_gi[ti] = pearsonr(preds_log, targets_log)[0] - # gene_corr_gi[ti] = pearsonr(gene_preds_gi[:,ti], gene_targets_gi[:,ti])[0] - else: - gene_corr_gi[ti] = np.nan - gene_within.append(gene_corr_gi) - gene_wvar.append(gene_targets_gi.var(axis=0)) - - # TEMP: save gene preds/targets - # os.makedirs('%s/gene_within' % options.out_dir, exist_ok=True) - # np.save('%s/gene_within/%s_preds.npy' % (options.out_dir, gene_id), gene_preds_gi.astype('float16')) - # np.save('%s/gene_within/%s_targets.npy' % (options.out_dir, gene_id), gene_targets_gi.astype('float16')) - - # mean coverage - gene_preds_gi = gene_preds_gi.mean(axis=0) - gene_targets_gi = gene_targets_gi.mean(axis=0) - - # scale by gene length - gene_preds_gi *= gene_lengths[gene_id] - gene_targets_gi *= gene_lengths[gene_id] - - gene_preds.append(gene_preds_gi) - gene_targets.append(gene_targets_gi) - - gene_targets = np.array(gene_targets) - gene_preds = np.array(gene_preds) - gene_within = np.array(gene_within) - gene_wvar = np.array(gene_wvar) - - # log2 transform - gene_targets = np.log2(gene_targets + 1) - gene_preds = np.log2(gene_preds + 1) - - # save values - genes_targets_df = pd.DataFrame( - gene_targets, index=gene_ids, columns=targets_strand_df.identifier - ) - genes_targets_df.to_csv("%s/gene_targets.tsv" % options.out_dir, sep="\t") - genes_preds_df = pd.DataFrame( - gene_preds, index=gene_ids, columns=targets_strand_df.identifier - ) - genes_preds_df.to_csv("%s/gene_preds.tsv" % options.out_dir, sep="\t") - genes_within_df = pd.DataFrame( - gene_within, index=gene_ids, columns=targets_strand_df.identifier - ) - genes_within_df.to_csv("%s/gene_within.tsv" % options.out_dir, sep="\t") - genes_var_df = pd.DataFrame( - gene_wvar, index=gene_ids, columns=targets_strand_df.identifier - ) - genes_var_df.to_csv("%s/gene_var.tsv" % options.out_dir, sep="\t") - - # quantile and mean normalize - gene_targets_norm = quantile_normalize(gene_targets, ncpus=2) - gene_targets_norm = gene_targets_norm - gene_targets_norm.mean( - axis=-1, keepdims=True - ) - gene_preds_norm = quantile_normalize(gene_preds, ncpus=2) - gene_preds_norm = gene_preds_norm - gene_preds_norm.mean(axis=-1, keepdims=True) - - ####################################################### - # accuracy stats - - wvar_t = np.percentile(gene_wvar, 80, axis=0) - - acc_pearsonr = [] - acc_r2 = [] - acc_npearsonr = [] - acc_nr2 = [] - acc_wpearsonr = [] - for ti in range(num_targets_strand): - r_ti = pearsonr(gene_targets[:, ti], gene_preds[:, ti])[0] - acc_pearsonr.append(r_ti) - r2_ti = explained_variance_score(gene_targets[:, ti], gene_preds[:, ti]) - acc_r2.append(r2_ti) - nr_ti = pearsonr(gene_targets_norm[:, ti], gene_preds_norm[:, ti])[0] - acc_npearsonr.append(nr_ti) - nr2_ti = explained_variance_score( - gene_targets_norm[:, ti], gene_preds_norm[:, ti] - ) - acc_nr2.append(nr2_ti) - var_mask = gene_wvar[:, ti] > wvar_t[ti] - wr_ti = gene_within[var_mask].mean() - acc_wpearsonr.append(wr_ti) - - acc_df = pd.DataFrame( - { - "identifier": targets_strand_df.identifier, - "pearsonr": acc_pearsonr, - "r2": acc_r2, - "pearsonr_norm": acc_npearsonr, - "r2_norm": acc_nr2, - "pearsonr_gene": acc_wpearsonr, - "description": targets_strand_df.description, - } - ) - acc_df.to_csv("%s/acc.txt" % options.out_dir, sep="\t") - - print("%d genes" % gene_targets.shape[0]) - print("Overall PearsonR: %.4f" % np.mean(acc_df.pearsonr)) - print("Overall R2: %.4f" % np.mean(acc_df.r2)) - print("Normalized PearsonR: %.4f" % np.mean(acc_df.pearsonr_norm)) - print("Normalized R2: %.4f" % np.mean(acc_df.r2_norm)) - print("Within-gene PearsonR: %.4f" % np.mean(acc_df.pearsonr_gene)) - - -def genes_aggregate(genes_bed_file, values_bedgraph): - """Aggregate values across genes. - - Args: - genes_bed_file (str): BED file of genes. - values_bedgraph (str): BedGraph file of values. - - Returns: - gene_values (dict): Dictionary of gene values. - """ - values_bt = pybedtools.BedTool(values_bedgraph) - genes_bt = pybedtools.BedTool(genes_bed_file) - - gene_values = {} - - for overlap in genes_bt.intersect(values_bt, wo=True): - gene_id = overlap[3] - value = overlap[7] - gene_values[gene_id] = gene_values.get(gene_id, 0) + value - - return gene_values - - -def make_genes_exon(genes_bed_file: str, genes_gtf_file: str, out_dir: str): - """Make a BED file with each genes' exons, excluding exons overlapping - across genes. - - Args: - genes_bed_file (str): Output BED file of genes. - genes_gtf_file (str): Input GTF file of genes. - out_dir (str): Output directory for temporary files. - """ - # read genes - genes_gtf = pygene.GTF(genes_gtf_file) - - # write gene exons - agenes_bed_file = "%s/genes_all.bed" % out_dir - agenes_bed_out = open(agenes_bed_file, "w") - for gene_id, gene in genes_gtf.genes.items(): - # collect exons - gene_intervals = IntervalTree() - for tx_id, tx in gene.transcripts.items(): - for exon in tx.exons: - gene_intervals[exon.start - 1 : exon.end] = True - - # union - gene_intervals.merge_overlaps() - - # write - for interval in sorted(gene_intervals): - cols = [ - gene.chrom, - str(interval.begin), - str(interval.end), - gene_id, - ".", - gene.strand, - ] - print("\t".join(cols), file=agenes_bed_out) - agenes_bed_out.close() - - # find overlapping exons - genes1_bt = pybedtools.BedTool(agenes_bed_file) - genes2_bt = pybedtools.BedTool(agenes_bed_file) - overlapping_exons = set() - for overlap in genes1_bt.intersect(genes2_bt, s=True, wo=True): - gene1_id = overlap[3] - gene1_start = int(overlap[1]) - gene1_end = int(overlap[2]) - overlapping_exons.add((gene1_id, gene1_start, gene1_end)) - - gene2_id = overlap[9] - gene2_start = int(overlap[7]) - gene2_end = int(overlap[8]) - overlapping_exons.add((gene2_id, gene2_start, gene2_end)) - - # filter for nonoverlapping exons - genes_bed_out = open(genes_bed_file, "w") - for line in open(agenes_bed_file): - a = line.split() - start = int(a[1]) - end = int(a[2]) - gene_id = a[-1] - if (gene_id, start, end) not in overlapping_exons: - print(line, end="", file=genes_bed_out) - genes_bed_out.close() - - -def make_genes_span( - genes_bed_file: str, genes_gtf_file: str, out_dir: str, stranded: bool = True -): - """Make a BED file with the span of each gene. - - Args: - genes_bed_file (str): Output BED file of genes. - genes_gtf_file (str): Input GTF file of genes. - out_dir (str): Output directory for temporary files. - stranded (bool): Perform stranded intersection. - """ - # read genes - genes_gtf = pygene.GTF(genes_gtf_file) - - # write all gene spans - agenes_bed_file = "%s/genes_all.bed" % out_dir - agenes_bed_out = open(agenes_bed_file, "w") - for gene_id, gene in genes_gtf.genes.items(): - start, end = gene.span() - cols = [gene.chrom, str(start - 1), str(end), gene_id, ".", gene.strand] - print("\t".join(cols), file=agenes_bed_out) - agenes_bed_out.close() - - # find overlapping genes - genes1_bt = pybedtools.BedTool(agenes_bed_file) - genes2_bt = pybedtools.BedTool(agenes_bed_file) - overlapping_genes = set() - for overlap in genes1_bt.intersect(genes2_bt, s=stranded, wo=True): - gene1_id = overlap[3] - gene2_id = overlap[7] - if gene1_id != gene2_id: - overlapping_genes.add(gene1_id) - overlapping_genes.add(gene2_id) - - # filter for nonoverlapping genes - genes_bed_out = open(genes_bed_file, "w") - for line in open(agenes_bed_file): - gene_id = line.split()[-1] - if gene_id not in overlapping_genes: - print(line, end="", file=genes_bed_out) - genes_bed_out.close() - - -################################################################################ -# __main__ -################################################################################ -if __name__ == "__main__": - main()