Skip to content

Commit

Permalink
Fix requirements.txt
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentyn1997 committed Aug 20, 2024
1 parent 834d742 commit db657b7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# local package
-e .

importlib-metadata<5.0
pytest
click
Sphinx
coverage
Expand All @@ -17,6 +16,7 @@ hydra_colorlog

seaborn
torch
torchvision
pyro-ppl
pytorch-lightning
mlflow
Expand All @@ -27,3 +27,5 @@ omegaconf
tqdm
ray[tune]
econml
POT
xbcausalforest
23 changes: 23 additions & 0 deletions tests/test_tarnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from hydra.experimental import initialize, compose
import logging

from runnables.train import main

logging.basicConfig(level='info')


class TestTarNET:
def test_tarnet(self):
with initialize(config_path="../config"):
args = compose(config_name="config.yaml", overrides=["+dataset=synthetic",
"+repr_net=tarnet",
"+repr_net/synthetic_hparams/tarnet/n500='1.0'",
"exp.seed=10",
"exp.logging=False",
"exp.device=cpu",
"repr_net.num_epochs=10",
"prop_net_cov.num_epochs=10",
"cnf_repr.num_epochs=10",
"prop_net_repr.num_epochs=10"])
results_1, results_2 = main(args), main(args)
assert results_1 == results_2

0 comments on commit db657b7

Please sign in to comment.