Quantifying aleatoric uncertainty of the treatment effect with an AU-learner based on conditional normalizing flows (CNFs)
The project is built with the following Python libraries:
- Pyro - deep learning and probabilistic models (MDNs, NFs)
- Hydra - simplified command line arguments management
- MlFlow - experiments tracking
First one needs to make the virtual environment and install all the requirements:
pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt
To start an experiments server, run:
mlflow server --port=5000
To access the MlFLow web UI with all the experiments, connect via ssh:
ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>
Then, one can go to the local browser http://localhost:5000.
Before running semi-synthetic experiments, place datasets in the corresponding folders:
- IHDP100 dataset: ihdp_npci_1-100.test.npz and ihdp_npci_1-100.train.npz to
data/ihdp100/
We use multi-country data from Banholzer et al. (2021). Access data here.
The main training script is universal for different methods and datasets. For details on mandatory arguments - see the main configuration file config/config.yaml
and other files in config/
folder.
Generic script with logging and fixed random seed is the following:
PYTHONPATH=. python3 runnables/train.py +dataset=<dataset> +model=<model> exp.seed=10
One needs to choose a model and then fill in the specific hyperparameters (they are left blank in the configs):
- AU-CNFs (= AU-learner with CNFs, this paper):
+model=dr_cnfs
with two variants:- CRPS:
model.target_mode=cdf
-
$W_2^2$ :model.target_mode=icdf
- CRPS:
- CA-CNFs (= CA-learner with CNFs, this paper):
+model=ca_cnfs
with two variants:- CRPS:
model.target_mode=cdf
-
$W_2^2$ :model.target_mode=icdf
- CRPS:
- IPTW-CNF (= IPTW-learner with CNF, this paper):
+model=iptw_plugin_cnfs
- Conditional Normalizing Flows (CNF, plug-in learner):
+model=plugin_cnfs
-
Distributional Kernel Mean Embeddings (DKME):
+model=plugin_dkme
Models already have the best hyperparameters saved (for each model and dataset), one can access them via: +model/<dataset>_hparams=<model>
or +model/<dataset>_hparams/<model>=<dataset_param>
. Hyperparameters for three variants of AU-CNFs, CA-CNFs, IPTW-CNF, and CNF are the same: +model/<dataset>_hparams=plugin_cnfs
.
To perform a manual hyperparameter tuning, use the flags model.tune_hparams=True
, and, then, see model.hparams_grid
.
One needs to specify a dataset/dataset generator (and some additional parameters, e.g. train size for the synthetic data dataset.n_samples_train=1000
):
- Synthetic data (adapted from https://arxiv.org/abs/1810.02894):
+dataset=sine
with 3 settings:- Normal:
dataset.mode=normal
- Multi-modal:
dataset.mode=multimodal
- Exponential:
dataset.mode=exp
- Normal:
- IHDP dataset:
+dataset=ihdp
- HC-MNIST dataset:
+dataset=hcmnist
Example of running an experiment with our AU-CNFs (CRPS) on Synthetic data in the normal setting with
PYTHONPATH=. python3 runnables/train.py -m +dataset=sine +model=dr_cnfs +model/sine_hparams/plugin_cnfs_normal=\'100\' model.target_mode=cdf model.correction_coeff=0.25 exp.seed=10,101,1010
Example of tuning hyperparameters of the CNF based on HC-MNIST dataset:
PYTHONPATH=. python3 runnables/train.py -m +dataset=hcmnist +model=plugin_cnfs +model/hcmnist_hparams=plugin_cnfs exp.seed=10 model.tune_hparams=True
Project based on the cookiecutter data science project template. #cookiecutterdatascience