-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
250 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Download [XNLI](https://github.com/facebookresearch/XNLI) and put `multinli.train.en.tsv`,`multinli.train.zh.tsv`,`xnli.dev.tsv`,`xnli.test.tsv` here. | ||
|
||
Concatenate the train files: | ||
```bash | ||
cat multinli.train.en.tsv multinli.train.zh.tsv > multinli.train.en_zh.tsv | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Put `pytorch_model.bin`, `config.json` and `sentencepiece.bpe.model` here. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Pruning the Classification model | ||
|
||
These scripts perform transformer pruning **in a self-supervised way** on the classification model (`XLMRobertaForSequenceClassification`) and evaluate the performance. | ||
|
||
Download the fine-tuned model or train your own model on XNLI dataset, and save the files to `../models/xlmr_xnli`. | ||
|
||
Download link: | ||
* [Hugging Face Models](https://huggingface.co/ziqingyang/XLMRobertaBaseForXNLI-en/tree/main) | ||
|
||
See the README in ../datasets/xnli for how to construct the dataset. | ||
|
||
* Pruning with the python script: | ||
```bash | ||
MODEL_PATH=../models/xlmr_xnli | ||
python transformer_pruning_selfsupervised.py $MODEL_PATH | ||
``` | ||
|
||
* Evaluate the model: | ||
|
||
Set `$PRUNED_MODEL_PATH` to the directory where the pruned model is stored. | ||
|
||
```bash | ||
cp $MODEL_PATH/sentencepiece.bpe.model $PRUNED_MODEL_PATH | ||
python measure_performance.py $PRUNED_MODEL_PATH | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import logging | ||
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
logger = logging.getLogger(__name__) | ||
|
||
from transformers import XLMRobertaForSequenceClassification,XLMRobertaTokenizer | ||
import sys, os | ||
|
||
sys.path.insert(0, os.path.abspath('..')) | ||
|
||
from classification_utils.my_dataset import MultilingualNLIDataset | ||
from classification_utils.predict_function import predict | ||
|
||
model_path = sys.argv[1] | ||
taskname = 'xnli' | ||
data_dir = '../datasets/xnli' | ||
split = 'test' | ||
max_seq_length=128 | ||
eval_langs = ['en','zh'] | ||
batch_size=32 | ||
device = 'cuda' | ||
|
||
# Re-initialze the tokenizer | ||
model = XLMRobertaForSequenceClassification.from_pretrained(model_path).to(device) | ||
tokenizer = XLMRobertaTokenizer.from_pretrained(model_path) | ||
eval_dataset = MultilingualNLIDataset( | ||
task=taskname, data_dir=data_dir, split=split, prefix='xlmr', | ||
max_seq_length=max_seq_length, langs=eval_langs, tokenizer=tokenizer) | ||
eval_datasets = [eval_dataset.lang_datasets[lang] for lang in eval_langs] | ||
predict(model, eval_datasets, eval_langs, device, batch_size) |
67 changes: 67 additions & 0 deletions
67
examples/transformer_pruning_xnli/transformer_pruning_selfsupervised.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import logging | ||
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
logger = logging.getLogger(__name__) | ||
|
||
from transformers import XLMRobertaForSequenceClassification,XLMRobertaTokenizer | ||
from textpruner import summary, TransformerPruner, TransformerPruningConfig | ||
import sys, os | ||
|
||
sys.path.insert(0, os.path.abspath('..')) | ||
|
||
from classification_utils.dataloader_script_xnli import dataloader, eval_langs, batch_size,MultilingualNLIDataset | ||
from classification_utils.predict_function import predict | ||
|
||
model_path = sys.argv[1] | ||
model = XLMRobertaForSequenceClassification.from_pretrained(model_path) | ||
tokenizer = XLMRobertaTokenizer.from_pretrained(model_path) | ||
|
||
print("Before pruning:") | ||
print(summary(model)) | ||
|
||
def adatpor(model_outputs): | ||
logits = model_outputs.logits | ||
return logits #entropy(logits) | ||
|
||
|
||
transformer_pruning_config = TransformerPruningConfig( | ||
target_ffn_size=1536, target_num_of_heads=6, | ||
pruning_method='iterative',n_iters=8,use_logits=True,head_even_masking=False,ffn_even_masking=False) | ||
pruner = TransformerPruner(model,transformer_pruning_config=transformer_pruning_config) | ||
pruner.prune(dataloader=dataloader, save_model=False, adaptor=adatpor) | ||
|
||
# save the tokenizer to the same place | ||
#tokenizer.save_pretrained(pruner.save_dir) | ||
|
||
print("After pruning:") | ||
print(summary(model)) | ||
|
||
for i in range(12): | ||
print ((model.base_model.encoder.layer[i].intermediate.dense.weight.shape, | ||
model.base_model.encoder.layer[i].intermediate.dense.bias.shape, | ||
model.base_model.encoder.layer[i].attention.self.key.weight.shape)) | ||
|
||
|
||
print("Measure performance") | ||
taskname = 'xnli' | ||
data_dir = '../datasets/xnli' | ||
split = 'dev' | ||
max_seq_length=128 | ||
eval_langs = ['en','zh'] | ||
batch_size=32 | ||
device= model.device | ||
eval_dataset = MultilingualNLIDataset( | ||
task=taskname, data_dir=data_dir, split=split, prefix='xlmr', | ||
max_seq_length=max_seq_length, langs=eval_langs, tokenizer=tokenizer) | ||
eval_datasets = [eval_dataset.lang_datasets[lang] for lang in eval_langs] | ||
print("dev") | ||
predict(model, eval_datasets, eval_langs, device, batch_size) | ||
|
||
split="test" | ||
print("test") | ||
eval_dataset = MultilingualNLIDataset( | ||
task=taskname, data_dir=data_dir, split=split, prefix='xlmr', | ||
max_seq_length=max_seq_length, langs=eval_langs, tokenizer=tokenizer) | ||
eval_datasets = [eval_dataset.lang_datasets[lang] for lang in eval_langs] | ||
predict(model, eval_datasets, eval_langs, device, batch_size) | ||
|
||
print(transformer_pruning_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Pruning the Classification model | ||
|
||
These scripts perform vocabulary pruning on the classification model (`XLMRobertaForSequenceClassification`) and evaluate the performance. | ||
|
||
We use the English and Chinese training sets as the vocabulary file. | ||
|
||
Download the fine-tuned model or train your own model on XNLI dataset, and save the files to `../models/xlmr_xnli`. | ||
|
||
Download link: | ||
* [Hugging Face Models](https://huggingface.co/ziqingyang/XLMRobertaBaseForXNLI-en/tree/main) | ||
|
||
See the README in ../datasets/xnli for how to construct the dataset. | ||
|
||
* Pruning with the python script: | ||
```bash | ||
VOCABULARY_FILE=../datasets/xnli/multinli.train.en_zh.tsv | ||
MODEL_PATH=../models/xlmr_xnli | ||
python vocabulary_pruning.py $MODEL_PATH $VOCABULARY_FILE | ||
``` | ||
|
||
* Evaluate the model: | ||
|
||
Set `$PRUNED_MODEL_PATH` to the directory where the pruned model is stored. | ||
|
||
```bash | ||
python measure_performance.py $PRUNED_MODEL_PATH | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import logging | ||
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
logger = logging.getLogger(__name__) | ||
|
||
from transformers import XLMRobertaForSequenceClassification,XLMRobertaTokenizer | ||
import sys, os | ||
|
||
sys.path.insert(0, os.path.abspath('..')) | ||
|
||
from classification_utils.my_dataset import MultilingualNLIDataset | ||
from classification_utils.predict_function import predict | ||
|
||
model_path = sys.argv[1] | ||
taskname = 'xnli' | ||
data_dir = '../datasets/xnli' | ||
split = 'test' | ||
max_seq_length=128 | ||
eval_langs = ['en'] | ||
batch_size=32 | ||
device = 'cuda' | ||
|
||
# Re-initialze the tokenizer | ||
model = XLMRobertaForSequenceClassification.from_pretrained(model_path).to(device) | ||
tokenizer = XLMRobertaTokenizer.from_pretrained(model_path) | ||
eval_dataset = MultilingualNLIDataset( | ||
task=taskname, data_dir=data_dir, split=split, prefix='xlmr', | ||
max_seq_length=max_seq_length, langs=eval_langs, tokenizer=tokenizer) | ||
eval_datasets = [eval_dataset.lang_datasets[lang] for lang in eval_langs] | ||
predict(model, eval_datasets, eval_langs, device, batch_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import logging | ||
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
logger = logging.getLogger(__name__) | ||
from transformers import XLMRobertaForSequenceClassification,XLMRobertaTokenizer | ||
from textpruner import summary, VocabularyPruner | ||
from textpruner.commands.utils import read_file_line_by_line | ||
import sys, os | ||
|
||
sys.path.insert(0, os.path.abspath('..')) | ||
from classification_utils.my_dataset import MultilingualNLIDataset | ||
from classification_utils.predict_function import predict | ||
|
||
# Initialize your model and load data | ||
model_path = sys.argv[1] | ||
vocabulary = sys.argv[2] | ||
model = XLMRobertaForSequenceClassification.from_pretrained(model_path) | ||
tokenizer = XLMRobertaTokenizer.from_pretrained(model_path) | ||
texts, _ = read_file_line_by_line(vocabulary) | ||
|
||
print("Before pruning:") | ||
print(summary(model)) | ||
|
||
pruner = VocabularyPruner(model, tokenizer) | ||
pruner.prune(dataiter=texts, save_model=True) | ||
|
||
print("After pruning:") | ||
print(summary(model)) | ||
|
||
|
||
print("Measure performance") | ||
|
||
taskname = 'xnli' | ||
data_dir = '../datasets/xnli' | ||
split = 'dev' | ||
max_seq_length=128 | ||
eval_langs = ['zh','en'] | ||
batch_size=32 | ||
device= model.device | ||
|
||
# Re-initialze the tokenizer | ||
tokenizer = XLMRobertaTokenizer.from_pretrained(pruner.save_dir) | ||
eval_dataset = MultilingualNLIDataset( | ||
task=taskname, data_dir=data_dir, split=split, prefix='xlmr', | ||
max_seq_length=max_seq_length, langs=eval_langs, tokenizer=tokenizer) | ||
eval_datasets = [eval_dataset.lang_datasets[lang] for lang in eval_langs] | ||
print("dev") | ||
predict(model, eval_datasets, eval_langs, device, batch_size) | ||
|
||
split="test" | ||
print("test") | ||
eval_dataset = MultilingualNLIDataset( | ||
task=taskname, data_dir=data_dir, split=split, prefix='xlmr', | ||
max_seq_length=max_seq_length, langs=eval_langs, tokenizer=tokenizer) | ||
eval_datasets = [eval_dataset.lang_datasets[lang] for lang in eval_langs] | ||
|
||
predict(model, eval_datasets, eval_langs, device, batch_size) |