Skip to content

Commit

Permalink
add missing decorator and island test for dynamic surrogate
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw92 committed Aug 15, 2024
1 parent 34f232f commit 8d0a5f4
Showing 1 changed file with 45 additions and 5 deletions.
50 changes: 45 additions & 5 deletions tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from propulate import Islands, Propulator, surrogate
from propulate.utils import get_default_propagator, set_logger_config

pytestmark = pytest.mark.filterwarnings(
"ignore::DeprecationWarning",
match="Assigning the 'data' attribute is an inherently unsafe operation and will be removed in the future.",
)

log = logging.getLogger(__name__) # Get logger instance.
set_logger_config(level=logging.DEBUG)

Expand Down Expand Up @@ -51,7 +56,7 @@ def test_static(mpi_tmp_path: Path) -> None:
rng = random.Random(
MPI.COMM_WORLD.rank + 100
) # Set up separate random number generator for evolutionary optimizer.
num_generations = 50
num_generations = 4

propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=pop_size, # Breeding population size
Expand All @@ -74,6 +79,7 @@ def test_static(mpi_tmp_path: Path) -> None:
MPI.COMM_WORLD.barrier()


@pytest.mark.mpi(min_size=8)
def test_static_island(mpi_tmp_path: Path) -> None:
"""Test static surrogate using a dummy function."""
pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size
Expand All @@ -84,7 +90,7 @@ def test_static_island(mpi_tmp_path: Path) -> None:
rng = random.Random(
MPI.COMM_WORLD.rank + 100
) # Set up separate random number generator for evolutionary optimizer.
num_generations = 50
num_generations = 4

propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=pop_size, # Breeding population size
Expand All @@ -111,13 +117,47 @@ def test_static_island(mpi_tmp_path: Path) -> None:
MPI.COMM_WORLD.barrier()


def test_dynamic(mpi_tmp_path: Path) -> None:
"""Test static surrogate using a dummy function."""
pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size
limits: Dict[str, Union[Tuple[int, int], Tuple[float, float], Tuple[str, ...]]] = {
"start": (0.1, 7.0),
"limit": (-1.0, 1.0),
} # Define search space.
rng = random.Random(
MPI.COMM_WORLD.rank + 100
) # Set up separate random number generator for evolutionary optimizer.
num_generations = 4

propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=pop_size, # Breeding population size
limits=limits, # Search space
crossover_prob=0.7, # Crossover probability
mutation_prob=0.4, # Mutation probability
random_init_prob=0.1, # Random-initialization probability
rng=rng, # Random number generator for evolutionary optimizer
)
propulator = Propulator(
loss_fn=ind_loss,
propagator=propagator,
generations=num_generations,
checkpoint_path=mpi_tmp_path,
rng=rng,
surrogate_factory=lambda: surrogate.DynamicSurrogate(limits),
) # Set up propulator performing actual optimization.

propulator.propulate() # Run optimization and print summary of results.
MPI.COMM_WORLD.barrier()


# @pytest.mark.mpi(min_size=8)
@pytest.mark.filterwarnings(
"ignore::DeprecationWarning",
match="Assigning the 'data' attribute is an inherently unsafe operation and will be removed in the future.",
)
def test_dynamic(mpi_tmp_path: Path) -> None:
def test_dynamic_island(mpi_tmp_path: Path) -> None:
"""Test dynamic surrogate using a dummy function."""
num_generations = 10 # Number of generations
num_generations = 4 # Number of generations
pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size
limits: Dict[str, Union[Tuple[int, int], Tuple[float, float], Tuple[str, ...]]] = {
"start": (0.1, 7.0),
Expand All @@ -137,7 +177,7 @@ def test_dynamic(mpi_tmp_path: Path) -> None:
propagator=propagator, # Evolutionary operator
rng=rng, # Random number generator
generations=num_generations, # Number of generations per worker
num_islands=1, # Number of islands
num_islands=2, # Number of islands
checkpoint_path=mpi_tmp_path,
surrogate_factory=lambda: surrogate.DynamicSurrogate(limits),
)
Expand Down

0 comments on commit 8d0a5f4

Please sign in to comment.