Skip to content

Code for the paper "Partial Counterfactual Identification of Continuous Outcomes with a Curvature Sensitivity Model"

Notifications You must be signed in to change notification settings

Valentyn1997/CSM-APID

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Partial Counterfactual Identification

Conference arXiv Python application

Partial counterfactual identification for continuous outcomes with a Curvature Sensitivity Model

image

The project is built with the following Python libraries:

  1. Pyro - deep learning and probabilistic modelling with residual normalizing flows and variational augmentations
  2. Hydra - simplified command line arguments management
  3. MlFlow - experiments tracking
  4. POT - pytorch implementation of the Wasserstein distance
  5. normflows - pytorch implementation of the residual normalizing flows

Installations

First one needs to make the virtual environment and install all the requirements:

pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt

MlFlow Setup / Connection

To start an experiments server, run:

mlflow server --port=5000

To access MlFLow web UI with all the experiments, connect via ssh:

ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>

Then, one can go to local browser http://localhost:5000.

Real-world case study

We use multi-country data from Banholzer et al. (2021). Access data here.

Experiments

Main training script of Augmented Pseudo-Invertible Decoder (APID) is universal for different datasets. For details on mandatory arguments - see the main configuration file config/config.yaml and other files in configs/ folder.

Generic script with logging and fixed random seed is following:

PYTHONPATH=.  python3 runnables/train_apid.py +dataset=<dataset> +model=apid exp.seed=10 exp.logging=True

Example of running APID on multi-modal dataset with curvature loss coefficient $\lambda_\kappa = 5.0$ and factual outcomes $y' = 0.0,0.5,1.0,1.5,2.0$:

PYTHONPATH=.  python3 runnables/train_apid.py +dataset=multi_modal +model=apid exp.seed=10 exp.logging=True model.curv_coeff=5.0 dataset.Y_f=0.0,0.5,1.0,1.5,2.0

Project based on the cookiecutter data science project template. #cookiecutterdatascience

About

Code for the paper "Partial Counterfactual Identification of Continuous Outcomes with a Curvature Sensitivity Model"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages