From 06563950e2b65a7ffbe861dd6e87dd1965de2fb5 Mon Sep 17 00:00:00 2001 From: Marie Weiel Date: Wed, 13 Mar 2024 12:22:32 +0100 Subject: [PATCH] add minimum working example --- tutorials/minimum_working_example.py | 37 ++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tutorials/minimum_working_example.py diff --git a/tutorials/minimum_working_example.py b/tutorials/minimum_working_example.py new file mode 100644 index 00000000..3a961176 --- /dev/null +++ b/tutorials/minimum_working_example.py @@ -0,0 +1,37 @@ +import propulate +from mpi4py import MPI +import random + +# Set the communicator and the optimization parameters +comm = MPI.COMM_WORLD +rng = random.Random(MPI.COMM_WORLD.rank) +population_size = comm.size * 2 +generations = 100 +checkpoint = "./propulate_checkpoints" +propulate.utils.set_logger_config() + + +# Define the function to minimize and the search space, e.g., a 2D sphere function on (-5.12, 5.12)^2. +def loss_fn(params): + """Loss function to minimize.""" + return params["x"] ** 2 + params["y"] ** 2 + + +limits = {"x": (-5.12, 5.12), "y": (-5.12, 5.12)} + +# Initialize the propagator and propulator with default parameters. +propagator = propulate.utils.get_default_propagator( + pop_size=population_size, limits=limits, rng=rng +) +propulator = propulate.Propulator( + loss_fn=loss_fn, + propagator=propagator, + rng=rng, + island_comm=comm, + generations=generations, + checkpoint_path=checkpoint, +) + +# Run optimization and get summary of results. +propulator.propulate() +propulator.summarize()