diff --git a/README.md b/README.md index b6e020c..62ce321 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,9 @@ You may also be interested in, ## News -* Jan 26, 2022 (new functionality in v1.0.1) Added support for self-supervised pruning via `use_logits` option in `TransformerPruningConfig`. +* [Mar 4, 2022] We are delighted to announce that TextPruner has been accepted to [ACL 2022 demo](https://2022.aclweb.org). The paper will be available when we finish the camera-ready version. + +* [Jan 26, 2022] (new functionality in v1.0.1) Added support for self-supervised pruning via `use_logits` option in `TransformerPruningConfig`. ## Table of Contents diff --git a/README_ZH.md b/README_ZH.md index 65764e9..8442e33 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -33,8 +33,9 @@ ## 新闻 +* [Mar 4, 2022] TextPruner论文被[ACL 2022 demo](https://2022.aclweb.org)录用。论文将在camera-ready稿件完成之后放出。 -* Jan 26, 2022 (1.0.1版本功能更新) 添加了对自监督裁剪的支持。通过`TransformerPruningConfig`中的`use_logits`设置。 +* [Jan 26, 2022] (1.0.1版本功能更新) 添加了对自监督裁剪的支持。通过`TransformerPruningConfig`中的`use_logits`设置。 ## 目录 diff --git a/src/textpruner/model_map.py b/src/textpruner/model_map.py index 764a1d7..244b037 100644 --- a/src/textpruner/model_map.py +++ b/src/textpruner/model_map.py @@ -21,5 +21,21 @@ 'xlm-roberta': {'resizer':model_utils.XLMRobertaVocabResizer, 'tokenizer_helper': tokenizer_utils.XLMRSentencepieceTokenizer, - 'structure': model_utils.XLMRobertaStructure} + 'structure': model_utils.XLMRobertaStructure}, + 'xlm': + {'resizer':model_utils.XLMVocabResizer, + 'tokenizer_helper':tokenizer_utils.XLMTokenizer, + 'structure':model_utils.XLMStructure}, + 'bart': + {'resizer' : model_utils.BartVocabResizer, + 'tokenizer_helper' : tokenizer_utils.RobertaGPT2Tokenizer, + 'structure': model_utils.BartStructure}, + 't5': + {'resizer' : model_utils.T5VocabResizer, + 'tokenizer_helper' : tokenizer_utils.T5SentencepieceTokenizer, + 'structure' : model_utils.T5Structure}, + 'mt5': + {'resizer' : model_utils.MT5VocabResizer, + 'tokenizer_helper' : tokenizer_utils.MT5SentencepieceTokenizer, + 'structure' : model_utils.MT5Structure}, } \ No newline at end of file diff --git a/src/textpruner/model_utils/__init__.py b/src/textpruner/model_utils/__init__.py index aec4193..5b5b6e9 100644 --- a/src/textpruner/model_utils/__init__.py +++ b/src/textpruner/model_utils/__init__.py @@ -2,4 +2,8 @@ from .bert import BertVocabResizer, BertStructure from .electra import ElectraVocabResizer, ElectraStructure from .roberta import RobertaVocabResizer, RobertaStructure -from .xlm_roberta import XLMRobertaVocabResizer, XLMRobertaStructure \ No newline at end of file +from .xlm_roberta import XLMRobertaVocabResizer, XLMRobertaStructure +from .xlm import XLMStructure, XLMVocabResizer +from .bart import BartVocabResizer, BartStructure +from .t5 import T5VocabResizer, T5Structure +from .mt5 import MT5Structure, MT5VocabResizer \ No newline at end of file diff --git a/src/textpruner/model_utils/bart.py b/src/textpruner/model_utils/bart.py new file mode 100644 index 0000000..b525d7e --- /dev/null +++ b/src/textpruner/model_utils/bart.py @@ -0,0 +1,62 @@ +from .utils import DefaultModelVocabResizer +from .model_structure import ModelStructure +import torch +from torch import nn +class BartVocabResizer(DefaultModelVocabResizer): + model_name : str = 'bart' + + @classmethod + def set_embeddings(cls, model, token_ids): + def _prun(old_weight, token_ids): + pruned_word_embeddings_weight = torch.index_select( + old_weight, 0, index=torch.LongTensor(token_ids).to(old_weight.device)) + return pruned_word_embeddings_weight + + old_word_embeddings_shared, old_word_embeddings_encoder, old_word_embeddings_decoder = \ + model.shared, model.encoder.embed_tokens, model.decoder.embed_tokens + + old_word_embeddings_shared_weight, old_word_embeddings_encoder_weight, old_word_embeddings_decoder_weight = \ + old_word_embeddings_shared.weight, old_word_embeddings_encoder.weight, old_word_embeddings_decoder.weight + + pruned_word_embeddings_shared_weight, pruned_word_embeddings_encoder_weight, pruned_word_embeddings_decoder_weight = \ + _prun(old_word_embeddings_shared_weight, token_ids), _prun(old_word_embeddings_encoder_weight, token_ids), _prun(old_word_embeddings_decoder_weight, token_ids) + + pruned_num_tokens, embedding_dim = pruned_word_embeddings_shared_weight.shape + + pruned_word_embeddings_shared = nn.Embedding( + pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) + pruned_word_embeddings_shared.weight.data[:] = pruned_word_embeddings_shared_weight[:] + + pruned_word_embeddings_encoder = nn.Embedding( + pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) + pruned_word_embeddings_encoder.weight.data[:] = pruned_word_embeddings_encoder_weight[:] + + pruned_word_embeddings_decoder = nn.Embedding( + pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) + pruned_word_embeddings_decoder.weight.data[:] = pruned_word_embeddings_decoder_weight[:] + + model.shared = pruned_word_embeddings_shared + model.encoder.embed_tokens = pruned_word_embeddings_encoder + model.decoder.embed_tokens = pruned_word_embeddings_decoder + +class BartStructure(ModelStructure): + MODEL_PREFIX: str = "model." + ENCODER_PREFIX: str = r"encoder.layers.[0-9]+\." + LAYER_PATTERNS = dict( + query="self_attn.q_proj", + key="self_attn.k_proj", + value="self_attn.v_proj", + att_dense="self_attn.out_proj", + interm_dense="fc1", + output_dense="fc2", + ) + ATTENTION_PREFIX = ("self_attn",) + ATTENTION_LAYERS = ("q_proj", "k_proj", "v_proj") + MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) + NAME_CONFIG = dict( + hidden_size="d_model", + intermediate_size="encoder_ffn_dim", + num_hidden_layers="encoder_layers", + num_attention_heads="num_attention_heads", + attention_head_size="", + ) \ No newline at end of file diff --git a/src/textpruner/model_utils/mt5.py b/src/textpruner/model_utils/mt5.py new file mode 100644 index 0000000..e5e354f --- /dev/null +++ b/src/textpruner/model_utils/mt5.py @@ -0,0 +1,72 @@ +from .utils import DefaultModelVocabResizer +from .model_structure import ModelStructure +import torch +from torch import nn +class MT5VocabResizer(DefaultModelVocabResizer): + model_name : str = 'mt5' + + @classmethod + def set_embeddings(cls, model, token_ids): + def _prun(old_weight, token_ids): + pruned_word_embeddings_weight = torch.index_select( + old_weight, 0, index=torch.LongTensor(token_ids).to(old_weight.device)) + return pruned_word_embeddings_weight + + + vocab_size = model.shared.weight.shape[0] + max_token_ids = token_ids[-1] + tokens_in_embed_notin_tokenizer_ids = list(range(max_token_ids+1, vocab_size)) + token_ids_temp = token_ids[:] + token_ids_temp.extend(tokens_in_embed_notin_tokenizer_ids) + + + model.config.vocab_size = len(token_ids_temp) + + old_word_embeddings_shared, old_word_embeddings_encoder, old_word_embeddings_decoder = \ + model.shared, model.encoder.embed_tokens, model.decoder.embed_tokens + + old_word_embeddings_shared_weight, old_word_embeddings_encoder_weight, old_word_embeddings_decoder_weight = \ + old_word_embeddings_shared.weight, old_word_embeddings_encoder.weight, old_word_embeddings_decoder.weight + + pruned_word_embeddings_shared_weight, pruned_word_embeddings_encoder_weight, pruned_word_embeddings_decoder_weight = \ + _prun(old_word_embeddings_shared_weight, token_ids_temp), _prun(old_word_embeddings_encoder_weight, token_ids_temp), _prun(old_word_embeddings_decoder_weight, token_ids_temp) + + pruned_num_tokens, embedding_dim = pruned_word_embeddings_shared_weight.shape + + pruned_word_embeddings_shared = nn.Embedding( + pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) + pruned_word_embeddings_shared.weight.data[:] = pruned_word_embeddings_shared_weight[:] + + pruned_word_embeddings_encoder = nn.Embedding( + pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) + pruned_word_embeddings_encoder.weight.data[:] = pruned_word_embeddings_encoder_weight[:] + + pruned_word_embeddings_decoder = nn.Embedding( + pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) + pruned_word_embeddings_decoder.weight.data[:] = pruned_word_embeddings_decoder_weight[:] + + model.shared = pruned_word_embeddings_shared + model.encoder.embed_tokens = pruned_word_embeddings_encoder + model.decoder.embed_tokens = pruned_word_embeddings_decoder + +class MT5Structure(ModelStructure): + MODEL_PREFIX: str = "transformer." + ENCODER_PREFIX: str = r"encoder.block.[0-9]+\.layer." + LAYER_PATTERNS = dict( + query="0.SelfAttention.q", + key="0.SelfAttention.k", + value="0.SelfAttention.v", + att_dense="0.SelfAttention.o", + interm_dense="1.DenseReluDense.wi", + output_dense="1.DenseReluDense.wo", + ) + ATTENTION_PREFIX = ("0.SelfAttention",) + ATTENTION_LAYERS = ("q", "k", "v") + MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) + NAME_CONFIG = dict( + hidden_size="d_model", + intermediate_size="d_ff", + num_hidden_layers="num_layers", + num_attention_heads="num_heads", + attention_head_size="", + ) \ No newline at end of file diff --git a/src/textpruner/model_utils/t5.py b/src/textpruner/model_utils/t5.py new file mode 100644 index 0000000..7d91d9c --- /dev/null +++ b/src/textpruner/model_utils/t5.py @@ -0,0 +1,70 @@ +from .utils import DefaultModelVocabResizer +from .model_structure import ModelStructure +import torch +from torch import nn +class T5VocabResizer(DefaultModelVocabResizer): + model_name : str = 't5' + + @classmethod + def set_embeddings(cls, model, token_ids): + def _prun(old_weight, token_ids): + pruned_word_embeddings_weight = torch.index_select( + old_weight, 0, index=torch.LongTensor(token_ids).to(old_weight.device)) + return pruned_word_embeddings_weight + + vocab_size = model.shared.weight.shape[0] + max_token_ids = token_ids[-1] + tokens_in_embed_notin_tokenizer_ids = list(range(max_token_ids+1, vocab_size)) + token_ids_temp = token_ids[:] + token_ids_temp.extend(tokens_in_embed_notin_tokenizer_ids) + + model.config.vocab_size = len(token_ids_temp) + + old_word_embeddings_shared, old_word_embeddings_encoder, old_word_embeddings_decoder = \ + model.shared, model.encoder.embed_tokens, model.decoder.embed_tokens + + old_word_embeddings_shared_weight, old_word_embeddings_encoder_weight, old_word_embeddings_decoder_weight = \ + old_word_embeddings_shared.weight, old_word_embeddings_encoder.weight, old_word_embeddings_decoder.weight + + pruned_word_embeddings_shared_weight, pruned_word_embeddings_encoder_weight, pruned_word_embeddings_decoder_weight = \ + _prun(old_word_embeddings_shared_weight, token_ids_temp), _prun(old_word_embeddings_encoder_weight, token_ids_temp), _prun(old_word_embeddings_decoder_weight, token_ids_temp) + + pruned_num_tokens, embedding_dim = pruned_word_embeddings_shared_weight.shape + + pruned_word_embeddings_shared = nn.Embedding( + pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) + pruned_word_embeddings_shared.weight.data[:] = pruned_word_embeddings_shared_weight[:] + + pruned_word_embeddings_encoder = nn.Embedding( + pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) + pruned_word_embeddings_encoder.weight.data[:] = pruned_word_embeddings_encoder_weight[:] + + pruned_word_embeddings_decoder = nn.Embedding( + pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device) + pruned_word_embeddings_decoder.weight.data[:] = pruned_word_embeddings_decoder_weight[:] + + model.shared = pruned_word_embeddings_shared + model.encoder.embed_tokens = pruned_word_embeddings_encoder + model.decoder.embed_tokens = pruned_word_embeddings_decoder + +class T5Structure(ModelStructure): + MODEL_PREFIX: str = "transformer." + ENCODER_PREFIX: str = r"encoder.block.[0-9]+\.layer." + LAYER_PATTERNS = dict( + query="0.SelfAttention.q", + key="0.SelfAttention.k", + value="0.SelfAttention.v", + att_dense="0.SelfAttention.o", + interm_dense="1.DenseReluDense.wi", + output_dense="1.DenseReluDense.wo", + ) + ATTENTION_PREFIX = ("0.SelfAttention",) + ATTENTION_LAYERS = ("q", "k", "v") + MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) + NAME_CONFIG = dict( + hidden_size="d_model", + intermediate_size="d_ff", + num_hidden_layers="num_layers", + num_attention_heads="num_heads", + attention_head_size="", + ) \ No newline at end of file diff --git a/src/textpruner/model_utils/xlm.py b/src/textpruner/model_utils/xlm.py new file mode 100644 index 0000000..39d3f54 --- /dev/null +++ b/src/textpruner/model_utils/xlm.py @@ -0,0 +1,59 @@ +from .utils import DefaultModelVocabResizer +from .model_structure import ModelStructure +import torch +from torch import nn +class XLMVocabResizer(DefaultModelVocabResizer): + model_name : str = 'xlm' + + @classmethod + def set_embeddings(cls, model, token_ids): + # self.model.get_input_embeddings() + + if hasattr(model.embeddings, 'word_embeddings'): #XLM + old_word_embeddings = model.embeddings.word_embeddings + else: + old_word_embeddings = model.embeddings + + + + # old_word_embeddings = model.embeddings.word_embeddings + old_word_embeddings_weight = old_word_embeddings.weight + + pruned_word_embeddings_weight = torch.index_select( + old_word_embeddings_weight, 0, index=torch.LongTensor(token_ids).to(old_word_embeddings_weight.device)) + pruned_num_tokens, embedding_dim = pruned_word_embeddings_weight.shape + + pruned_word_embeddings = nn.Embedding( + pruned_num_tokens, embedding_dim).to(old_word_embeddings_weight.device) + pruned_word_embeddings.weight.data[:] = pruned_word_embeddings_weight[:] + + + if hasattr(model.embeddings, 'word_embeddings'): + model.embeddings.word_embeddings = pruned_word_embeddings + else: + model.embeddings = pruned_word_embeddings + + + + +class XLMStructure(ModelStructure): + MODEL_PREFIX: str = "transformer." + ENCODER_PREFIX: str = r"attention.[0-9]+\." + LAYER_PATTERNS = dict( + query=r"attentions\.[0-9]+\.q_lin", + key=r"attentions\.[0-9]+\.k_lin", + value=r"attentions\.[0-9]+\.v_lin", + att_dense=r"attentions\.[0-9]+\.out_lin", + interm_dense=r"ffns\.[0-9]+\.lin1", + output_dense=r"ffns\.[0-9]+\.lin2", + ) + ATTENTION_PREFIX = (r"attentions\.[0-9]",) + ATTENTION_LAYERS = ("q_lin", "k_lin", "v_lin") + MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",) + NAME_CONFIG = dict( + hidden_size="emb_dim", + intermediate_size="emb_dim", + num_hidden_layers="n_layers", + num_attention_heads="n_heads", + attention_head_size="attention_head_size", + ) \ No newline at end of file diff --git a/src/textpruner/pruners/vocabulary_pruner.py b/src/textpruner/pruners/vocabulary_pruner.py index 6d19d9c..4aa72a0 100644 --- a/src/textpruner/pruners/vocabulary_pruner.py +++ b/src/textpruner/pruners/vocabulary_pruner.py @@ -89,8 +89,11 @@ def prune(self, dataiter=None, additional_tokens=None, def save_model(self, dir_name = None) -> str: - - vocab_size = len(self.pruned_token_ids) + + if self.model_type.lower() in ['t5', 'mt5']: + vocab_size = self.base_model.shared.weight.shape[0] + else: + vocab_size = len(self.pruned_token_ids) self.base_model.config.vocab_size = vocab_size if dir_name is None: diff --git a/src/textpruner/tokenizer_utils/__init__.py b/src/textpruner/tokenizer_utils/__init__.py index 7129bc7..ea4e224 100644 --- a/src/textpruner/tokenizer_utils/__init__.py +++ b/src/textpruner/tokenizer_utils/__init__.py @@ -2,3 +2,6 @@ from .subword_tokenizer import SubwordTokenizer from .sp_tokenizer import SentencepieceTokenizer from .xlmr_sp_tokenizer import XLMRSentencepieceTokenizer +from .xlm_tokenizer import XLMTokenizer +from .t5_sp_tokenizer import T5SentencepieceTokenizer +from .mt5_sp_tokenizer import MT5SentencepieceTokenizer \ No newline at end of file diff --git a/src/textpruner/tokenizer_utils/mt5_sp_tokenizer.py b/src/textpruner/tokenizer_utils/mt5_sp_tokenizer.py new file mode 100644 index 0000000..4c8ec6c --- /dev/null +++ b/src/textpruner/tokenizer_utils/mt5_sp_tokenizer.py @@ -0,0 +1,67 @@ + +import os +import re +from .utils import count_unique_tokens +import logging +logger = logging.getLogger(__name__) +try: + from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model +except ImportError: + logger.warning("Could not import sentencepiece. Pruning embeddings of sentencepiece-based model is not available.") + + +class MT5SentencepieceTokenizer: + additional_special_token_ids = [] + + @classmethod + def find_addition_special_token_ids(cls, tokenizer): + add_spe_bound = ['▁', '▁'] + lower, upper = tokenizer.convert_tokens_to_ids(add_spe_bound) + add_spe_tokens_ids_not_in_tokenizer = list(range(lower, upper+1)) + cls.additional_special_token_ids.extend(add_spe_tokens_ids_not_in_tokenizer) + cls.additional_special_token_ids = sorted(list(set(cls.additional_special_token_ids))) + + @classmethod + def get_token_ids(cls, tokenizer, dataiter=None, additional_tokens=None, additional_token_ids=None, min_count=1): + base_token_ids = list(range(3, 3+256)) + token_ids = [] + special_token_ids = list(tokenizer.all_special_ids) + cls.additional_special_token_ids = tokenizer.additional_special_tokens_ids + if len(cls.additional_special_token_ids) == 0: + cls.find_addition_special_token_ids(tokenizer) + special_token_ids.extend(cls.additional_special_token_ids) + special_token_ids = sorted(list(set(special_token_ids))) + + normal_token_ids = [] + if dataiter is not None: + token_ids_counter = count_unique_tokens(dataiter, tokenizer) + normal_token_ids += [k for k,v in token_ids_counter.items() if v >= min_count] + if additional_tokens is not None and len(additional_tokens) > 0: + normal_token_ids += list( + tokenizer.convert_tokens_to_ids(additional_tokens)) + if additional_token_ids is not None and len(additional_token_ids) > 0: + normal_token_ids += list(additional_token_ids) + normal_token_ids = list(set(normal_token_ids)-set(special_token_ids)) + token_ids = sorted(list(set(special_token_ids + normal_token_ids + base_token_ids))) + + return token_ids + + @classmethod + def save_vocab(cls, tokenizer, token_ids, outdir): + + spm_token_ids = token_ids + + spm_token_ids = sorted(spm_token_ids) + + m = sp_pb2_model.ModelProto() + m.ParseFromString(tokenizer.sp_model.serialized_model_proto()) + spm_tokens = set([m.pieces[i].piece for i in spm_token_ids]) + new_pieces = [p for p in m.pieces if p.piece in spm_tokens] + + del m.pieces[:] + m.pieces.extend(new_pieces) + + pruned_vocab_file = os.path.join(outdir, 'spiece.model') + with open(pruned_vocab_file, 'wb') as f: + f.write(m.SerializeToString()) + print(f"New embedding pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") \ No newline at end of file diff --git a/src/textpruner/tokenizer_utils/sp_tokenizer.py b/src/textpruner/tokenizer_utils/sp_tokenizer.py index 0f73031..4d17467 100644 --- a/src/textpruner/tokenizer_utils/sp_tokenizer.py +++ b/src/textpruner/tokenizer_utils/sp_tokenizer.py @@ -61,4 +61,4 @@ def save_vocab(tokenizer, token_ids, outdir): pruned_vocab_file = os.path.join(outdir, 'spiece.model') with open(pruned_vocab_file, 'wb') as f: f.write(m.SerializeToString()) - print(f"New embedding size {len(new_pieces)+2} pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") + print(f"New embedding size {len(new_pieces)+2} pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") \ No newline at end of file diff --git a/src/textpruner/tokenizer_utils/t5_sp_tokenizer.py b/src/textpruner/tokenizer_utils/t5_sp_tokenizer.py new file mode 100644 index 0000000..cc841ad --- /dev/null +++ b/src/textpruner/tokenizer_utils/t5_sp_tokenizer.py @@ -0,0 +1,58 @@ + +import os +import re +from .utils import count_unique_tokens +import logging +logger = logging.getLogger(__name__) +try: + from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model +except ImportError: + logger.warning("Could not import sentencepiece. Pruning embeddings of sentencepiece-based model is not available.") + + +class T5SentencepieceTokenizer: + additional_special_token_ids = [] + + + + @classmethod + def get_token_ids(cls, tokenizer, dataiter=None, additional_tokens=None, additional_token_ids=None, min_count=1): + token_ids = [] + #special_token_ids = list(set(tokenizer.all_special_ids) - set(tokenizer.additional_special_tokens_ids)) + special_token_ids = list(tokenizer.all_special_ids) + cls.additional_special_token_ids = tokenizer.additional_special_tokens_ids + + + normal_token_ids = [] + if dataiter is not None: + token_ids_counter = count_unique_tokens(dataiter, tokenizer) + normal_token_ids += [k for k,v in token_ids_counter.items() if v >= min_count] + if additional_tokens is not None and len(additional_tokens) > 0: + normal_token_ids += list( + tokenizer.convert_tokens_to_ids(additional_tokens)) + if additional_token_ids is not None and len(additional_token_ids) > 0: + normal_token_ids += list(additional_token_ids) + normal_token_ids = list(set(normal_token_ids)-set(special_token_ids)) + token_ids = sorted(special_token_ids + normal_token_ids) + + return token_ids + + @classmethod + def save_vocab(cls, tokenizer, token_ids, outdir): + + + spm_token_ids = list(set(token_ids) - set(cls.additional_special_token_ids)) + m = sp_pb2_model.ModelProto() + m.ParseFromString(tokenizer.sp_model.serialized_model_proto()) + + spm_tokens = set([m.pieces[i].piece for i in spm_token_ids]) + new_pieces = [p for p in m.pieces if p.piece in spm_tokens] + + # delete all + del m.pieces[:] + m.pieces.extend(new_pieces) + + pruned_vocab_file = os.path.join(outdir, 'spiece.model') + with open(pruned_vocab_file, 'wb') as f: + f.write(m.SerializeToString()) + print(f"New embedding pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") \ No newline at end of file diff --git a/src/textpruner/tokenizer_utils/xlm_tokenizer.py b/src/textpruner/tokenizer_utils/xlm_tokenizer.py new file mode 100644 index 0000000..3e70952 --- /dev/null +++ b/src/textpruner/tokenizer_utils/xlm_tokenizer.py @@ -0,0 +1,58 @@ +import os +from .utils import count_unique_tokens +import logging +import json +logger = logging.getLogger(__name__) + +class XLMTokenizer: + @staticmethod + def get_token_ids(tokenizer, dataiter=None, additional_tokens=None, additional_token_ids=None, min_count=1): + token_ids = [] + # add special tokens + special_tokens = ['', '', '', '', '', '', '', '', '', '', '', '', '', ''] + special_token_ids = list(range(0, 14)) + normal_token_ids = [] + + if dataiter is not None: + token_ids_counter = count_unique_tokens(dataiter, tokenizer) + normal_token_ids += [k for k,v in token_ids_counter.items() if v >= min_count] + if additional_tokens is not None and len(additional_tokens) > 0: + normal_token_ids += list( + tokenizer.convert_tokens_to_ids(additional_tokens)) + if additional_token_ids is not None and len(additional_token_ids) > 0: + normal_token_ids += list(additional_token_ids) + + normal_token_ids = list(set(normal_token_ids)-set(special_token_ids)) + token_ids = sorted(special_token_ids + normal_token_ids) + return token_ids + + @staticmethod + def save_vocab(tokenizer, token_ids, outdir): + assert len(token_ids) == len(set(token_ids)) + + tokens = tokenizer.convert_ids_to_tokens(token_ids) + token_dict = {} + for i in range(len(tokens)): + token_dict[tokens[i]] = i + + + tokenizer.save_pretrained(outdir) + pruned_vocab_file = os.path.join(outdir, 'vocab.json') + with open(pruned_vocab_file, 'w', encoding='utf-8') as f: + json.dump(token_dict, f) + + print(f"New embedding size {len(token_ids)} pruned vocab file has been saved to {pruned_vocab_file}. Reintialize the tokenizer!") + + + bpe_ranks = sorted(tokenizer.bpe_ranks.items(), key = lambda k: k[1]) + + + pruned_merges_file = os.path.join(outdir, 'merges.txt') + with open(pruned_merges_file, "w", encoding="utf-8") as writer: + for bpe_tokens, _ in bpe_ranks: + if len(bpe_tokens) != 2: + continue + writer.write(bpe_tokens[0] + " " + bpe_tokens[1] + "\n") + + +