From 979b8b81ad15401c8b9dce6bfdf768c13fcd1bbb Mon Sep 17 00:00:00 2001 From: Oskar Taubert Date: Wed, 15 May 2024 02:16:08 +0200 Subject: [PATCH] fixed mistake in checkpoint reading, fixed comm, used individual island origin, implemented active_on_island --- propulate/propulator.py | 110 ++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 49 deletions(-) diff --git a/propulate/propulator.py b/propulate/propulator.py index 1e505938..cdd593de 100644 --- a/propulate/propulator.py +++ b/propulate/propulator.py @@ -167,7 +167,10 @@ def __init__( self.island_displs = ( island_displs # Propulate world rank of each island's worker ) - self.island_counts = island_counts # Number of workers on each island + if island_counts is None: + self.island_counts = np.array([self.island_comm.Get_size()]) + else: + self.island_counts = island_counts # Number of workers on each island self.emigration_propagator = emigration_propagator # Emigration propagator self.rng = rng @@ -197,51 +200,57 @@ def load_checkpoint(self): # TODO check that the island and worker setup is the same as in the checkpoint # NOTE each individual is only stored once at the position given by its origin island and worker, the modifications have to be put in the checkpoint file during migration TODO test if this works as intended reliably # TODO get the started but not yet completed ones from the difference in start time and evaltime + # TODO only load an incomplete one, if you're then going to evaluate it with h5py.File(self.checkpoint_path, "r", driver=None) as f: - islandgroup = f[f"{self.island_idx}"] - # NOTE check limits are consistent limitsgroup = f["limits"] - for key in self.propagator.limits: - if set(limitsgroup.attrs.keys()) != set(self.propagator.limits): - raise RuntimeError("Limits inconsistent with checkpoint") + if set(limitsgroup.attrs.keys()) != set(self.propagator.limits): + raise RuntimeError("Limits inconsistent with checkpoint") # TODO check island sizes are consistent self.generation = int(f["generations"][self.propulate_comm.Get_rank()]) if ( - islandgroup[f"{self.island_comm.Get_rank()}"]["evalperiod"][ + f[f"{self.island_idx}"][f"{self.island_comm.Get_rank()}"]["evalperiod"][ self.generation ] > 0.0 ): self.generation += 1 # NOTE load individuals, since they might have migrated, every worker has to check each dataset - for rank in range(self.propulate_comm.size): - # for generation in range(len(islandgroup[f"{rank}"])): - for generation in range(f["generations"][rank] + 1): - if islandgroup[f"{rank}"]["current"][generation] == self.island_idx: - ind = Individual( - islandgroup[f"{rank}"]["x"][generation, 0], - self.propagator.limits, - ) - ind.rank = rank - ind.island = self.island_idx - ind.current = islandgroup[f"{rank}"]["current"][generation] - # TODO velocity loading - # if len(group[f"{rank}"].shape) > 1: - # ind.velocity = islandgroup[f"{rank}"]["x"][generation, 1] - ind.loss = islandgroup[f"{rank}"]["loss"][generation] - ind.startime = islandgroup[f"{rank}"]["starttime"][generation] - ind.evaltime = islandgroup[f"{rank}"]["evaltime"][generation] - ind.evalperiod = islandgroup[f"{rank}"]["evalperiod"][ - generation - ] - ind.generation = generation - self.population.append(ind) - if ind.loss is None: - # TODO resume evaluation - raise + num_islands = len(self.island_counts) + for i in range(num_islands): + islandgroup = f[f"{i}"] + for rank in range(self.island_counts[i]): + for generation in range(f["generations"][rank] + 1): + if islandgroup[f"{rank}"]["active_on_island"][generation][ + self.island_idx + ]: + ind = Individual( + islandgroup[f"{rank}"]["x"][generation, 0], + self.propagator.limits, + ) + # TODO check what rank was used for, i think it was the rank in the propulator_comm + ind.rank = rank + ind.island = self.island_idx + ind.current = islandgroup[f"{rank}"]["current"][generation] + # TODO velocity loading + # if len(group[f"{rank}"].shape) > 1: + # ind.velocity = islandgroup[f"{rank}"]["x"][generation, 1] + ind.loss = islandgroup[f"{rank}"]["loss"][generation] + # ind.startime = islandgroup[f"{rank}"]["starttime"][generation] + ind.evaltime = islandgroup[f"{rank}"]["evaltime"][ + generation + ] + ind.evalperiod = islandgroup[f"{rank}"]["evalperiod"][ + generation + ] + ind.generation = generation + ind.island_rank = rank + self.population.append(ind) + if ind.loss is None: + # TODO resume evaluation on this individual + raise def set_up_checkpoint(self): """Initialize checkpoint file or check consistenct with an existing one.""" @@ -252,10 +261,9 @@ def set_up_checkpoint(self): else: limit_dim += 1 - num_islands = 1 - if self.island_counts is not None: - num_islands = len(self.island_counts) + num_islands = len(self.island_counts) + # TODO this can probably be done without mpi just on rank 0 with h5py.File( self.checkpoint_path, "a", driver="mpio", comm=self.propulate_comm ) as f: @@ -286,17 +294,17 @@ def set_up_checkpoint(self): # population for i in range(num_islands): f.require_group(f"{i}") - for worker_idx in range(self.propulate_comm.Get_size()): + for worker_idx in range(self.island_counts[i]): group = f[f"{i}"].require_group(f"{worker_idx}") if oldgenerations < self.generations: group["x"].resize(self.generations, axis=0) group["loss"].resize(self.generations, axis=0) - group["active"].resize(self.generations, axis=0) group["current"].resize(self.generations, axis=0) group["migration_steps"].resize(self.generations, axis=0) group["starttime"].resize(self.generations, axis=0) group["evaltime"].resize(self.generations, axis=0) group["evalperiod"].resize(self.generations, axis=0) + group["active_on_island"].resize(self.generations, axis=0) group.require_dataset( "x", @@ -312,13 +320,6 @@ def set_up_checkpoint(self): chunks=True, maxshape=(None,), ) - group.require_dataset( - "active", - (self.generations,), - np.bool_, - chunks=True, - maxshape=(None,), - ) group.require_dataset( "current", (self.generations,), @@ -355,6 +356,14 @@ def set_up_checkpoint(self): maxshape=(None,), data=-1 * np.ones((self.generations,)), ) + group.require_dataset( + "active_on_island", + (self.generations, num_islands), + dtype=bool, + chunks=True, + maxshape=(None, None), + data=np.zeros((self.generations, num_islands), dtype=bool), + ) def propulate(self, logging_interval: int = 10, debug: int = 1) -> None: """ @@ -423,28 +432,31 @@ def _breed(self) -> Individual: def _evaluate_individual(self, hdf5_checkpoint) -> None: """Breed and evaluate individual.""" ind = self._breed() # Breed new individual. + ind.island_rank = self.island_comm.Get_rank() start_time = time.time_ns() - self.start_time # Start evaluation timer. ind.starttime = start_time ckpt_idx = ind.generation hdf5_checkpoint["generations"][self.propulate_comm.Get_rank()] = ind.generation - group = hdf5_checkpoint[f"{self.island_idx}"][ - f"{self.propulate_comm.Get_rank()}" - ] + group = hdf5_checkpoint[f"{self.island_idx}"][f"{self.island_comm.Get_rank()}"] # save candidate group["x"][ckpt_idx, 0, :] = ind.position[:] if ind.velocity is not None: group["x"][ckpt_idx, 1, :] = ind.velocity[:] group["starttime"][ckpt_idx] = start_time + group["current"][ckpt_idx] = ind.current ind.loss = self.loss_fn(ind) # Evaluate its loss. ind.evaltime = time.time_ns() - self.start_time # Stop evaluation timer. ind.evalperiod = ind.evaltime - start_time # Calculate evaluation duration. + # TODO fix evalperiod for resumed from checkpoint individuals + # TODO somehow store migration history, maybe just as islands_visited # save result for candidate group["loss"][ckpt_idx] = ind.loss group["evaltime"][ckpt_idx] = ind.evaltime group["evalperiod"][ckpt_idx] = ind.evalperiod + group["active_on_island"][ckpt_idx, self.island_idx] = True # Signal start of run to surrogate model. if self.surrogate is not None: self.surrogate.start_run(ind) @@ -472,7 +484,7 @@ def loss_fn(individual): self.surrogate.update(ind.loss) if self.propulate_comm is None: return - ind.evaltime = time.time() # Stop evaluation timer. + ind.evaltime = time.time_ns() - self.start_time # Stop evaluation timer. ind.evalperiod = ind.evaltime - start_time # Calculate evaluation duration. self.population.append( ind @@ -648,7 +660,7 @@ def _work(self, logging_interval: int = 10, debug: int = -1): # Loop over generations. with h5py.File( - self.checkpoint_path, "a", driver="mpio", comm=MPI.COMM_WORLD + self.checkpoint_path, "a", driver="mpio", comm=self.propulate_comm ) as f: while self.generation < self.generations: if self.generation % int(logging_interval) == 0: