Skip to content

Commit

Permalink
Merge branch 'develop' into no_path
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl authored Nov 27, 2024
2 parents e869409 + d39b959 commit fa0bd83
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ jobs:

- name: Run tests
timeout-minutes: 35
run: tox -e petab
run: tox -e petab && tox e -e petab -- pip uninstall -y amici
env:
CC: clang
CXX: clang++
Expand Down
2 changes: 2 additions & 0 deletions pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class ESSExitFlag(int, enum.Enum):
MAX_EVAL = -2
# Exited after exhausting wall-time budget
MAX_TIME = -3
# Termination because for other reason than exit criteria
ERROR = -99


class OptimizerFactory(Protocol):
Expand Down
158 changes: 120 additions & 38 deletions pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import multiprocessing
import os
import time
from contextlib import suppress
from dataclasses import dataclass
from math import ceil, sqrt
from multiprocessing import get_context
Expand All @@ -20,6 +21,7 @@

import pypesto

from ... import MemoryHistory
from ...startpoint import StartpointMethod
from ...store.read_from_hdf5 import read_result
from ...store.save_to_hdf5 import write_result
Expand Down Expand Up @@ -331,12 +333,18 @@ def minimize(
n_eval_total = sum(
worker_result.n_eval for worker_result in self.worker_results
)
logger.info(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations "
f"with global best {result.optimize_result[0].fval}."
)

if len(result.optimize_result):
logger.info(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations "
f"with global best {result.optimize_result[0].fval}."
)
else:
logger.error(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations without producing "
"a result."
)
return result

def _create_result(self, problem: Problem) -> pypesto.Result:
Expand All @@ -345,25 +353,40 @@ def _create_result(self, problem: Problem) -> pypesto.Result:
Creates an overall Result object from the results saved by the workers.
"""
# gather results from workers and delete temporary result files
result = None
result = pypesto.Result()
retry_after_sleep = True
for worker_idx in range(self.num_workers):
tmp_result_filename = SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
)
tmp_result = None
try:
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
except FileNotFoundError:
# wait and retry, maybe the file wasn't found due to some filesystem latency issues
time.sleep(5)
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
if retry_after_sleep:
time.sleep(5)
# waiting once is enough - don't wait for every worker
retry_after_sleep = False

try:
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
except FileNotFoundError:
logger.error(
f"Worker {worker_idx} did not produce a result."
)
continue
else:
logger.error(
f"Worker {worker_idx} did not produce a result."
)
continue

if result is None:
result = tmp_result
else:
if tmp_result:
result.optimize_result.append(
tmp_result.optimize_result,
sort=False,
Expand All @@ -375,7 +398,8 @@ def _create_result(self, problem: Problem) -> pypesto.Result:
filename = SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
)
os.remove(filename)
with suppress(FileNotFoundError):
os.remove(filename)
# delete tmpdir if empty
try:
self._tmpdir.rmdir()
Expand All @@ -397,6 +421,7 @@ class SacessManager:
Attributes
----------
_dim: Dimension of the optimization problem
_num_workers: Number of workers
_ess_options: ESS options for each worker
_best_known_fx: Best objective value encountered so far
Expand All @@ -410,6 +435,7 @@ class SacessManager:
_rejection_threshold: Threshold for relative objective improvements that
incoming solutions have to pass to be accepted
_lock: Lock for accessing shared state.
_terminate: Flag to signal termination of the SACESS run to workers
_logger: A logger instance
_options: Further optimizer hyperparameters.
"""
Expand All @@ -421,6 +447,7 @@ def __init__(
dim: int,
options: SacessOptions = None,
):
self._dim = dim
self._options = options or SacessOptions()
self._num_workers = len(ess_options)
self._ess_options = [shmem_manager.dict(o) for o in ess_options]
Expand All @@ -440,6 +467,7 @@ def __init__(
self._worker_scores = shmem_manager.Array(
"d", range(self._num_workers)
)
self._terminate = shmem_manager.Value("b", False)
self._worker_comms = shmem_manager.Array("i", [0] * self._num_workers)
self._lock = shmem_manager.RLock()
self._logger = logging.getLogger()
Expand Down Expand Up @@ -550,6 +578,16 @@ def submit_solution(
)
self._rejections.value = 0

def abort(self):
"""Abort the SACESS run."""
with self._lock:
self._terminate.value = True

def aborted(self) -> bool:
"""Whether this run was aborted."""
with self._lock:
return self._terminate.value


class SacessWorker:
"""A SACESS worker.
Expand Down Expand Up @@ -641,7 +679,7 @@ def run(
ess = self._setup_ess(startpoint_method)

# run ESS until exit criteria are met, but start at least one iteration
while self._keep_going() or ess.n_iter == 0:
while self._keep_going(ess) or ess.n_iter == 0:
# perform one ESS iteration
ess._do_iteration()

Expand All @@ -667,19 +705,42 @@ def run(
f"(best: {self._best_known_fx}, "
f"n_eval: {ess.evaluator.n_eval})."
)

ess.history.finalize(exitflag=ess.exit_flag.name)
worker_result = SacessWorkerResult(
x=ess.x_best,
fx=ess.fx_best,
history=ess.history,
n_eval=ess.evaluator.n_eval,
n_iter=ess.n_iter,
exit_flag=ess.exit_flag,
)
self._finalize(ess)

def _finalize(self, ess: ESSOptimizer = None):
"""Finalize the worker."""
# Whatever happens here, we need to put something to the queue before
# returning to avoid deadlocks.
worker_result = None
if ess is not None:
try:
ess.history.finalize(exitflag=ess.exit_flag.name)
ess._report_final()
worker_result = SacessWorkerResult(
x=ess.x_best,
fx=ess.fx_best,
history=ess.history,
n_eval=ess.evaluator.n_eval,
n_iter=ess.n_iter,
exit_flag=ess.exit_flag,
)
except Exception as e:
self._logger.exception(
f"Worker {self._worker_idx} failed to finalize: {e}"
)
if worker_result is None:
# Create some dummy result
worker_result = SacessWorkerResult(
x=np.full(self._manager._dim, np.nan),
fx=np.nan,
history=MemoryHistory(),
n_eval=0,
n_iter=0,
exit_flag=ESSExitFlag.ERROR,
)
self._manager._result_queue.put(worker_result)

self._logger.debug(f"Final configuration: {self._ess_kwargs}")
ess._report_final()

def _setup_ess(self, startpoint_method: StartpointMethod) -> ESSOptimizer:
"""Run ESS."""
Expand Down Expand Up @@ -821,7 +882,7 @@ def replace_solution(refset: RefSet, x: np.ndarray, fx: float):
fx=fx,
)

def _keep_going(self):
def _keep_going(self, ess: ESSOptimizer) -> bool:
"""Check exit criteria.
Returns
Expand All @@ -830,14 +891,26 @@ def _keep_going(self):
"""
# elapsed time
if time.time() - self._start_time >= self._max_walltime_s:
self.exit_flag = ESSExitFlag.MAX_TIME
ess.exit_flag = ESSExitFlag.MAX_TIME
self._logger.debug(
f"Max walltime ({self._max_walltime_s}s) exceeded."
)
return False

# other reasons for termination (some worker failed, ...)
if self._manager.aborted():
ess.exit_flag = ESSExitFlag.ERROR
self._logger.debug("Manager requested termination.")
return False
return True

def abort(self):
"""Send signal to abort."""
self._logger.error(f"Worker {self._worker_idx} aborting.")
# signal to manager
self._manager.abort()

self._finalize(None)

@staticmethod
def get_temp_result_filename(worker_idx: int, tmpdir: str | Path) -> str:
return str(Path(tmpdir, f"sacess-{worker_idx:02d}_tmp.h5").absolute())
Expand All @@ -853,15 +926,24 @@ def _run_worker(
Helper function as entrypoint for sacess worker processes.
"""
# different random seeds per process
np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32)

# Forward log messages to the logging process
h = logging.handlers.QueueHandler(log_process_queue)
worker._logger = logging.getLogger(multiprocessing.current_process().name)
worker._logger.addHandler(h)
try:
# different random seeds per process
np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32)

# Forward log messages to the logging process
h = logging.handlers.QueueHandler(log_process_queue)
worker._logger = logging.getLogger(
multiprocessing.current_process().name
)
worker._logger.addHandler(h)

return worker.run(problem=problem, startpoint_method=startpoint_method)
return worker.run(problem=problem, startpoint_method=startpoint_method)
except Exception as e:
with suppress(Exception):
worker._logger.exception(
f"Worker {worker._worker_idx} failed: {e}"
)
worker.abort()


def get_default_ess_options(
Expand Down
2 changes: 1 addition & 1 deletion pypesto/optimize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ def __repr__(self) -> str:
if self.options is not None:
rep += f" options={self.options}"
if self.local_options is not None:
rep += f" local_options={self.local_methods}"
rep += f" local_options={self.local_options}"
return rep + ">"

@minimize_decorator_collection
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
addopts = "--doctest-modules"
filterwarnings =
ignore:.*inspect.getargspec\(\) is deprecated.*:DeprecationWarning
norecursedirs = amici_models
35 changes: 35 additions & 0 deletions test/optimize/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pypesto
import pypesto.optimize as optimize
from pypesto import Objective
from pypesto.optimize.ess import (
ESSOptimizer,
FunctionEvaluatorMP,
Expand Down Expand Up @@ -577,6 +578,40 @@ def test_ess_refset_repr():
)


class FunctionOrError:
"""Callable that raises an error every nth invocation."""

def __init__(self, fun, error_frequency=100):
self.counter = 0
self.error_frequency = error_frequency
self.fun = fun

def __call__(self, *args, **kwargs):
self.counter += 1
if self.counter % self.error_frequency == 0:
raise RuntimeError("Intentional error.")
return self.fun(*args, **kwargs)


def test_sacess_worker_error(capsys):
"""Check that SacessOptimizer does not hang if an error occurs on a worker."""
objective = Objective(
fun=FunctionOrError(sp.optimize.rosen), grad=sp.optimize.rosen_der
)
problem = pypesto.Problem(
objective=objective, lb=0 * np.ones((1, 2)), ub=1 * np.ones((1, 2))
)
sacess = SacessOptimizer(
num_workers=2,
max_walltime_s=2,
sacess_loglevel=logging.DEBUG,
ess_loglevel=logging.DEBUG,
)
res = sacess.minimize(problem)
assert isinstance(res, pypesto.Result)
assert "Intentional error." in capsys.readouterr().err


def test_scipy_integrated_grad():
integrated = True
obj = rosen_for_sensi(max_sensi_order=2, integrated=integrated)["obj"]
Expand Down
10 changes: 7 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ envlist =
# Base-environment

[testenv]
passenv = AMICI_PARALLEL_COMPILE,CC,CXX,MPLBACKEND
passenv = AMICI_PARALLEL_COMPILE,CC,CXX,MPLBACKEND,BNGPATH

# Sub-environments
# inherit settings defined in the base
Expand Down Expand Up @@ -75,10 +75,14 @@ description =
Test basic functionality on Windows

[testenv:petab]
extras = test,amici,petab,pyswarm,roadrunner
extras = test,petab,pyswarm,roadrunner
deps =
git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master\#subdirectory=src/python
git+https://github.com/AMICI-dev/amici.git@develop\#egg=amici&subdirectory=python/sdist
# always install amici from develop branch, avoid caching
# to skip re-installation, run `tox -e petab --override testenv:petab.commands_pre=`
commands_pre =
python3 -m pip uninstall -y amici
python3 -m pip install git+https://github.com/AMICI-dev/amici.git@develop\#egg=amici&subdirectory=python/sdist
commands =
python3 -m pip install git+https://github.com/PEtab-dev/petab_test_suite@main
python3 -m pip install git+https://github.com/pysb/pysb@master
Expand Down

0 comments on commit fa0bd83

Please sign in to comment.