Skip to content

Latest commit

 

History

History
124 lines (92 loc) · 5.9 KB

File metadata and controls

124 lines (92 loc) · 5.9 KB

Attributes classification on CelebA dataset: Multi-label classification task

Authors: Apavou Clément & Belkada Younes

Python PyTorch PyTorch Lightning

🔎 Introduction :

This repository is related to a project of the Introduction to Numerical Imaging (i.e, Introduction à l'Imagerie Numérique in French), given by the MVA Masters program at ENS-Paris Saclay.

It was entirely build from scratch and contains code in PyTorch Lightning to train and then use a neural network for image classification. We used it to create a classifier allowing semantic attributes classification of faces with the dataset CelebA.

Some images of the CelebA dataset with attribute annotation.

📈 Experiments :

The dataset CelebA contains approximately 200,000 images of celebrities faces with 40 binary semantic attribute annotations such as smiling 😁 / not smiling 😐 or bald 👴 / not bald 👨. All attributes are available here.

We have fine-tuned two classifier a ResNet-50 and a ViT small with 16x16 patches. The training set contains 200,000 images, so only one epoch is sufficient to fine tune models to perform for attributes classification on CelebA dataset.

Experiments are available on wandb: .

🔎 Results:

Model Accuracy Weights Run
vit_small_patch16_224 0.7622 here
resnet50 0.8055 not available

🎉 Features :

For models, we used the timm library providing many models for image classification. All image classification models from this library can be used.

The entire code contains the following features :

  • Training of a neural network for image classification for any dataset (you just have to add your custom dataset in the folder datasets)
  • Visualisation with the library Weights and Biases of several metrics of classification such as losses, accuracy, précision and recall. Also, prédictions of some training and validation images are logged on wandb to follow in real time the progress of the training.
  • Inference of the model on your own images. You just have to add your images whose you want the model to infer and specify the path of the folder in the config file

The first goal of this repository was to use the InterFaceGAN method. So, there is a script train_svm.py which allows to train SVM for each semantic attributes to obtain boundaries and use them to control faces generation with InterfaceGAN of this repository.

The all code is useable by just modifying the config file (config/hparams.py). You can launch a training of classifier, launch an inférence of a classifier (by using weigths of à trained classifier) and you can train SVMs to create boundaries.

🎯 Code structure :

The structure of repository is the following :

├── assets                      # Put database here
├── datamodules
|   |
|   ├── celebadatamodule.py     # datamodules PyTorch lightning for CelebA dataset
|         
├── datasets
|   ├── celeba.py                # Fix issue for CelebA dataset PyTorch
|   ├── inference_dataset.py     # custom dataset PyTorch for inference
|          
├── lightningmodules
|   ├── classification.py        # lightning module for image classification (multi-label)
| 
├── utils                        # utils functions
|   ├── boundary_creator.py
|   ├── callbacks.py
|   ├── constant.py
|   ├── utils_functions.py
|
├── weights                     # put models weights here
|
├── analyse_score_latent_space.ipynb  # notebook to analyse scores predicted
|
├── hparams.py                   # configuration file
|
├── main.py                      # main script to launch for training of inference 
| 
├── train_svm.py                 # script to create boundaries for InterFaceGAN
|
└── README.md

🔨 Usage :

Train a classifier

Parameters to put in hparams.py:

    train : bool = True
    predict: bool = False 

Then change Hparams, TrainParams, DatasetParams and CallBackParams with your needs.

python main.py

Predict with the classifier

Parameters to put in hparams.py:

    train : bool = False
    predict: bool = True 

Then change Hparams, InferenceParams and DatasetParams with your needs.

python main.py

Train SVM for InterFaceGAN

Modify SVMParams in hparams.py with your needs.

python train_svm.py