Skip to content

Commit

Permalink
refactor argument parser in separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw92 committed Mar 13, 2024
1 parent 8b95599 commit caff589
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 294 deletions.
51 changes: 7 additions & 44 deletions tutorials/cmaes_example.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Simple example script using CMA-ES"""
import random
import argparse
import logging
import pathlib

from mpi4py import MPI

from propulate import Propulator
from propulate.propagators import BasicCMA, ActiveCMA, CMAPropagator
from propulate.utils import set_logger_config
from function_benchmark import get_function_search_space
from function_benchmark import get_function_search_space, parse_arguments


if __name__ == "__main__":
Expand All @@ -21,49 +20,11 @@
"#################################################\n"
)

parser = argparse.ArgumentParser(
prog="Simple Propulator example",
description="Set up and run a basic Propulator optimization of mathematical functions.",
)
parser.add_argument( # Function to optimize
"--function",
type=str,
choices=[
"bukin",
"eggcrate",
"himmelblau",
"keane",
"leon",
"rastrigin",
"schwefel",
"sphere",
"step",
"rosenbrock",
"quartic",
"bisphere",
"birastrigin",
"griewank",
],
default="sphere",
)
parser.add_argument(
"--generations", type=int, default=1000
) # Number of generations
parser.add_argument(
"--seed", type=int, default=0
) # Seed for Propulate random number generator
parser.add_argument("--adapter", type=str, default="basic")
parser.add_argument("--verbosity", type=int, default=1) # Verbosity level
parser.add_argument(
"--checkpoint", type=str, default="./"
) # Path for loading and writing checkpoints.
parser.add_argument("--top_n", type=int, default=1)
parser.add_argument("--logging_int", type=int, default=10)
config = parser.parse_args()
config, _ = parse_arguments(comm)

# Set up separate logger for Propulate optimization.
set_logger_config(
level=logging.INFO, # Logging level
level=config.logging_level, # Logging level
log_file=f"{config.checkpoint}/{pathlib.Path(__file__).stem}.log", # Logging path
log_to_stdout=True, # Print log on stdout.
log_rank=False, # Do not prepend MPI rank to logging messages.
Expand Down Expand Up @@ -98,5 +59,7 @@
)

# Run optimization and print summary of results.
propulator.propulate(logging_interval=config.logging_int, debug=config.verbosity)
propulator.propulate(
logging_interval=config.logging_interval, debug=config.verbosity
)
propulator.summarize(top_n=config.top_n, debug=config.verbosity)
128 changes: 128 additions & 0 deletions tutorials/function_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Benchmark function module."""
import argparse
import logging
from typing import Callable, Dict, Tuple

from mpi4py import MPI
import numpy as np


Expand Down Expand Up @@ -597,3 +601,127 @@ def get_function_search_space(
ValueError(f"Function {fname} undefined...exiting")

return function, limits


def parse_arguments(
propulate_comm: MPI.Comm = MPI.COMM_WORLD,
) -> Tuple[argparse.Namespace, Dict[str, bool]]:
"""
Set up argument parser for Propulate optimization of simple mathematical functions.
Parameters
----------
propulate_comm : MPI.Comm, optional
The communicator used to run the Propulate optimization. Default is ``MPI.COMM_WORLD``.
Returns
-------
Namespace
The namespace of all parsed arguments.
Dict[str, bool]
A dictionary logging if one of the PSO hyperparameters was actually set. Only relevant for PSO.
"""
parser = argparse.ArgumentParser(
prog="Simple Propulator example",
description="Set up and run a basic Propulator optimization of mathematical functions.",
)
parser.add_argument( # Function to optimize
"--function",
type=str,
choices=[
"bukin",
"eggcrate",
"himmelblau",
"keane",
"leon",
"rastrigin",
"schwefel",
"sphere",
"step",
"rosenbrock",
"quartic",
"bisphere",
"birastrigin",
"griewank",
],
default="sphere",
)
parser.add_argument(
"--generations", type=int, default=1000
) # Number of generations
parser.add_argument(
"--seed", type=int, default=0
) # Seed for Propulate random number generator
parser.add_argument("--verbosity", type=int, default=1) # Verbosity level
parser.add_argument(
"--checkpoint", type=str, default="./"
) # Path for loading and writing checkpoints.
parser.add_argument(
"--pop_size", type=int, default=2 * propulate_comm.size
) # Breeding pool size
parser.add_argument(
"--crossover_probability", type=float, default=0.7
) # Crossover probability
parser.add_argument(
"--mutation_probability", type=float, default=0.4
) # Mutation probability
parser.add_argument("--random_init_probability", type=float, default=0.1)
parser.add_argument("--top_n", type=int, default=1)
parser.add_argument("--logging_interval", type=int, default=10)
parser.add_argument("--logging_level", type=int, default=logging.INFO)

# -------- Island-model specific arguments (ignored if not needed) --------
parser.add_argument(
"--num_islands", type=int, default=2
) # Number of separate evolutionary islands
parser.add_argument(
"--migration_probability", type=float, default=0.9
) # Migration probability
parser.add_argument("--num_migrants", type=int, default=1)
parser.add_argument("--pollination", action="store_true")

# -------- PSO-specific arguments (ignored if not needed) --------
parser.add_argument(
"--variant",
type=str,
choices=["Basic", "VelocityClamping", "Constriction", "Canonical"],
default="Basic",
) # PSO variant to run
hp_set: Dict[str, bool] = {
"inertia": False,
"cognitive": False,
"social": False,
}

class ParamSettingCatcher(argparse.Action):
"""
This class extends ``argparse``'s ``Action`` class in order to allow for an action that logs if one of the PSO
hyperparameters was actually set.
"""

def __call__(self, parser, namespace, values, option_string=None):
hp_set[self.dest] = True
super().__call__(parser, namespace, values, option_string)

parser.add_argument(
"--inertia", type=float, default=0.729, action=ParamSettingCatcher
) # Inertia weight
parser.add_argument(
"--cognitive", type=float, default=1.49445, action=ParamSettingCatcher
) # Cognitive factor
parser.add_argument(
"--social", type=float, default=1.49445, action=ParamSettingCatcher
) # Social factor
parser.add_argument(
"--clamping_factor", type=float, default=0.6
) # Velocity clamping factor

# -------- CMA-ES specific arguments (ignored if not needed)
parser.add_argument("--adapter", type=str, default="basic")

# -------- Multi-rank worker specific arguments (ignored if not needed)
parser.add_argument(
"--ranks_per_worker", type=int, default=2
) # Number of sub ranks that each worker will use

return parser.parse_args(), hp_set
76 changes: 10 additions & 66 deletions tutorials/islands_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
import argparse
import logging
"""Simple island model example script."""
import pathlib
import random

import numpy as np
Expand All @@ -9,75 +8,18 @@
from propulate import Islands
from propulate.propagators import SelectMin, SelectMax
from propulate.utils import get_default_propagator, set_logger_config
from function_benchmark import get_function_search_space
from function_benchmark import get_function_search_space, parse_arguments

if __name__ == "__main__":
comm = MPI.COMM_WORLD

parser = argparse.ArgumentParser(
prog="Simple Propulator example",
description="Set up and run a basic Propulator optimization of mathematical functions.",
)
parser.add_argument( # Function to optimize
"-f",
"--function",
type=str,
choices=[
"bukin",
"eggcrate",
"himmelblau",
"keane",
"leon",
"rastrigin",
"schwefel",
"sphere",
"step",
"rosenbrock",
"quartic",
"bisphere",
"birastrigin",
"griewank",
],
default="sphere",
)
parser.add_argument(
"-g", "--generations", type=int, default=1000
) # Number of generations
parser.add_argument(
"-s", "--seed", type=int, default=0
) # Seed for Propulate random number generator
parser.add_argument("-v", "--verbosity", type=int, default=1) # Verbosity level
parser.add_argument(
"-ckpt", "--checkpoint", type=str, default="./"
) # Path for loading and writing checkpoints.
parser.add_argument(
"-p", "--pop_size", type=int, default=2 * comm.size
) # Breeding pool size
parser.add_argument(
"-cp", "--crossover_probability", type=float, default=0.7
) # Crossover probability
parser.add_argument(
"-mp", "--mutation_probability", type=float, default=0.4
) # Mutation probability
parser.add_argument("-rp", "--random_init_probability", type=float, default=0.1)
parser.add_argument(
"-i", "--num_islands", type=int, default=2
) # Number of separate evolutionary islands
parser.add_argument(
"-migp", "--migration_probability", type=float, default=0.9
) # Migration probability
parser.add_argument("-m", "--num_migrants", type=int, default=1)
parser.add_argument("-pln", "--pollination", action="store_true")
parser.add_argument(
"-t", "--top_n", type=int, default=1
) # Print top-n best individuals on each island in summary.
parser.add_argument("-l", "--logging_int", type=int, default=10) # Logging interval
config = parser.parse_args()
# Parse command-line arguments.
config, _ = parse_arguments(comm)

# Set up separate logger for Propulate optimization.
set_logger_config(
level=logging.INFO, # logging level
log_file=f"{config.checkpoint}/islands.log", # logging path
level=config.logging_level, # Logging level
log_file=f"{config.checkpoint}/{pathlib.Path(__file__).stem}.log", # Logging path
log_to_stdout=True, # Print log on stdout.
log_rank=False, # Do not prepend MPI rank to logging messages.
colors=True, # Use colors.
Expand Down Expand Up @@ -128,5 +70,7 @@

# Run actual optimization.
islands.evolve(
top_n=config.top_n, logging_interval=config.logging_int, debug=config.verbosity
top_n=config.top_n,
logging_interval=config.logging_interval,
debug=config.verbosity,
)
Loading

0 comments on commit caff589

Please sign in to comment.