Skip to content

Commit

Permalink
refactor: Favour larger, longer locks with lockf
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Dec 2, 2024
1 parent 20b492a commit 9d9651d
Show file tree
Hide file tree
Showing 17 changed files with 792 additions and 1,922 deletions.
37 changes: 9 additions & 28 deletions neps/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def is_nullable(e: str) -> bool:
parse=str,
default="lockf",
)
assert LINUX_FILELOCK_FUNCTION in ("lockf", "flock")

MAX_RETRIES_GET_NEXT_TRIAL = get_env(
"NEPS_MAX_RETRIES_GET_NEXT_TRIAL",
Expand Down Expand Up @@ -66,35 +65,17 @@ def is_nullable(e: str) -> bool:
default=120,
)

SEED_SNAPSHOT_FILELOCK_POLL = get_env(
"NEPS_SEED_SNAPSHOT_FILELOCK_POLL",
# NOTE: We want this to be greater than the trials filelock, so that
# anything requesting to just update the trials is more likely to obtain it
# as those operations tend to be faster than something that requires optimizer
# state.
STATE_FILELOCK_POLL = get_env(
"NEPS_STATE_FILELOCK_POLL",
parse=float,
default=0.05,
)
SEED_SNAPSHOT_FILELOCK_TIMEOUT = get_env(
"NEPS_SEED_SNAPSHOT_FILELOCK_TIMEOUT",
parse=lambda e: None if is_nullable(e) else float(e),
default=120,
)

OPTIMIZER_INFO_FILELOCK_POLL = get_env(
"NEPS_OPTIMIZER_INFO_FILELOCK_POLL",
parse=float,
default=0.05,
)
OPTIMIZER_INFO_FILELOCK_TIMEOUT = get_env(
"NEPS_OPTIMIZER_INFO_FILELOCK_TIMEOUT",
parse=lambda e: None if is_nullable(e) else float(e),
default=120,
)

OPTIMIZER_STATE_FILELOCK_POLL = get_env(
"NEPS_OPTIMIZER_STATE_FILELOCK_POLL",
parse=float,
default=0.05,
default=0.20,
)
OPTIMIZER_STATE_FILELOCK_TIMEOUT = get_env(
"NEPS_OPTIMIZER_STATE_FILELOCK_TIMEOUT",
STATE_FILELOCK_TIMEOUT = get_env(
"NEPS_STATE_FILELOCK_TIMEOUT",
parse=lambda e: None if is_nullable(e) else float(e),
default=120,
)
Expand Down
2 changes: 1 addition & 1 deletion neps/plot/tensorboard_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _initiate_internal_configurations() -> None:
register_notify_trial_end("NEPS_TBLOGGER", tblogger.end_of_config)

# We are assuming that neps state is all filebased here
root_dir = Path(neps_state.location)
root_dir = Path(neps_state.path)
assert root_dir.exists()

tblogger.config_working_directory = Path(trial.metadata.location)
Expand Down
148 changes: 89 additions & 59 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@
WorkerRaiseError,
)
from neps.state._eval import evaluate_trial
from neps.state.filebased import create_or_load_filebased_neps_state
from neps.state.neps_state import NePSState
from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo
from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings
from neps.state.trial import Trial

if TYPE_CHECKING:
from neps.optimizers.base_optimizer import BaseOptimizer
from neps.state.neps_state import NePSState
from neps.state.trial import Trial

logger = logging.getLogger(__name__)

Expand All @@ -64,7 +63,7 @@ def _default_worker_name() -> str:


# TODO: This only works with a filebased nepsstate
def get_workers_neps_state() -> NePSState[Path]:
def get_workers_neps_state() -> NePSState:
"""Get the worker's NePS state."""
if _WORKER_NEPS_STATE is None:
raise RuntimeError(
Expand All @@ -76,7 +75,7 @@ def get_workers_neps_state() -> NePSState[Path]:
return _WORKER_NEPS_STATE


def _set_workers_neps_state(state: NePSState[Path]) -> None:
def _set_workers_neps_state(state: NePSState) -> None:
global _WORKER_NEPS_STATE # noqa: PLW0603
_WORKER_NEPS_STATE = state

Expand Down Expand Up @@ -177,36 +176,14 @@ def new(
_pre_sample_hooks=_pre_sample_hooks,
)

def _get_next_trial_from_state(self) -> Trial:
nxt_trial = self.state.get_next_pending_trial()

# If we have a trial, we will use it
if nxt_trial is not None:
logger.info(
f"Worker '{self.worker_id}' got previosly sampled trial: {nxt_trial}"
)

# Otherwise sample a new one
else:
nxt_trial = self.state.sample_trial(
worker_id=self.worker_id,
optimizer=self.optimizer,
_sample_hooks=self._pre_sample_hooks,
)
logger.info(f"Worker '{self.worker_id}' sampled a new trial: {nxt_trial}")

return nxt_trial

def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911
def _check_worker_local_settings(
self,
*,
time_monotonic_start: float,
error_from_this_worker: Exception | None,
) -> str | Literal[False]:
# NOTE: Sorry this code is kind of ugly but it's pretty straightforward, just a
# lot of conditional checking and making sure to check cheaper conditions first.
# It would look a little nicer with a match statement but we've got to wait
# for python 3.10 for that.

# First check for stopping criterion for this worker in particular as it's
# cheaper and doesn't require anything from the state.
Expand Down Expand Up @@ -280,13 +257,16 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911
f", given by `{self.settings.max_evaluation_time_for_worker_seconds=}`."
)

return False

def _check_shared_error_stopping_criterion(self) -> str | Literal[False]:
# We check this global error stopping criterion as it's much
# cheaper than sweeping the state from all trials.
if self.settings.on_error in (
OnErrorPossibilities.RAISE_ANY_ERROR,
OnErrorPossibilities.STOP_ANY_ERROR,
):
err = self.state._shared_errors.synced().latest_err_as_raisable()
err = self.state.lock_and_get_errors().latest_err_as_raisable()
if err is not None:
msg = (
"An error occurred in another worker and this worker is set to stop"
Expand All @@ -306,20 +286,12 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911

return msg

# If there are no global stopping criterion, we can no just return early.
if (
self.settings.max_evaluations_total is None
and self.settings.max_cost_total is None
and self.settings.max_evaluation_time_total_seconds is None
):
return False

# At this point, if we have some global stopping criterion, we need to sweep
# the current state of trials to determine if we should stop
# NOTE: If these `sum` turn out to somehow be a bottleneck, these could
# be precomputed and accumulated over time. This would have to be handled
# in the `NePSState` class.
trials = self.state.get_all_trials()
return False

def _check_global_stopping_criterion(
self,
trials: Mapping[str, Trial],
) -> str | Literal[False]:
if self.settings.max_evaluations_total is not None:
if self.settings.include_in_progress_evaluations_towards_maximum:
# NOTE: We can just use the sum of trials in this case as they
Expand Down Expand Up @@ -368,6 +340,8 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911

return False

# Forgive me lord, for I have sinned, this function is atrocious but complicated
# due to locking.
def run(self) -> None: # noqa: C901, PLR0915, PLR0912
"""Run the worker.
Expand All @@ -385,18 +359,27 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912
n_failed_set_trial_state = 0
n_repeated_failed_check_should_stop = 0
while True:
# NOTE: We rely on this function to do logging and raising errors if it should
try:
should_stop = self._check_if_should_stop(
# First check local worker settings
should_stop = self._check_worker_local_settings(
time_monotonic_start=_time_monotonic_start,
error_from_this_worker=_error_from_evaluation,
)
if should_stop is not False:
logger.info(should_stop)
break

# Next check global errs having occured
should_stop = self._check_shared_error_stopping_criterion()
if should_stop is not False:
logger.info(should_stop)
break

except WorkerRaiseError as e:
# If we raise a specific error, we should stop the worker
raise e
except Exception as e:
# An unknown exception, check our retry countk
n_repeated_failed_check_should_stop += 1
if (
n_repeated_failed_check_should_stop
Expand All @@ -415,8 +398,48 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912
time.sleep(1) # Help stagger retries
continue

# From here, we now begin sampling or getting the next pending trial.
# As the global stopping criterion requires us to check all trials, and
# needs to be in locked in-step with sampling
try:
trial_to_eval = self._get_next_trial_from_state()
# If there are no global stopping criterion, we can no just return early.
with self.state.lock_for_sampling():
trials = self.state._trials.latest()

requires_checking_global_stopping_criterion = (
self.settings.max_evaluations_total is not None
or self.settings.max_cost_total is not None
or self.settings.max_evaluation_time_total_seconds is not None
)
if requires_checking_global_stopping_criterion:
should_stop = self._check_global_stopping_criterion(trials)
if should_stop is not False:
logger.info(should_stop)
break

pending_trials = [
trial
for trial in trials.values()
if trial.state == Trial.State.PENDING
]
if len(pending_trials) > 0:
earliest_pending = sorted(
pending_trials,
key=lambda t: t.metadata.time_sampled,
)[0]
earliest_pending.set_evaluating(
time_started=time.time(),
worker_id=self.worker_id,
)
self.state._trials.update_trial(earliest_pending)
trial_to_eval = earliest_pending
else:
sampled_trial = self.state._sample_trial(
optimizer=self.optimizer,
worker_id=self.worker_id,
)
trial_to_eval = sampled_trial

_repeated_fail_get_next_trial_count = 0
except Exception as e:
_repeated_fail_get_next_trial_count += 1
Expand All @@ -439,11 +462,6 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912

# If we can't set this working to evaluating, then just retry the loop
try:
trial_to_eval.set_evaluating(
time_started=time.time(),
worker_id=self.worker_id,
)
self.state.put_updated_trial(trial_to_eval)
n_failed_set_trial_state = 0
except VersionMismatchError:
n_failed_set_trial_state += 1
Expand Down Expand Up @@ -512,11 +530,12 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912
# We do not retry this, as if some other worker has
# managed to manipulate this trial in the meantime,
# then something has gone wrong
self.state.report_trial_evaluation(
trial=evaluated_trial,
report=report,
worker_id=self.worker_id,
)
with self.state.lock_trials():
self.state._report_trial_evaluation(
trial=evaluated_trial,
report=report,
worker_id=self.worker_id,
)

logger.debug("Config %s: %s", evaluated_trial.id, evaluated_trial.config)
logger.debug("Loss %s: %s", evaluated_trial.id, report.loss)
Expand Down Expand Up @@ -553,8 +572,9 @@ def _launch_runtime( # noqa: PLR0913

for _retry_count in range(MAX_RETRIES_CREATE_LOAD_STATE):
try:
neps_state = create_or_load_filebased_neps_state(
directory=optimization_dir,
neps_state = NePSState.create_or_load(
path=optimization_dir,
load_only=False,
optimizer_info=OptimizerInfo(optimizer_info),
optimizer_state=OptimizationState(
budget=(
Expand Down Expand Up @@ -613,7 +633,17 @@ def _launch_runtime( # noqa: PLR0913
# it's not directly advertised as a parameter/env variable or otherwise.
import portalocker.portalocker as portalocker_lock_module

setattr(portalocker_lock_module, "LOCKER", LINUX_FILELOCK_FUNCTION)
try:
import fcntl

if LINUX_FILELOCK_FUNCTION.lower() == "flock":
setattr(portalocker_lock_module, "LOCKER", fcntl.flock)
elif LINUX_FILELOCK_FUNCTION.lower() == "lockf":
setattr(portalocker_lock_module, "LOCKER", fcntl.lockf)
else:
pass
except ImportError:
pass

worker = DefaultWorker.new(
state=neps_state,
Expand Down
12 changes: 0 additions & 12 deletions neps/state/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,11 @@
from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo
from neps.state.protocols import (
Locker,
ReaderWriter,
Synced,
VersionedResource,
Versioner,
)
from neps.state.seed_snapshot import SeedSnapshot
from neps.state.trial import Trial

__all__ = [
"Locker",
"SeedSnapshot",
"Synced",
"BudgetInfo",
"OptimizationState",
"OptimizerInfo",
"Trial",
"ReaderWriter",
"Versioner",
"VersionedResource",
]
Loading

0 comments on commit 9d9651d

Please sign in to comment.