This repo contains code for the paper Dynamic Context Pruning for Efficient and Interpretable Autoregressive Transformers.
Adaptively Sparse Attention adapts Autoregressive Transformers giving them the ability to drop irrelevant tokens in the sequence.
We can thus enhance interpretability by highlighting the most relevant tokens in the sequence, while also reducing the computational cost of the model during inference.
An example of the droppig mechanism is shown below:
Tokens are generated one by one and arrows indicate the previous tokens dropped by the model.
To install the requirements, run:
pip install -r requirements.txt
To finetune pretrained model, run:
python train.py --log_dir gpt2-sparse --batch_size 6 --entmax_weight_scheduler cosine_1_8_25000 --sparse_attention true --pretrained gpt2 --int_n_embd 64
This will generate a folder gpt2-sparse
with tensorboard logs under the logs
folder. Try python train.py --help
for more options.
To generate text from a finetuned model, run:
python generate.py --model_path <folder-under-logs> --checkpoint <spcify-checkpoint>
If you find this repo useful, please cite our paper:
@article{anagnostidis2023dynamic,
title={Dynamic Context Pruning for Efficient and Interpretable Autoregressive Transformers},
author={Anagnostidis, Sotiris and Pavllo, Dario and Biggio, Luca and Noci, Lorenzo and Lucchi, Aurelien and Hoffmann, Thomas},
journal={arXiv preprint arXiv:2305.15805},
year={2023}
}
This repo was heavily inspired by the following: