Yuchi Liu, Yifan Sun, Jingdong Wang, Liang Zheng
This repository contains the implementation for the paper "Assessing Model Generalization in Vicinity".
This paper evaluates the generalization ability of classification models on out-of-distribution test sets without depending on ground truth labels. Common approaches often calculate an unsupervised metric related to a specific model property, like confidence or invariance, which correlates with out-of-distribution accuracy. However, these metrics are typically computed for each test sample individually, leading to potential issues caused by spurious model responses, such as overly high or low confidence. To tackle this challenge, we propose incorporating responses from neighboring test samples into the correctness assessment of each individual sample. In essence, if a model consistently demonstrates high correctness scores for nearby samples, it increases the likelihood of correctly predicting the target sample, and vice versa. The resulting scores are then averaged across all test samples to provide a holistic indication of model accuracy. Developed under the vicinal risk formulation, this approach, named vicinal risk proxy (VRP), computes accuracy without relying on labels. We show that applying the VRP method to existing generalization indicators, such as average confidence and effective invariance, consistently improves over these baselines both methodologically and experimentally. This yields a stronger correlation with model accuracy, especially on challenging out-of-distribution test sets.
Before running the scripts, ensure you have Python installed along with the necessary packages. To install the required packages, execute the following command:
pip install -r requirements.txt
Please check the Experimental Setup section in our paper.
bash src/test_getOutput.sh
You can cahnge the python files in test_getOutput.sh
to change the datasets"
test_savePredictions.py
: for Iamgenet Setuptest_savePredictions_cifar.py
: for cifar Setuptest_savePredictions_iwilds.py
: for iWilds Setup
To utilize the provided scripts effectively, please organize your data according to the following directory structure:
├── data
│ ├── ImageNet-Val
│ ├── ImageNet-A
│ ├── ImageNet-R
│ └── ...
└── modelOutput
├── imagenet_a_out_colorjitter
| ├── tv_reesnet152.npy
│ └── ...
|
├── imagenet_a_out_grey
├── imagenet_a_out_colorjitter
└── ...
|── iwildcam_weights
|── ...
└── src
To compute model risk estimation under different setups, run the following commands:
- 🔧 ImageNet setup:
python src/test_mentric.py
- Cifar10 setup:
python src/test_metric_cifar.py
- iWildCam setup:
python src/test_metric_iwilds.py
You can compute the correlation between the estimated model risks and their accuracy by running the following command:
python src/compute_correlation.py
If you find our code helpful, please consider citing our paper:
@article{liu2024assessing,
title={Assessing Model Generalization in Vicinity},
author={Liu, Yuchi and Sun, Yifan and Wang, Jingdong and Zheng, Liang},
journal={arXiv preprint arXiv:2406.09257},
year={2024}
}
This project is open source and available under the MIT License.
Let me know if there's anything else you need!