From f04f2a6131163d800131ec94f6603fb8cf9a52a9 Mon Sep 17 00:00:00 2001 From: Carlos Hernandez <cxh@stanford.edu> Date: Wed, 22 Jun 2016 14:09:45 -0700 Subject: [PATCH] add tests --- devtools/conda-recipe/meta.yaml | 1 + osprey/dataset_loaders.py | 6 +- osprey/tests/test_cli_worker_and_dump.py | 2 +- osprey/tests/test_dataset_loader.py | 87 +++++++++++++++++++++++- osprey/utils.py | 1 + 5 files changed, 90 insertions(+), 7 deletions(-) diff --git a/devtools/conda-recipe/meta.yaml b/devtools/conda-recipe/meta.yaml index dae40ba..4b9f0bb 100644 --- a/devtools/conda-recipe/meta.yaml +++ b/devtools/conda-recipe/meta.yaml @@ -31,6 +31,7 @@ test: - nose-timer - gpy - msmbuilder + - mdtraj - coverage - python-coveralls - matplotlib diff --git a/osprey/dataset_loaders.py b/osprey/dataset_loaders.py index 9e2fdb1..44efd31 100644 --- a/osprey/dataset_loaders.py +++ b/osprey/dataset_loaders.py @@ -48,8 +48,8 @@ def load(self): class HDF5DatasetLoader(BaseDatasetLoader): short_name = 'hdf5' - def __init__(self, filename, stride=1): - self.filename = filename + def __init__(self, filenames, stride=1): + self.filenames = filenames self.stride = stride def transform(self, arr): @@ -59,7 +59,7 @@ def load(self): from mdtraj import io X = [] y = [] - filenames = [fn.strip() for fn in self.filename.split(',')] + filenames = sorted(glob.glob(expand_path(self.filenames))) for i, fn in enumerate(filenames): dataset = io.loadh(fn) for key in dataset.iterkeys(): diff --git a/osprey/tests/test_cli_worker_and_dump.py b/osprey/tests/test_cli_worker_and_dump.py index 932dcc7..45259f2 100644 --- a/osprey/tests/test_cli_worker_and_dump.py +++ b/osprey/tests/test_cli_worker_and_dump.py @@ -73,7 +73,7 @@ def _test_dump_1(): def _test_plot_1(): - out = subprocess.check_output( + _ = subprocess.check_output( [OSPREY_BIN, 'plot', 'config.yaml', '--no-browser']) if not os.path.isfile('./plot.html'): raise ValueError('Plot not created') diff --git a/osprey/tests/test_dataset_loader.py b/osprey/tests/test_dataset_loader.py index f0a3d6a..08b91f6 100644 --- a/osprey/tests/test_dataset_loader.py +++ b/osprey/tests/test_dataset_loader.py @@ -7,9 +7,10 @@ import sklearn.datasets from sklearn.externals.joblib import dump -from osprey.dataset_loaders import FilenameDatasetLoader -from osprey.dataset_loaders import JoblibDatasetLoader -from osprey.dataset_loaders import SklearnDatasetLoader +from osprey.dataset_loaders import (FilenameDatasetLoader, JoblibDatasetLoader, + HDF5DatasetLoader, MDTrajDatasetLoader, + MSMBuilderDatasetLoader, + NumpyDatasetLoader, SklearnDatasetLoader) def test_FilenameDatasetLoader_1(): @@ -70,6 +71,86 @@ def test_JoblibDatasetLoader_1(): shutil.rmtree(dirname) +def test_HDF5DatasetLoader_1(): + from mdtraj import io + + assert HDF5DatasetLoader.short_name == 'hdf5' + + cwd = os.path.abspath(os.curdir) + dirname = tempfile.mkdtemp() + try: + os.chdir(dirname) + + # one file + io.saveh('f1.h5', **{'test': np.zeros((10, 2))}) + loader = HDF5DatasetLoader('f1.h5') + X, y = loader.load() + assert np.all(X == np.zeros((10, 2))) + assert y[0] == 0 + + # two files + io.saveh('f2.h5', **{'test': np.ones((10, 2))}) + loader = HDF5DatasetLoader('f*.h5') + X, y = loader.load() + assert isinstance(X, list) + assert np.all(X[0] == np.zeros((10, 2))) + assert np.all(X[1] == np.ones((10, 2))) + assert y[1] == 1 + + finally: + os.chdir(cwd) + shutil.rmtree(dirname) + + +def test_MDTrajDatasetLoader_1(): + from mdtraj.testing import get_fn + + loader = MDTrajDatasetLoader(get_fn('legacy_msmbuilder_trj0.lh5')) + X, y = loader.load() + assert X[0].n_frames == 501 + assert y is None + + +def test_MSMBuilderDatasetLoader_1(): + from msmbuilder.dataset import dataset + + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + x = np.random.randn(10, 2) + ds = dataset(path, 'w', 'dir-npy') + ds[0] = x + + loader = MSMBuilderDatasetLoader(path, fmt='dir-npy') + X, y = loader.load() + + assert np.all(X[0] == x) + assert y is None + + finally: + shutil.rmtree(path) + + +def test_NumpyDatasetLoader_1(): + cwd = os.path.abspath(os.curdir) + dirname = tempfile.mkdtemp() + try: + os.chdir(dirname) + + x = np.random.randn(10, 2) + np.save('f1.npy', x) + + loader = NumpyDatasetLoader('f1.npy') + X, y = loader.load() + + assert np.all(X[0] == x) + assert y is None + + finally: + os.chdir(cwd) + shutil.rmtree(dirname) + + def test_SklearnDatasetLoader_1(): assert SklearnDatasetLoader.short_name == 'sklearn_dataset' X, y = SklearnDatasetLoader('load_iris').load() diff --git a/osprey/utils.py b/osprey/utils.py index f67afcd..8a5213f 100644 --- a/osprey/utils.py +++ b/osprey/utils.py @@ -10,6 +10,7 @@ from .eval_scopes import import_all_estimators + def dict_merge(base, top): """Recursively merge two dictionaries, with the elements from `top` taking precedence over elements from `top`.