Skip to content

Code for the paper "Normalizing Flows for Interventional Density Estimation"

Notifications You must be signed in to change notification settings

Valentyn1997/INFs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

INFs

Conference arXiv Python application

Normalizing Flows for Interventional Density Estimation

image

The project is built with the following Python libraries:

  1. Pyro - deep learning and probabilistic models (MDNs, NFs)
  2. Hydra - simplified command line arguments management
  3. MlFlow - experiments tracking

Installations

First one needs to make the virtual environment and install all the requirements:

pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt

MlFlow Setup / Connection

To start an experiments server, run:

mlflow server --port=5000

To access the MlFLow web UI with all the experiments, connect via ssh:

ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>

Then, one can go to the local browser http://localhost:5000.

Experiments

The main training script is universal for different methods and datasets. For details on mandatory arguments - see the main configuration file config/config.yaml and other files in config/ folder.

Generic script with logging and fixed random seed is the following:

PYTHONPATH=.  python3 runnables/train.py +dataset=<dataset> +model=<model> exp.seed=10

Models (baselines)

One needs to choose a model and then fill in the specific hyperparameters (they are left blank in the configs):

Models already have the best hyperparameters saved (for each model and dataset), one can access them via: +model/<dataset>_hparams=<model> or +model/<dataset>_hparams/<model>=<dataset_param>. Hyperparameters for three variants of INFs are the same: +model/<dataset>_hparams=infs.

To perform a manual hyperparameter tuning use the flags model.tune_hparams=True, and then see model.hparams_grid.

Datasets

Before running semi-synthetic experiments, place datasets to data/ folder:

 ── acic_2016
    ├── synth_outcomes
    |   ├── zymu_<id0>.csv   
    |   ├── ... 
    │   └── zymu_<id14>.csv 
    ├── ids.csv
    └── x.csv 
 ── acic_2018
    ├── scaling
    |   ├── <id0>.csv 
    |   ├── <id0>_cf.csv
    |   ├── ... 
    │   ├── <id23>.csv
    │   └── <id23>_cf.csv 
    ├── ids.csv
    └── x.csv 

We also provide ids of random datasets (ids.csv), used in experiments for ACIC 2016 and ACIC 2018.

One needs to specify a dataset / dataset generator (and some additional parameters, e.g. set b for polynomial_normal with dataset.cov_shift=3.0 or dataset index for ACIC with dataset.dataset_ix=0):

  • Synthetic data using the SCM: +dataset=polynomial_normal
  • IHDP dataset: +dataset=ihdp
  • ACIC 2016 & 2018 datasets: +dataset=acic_2016 / +dataset=acic_2018
  • HC-MNIST dataset: +dataset=hcmnist

Examples

Example of running a 10-fold run with INFs (main) on Synthetic data with b = [2.0, 3.0, 4.0]:

PYTHONPATH=. python3 runnables/train.py -m +dataset=polynomial_normal +model=infs_aiptw +model/polynomial_normal_hparams/infs='2.0','3.0','4.0' exp.seed=10

Example of running 5 runs with random splits with MDNs on the first subset of ACIC 2016 (with tuning, based on the first split):

PYTHONPATH=. python3 runnables/train.py -m +dataset=acic_2016 +model=mdn_plugin model.tune_hparams=True exp.seed=10 dataset.n_shuffle_splits=5 dataset.dataset_ix=0

Example of running 10 runs with random splits with INFs (main) on HC-MNIST:

PYTHONPATH=. python3 runnables/train.py -m +dataset=hcmnist +model=infs_aiptw +model/hcmnist_hparams=infs model.target_count_bins=10 exp.seed=101 model.num_epochs=15000 model.target_num_epochs=5000 model.nuisance_hid_dim_multiplier=30 dataset.n_shuffle_splits=10

About

Code for the paper "Normalizing Flows for Interventional Density Estimation"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages