Skip to content

Code for the paper "Quantifying Aleatoric Uncertainty of the Treatment Effect: A Novel Orthogonal Learner"

Notifications You must be signed in to change notification settings

Valentyn1997/AU-CNFs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

AU-learner with conditional normalizing flows

Quantifying aleatoric uncertainty of the treatment effect with an AU-learner based on conditional normalizing flows (CNFs)

Preview 2024-10-29 19 00 37

The project is built with the following Python libraries:

  1. Pyro - deep learning and probabilistic models (MDNs, NFs)
  2. Hydra - simplified command line arguments management
  3. MlFlow - experiments tracking

Setup

Installations

First one needs to make the virtual environment and install all the requirements:

pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt

MlFlow Setup / Connection

To start an experiments server, run:

mlflow server --port=5000

To access the MlFLow web UI with all the experiments, connect via ssh:

ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>

Then, one can go to the local browser http://localhost:5000.

Semi-synthetic datasets setup

Before running semi-synthetic experiments, place datasets in the corresponding folders:

  • IHDP100 dataset: ihdp_npci_1-100.test.npz and ihdp_npci_1-100.train.npz to data/ihdp100/

Real-world case study

We use multi-country data from Banholzer et al. (2021). Access data here.

Experiments

The main training script is universal for different methods and datasets. For details on mandatory arguments - see the main configuration file config/config.yaml and other files in config/ folder.

Generic script with logging and fixed random seed is the following:

PYTHONPATH=.  python3 runnables/train.py +dataset=<dataset> +model=<model> exp.seed=10

Models

One needs to choose a model and then fill in the specific hyperparameters (they are left blank in the configs):

  • AU-CNFs (= AU-learner with CNFs, this paper): +model=dr_cnfs with two variants:
    • CRPS: model.target_mode=cdf
    • $W_2^2$: model.target_mode=icdf
  • CA-CNFs (= CA-learner with CNFs, this paper): +model=ca_cnfs with two variants:
    • CRPS: model.target_mode=cdf
    • $W_2^2$: model.target_mode=icdf
  • IPTW-CNF (= IPTW-learner with CNF, this paper): +model=iptw_plugin_cnfs
  • Conditional Normalizing Flows (CNF, plug-in learner): +model=plugin_cnfs
  • Distributional Kernel Mean Embeddings (DKME): +model=plugin_dkme

Models already have the best hyperparameters saved (for each model and dataset), one can access them via: +model/<dataset>_hparams=<model> or +model/<dataset>_hparams/<model>=<dataset_param>. Hyperparameters for three variants of AU-CNFs, CA-CNFs, IPTW-CNF, and CNF are the same: +model/<dataset>_hparams=plugin_cnfs.

To perform a manual hyperparameter tuning, use the flags model.tune_hparams=True, and, then, see model.hparams_grid.

Datasets

One needs to specify a dataset/dataset generator (and some additional parameters, e.g. train size for the synthetic data dataset.n_samples_train=1000):

  • Synthetic data (adapted from https://arxiv.org/abs/1810.02894): +dataset=sine with 3 settings:
    • Normal: dataset.mode=normal
    • Multi-modal: dataset.mode=multimodal
    • Exponential: dataset.mode=exp
  • IHDP dataset: +dataset=ihdp
  • HC-MNIST dataset: +dataset=hcmnist

Examples

Example of running an experiment with our AU-CNFs (CRPS) on Synthetic data in the normal setting with $n_{\text{train}} = 100$ with 3 random seeds:

PYTHONPATH=. python3 runnables/train.py -m +dataset=sine +model=dr_cnfs +model/sine_hparams/plugin_cnfs_normal=\'100\' model.target_mode=cdf model.correction_coeff=0.25 exp.seed=10,101,1010

Example of tuning hyperparameters of the CNF based on HC-MNIST dataset:

PYTHONPATH=. python3 runnables/train.py -m +dataset=hcmnist +model=plugin_cnfs +model/hcmnist_hparams=plugin_cnfs exp.seed=10 model.tune_hparams=True

Project based on the cookiecutter data science project template. #cookiecutterdatascience

About

Code for the paper "Quantifying Aleatoric Uncertainty of the Treatment Effect: A Novel Orthogonal Learner"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages