Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlos Hernandez committed Jun 22, 2016
1 parent 2fb7b06 commit f04f2a6
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 7 deletions.
1 change: 1 addition & 0 deletions devtools/conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ test:
- nose-timer
- gpy
- msmbuilder
- mdtraj
- coverage
- python-coveralls
- matplotlib
Expand Down
6 changes: 3 additions & 3 deletions osprey/dataset_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def load(self):
class HDF5DatasetLoader(BaseDatasetLoader):
short_name = 'hdf5'

def __init__(self, filename, stride=1):
self.filename = filename
def __init__(self, filenames, stride=1):
self.filenames = filenames
self.stride = stride

def transform(self, arr):
Expand All @@ -59,7 +59,7 @@ def load(self):
from mdtraj import io
X = []
y = []
filenames = [fn.strip() for fn in self.filename.split(',')]
filenames = sorted(glob.glob(expand_path(self.filenames)))
for i, fn in enumerate(filenames):
dataset = io.loadh(fn)
for key in dataset.iterkeys():
Expand Down
2 changes: 1 addition & 1 deletion osprey/tests/test_cli_worker_and_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _test_dump_1():


def _test_plot_1():
out = subprocess.check_output(
_ = subprocess.check_output(
[OSPREY_BIN, 'plot', 'config.yaml', '--no-browser'])
if not os.path.isfile('./plot.html'):
raise ValueError('Plot not created')
87 changes: 84 additions & 3 deletions osprey/tests/test_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import sklearn.datasets
from sklearn.externals.joblib import dump

from osprey.dataset_loaders import FilenameDatasetLoader
from osprey.dataset_loaders import JoblibDatasetLoader
from osprey.dataset_loaders import SklearnDatasetLoader
from osprey.dataset_loaders import (FilenameDatasetLoader, JoblibDatasetLoader,
HDF5DatasetLoader, MDTrajDatasetLoader,
MSMBuilderDatasetLoader,
NumpyDatasetLoader, SklearnDatasetLoader)


def test_FilenameDatasetLoader_1():
Expand Down Expand Up @@ -70,6 +71,86 @@ def test_JoblibDatasetLoader_1():
shutil.rmtree(dirname)


def test_HDF5DatasetLoader_1():
from mdtraj import io

assert HDF5DatasetLoader.short_name == 'hdf5'

cwd = os.path.abspath(os.curdir)
dirname = tempfile.mkdtemp()
try:
os.chdir(dirname)

# one file
io.saveh('f1.h5', **{'test': np.zeros((10, 2))})
loader = HDF5DatasetLoader('f1.h5')
X, y = loader.load()
assert np.all(X == np.zeros((10, 2)))
assert y[0] == 0

# two files
io.saveh('f2.h5', **{'test': np.ones((10, 2))})
loader = HDF5DatasetLoader('f*.h5')
X, y = loader.load()
assert isinstance(X, list)
assert np.all(X[0] == np.zeros((10, 2)))
assert np.all(X[1] == np.ones((10, 2)))
assert y[1] == 1

finally:
os.chdir(cwd)
shutil.rmtree(dirname)


def test_MDTrajDatasetLoader_1():
from mdtraj.testing import get_fn

loader = MDTrajDatasetLoader(get_fn('legacy_msmbuilder_trj0.lh5'))
X, y = loader.load()
assert X[0].n_frames == 501
assert y is None


def test_MSMBuilderDatasetLoader_1():
from msmbuilder.dataset import dataset

path = tempfile.mkdtemp()
shutil.rmtree(path)
try:
x = np.random.randn(10, 2)
ds = dataset(path, 'w', 'dir-npy')
ds[0] = x

loader = MSMBuilderDatasetLoader(path, fmt='dir-npy')
X, y = loader.load()

assert np.all(X[0] == x)
assert y is None

finally:
shutil.rmtree(path)


def test_NumpyDatasetLoader_1():
cwd = os.path.abspath(os.curdir)
dirname = tempfile.mkdtemp()
try:
os.chdir(dirname)

x = np.random.randn(10, 2)
np.save('f1.npy', x)

loader = NumpyDatasetLoader('f1.npy')
X, y = loader.load()

assert np.all(X[0] == x)
assert y is None

finally:
os.chdir(cwd)
shutil.rmtree(dirname)


def test_SklearnDatasetLoader_1():
assert SklearnDatasetLoader.short_name == 'sklearn_dataset'
X, y = SklearnDatasetLoader('load_iris').load()
Expand Down
1 change: 1 addition & 0 deletions osprey/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .eval_scopes import import_all_estimators


def dict_merge(base, top):
"""Recursively merge two dictionaries, with the elements from `top`
taking precedence over elements from `top`.
Expand Down

0 comments on commit f04f2a6

Please sign in to comment.