diff --git a/neps/runtime.py b/neps/runtime.py index 92d3b824..7b2167e5 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -55,6 +55,32 @@ def _default_worker_name() -> str: return f"{os.getpid()}-{isoformat}" +_DDP_ENV_VAR_NAME = "NEPS_DDP_TRIAL_ID" + + +def _is_ddp_and_not_rank_zero() -> bool: + import torch.distributed as dist + + # Check for environment variables typically set by DDP + ddp_env_vars = ["WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"] + rank_env_vars = ["RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK"] + + # Check if PyTorch distributed is initialized + if (dist.is_available() and dist.is_initialized()) or all( + var in os.environ for var in ddp_env_vars + ): + for var in rank_env_vars: + rank = os.environ.get(var) + if rank is not None: + return int(rank) != 0 + return False + + +def _set_ddp_env_var(trial_id: str) -> None: + """Sets an environment variable with current trial_id in a DDP setup.""" + os.environ[_DDP_ENV_VAR_NAME] = trial_id + + Loc = TypeVar("Loc") # NOTE: As each NEPS process is only ever evaluating a single trial, this global can @@ -119,6 +145,7 @@ def _set_global_trial(trial: Trial) -> Iterator[None]: "\n\nThis is most likely a bug and should be reported to NePS!" ) _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = trial + _set_ddp_env_var(trial.id) yield # This is mostly for `tblogger` @@ -608,6 +635,47 @@ def run(self) -> None: # noqa: C901, PLR0912, PLR0915 ) +def _launch_ddp_runtime( + *, + evaluation_fn: Callable[..., float | Mapping[str, Any]], + optimization_dir: Path, +) -> None: + neps_state = NePSState.create_or_load( + path=optimization_dir, + load_only=True, + ) + prev_trial = None + while True: + current_eval_trials = neps_state.lock_and_get_current_evaluating_trials() + # If the worker id on previous trial is the same as the current one, only then + # evaluate it. + if len(current_eval_trials) > 0: + current_trial = None + if prev_trial is None: + # In the beginning, we simply read the current trial from the + # environment variable + if _DDP_ENV_VAR_NAME in os.environ: + current_id = os.getenv(_DDP_ENV_VAR_NAME) + if current_id is None: + raise RuntimeError( + "In a pytorch-lightning DDP setup, the environment variable" + f" '{_DDP_ENV_VAR_NAME}' was not set. This is probably a bug in" + " NePS and should be reported." + ) + current_trial = neps_state.lock_and_get_trial_by_id(current_id) + else: + for trial in current_eval_trials: # type: ignore[unreachable] + if ( + trial.metadata.evaluating_worker_id + == prev_trial.metadata.evaluating_worker_id + ) and (trial.id != prev_trial.id): + current_trial = trial + break + if current_trial: + evaluation_fn(**current_trial.config) + prev_trial = current_trial + + # TODO: This should be done directly in `api.run` at some point to make it clearer at an # entryy point how the woerer is set up to run if someone reads the entry point code. def _launch_runtime( # noqa: PLR0913 @@ -626,6 +694,13 @@ def _launch_runtime( # noqa: PLR0913 max_evaluations_for_worker: int | None, sample_batch_size: int | None, ) -> None: + if _is_ddp_and_not_rank_zero(): + # Do not launch a new worker if we are in a DDP setup and not rank 0 + _launch_ddp_runtime( + evaluation_fn=evaluation_fn, optimization_dir=optimization_dir + ) + return + if overwrite_optimization_dir and optimization_dir.exists(): logger.info( f"Overwriting optimization directory '{optimization_dir}' as" diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index a14c0fc7..1ed8948f 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -530,6 +530,16 @@ def lock_and_get_next_pending_trial( return pendings[0] if pendings else None return pendings[:n] + def lock_and_get_current_evaluating_trials(self) -> list[Trial]: + """Get the current evaluating trials.""" + with self._trial_lock.lock(): + trials = self._trial_repo.latest() + return [ + trial + for trial in trials.values() + if trial.metadata.state == Trial.State.EVALUATING + ] + @classmethod def create_or_load( cls, diff --git a/neps_examples/__init__.py b/neps_examples/__init__.py index 6647aa39..f1c8f463 100644 --- a/neps_examples/__init__.py +++ b/neps_examples/__init__.py @@ -18,6 +18,7 @@ "multi_fidelity", "multi_fidelity_and_expert_priors", "pytorch_native_ddp", + "pytorch_lightning_ddp", ], } diff --git a/neps_examples/efficiency/pytorch_lightning_ddp.py b/neps_examples/efficiency/pytorch_lightning_ddp.py new file mode 100644 index 00000000..4b387ed4 --- /dev/null +++ b/neps_examples/efficiency/pytorch_lightning_ddp.py @@ -0,0 +1,101 @@ +import lightning as L +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, random_split +import neps +import logging + +NUM_GPU = 8 # Number of GPUs to use for DDP + +class ToyModel(nn.Module): + """ Taken from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html """ + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + +class LightningModel(L.LightningModule): + def __init__(self, lr): + super().__init__() + self.lr = lr + self.model = ToyModel() + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.mse_loss(y_hat, y) + self.log("train_loss", loss, prog_bar=True, sync_dist=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.mse_loss(y_hat, y) + self.log("val_loss", loss, prog_bar=True, sync_dist=True) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.mse_loss(y_hat, y) + self.log("test_loss", loss, prog_bar=True, sync_dist=True) + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=self.lr) + +def evaluate_pipeline(lr=0.1, epoch=20): + L.seed_everything(42) + # Model + model = LightningModel(lr=lr) + + # Generate random tensors for data and labels + data = torch.rand((1000, 10)) + labels = torch.rand((1000, 5)) + + dataset = list(zip(data, labels)) + + train_dataset, val_dataset, test_dataset = random_split(dataset, [600, 200, 200]) + + # Define simple data loaders using tensors and slicing + train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True) + val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=False) + test_dataloader = DataLoader(test_dataset, batch_size=20, shuffle=False) + + # Trainer with DDP Strategy + trainer = L.Trainer(gradient_clip_val=0.25, + max_epochs=epoch, + fast_dev_run=False, + strategy='ddp', + devices=NUM_GPU + ) + trainer.fit(model, train_dataloader, val_dataloader) + trainer.validate(model, test_dataloader) + return trainer.logged_metrics["val_loss"] + +pipeline_space = dict( + lr=neps.Float( + lower=0.001, + upper=0.1, + log=True, + prior=0.01 + ), + epoch=neps.Integer( + lower=1, + upper=3, + is_fidelity=True + ) + ) + +logging.basicConfig(level=logging.INFO) +neps.run( + evaluate_pipeline=evaluate_pipeline, + pipeline_space=pipeline_space, + root_directory="results/pytorch_lightning_ddp", + max_evaluations_total=5 + ) \ No newline at end of file