Another attempt to implement Yann LeCun's LeNet-5 network using PyTorch and test it in MNIST dataset.
-
Install torch and torchvision libs
-
Configure experiment in
main.py
file. Do not forget to set proper values to DATA_PATH (absolute path todata/datasets/
folder) and MODEL_STORE_PATH (absolute path todata/models/
folder) variables! -
Run
python main.py
lenet5.py
- contains LeNet5 class that implements LeNet-5 networklenet5_trainer.py
- contains LeNet5Trainer class. Method train(train_data, learning_rate, num_epochs, print_every=100) trains the network and returns it after.net_test.py
- contains NetTest class. Method test(test_data) performs testing of the network and returns quantity of correct tests and total quantity of tests.utils.py
- contains method prepare_mnist_data(data_path, batch_size) that prepares data for the experiment.main.py
- main file of the experiment.
data/datasets/
- MNIST-dataset will be downloaded here (at first run).data/models/
- folder to save files of trained models.
- I'm new in Python, so be lenient ;)
- The network implementation is very rough, so be very lenient ;)