Skip to content

Commit

Permalink
Merge pull request #160 from Helmholtz-AI-Energy/hotfix/blocking_send…
Browse files Browse the repository at this point in the history
…_deadlock

Hotfix/blocking send deadlock
  • Loading branch information
mcw92 authored Sep 13, 2024
2 parents 2cb1ac6 + 6db11b3 commit ff3cb91
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 129 deletions.
2 changes: 1 addition & 1 deletion coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions propulate/propagators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,12 @@ def __call__(self, *inds: Individual) -> Individual: # type: ignore[override]
if isinstance(
self.limits[limit][0], int
): # If ordinal trait of type integer.
if len(self.limits[limit]) == 2: # Selecting one value in range of ordinal parameter
if (
len(self.limits[limit]) == 2
): # Selecting one value in range of ordinal parameter
position[limit] = self.rng.randint(*self.limits[limit])
else: # Selecting one distinct value from ordinal parameters
position[limit] = self.rng.choice(self.limits[limit])
position[limit] = self.rng.choice(self.limits[limit]) # type: ignore
elif isinstance(
self.limits[limit][0], float
): # If interval trait of type float.
Expand Down
147 changes: 47 additions & 100 deletions propulate/propulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,12 @@ def __init__(
if MPI.COMM_WORLD.rank == 0:
log.info("Requested number of generations is zero...[RETURN]")
return
self.generations = (
generations # Number of generations (evaluations per individual)
)
self.generations = generations # Number of generations (evaluations per individual)
self.generation = 0 # Current generation not yet evaluated
self.island_idx = island_idx # Island index
self.island_comm = island_comm # Intra-island communicator
self.propulate_comm = propulate_comm # Propulate world communicator
self.worker_sub_comm = (
worker_sub_comm # Sub communicator for each (multi rank) worker
)
self.worker_sub_comm = worker_sub_comm # Sub communicator for each (multi rank) worker

# Always initialize the ``Surrogate`` as the class attribute has to be set for ``None`` checks later.
self.surrogate = None if surrogate_factory is None else surrogate_factory()
Expand All @@ -164,12 +160,13 @@ def __init__(
self.checkpoint_path.mkdir(parents=True, exist_ok=True)
self.migration_prob = migration_prob # Per-rank migration probability
self.migration_topology = migration_topology # Migration topology
self.island_displs = (
island_displs # Propulate world rank of each island's worker
)
self.island_displs = island_displs # Propulate world rank of each island's worker
self.island_counts = island_counts # Number of workers on each island
self.emigration_propagator = emigration_propagator # Emigration propagator
self.rng = rng
self.rng = rng # Generator for inter-island communication

self.intra_requests: list[MPI.Request] = [] # Keep track of intra-island send requests.
self.intra_buffers: list[Individual] = [] # Send buffers for intra-island communication

# Load initial population of evaluated individuals from checkpoint if exists.
load_ckpt_file = self.checkpoint_path / f"island_{self.island_idx}_ckpt.pickle"
Expand All @@ -181,32 +178,20 @@ def __init__(
try:
self.population = pickle.load(f)
self.generation = (
max(
[
x.generation
for x in self.population
if x.rank == self.island_comm.rank
]
)
+ 1
max([x.generation for x in self.population if x.rank == self.island_comm.rank]) + 1
) # Determine generation to be evaluated next from population checkpoint.
if self.island_comm.rank == 0:
log.info(
"Valid checkpoint file found. "
f"Resuming from generation {self.generation} of loaded population..."
"Valid checkpoint file found. " f"Resuming from generation {self.generation} of loaded population..."
)
except OSError:
self.population = []
if self.island_comm.rank == 0:
log.info(
"No valid checkpoint file. Initializing population randomly..."
)
log.info("No valid checkpoint file. Initializing population randomly...")
else:
self.population = []
if self.island_comm.rank == 0:
log.info(
"No valid checkpoint file given. Initializing population randomly..."
)
log.info("No valid checkpoint file given. Initializing population randomly...")

def _get_active_individuals(self) -> Tuple[List[Individual], int]:
"""
Expand Down Expand Up @@ -236,9 +221,7 @@ def _breed(self) -> Individual:
): # Only processes in the Propulate world communicator, consisting of rank 0 of each worker's sub
# communicator, are involved in the actual optimization routine.
active_pop, _ = self._get_active_individuals()
ind = self.propagator(
active_pop
) # Breed new individual from active population.
ind = self.propagator(active_pop) # Breed new individual from active population.
assert isinstance(ind, Individual)
ind.generation = self.generation # Set generation.
ind.rank = self.island_comm.rank # Set worker rank.
Expand All @@ -250,9 +233,7 @@ def _breed(self) -> Individual:
else: # The other processes do not breed themselves.
ind = None

if (
self.worker_sub_comm != MPI.COMM_SELF
): # Broadcast newly bred individual to all internal ranks of a worker from rank 0,
if self.worker_sub_comm != MPI.COMM_SELF: # Broadcast newly bred individual to all internal ranks of a worker from rank 0,
# which is also part of the Propulate comm.
ind = self.worker_sub_comm.bcast(obj=ind, root=0)

Expand Down Expand Up @@ -310,9 +291,7 @@ def loss_fn(individual: Individual) -> float:
return
ind.evaltime = time.time() # Stop evaluation timer.
ind.evalperiod = ind.evaltime - start_time # Calculate evaluation duration.
self.population.append(
ind
) # Add evaluated individual to worker-local population.
self.population.append(ind) # Add evaluated individual to worker-local population.
log.debug(
f"Island {self.island_idx} Worker {self.island_comm.rank} Generation {self.generation}: BREEDING\n"
f"Bred and evaluated individual {ind}."
Expand All @@ -323,12 +302,11 @@ def loss_fn(individual: Individual) -> float:
ind[SURROGATE_KEY] = self.surrogate.data()

# Tell other workers in own island about results to synchronize their populations.
for r in range(
self.island_comm.size
): # Loop over ranks in intra-island communicator.
for r in range(self.island_comm.size): # Loop over ranks in intra-island communicator.
if r == self.island_comm.rank:
continue # No self-talk.
self.island_comm.send(copy.deepcopy(ind), dest=r, tag=INDIVIDUAL_TAG)
self.intra_buffers.append(copy.deepcopy(ind))
self.intra_requests.append(self.island_comm.isend(self.intra_buffers[-1], dest=r, tag=INDIVIDUAL_TAG))

if self.surrogate is not None:
# Remove data from individual again as ``__eq__`` fails otherwise.
Expand All @@ -342,21 +320,15 @@ def _receive_intra_island_individuals(self) -> None:
)
probe_ind = True
while probe_ind:
stat = (
MPI.Status()
) # Retrieve status of reception operation, including source and tag.
probe_ind = self.island_comm.iprobe(
source=MPI.ANY_SOURCE, tag=INDIVIDUAL_TAG, status=stat
)
stat = MPI.Status() # Retrieve status of reception operation, including source and tag.
probe_ind = self.island_comm.iprobe(source=MPI.ANY_SOURCE, tag=INDIVIDUAL_TAG, status=stat)
# If True, continue checking for incoming messages. Tells whether message corresponding
# to filters passed is waiting for reception via a flag that it sets.
# If no such message has arrived yet, it returns False.
log_string += f"Incoming individual to receive?...{probe_ind}\n"
if probe_ind:
# Receive individual and add it to own population.
ind_temp = self.island_comm.recv(
source=stat.Get_source(), tag=INDIVIDUAL_TAG
)
ind_temp = self.island_comm.recv(source=stat.Get_source(), tag=INDIVIDUAL_TAG)

# Only merge if surrogate model is used.
if SURROGATE_KEY in ind_temp and self.surrogate is not None:
Expand All @@ -365,15 +337,11 @@ def _receive_intra_island_individuals(self) -> None:
if SURROGATE_KEY in ind_temp:
del ind_temp[SURROGATE_KEY]

self.population.append(
ind_temp
) # Add received individual to own worker-local population.
self.population.append(ind_temp) # Add received individual to own worker-local population.

log_string += f"Added individual {ind_temp} from W{stat.Get_source()} to own population.\n"
_, n_active = self._get_active_individuals()
log_string += (
f"After probing within island: {n_active}/{len(self.population)} active."
)
log_string += f"After probing within island: {n_active}/{len(self.population)} active."
log.debug(log_string)

def _send_emigrants(self) -> None:
Expand Down Expand Up @@ -422,9 +390,7 @@ def _get_unique_individuals(self) -> List[Individual]:
unique_inds.append(individual)
return unique_inds

def _check_intra_island_synchronization(
self, populations: List[List[Individual]]
) -> bool:
def _check_intra_island_synchronization(self, populations: List[List[Individual]]) -> bool:
"""
Check synchronization of populations of workers within one island.
Expand All @@ -440,15 +406,10 @@ def _check_intra_island_synchronization(
"""
synchronized = True
for population in populations:
difference = deepdiff.DeepDiff(
population, populations[0], ignore_order=True
)
difference = deepdiff.DeepDiff(population, populations[0], ignore_order=True)
if len(difference) == 0:
continue
log.info(
f"Island {self.island_idx} Worker {self.island_comm.rank}: Population not synchronized:\n"
f"{difference}"
)
log.info(f"Island {self.island_idx} Worker {self.island_comm.rank}: Population not synchronized:\n" f"{difference}")
synchronized = False
return synchronized

Expand Down Expand Up @@ -487,22 +448,21 @@ def propulate(self, logging_interval: int = 10, debug: int = -1) -> None:
# Loop over generations.
while self.generations <= -1 or self.generation < self.generations:
if self.generation % int(logging_interval) == 0:
log.info(
f"Island {self.island_idx} Worker {self.island_comm.rank}: In generation {self.generation}..."
)
log.info(f"Island {self.island_idx} Worker {self.island_comm.rank}: In generation {self.generation}...")

# Breed and evaluate individual.
self._evaluate_individual()

# Check for and possibly receive incoming individuals from other intra-island workers.
self._receive_intra_island_individuals()

# Clean up requests and buffers.
self._intra_send_cleanup()

if dump: # Dump checkpoint.
self._dump_checkpoint()

dump = (
self._determine_worker_dumping_next()
) # Determine worker dumping checkpoint in the next generation.
dump = self._determine_worker_dumping_next() # Determine worker dumping checkpoint in the next generation.

# Go to next generation.
self.generation += 1
Expand All @@ -527,11 +487,18 @@ def propulate(self, logging_interval: int = 10, debug: int = -1) -> None:
_ = self._determine_worker_dumping_next()
self.propulate_comm.barrier()

def _intra_send_cleanup(self) -> None:
"""Delete all send buffers that have been sent."""
# Test for requests to complete.
completed = MPI.Request.Testsome(self.intra_requests)
# Remove requests and buffers of complete send operations.
self.intra_requests = [r for i, r in enumerate(self.intra_requests) if i not in completed]
self.intra_buffers = [b for i, b in enumerate(self.intra_buffers) if i not in completed]

def _dump_checkpoint(self) -> None:
"""Dump checkpoint to file."""
log.debug(
f"Island {self.island_idx} Worker {self.island_comm.rank} Generation {self.generation}: "
f"Dumping checkpoint..."
f"Island {self.island_idx} Worker {self.island_comm.rank} Generation {self.generation}: " f"Dumping checkpoint..."
)
save_ckpt_file = self.checkpoint_path / f"island_{self.island_idx}_ckpt.pickle"
if os.path.isfile(save_ckpt_file):
Expand All @@ -542,20 +509,14 @@ def _dump_checkpoint(self) -> None:
with open(save_ckpt_file, "wb") as f:
pickle.dump(self.population, f)

dest = (
self.island_comm.rank + 1
if self.island_comm.rank + 1 < self.island_comm.size
else 0
)
dest = self.island_comm.rank + 1 if self.island_comm.rank + 1 < self.island_comm.size else 0
self.island_comm.send(True, dest=dest, tag=DUMP_TAG)

def _determine_worker_dumping_next(self) -> bool:
"""Determine the worker who dumps the checkpoint in the next generation."""
dump = False
stat = MPI.Status()
probe_dump = self.island_comm.iprobe(
source=MPI.ANY_SOURCE, tag=DUMP_TAG, status=stat
)
probe_dump = self.island_comm.iprobe(source=MPI.ANY_SOURCE, tag=DUMP_TAG, status=stat)
if probe_dump:
dump = self.island_comm.recv(source=stat.Get_source(), tag=DUMP_TAG)
log.debug(
Expand All @@ -575,9 +536,7 @@ def _dump_final_checkpoint(self) -> None:
with open(save_ckpt_file, "wb") as f:
pickle.dump(self.population, f)

def _check_for_duplicates(
self, active: bool, debug: int = 1
) -> Tuple[List[List[Union[Individual, int]]], List[Individual]]:
def _check_for_duplicates(self, active: bool, debug: int = 1) -> Tuple[List[List[Union[Individual, int]]], List[Individual]]:
"""
Check for duplicates in current population.
Expand Down Expand Up @@ -620,9 +579,7 @@ def _check_for_duplicates(
occurrences.append([individual, num_copies])
return occurrences, unique_inds

def summarize(
self, top_n: int = 1, debug: int = 1
) -> Union[List[Union[List[Individual], Individual]], None]:
def summarize(self, top_n: int = 1, debug: int = 1) -> Union[List[Union[List[Individual], Individual]], None]:
"""
Get top-n results from Propulate optimization.
Expand All @@ -641,15 +598,9 @@ def summarize(
if self.propulate_comm is None:
return None
active_pop, num_active = self._get_active_individuals()
assert np.all(
np.array(self.island_comm.allgather(num_active), dtype=int) == num_active
)
assert np.all(np.array(self.island_comm.allgather(num_active), dtype=int) == num_active)
if self.island_counts is not None:
num_active = int(
self.propulate_comm.allreduce(
num_active / self.island_counts[self.island_idx]
)
)
num_active = int(self.propulate_comm.allreduce(num_active / self.island_counts[self.island_idx]))

self.propulate_comm.barrier()
if self.propulate_comm.rank == 0:
Expand All @@ -664,13 +615,9 @@ def summarize(
occurrences, _ = self._check_for_duplicates(True, debug)
if self.island_comm.rank == 0:
if self._check_intra_island_synchronization(populations):
log.info(
f"Island {self.island_idx}: Populations among workers synchronized."
)
log.info(f"Island {self.island_idx}: Populations among workers synchronized.")
else:
log.info(
f"Island {self.island_idx}: Populations among workers not synchronized:\n{populations}"
)
log.info(f"Island {self.island_idx}: Populations among workers not synchronized:\n{populations}")
log.info(
f"Island {self.island_idx}: {len(active_pop)}/{len(self.population)} "
f"individuals active ({len(occurrences)} unique)"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ exclude = [
]

# Same as Black.
line-length = 88
line-length = 132
indent-width = 4

# Assume Python 3.9.
Expand Down
Loading

0 comments on commit ff3cb91

Please sign in to comment.