diff --git a/requirements.txt b/requirements.txt index 46bd159..77a2565 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ # local package --e . - importlib-metadata<5.0 +pytest click Sphinx coverage @@ -17,6 +16,7 @@ hydra_colorlog seaborn torch +torchvision pyro-ppl pytorch-lightning mlflow @@ -27,3 +27,5 @@ omegaconf tqdm ray[tune] econml +POT +xbcausalforest \ No newline at end of file diff --git a/tests/test_tarnet.py b/tests/test_tarnet.py new file mode 100644 index 0000000..3d978e3 --- /dev/null +++ b/tests/test_tarnet.py @@ -0,0 +1,23 @@ +from hydra.experimental import initialize, compose +import logging + +from runnables.train import main + +logging.basicConfig(level='info') + + +class TestTarNET: + def test_tarnet(self): + with initialize(config_path="../config"): + args = compose(config_name="config.yaml", overrides=["+dataset=synthetic", + "+repr_net=tarnet", + "+repr_net/synthetic_hparams/tarnet/n500='1.0'", + "exp.seed=10", + "exp.logging=False", + "exp.device=cpu", + "repr_net.num_epochs=10", + "prop_net_cov.num_epochs=10", + "cnf_repr.num_epochs=10", + "prop_net_repr.num_epochs=10"]) + results_1, results_2 = main(args), main(args) + assert results_1 == results_2 \ No newline at end of file