Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/airaria/TextPruner
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed Mar 21, 2022
2 parents 786a201 + 0952aaa commit 2c38776
Show file tree
Hide file tree
Showing 14 changed files with 482 additions and 7 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ You may also be interested in,

## News

* Jan 26, 2022 (new functionality in v1.0.1) Added support for self-supervised pruning via `use_logits` option in `TransformerPruningConfig`.
* [Mar 4, 2022] We are delighted to announce that TextPruner has been accepted to [ACL 2022 demo](https://2022.aclweb.org). The paper will be available when we finish the camera-ready version.

* [Jan 26, 2022] (new functionality in v1.0.1) Added support for self-supervised pruning via `use_logits` option in `TransformerPruningConfig`.

## Table of Contents

Expand Down
3 changes: 2 additions & 1 deletion README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@

## 新闻

* [Mar 4, 2022] TextPruner论文被[ACL 2022 demo](https://2022.aclweb.org)录用。论文将在camera-ready稿件完成之后放出。

* Jan 26, 2022 (1.0.1版本功能更新) 添加了对自监督裁剪的支持。通过`TransformerPruningConfig`中的`use_logits`设置。
* [Jan 26, 2022] (1.0.1版本功能更新) 添加了对自监督裁剪的支持。通过`TransformerPruningConfig`中的`use_logits`设置。

## 目录

Expand Down
18 changes: 17 additions & 1 deletion src/textpruner/model_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,21 @@
'xlm-roberta':
{'resizer':model_utils.XLMRobertaVocabResizer,
'tokenizer_helper': tokenizer_utils.XLMRSentencepieceTokenizer,
'structure': model_utils.XLMRobertaStructure}
'structure': model_utils.XLMRobertaStructure},
'xlm':
{'resizer':model_utils.XLMVocabResizer,
'tokenizer_helper':tokenizer_utils.XLMTokenizer,
'structure':model_utils.XLMStructure},
'bart':
{'resizer' : model_utils.BartVocabResizer,
'tokenizer_helper' : tokenizer_utils.RobertaGPT2Tokenizer,
'structure': model_utils.BartStructure},
't5':
{'resizer' : model_utils.T5VocabResizer,
'tokenizer_helper' : tokenizer_utils.T5SentencepieceTokenizer,
'structure' : model_utils.T5Structure},
'mt5':
{'resizer' : model_utils.MT5VocabResizer,
'tokenizer_helper' : tokenizer_utils.MT5SentencepieceTokenizer,
'structure' : model_utils.MT5Structure},
}
6 changes: 5 additions & 1 deletion src/textpruner/model_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@
from .bert import BertVocabResizer, BertStructure
from .electra import ElectraVocabResizer, ElectraStructure
from .roberta import RobertaVocabResizer, RobertaStructure
from .xlm_roberta import XLMRobertaVocabResizer, XLMRobertaStructure
from .xlm_roberta import XLMRobertaVocabResizer, XLMRobertaStructure
from .xlm import XLMStructure, XLMVocabResizer
from .bart import BartVocabResizer, BartStructure
from .t5 import T5VocabResizer, T5Structure
from .mt5 import MT5Structure, MT5VocabResizer
62 changes: 62 additions & 0 deletions src/textpruner/model_utils/bart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from .utils import DefaultModelVocabResizer
from .model_structure import ModelStructure
import torch
from torch import nn
class BartVocabResizer(DefaultModelVocabResizer):
model_name : str = 'bart'

@classmethod
def set_embeddings(cls, model, token_ids):
def _prun(old_weight, token_ids):
pruned_word_embeddings_weight = torch.index_select(
old_weight, 0, index=torch.LongTensor(token_ids).to(old_weight.device))
return pruned_word_embeddings_weight

old_word_embeddings_shared, old_word_embeddings_encoder, old_word_embeddings_decoder = \
model.shared, model.encoder.embed_tokens, model.decoder.embed_tokens

old_word_embeddings_shared_weight, old_word_embeddings_encoder_weight, old_word_embeddings_decoder_weight = \
old_word_embeddings_shared.weight, old_word_embeddings_encoder.weight, old_word_embeddings_decoder.weight

pruned_word_embeddings_shared_weight, pruned_word_embeddings_encoder_weight, pruned_word_embeddings_decoder_weight = \
_prun(old_word_embeddings_shared_weight, token_ids), _prun(old_word_embeddings_encoder_weight, token_ids), _prun(old_word_embeddings_decoder_weight, token_ids)

pruned_num_tokens, embedding_dim = pruned_word_embeddings_shared_weight.shape

pruned_word_embeddings_shared = nn.Embedding(
pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device)
pruned_word_embeddings_shared.weight.data[:] = pruned_word_embeddings_shared_weight[:]

pruned_word_embeddings_encoder = nn.Embedding(
pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device)
pruned_word_embeddings_encoder.weight.data[:] = pruned_word_embeddings_encoder_weight[:]

pruned_word_embeddings_decoder = nn.Embedding(
pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device)
pruned_word_embeddings_decoder.weight.data[:] = pruned_word_embeddings_decoder_weight[:]

model.shared = pruned_word_embeddings_shared
model.encoder.embed_tokens = pruned_word_embeddings_encoder
model.decoder.embed_tokens = pruned_word_embeddings_decoder

class BartStructure(ModelStructure):
MODEL_PREFIX: str = "model."
ENCODER_PREFIX: str = r"encoder.layers.[0-9]+\."
LAYER_PATTERNS = dict(
query="self_attn.q_proj",
key="self_attn.k_proj",
value="self_attn.v_proj",
att_dense="self_attn.out_proj",
interm_dense="fc1",
output_dense="fc2",
)
ATTENTION_PREFIX = ("self_attn",)
ATTENTION_LAYERS = ("q_proj", "k_proj", "v_proj")
MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",)
NAME_CONFIG = dict(
hidden_size="d_model",
intermediate_size="encoder_ffn_dim",
num_hidden_layers="encoder_layers",
num_attention_heads="num_attention_heads",
attention_head_size="",
)
72 changes: 72 additions & 0 deletions src/textpruner/model_utils/mt5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from .utils import DefaultModelVocabResizer
from .model_structure import ModelStructure
import torch
from torch import nn
class MT5VocabResizer(DefaultModelVocabResizer):
model_name : str = 'mt5'

@classmethod
def set_embeddings(cls, model, token_ids):
def _prun(old_weight, token_ids):
pruned_word_embeddings_weight = torch.index_select(
old_weight, 0, index=torch.LongTensor(token_ids).to(old_weight.device))
return pruned_word_embeddings_weight


vocab_size = model.shared.weight.shape[0]
max_token_ids = token_ids[-1]
tokens_in_embed_notin_tokenizer_ids = list(range(max_token_ids+1, vocab_size))
token_ids_temp = token_ids[:]
token_ids_temp.extend(tokens_in_embed_notin_tokenizer_ids)


model.config.vocab_size = len(token_ids_temp)

old_word_embeddings_shared, old_word_embeddings_encoder, old_word_embeddings_decoder = \
model.shared, model.encoder.embed_tokens, model.decoder.embed_tokens

old_word_embeddings_shared_weight, old_word_embeddings_encoder_weight, old_word_embeddings_decoder_weight = \
old_word_embeddings_shared.weight, old_word_embeddings_encoder.weight, old_word_embeddings_decoder.weight

pruned_word_embeddings_shared_weight, pruned_word_embeddings_encoder_weight, pruned_word_embeddings_decoder_weight = \
_prun(old_word_embeddings_shared_weight, token_ids_temp), _prun(old_word_embeddings_encoder_weight, token_ids_temp), _prun(old_word_embeddings_decoder_weight, token_ids_temp)

pruned_num_tokens, embedding_dim = pruned_word_embeddings_shared_weight.shape

pruned_word_embeddings_shared = nn.Embedding(
pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device)
pruned_word_embeddings_shared.weight.data[:] = pruned_word_embeddings_shared_weight[:]

pruned_word_embeddings_encoder = nn.Embedding(
pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device)
pruned_word_embeddings_encoder.weight.data[:] = pruned_word_embeddings_encoder_weight[:]

pruned_word_embeddings_decoder = nn.Embedding(
pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device)
pruned_word_embeddings_decoder.weight.data[:] = pruned_word_embeddings_decoder_weight[:]

model.shared = pruned_word_embeddings_shared
model.encoder.embed_tokens = pruned_word_embeddings_encoder
model.decoder.embed_tokens = pruned_word_embeddings_decoder

class MT5Structure(ModelStructure):
MODEL_PREFIX: str = "transformer."
ENCODER_PREFIX: str = r"encoder.block.[0-9]+\.layer."
LAYER_PATTERNS = dict(
query="0.SelfAttention.q",
key="0.SelfAttention.k",
value="0.SelfAttention.v",
att_dense="0.SelfAttention.o",
interm_dense="1.DenseReluDense.wi",
output_dense="1.DenseReluDense.wo",
)
ATTENTION_PREFIX = ("0.SelfAttention",)
ATTENTION_LAYERS = ("q", "k", "v")
MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",)
NAME_CONFIG = dict(
hidden_size="d_model",
intermediate_size="d_ff",
num_hidden_layers="num_layers",
num_attention_heads="num_heads",
attention_head_size="",
)
70 changes: 70 additions & 0 deletions src/textpruner/model_utils/t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from .utils import DefaultModelVocabResizer
from .model_structure import ModelStructure
import torch
from torch import nn
class T5VocabResizer(DefaultModelVocabResizer):
model_name : str = 't5'

@classmethod
def set_embeddings(cls, model, token_ids):
def _prun(old_weight, token_ids):
pruned_word_embeddings_weight = torch.index_select(
old_weight, 0, index=torch.LongTensor(token_ids).to(old_weight.device))
return pruned_word_embeddings_weight

vocab_size = model.shared.weight.shape[0]
max_token_ids = token_ids[-1]
tokens_in_embed_notin_tokenizer_ids = list(range(max_token_ids+1, vocab_size))
token_ids_temp = token_ids[:]
token_ids_temp.extend(tokens_in_embed_notin_tokenizer_ids)

model.config.vocab_size = len(token_ids_temp)

old_word_embeddings_shared, old_word_embeddings_encoder, old_word_embeddings_decoder = \
model.shared, model.encoder.embed_tokens, model.decoder.embed_tokens

old_word_embeddings_shared_weight, old_word_embeddings_encoder_weight, old_word_embeddings_decoder_weight = \
old_word_embeddings_shared.weight, old_word_embeddings_encoder.weight, old_word_embeddings_decoder.weight

pruned_word_embeddings_shared_weight, pruned_word_embeddings_encoder_weight, pruned_word_embeddings_decoder_weight = \
_prun(old_word_embeddings_shared_weight, token_ids_temp), _prun(old_word_embeddings_encoder_weight, token_ids_temp), _prun(old_word_embeddings_decoder_weight, token_ids_temp)

pruned_num_tokens, embedding_dim = pruned_word_embeddings_shared_weight.shape

pruned_word_embeddings_shared = nn.Embedding(
pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device)
pruned_word_embeddings_shared.weight.data[:] = pruned_word_embeddings_shared_weight[:]

pruned_word_embeddings_encoder = nn.Embedding(
pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device)
pruned_word_embeddings_encoder.weight.data[:] = pruned_word_embeddings_encoder_weight[:]

pruned_word_embeddings_decoder = nn.Embedding(
pruned_num_tokens, embedding_dim).to(old_word_embeddings_shared_weight.device)
pruned_word_embeddings_decoder.weight.data[:] = pruned_word_embeddings_decoder_weight[:]

model.shared = pruned_word_embeddings_shared
model.encoder.embed_tokens = pruned_word_embeddings_encoder
model.decoder.embed_tokens = pruned_word_embeddings_decoder

class T5Structure(ModelStructure):
MODEL_PREFIX: str = "transformer."
ENCODER_PREFIX: str = r"encoder.block.[0-9]+\.layer."
LAYER_PATTERNS = dict(
query="0.SelfAttention.q",
key="0.SelfAttention.k",
value="0.SelfAttention.v",
att_dense="0.SelfAttention.o",
interm_dense="1.DenseReluDense.wi",
output_dense="1.DenseReluDense.wo",
)
ATTENTION_PREFIX = ("0.SelfAttention",)
ATTENTION_LAYERS = ("q", "k", "v")
MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",)
NAME_CONFIG = dict(
hidden_size="d_model",
intermediate_size="d_ff",
num_hidden_layers="num_layers",
num_attention_heads="num_heads",
attention_head_size="",
)
59 changes: 59 additions & 0 deletions src/textpruner/model_utils/xlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from .utils import DefaultModelVocabResizer
from .model_structure import ModelStructure
import torch
from torch import nn
class XLMVocabResizer(DefaultModelVocabResizer):
model_name : str = 'xlm'

@classmethod
def set_embeddings(cls, model, token_ids):
# self.model.get_input_embeddings()

if hasattr(model.embeddings, 'word_embeddings'): #XLM
old_word_embeddings = model.embeddings.word_embeddings
else:
old_word_embeddings = model.embeddings



# old_word_embeddings = model.embeddings.word_embeddings
old_word_embeddings_weight = old_word_embeddings.weight

pruned_word_embeddings_weight = torch.index_select(
old_word_embeddings_weight, 0, index=torch.LongTensor(token_ids).to(old_word_embeddings_weight.device))
pruned_num_tokens, embedding_dim = pruned_word_embeddings_weight.shape

pruned_word_embeddings = nn.Embedding(
pruned_num_tokens, embedding_dim).to(old_word_embeddings_weight.device)
pruned_word_embeddings.weight.data[:] = pruned_word_embeddings_weight[:]


if hasattr(model.embeddings, 'word_embeddings'):
model.embeddings.word_embeddings = pruned_word_embeddings
else:
model.embeddings = pruned_word_embeddings




class XLMStructure(ModelStructure):
MODEL_PREFIX: str = "transformer."
ENCODER_PREFIX: str = r"attention.[0-9]+\."
LAYER_PATTERNS = dict(
query=r"attentions\.[0-9]+\.q_lin",
key=r"attentions\.[0-9]+\.k_lin",
value=r"attentions\.[0-9]+\.v_lin",
att_dense=r"attentions\.[0-9]+\.out_lin",
interm_dense=r"ffns\.[0-9]+\.lin1",
output_dense=r"ffns\.[0-9]+\.lin2",
)
ATTENTION_PREFIX = (r"attentions\.[0-9]",)
ATTENTION_LAYERS = ("q_lin", "k_lin", "v_lin")
MHA_LAYERS = ATTENTION_LAYERS + ("att_dense",)
NAME_CONFIG = dict(
hidden_size="emb_dim",
intermediate_size="emb_dim",
num_hidden_layers="n_layers",
num_attention_heads="n_heads",
attention_head_size="attention_head_size",
)
7 changes: 5 additions & 2 deletions src/textpruner/pruners/vocabulary_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,11 @@ def prune(self, dataiter=None, additional_tokens=None,


def save_model(self, dir_name = None) -> str:

vocab_size = len(self.pruned_token_ids)

if self.model_type.lower() in ['t5', 'mt5']:
vocab_size = self.base_model.shared.weight.shape[0]
else:
vocab_size = len(self.pruned_token_ids)
self.base_model.config.vocab_size = vocab_size

if dir_name is None:
Expand Down
3 changes: 3 additions & 0 deletions src/textpruner/tokenizer_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
from .subword_tokenizer import SubwordTokenizer
from .sp_tokenizer import SentencepieceTokenizer
from .xlmr_sp_tokenizer import XLMRSentencepieceTokenizer
from .xlm_tokenizer import XLMTokenizer
from .t5_sp_tokenizer import T5SentencepieceTokenizer
from .mt5_sp_tokenizer import MT5SentencepieceTokenizer
Loading

0 comments on commit 2c38776

Please sign in to comment.