Skip to content

Commit

Permalink
clean up propulator test
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw92 committed Apr 23, 2024
1 parent 0481142 commit 4a2cf57
Showing 1 changed file with 38 additions and 61 deletions.
99 changes: 38 additions & 61 deletions tests/test_propulator.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,64 @@
import copy
import logging
import pathlib
import random
from typing import Tuple

import deepdiff
import pytest
from mpi4py import MPI

from propulate import Propulator
from propulate.utils import get_default_propagator, set_logger_config
from propulate.utils.benchmark_functions import get_function_search_space


@pytest.fixture(
params=[
("rosenbrock", 0.0, 0.1),
("step", -25.0, 2.0),
("quartic", 0.0, 1000.0),
("rastrigin", 0.0, 1000.0),
("griewank", 0.0, 10000.0),
("schwefel", 0.0, 10000.0),
("bisphere", 0.0, 1000.0),
("birastrigin", 0.0, 1000.0),
("bukin", 0.0, 100.0),
("eggcrate", -1.0, 10.0),
("himmelblau", 0.0, 1.0),
("keane", 0.6736675, 1.0),
("leon", 0.0, 10.0),
("sphere", 0.0, 0.01), # (fname, expected, abs)
]
)
def function_parameters(request):
"""Define benchmark function parameter sets as used in tests."""
return request.param


def test_propulator(function_parameters, mpi_tmp_path) -> None:
def test_propulator(
function_parameters: Tuple[str, float], mpi_tmp_path: pathlib.Path
) -> None:
"""
Test single worker using Propulator to optimize a benchmark function using the default genetic propagator.
Test standard Propulator to optimize the benchmark functions using the default genetic propagator.
This test is run both sequentially and in parallel.
Parameters
----------
function_parameters : tuple
The tuple containing (fname, expected, abs).
function_parameters : Tuple
The tuple containing each function name along with its global minimum.
mpi_tmp_path : pathlib.Path
The temporary checkpoint directory.
"""
fname, expected, abs_tolerance = function_parameters
rng = random.Random(
42 + MPI.COMM_WORLD.rank
) # Separate random number generator for optimization
function, limits = get_function_search_space(fname)
) # Random number generator for optimization
function, limits = get_function_search_space(function_parameters[0])
set_logger_config(
level=logging.INFO,
log_file=mpi_tmp_path / "propulate.log",
log_to_stdout=True,
log_rank=False,
colors=True,
log_file=mpi_tmp_path / "log.log",
)
# Set up evolutionary operator.
propagator = get_default_propagator(
pop_size=4,
limits=limits,
crossover_prob=0.7,
mutation_prob=0.9,
random_init_prob=0.1,
rng=rng,
)

# Set up propulator performing actual optimization.
) # Set up evolutionary operator.
propulator = Propulator(
loss_fn=function,
propagator=propagator,
rng=rng,
generations=100,
checkpoint_path=mpi_tmp_path,
)
) # Set up propulator performing actual optimization.
propulator.propulate() # Run optimization and print summary of results.

# Run optimization and print summary of results.
propulator.propulate()
# assert propulator.summarize(top_n=1, debug=2)[0][0].loss == pytest.approx(
# expected=expected, abs=abs_tolerance
# )

def test_propulator_checkpointing(mpi_tmp_path: pathlib.Path) -> None:
"""
Test standard Propulator checkpointing for the sphere benchmark function.
This test is run both sequentially and in parallel.
def test_propulator_checkpointing(mpi_tmp_path) -> None:
"""Test single worker Propulator checkpointing."""
Parameters
----------
mpi_tmp_path : pathlib.Path
The temporary checkpoint directory.
"""
rng = random.Random(
42 + MPI.COMM_WORLD.rank
) # Separate random number generator for optimization
Expand All @@ -91,9 +67,6 @@ def test_propulator_checkpointing(mpi_tmp_path) -> None:
propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=4, # Breeding pool size
limits=limits, # Search-space limits
crossover_prob=0.7, # Crossover probability
mutation_prob=0.9, # Mutation probability
random_init_prob=0.1, # Random-initialization probability
rng=rng, # Random number generator
)
propulator = Propulator(
Expand All @@ -102,22 +75,26 @@ def test_propulator_checkpointing(mpi_tmp_path) -> None:
generations=1000,
checkpoint_path=mpi_tmp_path,
rng=rng,
)
) # Set up propulator performing actual optimization.

propulator.propulate()
propulator.propulate() # Run optimization and print summary of results.

old_population = copy.deepcopy(propulator.population)
del propulator
MPI.COMM_WORLD.barrier()
old_population = copy.deepcopy(
propulator.population
) # Save population list from the last run.
del propulator # Delete propulator object.
MPI.COMM_WORLD.barrier() # Synchronize all processes.

propulator = Propulator(
loss_fn=function,
propagator=propagator,
generations=20,
checkpoint_path=mpi_tmp_path,
rng=rng,
)
) # Set up new propulator starting from checkpoint.

# As the number of requested generations is smaller than the number of generations from the run before,
# no new evaluations are performed. Thus, the length of both Propulators' populations must be equal.
assert (
len(deepdiff.DeepDiff(old_population, propulator.population, ignore_order=True))
== 0
Expand Down

0 comments on commit 4a2cf57

Please sign in to comment.