Skip to content

Commit

Permalink
Merge pull request #150 from Helmholtz-AI-Energy/hotfix/test_surrogate
Browse files Browse the repository at this point in the history
Add missing decorator and island test for dynamic surrogate
  • Loading branch information
oskar-taubert authored Aug 15, 2024
2 parents 34f232f + 2e742d9 commit 98ab80d
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 22 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ."[dev]"
pip install genbadge"[coverage]"
pip install ."[test]"
- name: Lint with ruff
run: |
# Stop the build if there are Python syntax errors or undefined names.
Expand Down
2 changes: 1 addition & 1 deletion coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ dev = [
"torchmetrics",
]

test = [
"coverage",
"genbadge[coverage]",
"ruff",
"pytest",
"pytest-cov",
"pytest-mpi",
]

tutorials = [
"torch",
"torchvision",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cmaes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_cmaes_basic(cma_adapter: CMAAdapter, mpi_tmp_path: pathlib.Path) -> Non
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
checkpoint_path=mpi_tmp_path,
)
# Run optimization and print summary of results.
Expand Down
14 changes: 7 additions & 7 deletions tests/test_island.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_islands(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
num_islands=2,
migration_probability=0.9,
pollination=pollination,
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_checkpointing_isolated(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
num_islands=2,
migration_probability=0.0,
checkpoint_path=mpi_tmp_path,
Expand All @@ -126,7 +126,7 @@ def test_checkpointing_isolated(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
num_islands=2,
migration_probability=0.0,
checkpoint_path=mpi_tmp_path,
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_checkpointing(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
num_islands=2,
migration_probability=0.9,
pollination=pollination,
Expand All @@ -188,7 +188,7 @@ def test_checkpointing(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
num_islands=2,
migration_probability=0.9,
pollination=pollination,
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_checkpointing_unequal_populations(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
num_islands=2,
island_sizes=np.array([3, 5]),
migration_probability=0.9,
Expand All @@ -252,7 +252,7 @@ def test_checkpointing_unequal_populations(
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
num_islands=2,
island_sizes=np.array([3, 5]),
migration_probability=0.9,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_multi_rank_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_multi_rank_workers(mpi_tmp_path: pathlib.Path) -> None:
loss_fn=parallel_sphere, # Loss function to be minimized
propagator=propagator, # Propagator, i.e., evolutionary operator to be used
rng=rng, # Separate random number generator for Propulate optimization
generations=100, # Overall number of generations
generations=10, # Overall number of generations
num_islands=2, # Number of islands
migration_probability=0.9, # Migration probability
pollination=False, # Whether to use pollination or migration
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_cmaes(
loss_fn=function,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
checkpoint_path=mpi_tmp_path,
)
# Run optimization and print summary of results.
Expand Down
6 changes: 3 additions & 3 deletions tests/test_propulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_propulator(function_name: str, mpi_tmp_path: pathlib.Path) -> None:
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
checkpoint_path=mpi_tmp_path,
) # Set up propulator performing actual optimization.
propulator.propulate() # Run optimization and print summary of results.
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_propulator_checkpointing(mpi_tmp_path: pathlib.Path) -> None:
propulator = Propulator(
loss_fn=benchmark_function,
propagator=propagator,
generations=100,
generations=10,
checkpoint_path=mpi_tmp_path,
rng=rng,
) # Set up propulator performing actual optimization.
Expand All @@ -107,7 +107,7 @@ def test_propulator_checkpointing(mpi_tmp_path: pathlib.Path) -> None:
propulator = Propulator(
loss_fn=benchmark_function,
propagator=propagator,
generations=20,
generations=5,
checkpoint_path=mpi_tmp_path,
rng=rng,
) # Set up new propulator starting from checkpoint.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_pso(pso_propagator: Propagator, mpi_tmp_path: pathlib.Path) -> None:
loss_fn=sphere,
propagator=propagator,
rng=rng,
generations=100,
generations=10,
checkpoint_path=mpi_tmp_path,
)

Expand Down
50 changes: 45 additions & 5 deletions tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from propulate import Islands, Propulator, surrogate
from propulate.utils import get_default_propagator, set_logger_config

pytestmark = pytest.mark.filterwarnings(
"ignore::DeprecationWarning",
match="Assigning the 'data' attribute is an inherently unsafe operation and will be removed in the future.",
)

log = logging.getLogger(__name__) # Get logger instance.
set_logger_config(level=logging.DEBUG)

Expand Down Expand Up @@ -51,7 +56,7 @@ def test_static(mpi_tmp_path: Path) -> None:
rng = random.Random(
MPI.COMM_WORLD.rank + 100
) # Set up separate random number generator for evolutionary optimizer.
num_generations = 50
num_generations = 4

propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=pop_size, # Breeding population size
Expand All @@ -74,6 +79,7 @@ def test_static(mpi_tmp_path: Path) -> None:
MPI.COMM_WORLD.barrier()


@pytest.mark.mpi(min_size=8)
def test_static_island(mpi_tmp_path: Path) -> None:
"""Test static surrogate using a dummy function."""
pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size
Expand All @@ -84,7 +90,7 @@ def test_static_island(mpi_tmp_path: Path) -> None:
rng = random.Random(
MPI.COMM_WORLD.rank + 100
) # Set up separate random number generator for evolutionary optimizer.
num_generations = 50
num_generations = 4

propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=pop_size, # Breeding population size
Expand All @@ -111,13 +117,47 @@ def test_static_island(mpi_tmp_path: Path) -> None:
MPI.COMM_WORLD.barrier()


def test_dynamic(mpi_tmp_path: Path) -> None:
"""Test static surrogate using a dummy function."""
pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size
limits: Dict[str, Union[Tuple[int, int], Tuple[float, float], Tuple[str, ...]]] = {
"start": (0.1, 7.0),
"limit": (-1.0, 1.0),
} # Define search space.
rng = random.Random(
MPI.COMM_WORLD.rank + 100
) # Set up separate random number generator for evolutionary optimizer.
num_generations = 4

propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=pop_size, # Breeding population size
limits=limits, # Search space
crossover_prob=0.7, # Crossover probability
mutation_prob=0.4, # Mutation probability
random_init_prob=0.1, # Random-initialization probability
rng=rng, # Random number generator for evolutionary optimizer
)
propulator = Propulator(
loss_fn=ind_loss,
propagator=propagator,
generations=num_generations,
checkpoint_path=mpi_tmp_path,
rng=rng,
surrogate_factory=lambda: surrogate.DynamicSurrogate(limits),
) # Set up propulator performing actual optimization.

propulator.propulate() # Run optimization and print summary of results.
MPI.COMM_WORLD.barrier()


@pytest.mark.mpi(min_size=8)
@pytest.mark.filterwarnings(
"ignore::DeprecationWarning",
match="Assigning the 'data' attribute is an inherently unsafe operation and will be removed in the future.",
)
def test_dynamic(mpi_tmp_path: Path) -> None:
def test_dynamic_island(mpi_tmp_path: Path) -> None:
"""Test dynamic surrogate using a dummy function."""
num_generations = 10 # Number of generations
num_generations = 4 # Number of generations
pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size
limits: Dict[str, Union[Tuple[int, int], Tuple[float, float], Tuple[str, ...]]] = {
"start": (0.1, 7.0),
Expand All @@ -137,7 +177,7 @@ def test_dynamic(mpi_tmp_path: Path) -> None:
propagator=propagator, # Evolutionary operator
rng=rng, # Random number generator
generations=num_generations, # Number of generations per worker
num_islands=1, # Number of islands
num_islands=2, # Number of islands
checkpoint_path=mpi_tmp_path,
surrogate_factory=lambda: surrogate.DynamicSurrogate(limits),
)
Expand Down

0 comments on commit 98ab80d

Please sign in to comment.