- Abstract
- Installation
- Datasets
- Training
- Pretrained Models
- Evaluation
- Updates
- Acknowledgements
- Citation
- License
Neural networks often learn spurious correlations from biased datasets, leading to poor generalization. DeNetDM introduces a novel debiasing method leveraging network depth. Shallow networks prioritize core attributes, while deeper ones emphasize biases. We create biased and debiased branches, distill knowledge from both, and train a model that outperforms existing methods on multiple datasets without requiring bias labels or explicit augmentation techniques. Our approach demonstrates a performance improvement of approximately 5% across three datasets (synthetic and real-world), all without bias annotations, achieving comparable results to supervised debiasing approaches.
First, clone the repository and set up the environment:
git clone https://github.com/kadarsh22/DeNetDM # Clone the project
cd DeNetDM # Navigate into the project directory
conda env create -f denetdm.yml # Create a conda environment with dependencies
conda activate denetdm # Activate the environment
You can generate or download the necessary datasets as described below:
- ColoredMNIST: Follow the instructions from the Learning from Failure repository to set up generate the dataset.
- CorruptedCIFAR10: Use the instructions from the same repository above for the CorruptedCIFAR10.
- BFFHQ: Download the dataset from BFFHQ.
Once the datasets are ready, place them in a directory that is accessible by your project and update the config.py file with the corresponding path
Run the following command to start training on the dataset of your choice:
bash scripts/$DATASET.sh
Replace $DATASET
with one of the following options:
coloredmnist
corruptedcifar10
bffhq
For example, to run the training on ColoredMNIST:
bash scripts/coloredmnist.sh
Pretrained models can be downloaded from this Google Drive link.
- Download the models and place them in the
pretrained_models/
directory. - Update the
config.py
file of the corresponding dataset to point to the pretrained model. - Set
train = False
inmain.py
. - Run the evaluation (see the next section for details).
Once you've set up the pretrained models, you can evaluate them using the following commands:
python main.py with colored_mnist skewed1 severity4 # Evaluate on ColoredMNIST
python main.py with corrupted_cifar10 skewed1 severity4 # Evaluate on CorruptedCIFAR10
python3 main.py with bffhq # Evaluate on BFFHQ dataset
Ensure the path to the pretrained model is correctly set in the config.py
file for each dataset before running the evaluation.
- September 27, 2024: Paper accepted to NeurIPS 2024.
- October 24, 2024: arXiv version posted, code released.
- December 12, 2024: Poster presentation at NeurIPS 2024.
This code is partly based on the open-source implementations from the following projects:
If you find this code or idea useful, please cite our work:
@misc{sreelatha2024denetdmdebiasingnetworkdepth,
title={DeNetDM: Debiasing by Network Depth Modulation},
author={Silpa Vadakkeeveetil Sreelatha and Adarsh Kappiyath and Abhra Chaudhuri and Anjan Dutta},
year={2024},
eprint={2403.19863},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2403.19863},
}
This project is licensed under the MIT License - see the LICENSE file for details.