Skip to content

Commit

Permalink
fixed mistake in checkpoint reading, fixed comm, used individual isla…
Browse files Browse the repository at this point in the history
…nd origin, implemented active_on_island
  • Loading branch information
Oskar Taubert committed May 15, 2024
1 parent a321284 commit 979b8b8
Showing 1 changed file with 61 additions and 49 deletions.
110 changes: 61 additions & 49 deletions propulate/propulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ def __init__(
self.island_displs = (
island_displs # Propulate world rank of each island's worker
)
self.island_counts = island_counts # Number of workers on each island
if island_counts is None:
self.island_counts = np.array([self.island_comm.Get_size()])
else:
self.island_counts = island_counts # Number of workers on each island
self.emigration_propagator = emigration_propagator # Emigration propagator
self.rng = rng

Expand Down Expand Up @@ -197,51 +200,57 @@ def load_checkpoint(self):
# TODO check that the island and worker setup is the same as in the checkpoint
# NOTE each individual is only stored once at the position given by its origin island and worker, the modifications have to be put in the checkpoint file during migration TODO test if this works as intended reliably
# TODO get the started but not yet completed ones from the difference in start time and evaltime
# TODO only load an incomplete one, if you're then going to evaluate it

with h5py.File(self.checkpoint_path, "r", driver=None) as f:
islandgroup = f[f"{self.island_idx}"]

# NOTE check limits are consistent
limitsgroup = f["limits"]
for key in self.propagator.limits:
if set(limitsgroup.attrs.keys()) != set(self.propagator.limits):
raise RuntimeError("Limits inconsistent with checkpoint")
if set(limitsgroup.attrs.keys()) != set(self.propagator.limits):
raise RuntimeError("Limits inconsistent with checkpoint")
# TODO check island sizes are consistent

self.generation = int(f["generations"][self.propulate_comm.Get_rank()])
if (
islandgroup[f"{self.island_comm.Get_rank()}"]["evalperiod"][
f[f"{self.island_idx}"][f"{self.island_comm.Get_rank()}"]["evalperiod"][
self.generation
]
> 0.0
):
self.generation += 1
# NOTE load individuals, since they might have migrated, every worker has to check each dataset
for rank in range(self.propulate_comm.size):
# for generation in range(len(islandgroup[f"{rank}"])):
for generation in range(f["generations"][rank] + 1):
if islandgroup[f"{rank}"]["current"][generation] == self.island_idx:
ind = Individual(
islandgroup[f"{rank}"]["x"][generation, 0],
self.propagator.limits,
)
ind.rank = rank
ind.island = self.island_idx
ind.current = islandgroup[f"{rank}"]["current"][generation]
# TODO velocity loading
# if len(group[f"{rank}"].shape) > 1:
# ind.velocity = islandgroup[f"{rank}"]["x"][generation, 1]
ind.loss = islandgroup[f"{rank}"]["loss"][generation]
ind.startime = islandgroup[f"{rank}"]["starttime"][generation]
ind.evaltime = islandgroup[f"{rank}"]["evaltime"][generation]
ind.evalperiod = islandgroup[f"{rank}"]["evalperiod"][
generation
]
ind.generation = generation
self.population.append(ind)
if ind.loss is None:
# TODO resume evaluation
raise
num_islands = len(self.island_counts)
for i in range(num_islands):
islandgroup = f[f"{i}"]
for rank in range(self.island_counts[i]):
for generation in range(f["generations"][rank] + 1):
if islandgroup[f"{rank}"]["active_on_island"][generation][
self.island_idx
]:
ind = Individual(
islandgroup[f"{rank}"]["x"][generation, 0],
self.propagator.limits,
)
# TODO check what rank was used for, i think it was the rank in the propulator_comm
ind.rank = rank
ind.island = self.island_idx
ind.current = islandgroup[f"{rank}"]["current"][generation]
# TODO velocity loading
# if len(group[f"{rank}"].shape) > 1:
# ind.velocity = islandgroup[f"{rank}"]["x"][generation, 1]
ind.loss = islandgroup[f"{rank}"]["loss"][generation]
# ind.startime = islandgroup[f"{rank}"]["starttime"][generation]
ind.evaltime = islandgroup[f"{rank}"]["evaltime"][
generation
]
ind.evalperiod = islandgroup[f"{rank}"]["evalperiod"][
generation
]
ind.generation = generation
ind.island_rank = rank
self.population.append(ind)
if ind.loss is None:
# TODO resume evaluation on this individual
raise

def set_up_checkpoint(self):
"""Initialize checkpoint file or check consistenct with an existing one."""
Expand All @@ -252,10 +261,9 @@ def set_up_checkpoint(self):
else:
limit_dim += 1

num_islands = 1
if self.island_counts is not None:
num_islands = len(self.island_counts)
num_islands = len(self.island_counts)

# TODO this can probably be done without mpi just on rank 0
with h5py.File(
self.checkpoint_path, "a", driver="mpio", comm=self.propulate_comm
) as f:
Expand Down Expand Up @@ -286,17 +294,17 @@ def set_up_checkpoint(self):
# population
for i in range(num_islands):
f.require_group(f"{i}")
for worker_idx in range(self.propulate_comm.Get_size()):
for worker_idx in range(self.island_counts[i]):
group = f[f"{i}"].require_group(f"{worker_idx}")
if oldgenerations < self.generations:
group["x"].resize(self.generations, axis=0)
group["loss"].resize(self.generations, axis=0)
group["active"].resize(self.generations, axis=0)
group["current"].resize(self.generations, axis=0)
group["migration_steps"].resize(self.generations, axis=0)
group["starttime"].resize(self.generations, axis=0)
group["evaltime"].resize(self.generations, axis=0)
group["evalperiod"].resize(self.generations, axis=0)
group["active_on_island"].resize(self.generations, axis=0)

group.require_dataset(
"x",
Expand All @@ -312,13 +320,6 @@ def set_up_checkpoint(self):
chunks=True,
maxshape=(None,),
)
group.require_dataset(
"active",
(self.generations,),
np.bool_,
chunks=True,
maxshape=(None,),
)
group.require_dataset(
"current",
(self.generations,),
Expand Down Expand Up @@ -355,6 +356,14 @@ def set_up_checkpoint(self):
maxshape=(None,),
data=-1 * np.ones((self.generations,)),
)
group.require_dataset(
"active_on_island",
(self.generations, num_islands),
dtype=bool,
chunks=True,
maxshape=(None, None),
data=np.zeros((self.generations, num_islands), dtype=bool),
)

def propulate(self, logging_interval: int = 10, debug: int = 1) -> None:
"""
Expand Down Expand Up @@ -423,28 +432,31 @@ def _breed(self) -> Individual:
def _evaluate_individual(self, hdf5_checkpoint) -> None:
"""Breed and evaluate individual."""
ind = self._breed() # Breed new individual.
ind.island_rank = self.island_comm.Get_rank()
start_time = time.time_ns() - self.start_time # Start evaluation timer.
ind.starttime = start_time
ckpt_idx = ind.generation
hdf5_checkpoint["generations"][self.propulate_comm.Get_rank()] = ind.generation

group = hdf5_checkpoint[f"{self.island_idx}"][
f"{self.propulate_comm.Get_rank()}"
]
group = hdf5_checkpoint[f"{self.island_idx}"][f"{self.island_comm.Get_rank()}"]
# save candidate
group["x"][ckpt_idx, 0, :] = ind.position[:]
if ind.velocity is not None:
group["x"][ckpt_idx, 1, :] = ind.velocity[:]
group["starttime"][ckpt_idx] = start_time
group["current"][ckpt_idx] = ind.current

ind.loss = self.loss_fn(ind) # Evaluate its loss.
ind.evaltime = time.time_ns() - self.start_time # Stop evaluation timer.
ind.evalperiod = ind.evaltime - start_time # Calculate evaluation duration.
# TODO fix evalperiod for resumed from checkpoint individuals
# TODO somehow store migration history, maybe just as islands_visited

# save result for candidate
group["loss"][ckpt_idx] = ind.loss
group["evaltime"][ckpt_idx] = ind.evaltime
group["evalperiod"][ckpt_idx] = ind.evalperiod
group["active_on_island"][ckpt_idx, self.island_idx] = True
# Signal start of run to surrogate model.
if self.surrogate is not None:
self.surrogate.start_run(ind)
Expand Down Expand Up @@ -472,7 +484,7 @@ def loss_fn(individual):
self.surrogate.update(ind.loss)
if self.propulate_comm is None:
return
ind.evaltime = time.time() # Stop evaluation timer.
ind.evaltime = time.time_ns() - self.start_time # Stop evaluation timer.
ind.evalperiod = ind.evaltime - start_time # Calculate evaluation duration.
self.population.append(
ind
Expand Down Expand Up @@ -648,7 +660,7 @@ def _work(self, logging_interval: int = 10, debug: int = -1):

# Loop over generations.
with h5py.File(
self.checkpoint_path, "a", driver="mpio", comm=MPI.COMM_WORLD
self.checkpoint_path, "a", driver="mpio", comm=self.propulate_comm
) as f:
while self.generation < self.generations:
if self.generation % int(logging_interval) == 0:
Expand Down

0 comments on commit 979b8b8

Please sign in to comment.