Partial counterfactual identification for continuous outcomes with a Curvature Sensitivity Model
The project is built with the following Python libraries:
- Pyro - deep learning and probabilistic modelling with residual normalizing flows and variational augmentations
- Hydra - simplified command line arguments management
- MlFlow - experiments tracking
- POT - pytorch implementation of the Wasserstein distance
- normflows - pytorch implementation of the residual normalizing flows
First one needs to make the virtual environment and install all the requirements:
pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt
To start an experiments server, run:
mlflow server --port=5000
To access MlFLow web UI with all the experiments, connect via ssh:
ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>
Then, one can go to local browser http://localhost:5000.
We use multi-country data from Banholzer et al. (2021). Access data here.
Main training script of Augmented Pseudo-Invertible Decoder (APID) is universal for different datasets. For details on mandatory arguments - see the main configuration file config/config.yaml
and other files in configs/
folder.
Generic script with logging and fixed random seed is following:
PYTHONPATH=. python3 runnables/train_apid.py +dataset=<dataset> +model=apid exp.seed=10 exp.logging=True
Example of running APID on multi-modal dataset with curvature loss coefficient
PYTHONPATH=. python3 runnables/train_apid.py +dataset=multi_modal +model=apid exp.seed=10 exp.logging=True model.curv_coeff=5.0 dataset.Y_f=0.0,0.5,1.0,1.5,2.0
Project based on the cookiecutter data science project template. #cookiecutterdatascience