This repo implements an end-to-end classifier in Traffic Light Dataset based on pytorch.
- pytorch:0.4.0
- torchsummary:
pip install torchsummary
- cv2:
pip install opencv-python
- matplotlib
- numpy
First, you should clone this repo:
$ git clone https://github.com/FangYang970206/TL_Dataset_Classification
Second, download the dataset with onedrive, baiduyun or google. Move the TL_Dataset.zip
in the TL_Dataset_Classification/
, then unzip the TL_Dataset.zip
.
Third, start training.
$ python main.py
or(custom)
$ python main.py --img_resize_shape tuple --batch_size int --lr float --num_workers int --epochs int --val_size float --save_model bool --save_path str
The val_size(defualt=0.3) is radio in whole dataset.
learning curve:
In the testset, achieve 97.425%. (keep improving)