diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml
index 55ffd4e1..4e571c73 100644
--- a/.github/workflows/python-test.yml
+++ b/.github/workflows/python-test.yml
@@ -18,8 +18,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install ."[dev]"
- pip install genbadge"[coverage]"
+ pip install ."[test]"
- name: Lint with ruff
run: |
# Stop the build if there are Python syntax errors or undefined names.
diff --git a/coverage.svg b/coverage.svg
index a62cc51c..ef04c8f7 100644
--- a/coverage.svg
+++ b/coverage.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index c4852d2d..70553c77 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,6 +47,15 @@ dev = [
"torchmetrics",
]
+test = [
+ "coverage",
+ "genbadge[coverage]",
+ "ruff",
+ "pytest",
+ "pytest-cov",
+ "pytest-mpi",
+]
+
tutorials = [
"torch",
"torchvision",
diff --git a/tests/test_cmaes.py b/tests/test_cmaes.py
index b227add3..8eef4e70 100644
--- a/tests/test_cmaes.py
+++ b/tests/test_cmaes.py
@@ -38,7 +38,7 @@ def test_cmaes_basic(cma_adapter: CMAAdapter, mpi_tmp_path: pathlib.Path) -> Non
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
checkpoint_path=mpi_tmp_path,
)
# Run optimization and print summary of results.
diff --git a/tests/test_island.py b/tests/test_island.py
index 720c711e..5639a9a4 100644
--- a/tests/test_island.py
+++ b/tests/test_island.py
@@ -70,7 +70,7 @@ def test_islands(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
num_islands=2,
migration_probability=0.9,
pollination=pollination,
@@ -109,7 +109,7 @@ def test_checkpointing_isolated(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
num_islands=2,
migration_probability=0.0,
checkpoint_path=mpi_tmp_path,
@@ -126,7 +126,7 @@ def test_checkpointing_isolated(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
num_islands=2,
migration_probability=0.0,
checkpoint_path=mpi_tmp_path,
@@ -170,7 +170,7 @@ def test_checkpointing(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
num_islands=2,
migration_probability=0.9,
pollination=pollination,
@@ -188,7 +188,7 @@ def test_checkpointing(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
num_islands=2,
migration_probability=0.9,
pollination=pollination,
@@ -233,7 +233,7 @@ def test_checkpointing_unequal_populations(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
num_islands=2,
island_sizes=np.array([3, 5]),
migration_probability=0.9,
@@ -252,7 +252,7 @@ def test_checkpointing_unequal_populations(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
num_islands=2,
island_sizes=np.array([3, 5]),
migration_probability=0.9,
diff --git a/tests/test_multi_rank_workers.py b/tests/test_multi_rank_workers.py
index 7bb99aa0..ca7b524c 100644
--- a/tests/test_multi_rank_workers.py
+++ b/tests/test_multi_rank_workers.py
@@ -69,7 +69,7 @@ def test_multi_rank_workers(mpi_tmp_path: pathlib.Path) -> None:
loss_fn=parallel_sphere, # Loss function to be minimized
propagator=propagator, # Propagator, i.e., evolutionary operator to be used
rng=rng, # Separate random number generator for Propulate optimization
- generations=100, # Overall number of generations
+ generations=10, # Overall number of generations
num_islands=2, # Number of islands
migration_probability=0.9, # Migration probability
pollination=False, # Whether to use pollination or migration
diff --git a/tests/test_nm.py b/tests/test_nm.py
index e8336912..7a3def00 100644
--- a/tests/test_nm.py
+++ b/tests/test_nm.py
@@ -61,7 +61,7 @@ def test_cmaes(
loss_fn=function,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
checkpoint_path=mpi_tmp_path,
)
# Run optimization and print summary of results.
diff --git a/tests/test_propulator.py b/tests/test_propulator.py
index fc9f6ec3..16c4241e 100644
--- a/tests/test_propulator.py
+++ b/tests/test_propulator.py
@@ -61,7 +61,7 @@ def test_propulator(function_name: str, mpi_tmp_path: pathlib.Path) -> None:
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
checkpoint_path=mpi_tmp_path,
) # Set up propulator performing actual optimization.
propulator.propulate() # Run optimization and print summary of results.
@@ -91,7 +91,7 @@ def test_propulator_checkpointing(mpi_tmp_path: pathlib.Path) -> None:
propulator = Propulator(
loss_fn=benchmark_function,
propagator=propagator,
- generations=100,
+ generations=10,
checkpoint_path=mpi_tmp_path,
rng=rng,
) # Set up propulator performing actual optimization.
@@ -107,7 +107,7 @@ def test_propulator_checkpointing(mpi_tmp_path: pathlib.Path) -> None:
propulator = Propulator(
loss_fn=benchmark_function,
propagator=propagator,
- generations=20,
+ generations=5,
checkpoint_path=mpi_tmp_path,
rng=rng,
) # Set up new propulator starting from checkpoint.
diff --git a/tests/test_pso.py b/tests/test_pso.py
index 6a10adec..f6fb5e3c 100644
--- a/tests/test_pso.py
+++ b/tests/test_pso.py
@@ -77,7 +77,7 @@ def test_pso(pso_propagator: Propagator, mpi_tmp_path: pathlib.Path) -> None:
loss_fn=sphere,
propagator=propagator,
rng=rng,
- generations=100,
+ generations=10,
checkpoint_path=mpi_tmp_path,
)
diff --git a/tests/test_surrogate.py b/tests/test_surrogate.py
index 7df0136d..80bc0578 100644
--- a/tests/test_surrogate.py
+++ b/tests/test_surrogate.py
@@ -10,6 +10,11 @@
from propulate import Islands, Propulator, surrogate
from propulate.utils import get_default_propagator, set_logger_config
+pytestmark = pytest.mark.filterwarnings(
+ "ignore::DeprecationWarning",
+ match="Assigning the 'data' attribute is an inherently unsafe operation and will be removed in the future.",
+)
+
log = logging.getLogger(__name__) # Get logger instance.
set_logger_config(level=logging.DEBUG)
@@ -51,7 +56,7 @@ def test_static(mpi_tmp_path: Path) -> None:
rng = random.Random(
MPI.COMM_WORLD.rank + 100
) # Set up separate random number generator for evolutionary optimizer.
- num_generations = 50
+ num_generations = 4
propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=pop_size, # Breeding population size
@@ -74,6 +79,7 @@ def test_static(mpi_tmp_path: Path) -> None:
MPI.COMM_WORLD.barrier()
+@pytest.mark.mpi(min_size=8)
def test_static_island(mpi_tmp_path: Path) -> None:
"""Test static surrogate using a dummy function."""
pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size
@@ -84,7 +90,7 @@ def test_static_island(mpi_tmp_path: Path) -> None:
rng = random.Random(
MPI.COMM_WORLD.rank + 100
) # Set up separate random number generator for evolutionary optimizer.
- num_generations = 50
+ num_generations = 4
propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=pop_size, # Breeding population size
@@ -111,13 +117,47 @@ def test_static_island(mpi_tmp_path: Path) -> None:
MPI.COMM_WORLD.barrier()
+def test_dynamic(mpi_tmp_path: Path) -> None:
+ """Test static surrogate using a dummy function."""
+ pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size
+ limits: Dict[str, Union[Tuple[int, int], Tuple[float, float], Tuple[str, ...]]] = {
+ "start": (0.1, 7.0),
+ "limit": (-1.0, 1.0),
+ } # Define search space.
+ rng = random.Random(
+ MPI.COMM_WORLD.rank + 100
+ ) # Set up separate random number generator for evolutionary optimizer.
+ num_generations = 4
+
+ propagator = get_default_propagator( # Get default evolutionary operator.
+ pop_size=pop_size, # Breeding population size
+ limits=limits, # Search space
+ crossover_prob=0.7, # Crossover probability
+ mutation_prob=0.4, # Mutation probability
+ random_init_prob=0.1, # Random-initialization probability
+ rng=rng, # Random number generator for evolutionary optimizer
+ )
+ propulator = Propulator(
+ loss_fn=ind_loss,
+ propagator=propagator,
+ generations=num_generations,
+ checkpoint_path=mpi_tmp_path,
+ rng=rng,
+ surrogate_factory=lambda: surrogate.DynamicSurrogate(limits),
+ ) # Set up propulator performing actual optimization.
+
+ propulator.propulate() # Run optimization and print summary of results.
+ MPI.COMM_WORLD.barrier()
+
+
+@pytest.mark.mpi(min_size=8)
@pytest.mark.filterwarnings(
"ignore::DeprecationWarning",
match="Assigning the 'data' attribute is an inherently unsafe operation and will be removed in the future.",
)
-def test_dynamic(mpi_tmp_path: Path) -> None:
+def test_dynamic_island(mpi_tmp_path: Path) -> None:
"""Test dynamic surrogate using a dummy function."""
- num_generations = 10 # Number of generations
+ num_generations = 4 # Number of generations
pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size
limits: Dict[str, Union[Tuple[int, int], Tuple[float, float], Tuple[str, ...]]] = {
"start": (0.1, 7.0),
@@ -137,7 +177,7 @@ def test_dynamic(mpi_tmp_path: Path) -> None:
propagator=propagator, # Evolutionary operator
rng=rng, # Random number generator
generations=num_generations, # Number of generations per worker
- num_islands=1, # Number of islands
+ num_islands=2, # Number of islands
checkpoint_path=mpi_tmp_path,
surrogate_factory=lambda: surrogate.DynamicSurrogate(limits),
)