Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ddp example #124

Merged
merged 19 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_cmaes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def test_cmaes_basic(cma_adapter, mpi_tmp_path: pathlib.Path) -> None:
The temporary checkpoint directory.
"""
rng = random.Random(42) # Separate random number generator for optimization.
function, limits = get_function_search_space("sphere")
benchmark_function, limits = get_function_search_space("sphere")
# Set up evolutionary operator.
adapter = cma_adapter
propagator = CMAPropagator(adapter, limits, rng=rng)

# Set up Propulator performing actual optimization.
propulator = Propulator(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
Expand Down
26 changes: 13 additions & 13 deletions tests/test_island.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ def global_variables():
rng = random.Random(
42 + MPI.COMM_WORLD.rank
) # Set up separate random number generator for optimization.
function, limits = get_function_search_space(
benchmark_function, limits = get_function_search_space(
"sphere"
) # Get function and search space to optimize.
propagator = get_default_propagator(
pop_size=4,
limits=limits,
rng=rng,
) # Set up evolutionary operator.
yield rng, function, limits, propagator
yield rng, benchmark_function, limits, propagator


@pytest.fixture(
Expand Down Expand Up @@ -62,12 +62,12 @@ def test_islands(
mpi_tmp_path : pathlib.Path
The temporary checkpoint directory.
"""
rng, function, limits, propagator = global_variables
rng, benchmark_function, limits, propagator = global_variables
set_logger_config(log_file=mpi_tmp_path / "log.log")

# Set up island model.
islands = Islands(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
Expand Down Expand Up @@ -100,12 +100,12 @@ def test_checkpointing_isolated(
mpi_tmp_path : pathlib.Path
The temporary checkpoint directory.
"""
rng, function, limits, propagator = global_variables
rng, benchmark_function, limits, propagator = global_variables
set_logger_config(log_file=mpi_tmp_path / "log.log")

# Set up island model.
islands = Islands(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
Expand All @@ -121,7 +121,7 @@ def test_checkpointing_isolated(
del islands

islands = Islands(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
Expand Down Expand Up @@ -160,12 +160,12 @@ def test_checkpointing(
mpi_tmp_path : pathlib.Path
The temporary checkpoint directory.
"""
rng, function, limits, propagator = global_variables
rng, benchmark_function, limits, propagator = global_variables
set_logger_config(log_file=mpi_tmp_path / "log.log")

# Set up island model.
islands = Islands(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
Expand All @@ -185,7 +185,7 @@ def test_checkpointing(
del islands

islands = Islands(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
Expand Down Expand Up @@ -225,12 +225,12 @@ def test_checkpointing_unequal_populations(
mpi_tmp_path : pathlib.Path
The temporary checkpoint directory.
"""
rng, function, limits, propagator = global_variables
rng, benchmark_function, limits, propagator = global_variables
set_logger_config(log_file=mpi_tmp_path / "log.log")

# Set up island model.
islands = Islands(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
Expand All @@ -251,7 +251,7 @@ def test_checkpointing_unequal_populations(
del islands

islands = Islands(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_propulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def test_propulator(function_name: str, mpi_tmp_path: pathlib.Path) -> None:
rng = random.Random(
42 + MPI.COMM_WORLD.rank
) # Random number generator for optimization
function, limits = get_function_search_space(function_name)
benchmark_function, limits = get_function_search_space(function_name)
set_logger_config(log_file=mpi_tmp_path / "log.log")
propagator = get_default_propagator(
pop_size=4,
limits=limits,
rng=rng,
) # Set up evolutionary operator.
propulator = Propulator(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=100,
Expand All @@ -81,15 +81,15 @@ def test_propulator_checkpointing(mpi_tmp_path: pathlib.Path) -> None:
rng = random.Random(
42 + MPI.COMM_WORLD.rank
) # Separate random number generator for optimization
function, limits = get_function_search_space("sphere")
benchmark_function, limits = get_function_search_space("sphere")

propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=4, # Breeding pool size
limits=limits, # Search-space limits
rng=rng, # Random number generator
)
propulator = Propulator(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
generations=100,
checkpoint_path=mpi_tmp_path,
Expand All @@ -105,7 +105,7 @@ def test_propulator_checkpointing(mpi_tmp_path: pathlib.Path) -> None:
MPI.COMM_WORLD.barrier() # Synchronize all processes.

propulator = Propulator(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
generations=20,
checkpoint_path=mpi_tmp_path,
Expand Down
5 changes: 3 additions & 2 deletions tutorials/cmaes_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Simple example script using CMA-ES."""

import pathlib
import random

Expand Down Expand Up @@ -36,7 +37,7 @@
rng = random.Random(
config.seed + comm.rank
) # Separate random number generator for optimization.
function, limits = get_function_search_space(
benchmark_function, limits = get_function_search_space(
config.function
) # Get callable function + search-space limits.

Expand All @@ -52,7 +53,7 @@

# Set up propulator performing actual optimization.
propulator = Propulator(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
island_comm=comm,
Expand Down
5 changes: 3 additions & 2 deletions tutorials/islands_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Simple island model example script."""

import pathlib
import random

Expand Down Expand Up @@ -31,7 +32,7 @@
rng = random.Random(
config.seed + comm.rank
) # Separate random number generator for optimization.
function, limits = get_function_search_space(
benchmark_function, limits = get_function_search_space(
config.function
) # Get callable function + search-space limits.

Expand All @@ -58,7 +59,7 @@

# Set up island model.
islands = Islands(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
generations=config.generations,
Expand Down
5 changes: 3 additions & 2 deletions tutorials/propulator_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Simple Propulator example script using the default genetic propagator."""

import pathlib
import random

Expand Down Expand Up @@ -35,7 +36,7 @@
rng = random.Random(
config.seed + comm.rank
) # Separate random number generator for optimization.
function, limits = get_function_search_space(
benchmark_function, limits = get_function_search_space(
config.function
) # Get callable function + search-space limits.
# Set up evolutionary operator.
Expand All @@ -50,7 +51,7 @@

# Set up propulator performing actual optimization.
propulator = Propulator(
loss_fn=function,
loss_fn=benchmark_function,
propagator=propagator,
rng=rng,
island_comm=comm,
Expand Down
5 changes: 3 additions & 2 deletions tutorials/pso_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
You can choose between benchmark functions and optimize them. The example shows how to set up Propulate in order to use
it with PSO.
"""

import pathlib
import random

Expand Down Expand Up @@ -47,7 +48,7 @@
rng = random.Random(
config.seed + comm.rank
) # Separate random number generator for optimization.
function, limits = get_function_search_space(
benchmark_function, limits = get_function_search_space(
config.function
) # Get callable function + search-space limits.

Expand Down Expand Up @@ -91,7 +92,7 @@
propagator = Conditional(config.pop_size, pso_propagator, init)

propulator = Propulator(
function,
benchmark_function,
propagator,
rng=rng,
island_comm=comm,
Expand Down
Loading