From f04f2a6131163d800131ec94f6603fb8cf9a52a9 Mon Sep 17 00:00:00 2001
From: Carlos Hernandez <cxh@stanford.edu>
Date: Wed, 22 Jun 2016 14:09:45 -0700
Subject: [PATCH] add tests

---
 devtools/conda-recipe/meta.yaml          |  1 +
 osprey/dataset_loaders.py                |  6 +-
 osprey/tests/test_cli_worker_and_dump.py |  2 +-
 osprey/tests/test_dataset_loader.py      | 87 +++++++++++++++++++++++-
 osprey/utils.py                          |  1 +
 5 files changed, 90 insertions(+), 7 deletions(-)

diff --git a/devtools/conda-recipe/meta.yaml b/devtools/conda-recipe/meta.yaml
index dae40ba..4b9f0bb 100644
--- a/devtools/conda-recipe/meta.yaml
+++ b/devtools/conda-recipe/meta.yaml
@@ -31,6 +31,7 @@ test:
     - nose-timer
     - gpy
     - msmbuilder
+    - mdtraj
     - coverage
     - python-coveralls
     - matplotlib
diff --git a/osprey/dataset_loaders.py b/osprey/dataset_loaders.py
index 9e2fdb1..44efd31 100644
--- a/osprey/dataset_loaders.py
+++ b/osprey/dataset_loaders.py
@@ -48,8 +48,8 @@ def load(self):
 class HDF5DatasetLoader(BaseDatasetLoader):
     short_name = 'hdf5'
 
-    def __init__(self, filename, stride=1):
-        self.filename = filename
+    def __init__(self, filenames, stride=1):
+        self.filenames = filenames
         self.stride = stride
 
     def transform(self, arr):
@@ -59,7 +59,7 @@ def load(self):
         from mdtraj import io
         X = []
         y = []
-        filenames = [fn.strip() for fn in self.filename.split(',')]
+        filenames = sorted(glob.glob(expand_path(self.filenames)))
         for i, fn in enumerate(filenames):
             dataset = io.loadh(fn)
             for key in dataset.iterkeys():
diff --git a/osprey/tests/test_cli_worker_and_dump.py b/osprey/tests/test_cli_worker_and_dump.py
index 932dcc7..45259f2 100644
--- a/osprey/tests/test_cli_worker_and_dump.py
+++ b/osprey/tests/test_cli_worker_and_dump.py
@@ -73,7 +73,7 @@ def _test_dump_1():
 
 
 def _test_plot_1():
-    out = subprocess.check_output(
+    _ = subprocess.check_output(
         [OSPREY_BIN, 'plot', 'config.yaml', '--no-browser'])
     if not os.path.isfile('./plot.html'):
         raise ValueError('Plot not created')
diff --git a/osprey/tests/test_dataset_loader.py b/osprey/tests/test_dataset_loader.py
index f0a3d6a..08b91f6 100644
--- a/osprey/tests/test_dataset_loader.py
+++ b/osprey/tests/test_dataset_loader.py
@@ -7,9 +7,10 @@
 import sklearn.datasets
 from sklearn.externals.joblib import dump
 
-from osprey.dataset_loaders import FilenameDatasetLoader
-from osprey.dataset_loaders import JoblibDatasetLoader
-from osprey.dataset_loaders import SklearnDatasetLoader
+from osprey.dataset_loaders import (FilenameDatasetLoader, JoblibDatasetLoader,
+                                    HDF5DatasetLoader, MDTrajDatasetLoader,
+                                    MSMBuilderDatasetLoader,
+                                    NumpyDatasetLoader, SklearnDatasetLoader)
 
 
 def test_FilenameDatasetLoader_1():
@@ -70,6 +71,86 @@ def test_JoblibDatasetLoader_1():
         shutil.rmtree(dirname)
 
 
+def test_HDF5DatasetLoader_1():
+    from mdtraj import io
+
+    assert HDF5DatasetLoader.short_name == 'hdf5'
+
+    cwd = os.path.abspath(os.curdir)
+    dirname = tempfile.mkdtemp()
+    try:
+        os.chdir(dirname)
+
+        # one file
+        io.saveh('f1.h5', **{'test': np.zeros((10, 2))})
+        loader = HDF5DatasetLoader('f1.h5')
+        X, y = loader.load()
+        assert np.all(X == np.zeros((10, 2)))
+        assert y[0] == 0
+
+        # two files
+        io.saveh('f2.h5', **{'test': np.ones((10, 2))})
+        loader = HDF5DatasetLoader('f*.h5')
+        X, y = loader.load()
+        assert isinstance(X, list)
+        assert np.all(X[0] == np.zeros((10, 2)))
+        assert np.all(X[1] == np.ones((10, 2)))
+        assert y[1] == 1
+
+    finally:
+        os.chdir(cwd)
+        shutil.rmtree(dirname)
+
+
+def test_MDTrajDatasetLoader_1():
+    from mdtraj.testing import get_fn
+
+    loader = MDTrajDatasetLoader(get_fn('legacy_msmbuilder_trj0.lh5'))
+    X, y = loader.load()
+    assert X[0].n_frames == 501
+    assert y is None
+
+
+def test_MSMBuilderDatasetLoader_1():
+    from msmbuilder.dataset import dataset
+
+    path = tempfile.mkdtemp()
+    shutil.rmtree(path)
+    try:
+        x = np.random.randn(10, 2)
+        ds = dataset(path, 'w', 'dir-npy')
+        ds[0] = x
+
+        loader = MSMBuilderDatasetLoader(path, fmt='dir-npy')
+        X, y = loader.load()
+
+        assert np.all(X[0] == x)
+        assert y is None
+
+    finally:
+        shutil.rmtree(path)
+
+
+def test_NumpyDatasetLoader_1():
+        cwd = os.path.abspath(os.curdir)
+        dirname = tempfile.mkdtemp()
+        try:
+            os.chdir(dirname)
+
+            x = np.random.randn(10, 2)
+            np.save('f1.npy', x)
+
+            loader = NumpyDatasetLoader('f1.npy')
+            X, y = loader.load()
+
+            assert np.all(X[0] == x)
+            assert y is None
+
+        finally:
+            os.chdir(cwd)
+            shutil.rmtree(dirname)
+
+
 def test_SklearnDatasetLoader_1():
     assert SklearnDatasetLoader.short_name == 'sklearn_dataset'
     X, y = SklearnDatasetLoader('load_iris').load()
diff --git a/osprey/utils.py b/osprey/utils.py
index f67afcd..8a5213f 100644
--- a/osprey/utils.py
+++ b/osprey/utils.py
@@ -10,6 +10,7 @@
 
 from .eval_scopes import import_all_estimators
 
+
 def dict_merge(base, top):
     """Recursively merge two dictionaries, with the elements from `top`
     taking precedence over elements from `top`.