此开源hub是基于Tensorflow2.x的文本分类任务等项目
通过对 Config 文件配置,可支持如下功能:
- Bert/MacBert/RoBerta/DistilBert/AlBert/Electra/XLNet各种预训练模型训练
- 支持二分类和多分类
- 支持单例测试和批量测试
- 保存为 pb 文件可供上线部署
- 支持对抗训练 fgm/pgd
- 支持 label_smoothing
- 支持针对样本不均衡的 loss(包括多分类)
- python 3.7.8
- tensorflow-gpu==2.2.0
- tensorflow-addons==0.15.0
- transformers==4.9.1
- tqdm==4.31.1
- pandas==1.3.5
- scikit-learn==1.0.2
.
├── LICENSE
├── README.md
├── __init__.py
├── config.py 参数配置文件
├── data
│ ├── 提交示例.csv
│ ├── 测试集
│ │ └── test1.csv
│ └── 训练集
│ └── train.csv
├── kernels
│ ├── __init__.py
│ ├── data_processer.py
│ ├── data_statistics.py
│ ├── models
│ │ └── TFPretrainedModel.py
│ ├── predict.py
│ ├── train.py
│ ├── tricks
│ │ ├── adversarial_fgm.py
│ │ ├── adversarial_pgd.py
│ │ └── focal_loss.py
│ └── utils
│ ├── __init__.py
│ ├── cal_metrics.py
│ └── logger.py
├── requirements.txt
└── run.py
7 directories, 21 files
Version | Describe |
---|---|
v1.0.0 | 初始仓库 |
v2.0.0 | 预训练模型基本版 |
v2.1.0 | 添加训练 tricks |
在config.py中配置好各个参数,文件中有详细参数说明
参数配置完后开始模型训练
# [train_classifier, predict_single, predict_test, save_pb_model]
mode = 'train_classifier'
训练好模型直接可以开始测试,支持单例测试和批量测试
- 单例测试
# [train_classifier, predict_single, predict_test, save_pb_model]
mode = 'predict_single'
- 批量测试
# [train_classifier, predict_single, predict_test, save_pb_model]
mode = 'predict_test'
本项目作为笔者在之前工作中项目背景下的抽象出的文本分类实验demo和trick。 源码和数据(实验数据)已经在项目中给出。
如需要更深一步的交流,请发送消息至邮箱 [email protected],或者在 Github 上直接留言。