From 57bbd11294937e086d1a328b3ba33f0e087fb48f Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Mon, 22 Apr 2024 23:41:04 +0800 Subject: [PATCH] improve token initialisation Signed-off-by: Zhiyuan Chen --- .../models/rnabert/convert_checkpoint.py | 38 ++++----- .../models/rnafm/convert_checkpoint.py | 38 ++++----- .../models/rnamsm/convert_checkpoint.py | 39 ++++------ .../models/splicebert/convert_checkpoint.py | 36 ++++----- .../models/utrbert/convert_checkpoint.py | 36 ++++----- .../models/utrlm/convert_checkpoint.py | 38 ++++----- multimolecule/tokenizers/rna/utils.py | 77 +++++++++++++++++++ 7 files changed, 174 insertions(+), 128 deletions(-) diff --git a/multimolecule/models/rnabert/convert_checkpoint.py b/multimolecule/models/rnabert/convert_checkpoint.py index 7205553c..1783a5ce 100644 --- a/multimolecule/models/rnabert/convert_checkpoint.py +++ b/multimolecule/models/rnabert/convert_checkpoint.py @@ -3,11 +3,15 @@ import chanfig import torch -from torch import nn from multimolecule.models.rnabert import RnaBertConfig as Config from multimolecule.models.rnabert import RnaBertForPretraining as Model -from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list +from multimolecule.tokenizers.rna.utils import ( + convert_word_embeddings, + get_special_tokens_map, + get_tokenizer_config, + get_vocab_list, +) try: from huggingface_hub import HfApi @@ -33,27 +37,17 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_ continue state_dict[key] = value - state_vocab_size = state_dict["rnabert.embeddings.word_embeddings.weight"].size(0) - original_vocab_size = len(original_vocab_list) - if state_vocab_size != original_vocab_size: - raise ValueError( - f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}." - ) - word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - word_embed_weight = word_embed.weight.data - predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size)) - predictions_bias = torch.zeros(config.vocab_size) - # nn.init.normal_(pos_embed.weight, std=0.02) - for original_index, original_token in enumerate(original_vocab_list): - new_index = vocab_list.index(original_token) - word_embed_weight[new_index] = state_dict["rnabert.embeddings.word_embeddings.weight"][original_index] - predictions_decoder_weight[new_index] = state_dict["pretrain_head.predictions.decoder.weight"][original_index] - predictions_bias[new_index] = state_dict["pretrain_head.predictions.bias"][original_index] - state_dict["rnabert.embeddings.word_embeddings.weight"] = word_embed_weight - state_dict["pretrain_head.predictions.decoder.weight"] = predictions_decoder_weight - state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = ( - predictions_bias + word_embed_weight, decoder_weight, decoder_bias = convert_word_embeddings( + state_dict["rnabert.embeddings.word_embeddings.weight"], + state_dict["pretrain_head.predictions.decoder.weight"], + state_dict["pretrain_head.predictions.bias"], + old_vocab=original_vocab_list, + new_vocab=vocab_list, + std=config.initializer_range, ) + state_dict["rnabert.embeddings.word_embeddings.weight"] = word_embed_weight + state_dict["pretrain_head.predictions.decoder.weight"] = decoder_weight + state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = decoder_bias state_dict["pretrain_head.predictions_ss.decoder.bias"] = state_dict["pretrain_head.predictions_ss.bias"] return state_dict diff --git a/multimolecule/models/rnafm/convert_checkpoint.py b/multimolecule/models/rnafm/convert_checkpoint.py index e03ff378..25499c57 100644 --- a/multimolecule/models/rnafm/convert_checkpoint.py +++ b/multimolecule/models/rnafm/convert_checkpoint.py @@ -3,11 +3,15 @@ import chanfig import torch -from torch import nn from multimolecule.models import RnaFmConfig as Config from multimolecule.models import RnaFmForPretraining as Model -from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list +from multimolecule.tokenizers.rna.utils import ( + convert_word_embeddings, + get_special_tokens_map, + get_tokenizer_config, + get_vocab_list, +) try: from huggingface_hub import HfApi @@ -45,27 +49,17 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_ key = key.replace("rnafm.encoder.contact_head", "pretrain_head.contact") state_dict[key] = value - state_vocab_size = state_dict["rnafm.embeddings.word_embeddings.weight"].size(0) - original_vocab_size = len(original_vocab_list) - if state_vocab_size != original_vocab_size: - raise ValueError( - f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}." - ) - word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - word_embed_weight = word_embed.weight.data - predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size)) - predictions_bias = torch.zeros(config.vocab_size) - # nn.init.normal_(pos_embed.weight, std=0.02) - for original_index, original_token in enumerate(original_vocab_list): - new_index = vocab_list.index(original_token) - word_embed_weight[new_index] = state_dict["rnafm.embeddings.word_embeddings.weight"][original_index] - predictions_decoder_weight[new_index] = state_dict["pretrain_head.predictions.decoder.weight"][original_index] - predictions_bias[new_index] = state_dict["pretrain_head.predictions.bias"][original_index] - state_dict["rnafm.embeddings.word_embeddings.weight"] = word_embed_weight - state_dict["pretrain_head.predictions.decoder.weight"] = predictions_decoder_weight - state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = ( - predictions_bias + word_embed_weight, decoder_weight, decoder_bias = convert_word_embeddings( + state_dict["rnafm.embeddings.word_embeddings.weight"], + state_dict["pretrain_head.predictions.decoder.weight"], + state_dict["pretrain_head.predictions.bias"], + old_vocab=original_vocab_list, + new_vocab=vocab_list, + std=config.initializer_range, ) + state_dict["rnafm.embeddings.word_embeddings.weight"] = word_embed_weight + state_dict["pretrain_head.predictions.decoder.weight"] = decoder_weight + state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = decoder_bias return state_dict diff --git a/multimolecule/models/rnamsm/convert_checkpoint.py b/multimolecule/models/rnamsm/convert_checkpoint.py index 94efe36b..0d181763 100644 --- a/multimolecule/models/rnamsm/convert_checkpoint.py +++ b/multimolecule/models/rnamsm/convert_checkpoint.py @@ -3,11 +3,15 @@ import chanfig import torch -from torch import nn from multimolecule.models import RnaMsmConfig as Config from multimolecule.models import RnaMsmForPretraining as Model -from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list +from multimolecule.tokenizers.rna.utils import ( + convert_word_embeddings, + get_special_tokens_map, + get_tokenizer_config, + get_vocab_list, +) try: from huggingface_hub import HfApi @@ -29,31 +33,22 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_ key = key.replace("regression", "decoder") key = key.replace("contact_head", "pretrain_head.contact") key = key.replace("lm_head", "pretrain_head.predictions") + key = key.replace("pretrain_head.predictions.weight", "pretrain_head.predictions.decoder.weight") key = key.replace("pretrain_head.predictions.dense", "pretrain_head.predictions.transform.dense") key = key.replace("pretrain_head.predictions.layer_norm", "pretrain_head.predictions.transform.layer_norm") state_dict[key] = value - state_vocab_size = state_dict["rnamsm.embeddings.word_embeddings.weight"].size(0) - original_vocab_size = len(original_vocab_list) - if state_vocab_size != original_vocab_size: - raise ValueError( - f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}." - ) - word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - word_embed_weight = word_embed.weight.data - predictions_bias = torch.zeros(config.vocab_size) - # nn.init.normal_(pos_embed.weight, std=0.02) - for original_index, original_token in enumerate(original_vocab_list): - new_index = vocab_list.index(original_token) - word_embed_weight[new_index] = state_dict["rnamsm.embeddings.word_embeddings.weight"][original_index] - predictions_bias[new_index] = state_dict["pretrain_head.predictions.bias"][original_index] - state_dict["rnamsm.embeddings.word_embeddings.weight"] = state_dict["pretrain_head.predictions.decoder.weight"] = ( - word_embed_weight - ) - state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = ( - predictions_bias + word_embed_weight, decoder_weight, decoder_bias = convert_word_embeddings( + state_dict["rnamsm.embeddings.word_embeddings.weight"], + state_dict["pretrain_head.predictions.decoder.weight"], + state_dict["pretrain_head.predictions.bias"], + old_vocab=original_vocab_list, + new_vocab=vocab_list, + std=config.initializer_range, ) - del state_dict["pretrain_head.predictions.weight"] + state_dict["rnamsm.embeddings.word_embeddings.weight"] = word_embed_weight + state_dict["pretrain_head.predictions.decoder.weight"] = decoder_weight + state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = decoder_bias return state_dict diff --git a/multimolecule/models/splicebert/convert_checkpoint.py b/multimolecule/models/splicebert/convert_checkpoint.py index 36d83fe8..ef3d298f 100644 --- a/multimolecule/models/splicebert/convert_checkpoint.py +++ b/multimolecule/models/splicebert/convert_checkpoint.py @@ -3,11 +3,15 @@ import chanfig import torch -from torch import nn from multimolecule.models import SpliceBertConfig as Config from multimolecule.models import SpliceBertForPretraining as Model -from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list +from multimolecule.tokenizers.rna.utils import ( + convert_word_embeddings, + get_special_tokens_map, + get_tokenizer_config, + get_vocab_list, +) try: from huggingface_hub import HfApi @@ -32,25 +36,17 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_ continue state_dict[key] = value - state_vocab_size = state_dict["splicebert.embeddings.word_embeddings.weight"].size(0) - original_vocab_size = len(original_vocab_list) - if state_vocab_size != original_vocab_size: - raise ValueError( - f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}." - ) - word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - word_embed_weight = word_embed.weight.data - predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size)) - predictions_bias = torch.zeros(config.vocab_size) - # nn.init.normal_(pos_embed.weight, std=0.02) - for original_index, original_token in enumerate(original_vocab_list): - new_index = vocab_list.index(original_token) - word_embed_weight[new_index] = state_dict["splicebert.embeddings.word_embeddings.weight"][original_index] - predictions_decoder_weight[new_index] = state_dict["lm_head.decoder.weight"][original_index] - predictions_bias[new_index] = state_dict["lm_head.decoder.bias"][original_index] + [word_embed_weight, decoder_weight], [decoder_bias] = convert_word_embeddings( + state_dict["splicebert.embeddings.word_embeddings.weight"], + state_dict["lm_head.decoder.weight"], + state_dict["lm_head.decoder.bias"], + old_vocab=original_vocab_list, + new_vocab=vocab_list, + std=config.initializer_range, + ) state_dict["splicebert.embeddings.word_embeddings.weight"] = word_embed_weight - state_dict["lm_head.decoder.weight"] = predictions_decoder_weight - state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"] = predictions_bias + state_dict["lm_head.decoder.weight"] = decoder_weight + state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"] = decoder_bias del state_dict["splicebert.embeddings.position_ids"] return state_dict diff --git a/multimolecule/models/utrbert/convert_checkpoint.py b/multimolecule/models/utrbert/convert_checkpoint.py index 7d3a1ed1..f2c8e7ae 100644 --- a/multimolecule/models/utrbert/convert_checkpoint.py +++ b/multimolecule/models/utrbert/convert_checkpoint.py @@ -3,11 +3,15 @@ import chanfig import torch -from torch import nn from multimolecule.models import UtrBertConfig as Config from multimolecule.models import UtrBertForPretraining as Model -from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list +from multimolecule.tokenizers.rna.utils import ( + convert_word_embeddings, + get_special_tokens_map, + get_tokenizer_config, + get_vocab_list, +) try: from huggingface_hub import HfApi @@ -32,25 +36,17 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_ continue state_dict[key] = value - state_vocab_size = state_dict["utrbert.embeddings.word_embeddings.weight"].size(0) - original_vocab_size = len(original_vocab_list) - if state_vocab_size != original_vocab_size: - raise ValueError( - f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}." - ) - word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - word_embed_weight = word_embed.weight.data - predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size)) - predictions_bias = torch.zeros(config.vocab_size) - # nn.init.normal_(pos_embed.weight, std=0.02) - for original_index, original_token in enumerate(original_vocab_list): - new_index = vocab_list.index(original_token) - word_embed_weight[new_index] = state_dict["utrbert.embeddings.word_embeddings.weight"][original_index] - predictions_decoder_weight[new_index] = state_dict["lm_head.decoder.weight"][original_index] - predictions_bias[new_index] = state_dict["lm_head.decoder.bias"][original_index] + [word_embed_weight, decoder_weight], [decoder_bias] = convert_word_embeddings( + state_dict["utrbert.embeddings.word_embeddings.weight"], + state_dict["lm_head.decoder.weight"], + state_dict["lm_head.decoder.bias"], + old_vocab=original_vocab_list, + new_vocab=vocab_list, + std=config.initializer_range, + ) state_dict["utrbert.embeddings.word_embeddings.weight"] = word_embed_weight - state_dict["lm_head.decoder.weight"] = predictions_decoder_weight - state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"] = predictions_bias + state_dict["lm_head.decoder.weight"] = decoder_weight + state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"] = decoder_bias return state_dict diff --git a/multimolecule/models/utrlm/convert_checkpoint.py b/multimolecule/models/utrlm/convert_checkpoint.py index 15636247..7ce148a7 100644 --- a/multimolecule/models/utrlm/convert_checkpoint.py +++ b/multimolecule/models/utrlm/convert_checkpoint.py @@ -3,11 +3,15 @@ import chanfig import torch -from torch import nn from multimolecule.models import UtrLmConfig as Config from multimolecule.models import UtrLmForPretraining as Model -from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list +from multimolecule.tokenizers.rna.utils import ( + convert_word_embeddings, + get_special_tokens_map, + get_tokenizer_config, + get_vocab_list, +) try: from huggingface_hub import HfApi @@ -50,27 +54,17 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_ key = key.replace("utrlm.supervised_linear", "pretrain_head.supervised.decoder") state_dict[key] = value - state_vocab_size = state_dict["utrlm.embeddings.word_embeddings.weight"].size(0) - original_vocab_size = len(original_vocab_list) - if state_vocab_size != original_vocab_size: - raise ValueError( - f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}." - ) - word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - word_embed_weight = word_embed.weight.data - predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size)) - predictions_bias = torch.zeros(config.vocab_size) - # nn.init.normal_(pos_embed.weight, std=0.02) - for original_index, original_token in enumerate(original_vocab_list): - new_index = vocab_list.index(original_token) - word_embed_weight[new_index] = state_dict["utrlm.embeddings.word_embeddings.weight"][original_index] - predictions_decoder_weight[new_index] = state_dict["pretrain_head.predictions.decoder.weight"][original_index] - predictions_bias[new_index] = state_dict["pretrain_head.predictions.bias"][original_index] - state_dict["utrlm.embeddings.word_embeddings.weight"] = word_embed_weight - state_dict["pretrain_head.predictions.decoder.weight"] = predictions_decoder_weight - state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = ( - predictions_bias + [word_embed_weight, decoder_weight], [decoder_bias] = convert_word_embeddings( + state_dict["utrlm.embeddings.word_embeddings.weight"], + state_dict["pretrain_head.predictions.decoder.weight"], + state_dict["pretrain_head.predictions.bias"], + old_vocab=original_vocab_list, + new_vocab=vocab_list, + std=config.initializer_range, ) + state_dict["utrlm.embeddings.word_embeddings.weight"] = word_embed_weight + state_dict["pretrain_head.predictions.decoder.weight"] = decoder_weight + state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = decoder_bias return state_dict diff --git a/multimolecule/tokenizers/rna/utils.py b/multimolecule/tokenizers/rna/utils.py index e070cc30..287ef22d 100755 --- a/multimolecule/tokenizers/rna/utils.py +++ b/multimolecule/tokenizers/rna/utils.py @@ -1,5 +1,14 @@ +from __future__ import annotations + +from typing import List, Sequence + +import torch +from torch import Tensor + from ..utils import generate_kmer_vocabulary +torch.manual_seed(1013) + def get_vocab_list(nmers: int = 1, strameline: bool = False): vocab_list = STRAMELINE_VOCAB_LIST if strameline else VOCAB_LIST @@ -8,6 +17,52 @@ def get_vocab_list(nmers: int = 1, strameline: bool = False): return vocab_list +def get_vocab_mapping(): + return VOCAB_MAPPING + + +def convert_word_embeddings( + *old_embeddings: Tensor, + old_vocab: List[str], + new_vocab: List[str], + mean: float = 0.0, + std: float = 0.02, + vocab_mapping: dict[str, str] | None = None, +) -> Sequence[Tensor]: + if old_vocab == new_vocab: + return old_embeddings + if vocab_mapping is None: + vocab_mapping = get_vocab_mapping() + + new_embeddings = [] + # Initialize the new embeddings + for embeddings in old_embeddings: + shape = embeddings.shape + if shape[0] != len(old_vocab): + raise ValueError("The first dimension of the embeddings must match the size of the vocabulary.") + if embeddings.ndim == 1: # Bias + new_embeddings.append(torch.zeros(len(new_vocab))) + else: + new_embeddings.append(torch.normal(size=(len(new_vocab), *shape[1:]), mean=mean, std=std)) + + # First Pass, copy the embeddings for the tokens that are in both vocabularies + for old_index, old_token in enumerate(old_vocab): + new_index = new_vocab.index(old_token) + for new_embed, old_embed in zip(new_embeddings, old_embeddings): + new_embed[new_index] = old_embed[old_index] + + # Second Pass, average the embeddings and biases for the tokens that are in the new vocabulary but not in the old + for token, tokens in vocab_mapping.items(): + if token not in new_vocab or token in old_vocab or len(tokens) == 1: + continue + index = new_vocab.index(token) + indexes = [new_vocab.index(t) for t in tokens] + for embed in new_embeddings: + embed[index] = embed[indexes].mean(dim=0) + + return new_embeddings + + def get_special_tokens_map(): return SPECIAL_TOKENS_MAP @@ -63,6 +118,28 @@ def get_tokenizer_config(): "-", ] +VOCAB_MAPPING = { + "A": "A", + "C": "C", + "G": "G", + "U": "U", + "N": "N", + "X": "ACGU", + "V": "ACG", + "H": "ACU", + "D": "AGU", + "B": "CGU", + "M": "AC", + "R": "AG", + "W": "AU", + "S": "CG", + "Y": "CU", + "K": "GU", + ".": ".", + "*": "*", + "-": "-", +} + SPECIAL_TOKENS_MAP = { "pad_token": { "content": "",