- Python 3.8.0
- PyTorch 2.0.0
- CUDA 11.7
- Transformers 4.27.4
Prepare the anaconda environment:
conda create -n t2t python=3.8.0
conda activate t2t
pip install -r requirements.txt
In this experiment, we use banking_data for intent classification. The data is already included in the repository.
Run the following command to preprocess the data:
python data_process.py
It will process the origin train.csv
and test.csv
into two parts: one is for text-to-text generation, where the labels are the original categories; while the other is for text-to-indices classification, where the labels are all converted to indices.
Add --split
to randomly split 10% samples from training set as validation set to tune the hyperparameters (we have already done it and fixed all the hyperparameters).
For text-to-text generation, run the following command:
# standard training
sh scripts/run_sen_gen.sh [GPU] [batch_size] [model_name] [learning_rate]
# for example
sh scripts/run_sen_gen.sh 7 32 t5-small 5e-4
sh scripts/run_sen_gen.sh 7 4 t5-3b 5e-4
# few-shot training
sh scripts/run_sen_gen_few-shot.sh [GPU] [batch_size] [shot] [model_name] [learning_rate]
# for example
sh scripts/run_sen_gen_few-shot.sh 7 32 3 t5-small 5e-4
sh scripts/run_sen_gen_few-shot.sh 7 32 5 t5-small 5e-4
For text-to-indices classification, run the following command:
sh scripts/run_sen_cls.sh [GPU] [batch_size] [model_name] [learning_rate]
# for example
sh scripts/run_sen_cls.sh 7 32 t5-small 5e-4
sh scripts/run_sen_cls.sh 7 4 t5-3b 1e-4
# few-shot training
sh scripts/run_sen_cls_few-shot.sh [GPU] [batch_size] [shot] [model_name] [learning_rate]
# for example
sh scripts/run_sen_cls_few-shot.sh 7 32 3 t5-small 5e-4
sh scripts/run_sen_cls_few-shot.sh 7 32 5 t5-small 5e-4
The results will be saved in ./out
or ./out/k-shot
.
- The standard classification results are shown in the following table:
Model | Text-to-Indices | Text-to-Text | |
accuracy | accuracy | In-distribution ratio | |
T5-small (60M) | 91.3961 | 91.0065 | 99.8701 |
T5-base (220M) | 93.9935 | 93.7013 | 99.9675 |
T5-large (770M) | 93.2143 | 93.7662 | 99.9351 |
T5-3B | 94.4156 | 93.7987 | 99.9351 |
- 3-shot results (231 training instances):
Model | Text-to-Indices | Text-to-Text | |
accuracy | accuracy | In-distribution ratio | |
T5-small (60M) |
4.3651
|
13.4298
|
6.6371
|
T5-base (220M) |
7.8198
|
39.9471
|
54.1783
|
T5-large (770M) |
2.2954
|
58.5590
|
79.2017
|
T5-3B |
6.3414
|
66.1920
|
85.7688
|
- 5-shot results (385 training instances):
Model | Text-to-Indices | Text-to-Text | |
accuracy | accuracy | In-distribution ratio | |
T5-small (60M) |
5.9379
|
12.6634
|
13.1359
|
T5-base (220M) |
9.8834
|
58.6077
|
80.9025
|
T5-large (770M) |
8.0091
|
71.0112
|
90.6442
|
T5-3B |
12.687
|
73.7833
|
94.5031
|