Skip to content

Commit

Permalink
improve token initialisation
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 22, 2024
1 parent 5370c57 commit 57bbd11
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 128 deletions.
38 changes: 16 additions & 22 deletions multimolecule/models/rnabert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

import chanfig
import torch
from torch import nn

from multimolecule.models.rnabert import RnaBertConfig as Config
from multimolecule.models.rnabert import RnaBertForPretraining as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list
from multimolecule.tokenizers.rna.utils import (
convert_word_embeddings,
get_special_tokens_map,
get_tokenizer_config,
get_vocab_list,
)

try:
from huggingface_hub import HfApi
Expand All @@ -33,27 +37,17 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_
continue
state_dict[key] = value

state_vocab_size = state_dict["rnabert.embeddings.word_embeddings.weight"].size(0)
original_vocab_size = len(original_vocab_list)
if state_vocab_size != original_vocab_size:
raise ValueError(
f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}."
)
word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
predictions_bias = torch.zeros(config.vocab_size)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_index, original_token in enumerate(original_vocab_list):
new_index = vocab_list.index(original_token)
word_embed_weight[new_index] = state_dict["rnabert.embeddings.word_embeddings.weight"][original_index]
predictions_decoder_weight[new_index] = state_dict["pretrain_head.predictions.decoder.weight"][original_index]
predictions_bias[new_index] = state_dict["pretrain_head.predictions.bias"][original_index]
state_dict["rnabert.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["pretrain_head.predictions.decoder.weight"] = predictions_decoder_weight
state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = (
predictions_bias
word_embed_weight, decoder_weight, decoder_bias = convert_word_embeddings(
state_dict["rnabert.embeddings.word_embeddings.weight"],
state_dict["pretrain_head.predictions.decoder.weight"],
state_dict["pretrain_head.predictions.bias"],
old_vocab=original_vocab_list,
new_vocab=vocab_list,
std=config.initializer_range,
)
state_dict["rnabert.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["pretrain_head.predictions.decoder.weight"] = decoder_weight
state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = decoder_bias
state_dict["pretrain_head.predictions_ss.decoder.bias"] = state_dict["pretrain_head.predictions_ss.bias"]
return state_dict

Expand Down
38 changes: 16 additions & 22 deletions multimolecule/models/rnafm/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

import chanfig
import torch
from torch import nn

from multimolecule.models import RnaFmConfig as Config
from multimolecule.models import RnaFmForPretraining as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list
from multimolecule.tokenizers.rna.utils import (
convert_word_embeddings,
get_special_tokens_map,
get_tokenizer_config,
get_vocab_list,
)

try:
from huggingface_hub import HfApi
Expand Down Expand Up @@ -45,27 +49,17 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_
key = key.replace("rnafm.encoder.contact_head", "pretrain_head.contact")
state_dict[key] = value

state_vocab_size = state_dict["rnafm.embeddings.word_embeddings.weight"].size(0)
original_vocab_size = len(original_vocab_list)
if state_vocab_size != original_vocab_size:
raise ValueError(
f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}."
)
word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
predictions_bias = torch.zeros(config.vocab_size)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_index, original_token in enumerate(original_vocab_list):
new_index = vocab_list.index(original_token)
word_embed_weight[new_index] = state_dict["rnafm.embeddings.word_embeddings.weight"][original_index]
predictions_decoder_weight[new_index] = state_dict["pretrain_head.predictions.decoder.weight"][original_index]
predictions_bias[new_index] = state_dict["pretrain_head.predictions.bias"][original_index]
state_dict["rnafm.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["pretrain_head.predictions.decoder.weight"] = predictions_decoder_weight
state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = (
predictions_bias
word_embed_weight, decoder_weight, decoder_bias = convert_word_embeddings(
state_dict["rnafm.embeddings.word_embeddings.weight"],
state_dict["pretrain_head.predictions.decoder.weight"],
state_dict["pretrain_head.predictions.bias"],
old_vocab=original_vocab_list,
new_vocab=vocab_list,
std=config.initializer_range,
)
state_dict["rnafm.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["pretrain_head.predictions.decoder.weight"] = decoder_weight
state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = decoder_bias
return state_dict


Expand Down
39 changes: 17 additions & 22 deletions multimolecule/models/rnamsm/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

import chanfig
import torch
from torch import nn

from multimolecule.models import RnaMsmConfig as Config
from multimolecule.models import RnaMsmForPretraining as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list
from multimolecule.tokenizers.rna.utils import (
convert_word_embeddings,
get_special_tokens_map,
get_tokenizer_config,
get_vocab_list,
)

try:
from huggingface_hub import HfApi
Expand All @@ -29,31 +33,22 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_
key = key.replace("regression", "decoder")
key = key.replace("contact_head", "pretrain_head.contact")
key = key.replace("lm_head", "pretrain_head.predictions")
key = key.replace("pretrain_head.predictions.weight", "pretrain_head.predictions.decoder.weight")
key = key.replace("pretrain_head.predictions.dense", "pretrain_head.predictions.transform.dense")
key = key.replace("pretrain_head.predictions.layer_norm", "pretrain_head.predictions.transform.layer_norm")
state_dict[key] = value

state_vocab_size = state_dict["rnamsm.embeddings.word_embeddings.weight"].size(0)
original_vocab_size = len(original_vocab_list)
if state_vocab_size != original_vocab_size:
raise ValueError(
f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}."
)
word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_bias = torch.zeros(config.vocab_size)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_index, original_token in enumerate(original_vocab_list):
new_index = vocab_list.index(original_token)
word_embed_weight[new_index] = state_dict["rnamsm.embeddings.word_embeddings.weight"][original_index]
predictions_bias[new_index] = state_dict["pretrain_head.predictions.bias"][original_index]
state_dict["rnamsm.embeddings.word_embeddings.weight"] = state_dict["pretrain_head.predictions.decoder.weight"] = (
word_embed_weight
)
state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = (
predictions_bias
word_embed_weight, decoder_weight, decoder_bias = convert_word_embeddings(
state_dict["rnamsm.embeddings.word_embeddings.weight"],
state_dict["pretrain_head.predictions.decoder.weight"],
state_dict["pretrain_head.predictions.bias"],
old_vocab=original_vocab_list,
new_vocab=vocab_list,
std=config.initializer_range,
)
del state_dict["pretrain_head.predictions.weight"]
state_dict["rnamsm.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["pretrain_head.predictions.decoder.weight"] = decoder_weight
state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = decoder_bias
return state_dict


Expand Down
36 changes: 16 additions & 20 deletions multimolecule/models/splicebert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

import chanfig
import torch
from torch import nn

from multimolecule.models import SpliceBertConfig as Config
from multimolecule.models import SpliceBertForPretraining as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list
from multimolecule.tokenizers.rna.utils import (
convert_word_embeddings,
get_special_tokens_map,
get_tokenizer_config,
get_vocab_list,
)

try:
from huggingface_hub import HfApi
Expand All @@ -32,25 +36,17 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_
continue
state_dict[key] = value

state_vocab_size = state_dict["splicebert.embeddings.word_embeddings.weight"].size(0)
original_vocab_size = len(original_vocab_list)
if state_vocab_size != original_vocab_size:
raise ValueError(
f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}."
)
word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
predictions_bias = torch.zeros(config.vocab_size)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_index, original_token in enumerate(original_vocab_list):
new_index = vocab_list.index(original_token)
word_embed_weight[new_index] = state_dict["splicebert.embeddings.word_embeddings.weight"][original_index]
predictions_decoder_weight[new_index] = state_dict["lm_head.decoder.weight"][original_index]
predictions_bias[new_index] = state_dict["lm_head.decoder.bias"][original_index]
[word_embed_weight, decoder_weight], [decoder_bias] = convert_word_embeddings(
state_dict["splicebert.embeddings.word_embeddings.weight"],
state_dict["lm_head.decoder.weight"],
state_dict["lm_head.decoder.bias"],
old_vocab=original_vocab_list,
new_vocab=vocab_list,
std=config.initializer_range,
)
state_dict["splicebert.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["lm_head.decoder.weight"] = predictions_decoder_weight
state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"] = predictions_bias
state_dict["lm_head.decoder.weight"] = decoder_weight
state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"] = decoder_bias
del state_dict["splicebert.embeddings.position_ids"]
return state_dict

Expand Down
36 changes: 16 additions & 20 deletions multimolecule/models/utrbert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

import chanfig
import torch
from torch import nn

from multimolecule.models import UtrBertConfig as Config
from multimolecule.models import UtrBertForPretraining as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list
from multimolecule.tokenizers.rna.utils import (
convert_word_embeddings,
get_special_tokens_map,
get_tokenizer_config,
get_vocab_list,
)

try:
from huggingface_hub import HfApi
Expand All @@ -32,25 +36,17 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_
continue
state_dict[key] = value

state_vocab_size = state_dict["utrbert.embeddings.word_embeddings.weight"].size(0)
original_vocab_size = len(original_vocab_list)
if state_vocab_size != original_vocab_size:
raise ValueError(
f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}."
)
word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
predictions_bias = torch.zeros(config.vocab_size)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_index, original_token in enumerate(original_vocab_list):
new_index = vocab_list.index(original_token)
word_embed_weight[new_index] = state_dict["utrbert.embeddings.word_embeddings.weight"][original_index]
predictions_decoder_weight[new_index] = state_dict["lm_head.decoder.weight"][original_index]
predictions_bias[new_index] = state_dict["lm_head.decoder.bias"][original_index]
[word_embed_weight, decoder_weight], [decoder_bias] = convert_word_embeddings(
state_dict["utrbert.embeddings.word_embeddings.weight"],
state_dict["lm_head.decoder.weight"],
state_dict["lm_head.decoder.bias"],
old_vocab=original_vocab_list,
new_vocab=vocab_list,
std=config.initializer_range,
)
state_dict["utrbert.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["lm_head.decoder.weight"] = predictions_decoder_weight
state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"] = predictions_bias
state_dict["lm_head.decoder.weight"] = decoder_weight
state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"] = decoder_bias
return state_dict


Expand Down
38 changes: 16 additions & 22 deletions multimolecule/models/utrlm/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

import chanfig
import torch
from torch import nn

from multimolecule.models import UtrLmConfig as Config
from multimolecule.models import UtrLmForPretraining as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list
from multimolecule.tokenizers.rna.utils import (
convert_word_embeddings,
get_special_tokens_map,
get_tokenizer_config,
get_vocab_list,
)

try:
from huggingface_hub import HfApi
Expand Down Expand Up @@ -50,27 +54,17 @@ def _convert_checkpoint(config, original_state_dict, vocab_list, original_vocab_
key = key.replace("utrlm.supervised_linear", "pretrain_head.supervised.decoder")
state_dict[key] = value

state_vocab_size = state_dict["utrlm.embeddings.word_embeddings.weight"].size(0)
original_vocab_size = len(original_vocab_list)
if state_vocab_size != original_vocab_size:
raise ValueError(
f"Vocabulary size do not match. Expected to have {original_vocab_size}, but got {state_vocab_size}."
)
word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
predictions_bias = torch.zeros(config.vocab_size)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_index, original_token in enumerate(original_vocab_list):
new_index = vocab_list.index(original_token)
word_embed_weight[new_index] = state_dict["utrlm.embeddings.word_embeddings.weight"][original_index]
predictions_decoder_weight[new_index] = state_dict["pretrain_head.predictions.decoder.weight"][original_index]
predictions_bias[new_index] = state_dict["pretrain_head.predictions.bias"][original_index]
state_dict["utrlm.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["pretrain_head.predictions.decoder.weight"] = predictions_decoder_weight
state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = (
predictions_bias
[word_embed_weight, decoder_weight], [decoder_bias] = convert_word_embeddings(
state_dict["utrlm.embeddings.word_embeddings.weight"],
state_dict["pretrain_head.predictions.decoder.weight"],
state_dict["pretrain_head.predictions.bias"],
old_vocab=original_vocab_list,
new_vocab=vocab_list,
std=config.initializer_range,
)
state_dict["utrlm.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["pretrain_head.predictions.decoder.weight"] = decoder_weight
state_dict["pretrain_head.predictions.decoder.bias"] = state_dict["pretrain_head.predictions.bias"] = decoder_bias
return state_dict


Expand Down
Loading

0 comments on commit 57bbd11

Please sign in to comment.