Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentyn1997 committed Aug 20, 2024
1 parent a96b1ae commit 6b698f7
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v2
Expand Down
9 changes: 4 additions & 5 deletions runnables/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,10 @@ def main(args: DictConfig):

mlflow_logger.experiment.set_terminated(mlflow_logger.run_id) if args.exp.logging else None

return results_in, results_out
# return {
# 'models': {'repr_net': repr_net, 'cnf_repr': cnf_repr, 'prop_net_repr': prop_net_repr, 'prop_net_cov': prop_net_cov},
# 'data_dicts': {'train_data_dict': train_data_dict, 'test_data_dict': test_data_dict}
# }
return {
'models': {'repr_net': repr_net, 'cnf_repr': cnf_repr, 'prop_net_repr': prop_net_repr, 'prop_net_cov': prop_net_cov},
'data_dicts': {'train_data_dict': train_data_dict, 'test_data_dict': test_data_dict}
}


if __name__ == "__main__":
Expand Down
28 changes: 15 additions & 13 deletions tests/test_tarnet.py → tests/test_classical.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
from hydra.experimental import initialize, compose
import logging
import torch
import numpy as np
import random

from runnables.train import main
from runnables.train_classic import main

logging.basicConfig(level='info')


class TestTarNET:
def test_tarnet(self):
class TestClassicalMethods:
def test_knn(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+repr_net=tarnet",
"+repr_net/synthetic_hparams/tarnet/n500='1.0'",
"+classic_cate=knn",
"exp.seed=10",
"exp.logging=False",
"exp.device=cpu",
"repr_net.num_epochs=5",
"prop_net_cov.num_epochs=5",
"cnf_repr.num_epochs=5",
"prop_net_repr.num_epochs=5"])
"dataset.n_samples_train=500"])
results = main(args)
assert results is not None

def test_causal_forest(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+classic_cate=causal_forest",
"exp.seed=10",
"exp.logging=False",
"dataset.n_samples_train=500"])
results = main(args)
assert results is not None
129 changes: 129 additions & 0 deletions tests/test_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from hydra.experimental import initialize, compose
import logging

from runnables.train import main

logging.basicConfig(level='info')


class TestReprNETs:
def test_tarnet(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+repr_net=tarnet",
"+repr_net/synthetic_hparams/tarnet/n500='1.0'",
"exp.seed=10",
"exp.logging=False",
"exp.device=cpu",
"repr_net.num_epochs=5",
"prop_net_cov.num_epochs=5",
"cnf_repr.num_epochs=5",
"prop_net_repr.num_epochs=5"])
results = main(args)
assert results is not None

def test_cfrnet_mmd(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+repr_net=cfrnet",
"+repr_net/synthetic_hparams/cfrnet/n500/mmd_0_5='1.0'",
"exp.seed=10",
"exp.logging=False",
"exp.device=cpu",
"repr_net.num_epochs=5",
"prop_net_cov.num_epochs=5",
"cnf_repr.num_epochs=5",
"prop_net_repr.num_epochs=5"])
results = main(args)
assert results is not None

def test_cfrnet_wass(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+repr_net=cfrnet",
"+repr_net/synthetic_hparams/cfrnet/n500/wass_2_0='1.0'",
"exp.seed=10",
"exp.logging=False",
"exp.device=cpu",
"repr_net.num_epochs=5",
"prop_net_cov.num_epochs=5",
"cnf_repr.num_epochs=5",
"prop_net_repr.num_epochs=5"])
results = main(args)
assert results is not None

def test_rcfrnet(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+repr_net=rcfrnet",
"+repr_net/synthetic_hparams/rcfrnet/n500='1.0'",
"exp.seed=10",
"exp.logging=False",
"exp.device=cpu",
"repr_net.num_epochs=5",
"prop_net_cov.num_epochs=5",
"cnf_repr.num_epochs=5",
"prop_net_repr.num_epochs=5"])
results = main(args)
assert results is not None

def test_cfrisw(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+repr_net=cfrisw",
"+repr_net/synthetic_hparams/cfrisw/n500='1.0'",
"exp.seed=10",
"exp.logging=False",
"exp.device=cpu",
"repr_net.num_epochs=5",
"prop_net_cov.num_epochs=5",
"cnf_repr.num_epochs=5",
"prop_net_repr.num_epochs=5"])
results = main(args)
assert results is not None

def test_bwcfr(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+repr_net=bwcfr",
"+repr_net/synthetic_hparams/bwcfr/n500='1.0'",
"exp.seed=10",
"exp.logging=False",
"exp.device=cpu",
"repr_net.num_epochs=5",
"prop_net_cov.num_epochs=5",
"cnf_repr.num_epochs=5",
"prop_net_repr.num_epochs=5"])
results = main(args)
assert results is not None

def test_bnnet(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+repr_net=bnnet",
"+repr_net/synthetic_hparams/bnnet/n500='1.0'",
"exp.seed=10",
"exp.logging=False",
"exp.device=cpu",
"repr_net.num_epochs=5",
"prop_net_cov.num_epochs=5",
"cnf_repr.num_epochs=5",
"prop_net_repr.num_epochs=5"])
results = main(args)
assert results is not None

def test_inv_tarnet(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+repr_net=inv_tarnet",
"+repr_net/synthetic_hparams/inv_tarnet/n500='1.0'",
"exp.seed=10",
"exp.logging=False",
"exp.device=cpu",
"repr_net.num_epochs=5",
"prop_net_cov.num_epochs=5",
"cnf_repr.num_epochs=5",
"prop_net_repr.num_epochs=5"])

results = main(args)
assert results is not None

0 comments on commit 6b698f7

Please sign in to comment.