diff --git a/tests/test_ensemble_boot.py b/tests/test_ensemble_boot.py index b1592eb9..b656420a 100644 --- a/tests/test_ensemble_boot.py +++ b/tests/test_ensemble_boot.py @@ -1,5 +1,5 @@ from easyvvuq.analysis.ensemble_boot import confidence_interval, bootstrap -from easyvvuq.analysis.ensemble_boot import ensemble_bootstrap, EnsembleBoot +from easyvvuq.analysis.ensemble_boot import ensemble_bootstrap, EnsembleBoot, EnsembleBootMultiple import os import numpy as np import pandas as pd @@ -82,3 +82,19 @@ def test_ensemble_boot(): 'b': ['group1'] * VALUES.shape[0] + ['group2'] * VALUES.shape[0]}) results = analysis.analyse(df) assert (not results.empty) + +def test_ensemble_boot_multiple(): + analysis = EnsembleBootMultiple() + assert (analysis.element_name() == 'ensemble_boot_multiple') + assert (analysis.element_version() == '0.1') + with pytest.raises(RuntimeError): + analysis.analyse() + with pytest.raises(RuntimeError): + analysis.analyse(pd.DataFrame({})) + analysis = EnsembleBootMultiple(groupby=['b'], qoi_cols=['a'], stat_func=[np.mean, np.var, np.median]) + df = pd.DataFrame({ + 'a': np.concatenate((VALUES, VALUES)), + 'b': ['group1'] * VALUES.shape[0] + ['group2'] * VALUES.shape[0]}) + results = analysis.analyse(df) + assert (not results.empty) + assert (results.values.shape == (2, 9))