Skip to content

Commit

Permalink
feat: Added support for PyTorch Lightning in the DDP backend. (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
gopaljigaur authored Jan 17, 2025
2 parents 5d8ce6d + f4de5c9 commit 031f151
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 0 deletions.
75 changes: 75 additions & 0 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ def _default_worker_name() -> str:
return f"{os.getpid()}-{isoformat}"


_DDP_ENV_VAR_NAME = "NEPS_DDP_TRIAL_ID"


def _is_ddp_and_not_rank_zero() -> bool:
import torch.distributed as dist

# Check for environment variables typically set by DDP
ddp_env_vars = ["WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]
rank_env_vars = ["RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK"]

# Check if PyTorch distributed is initialized
if (dist.is_available() and dist.is_initialized()) or all(
var in os.environ for var in ddp_env_vars
):
for var in rank_env_vars:
rank = os.environ.get(var)
if rank is not None:
return int(rank) != 0
return False


def _set_ddp_env_var(trial_id: str) -> None:
"""Sets an environment variable with current trial_id in a DDP setup."""
os.environ[_DDP_ENV_VAR_NAME] = trial_id


Loc = TypeVar("Loc")

# NOTE: As each NEPS process is only ever evaluating a single trial, this global can
Expand Down Expand Up @@ -119,6 +145,7 @@ def _set_global_trial(trial: Trial) -> Iterator[None]:
"\n\nThis is most likely a bug and should be reported to NePS!"
)
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS = trial
_set_ddp_env_var(trial.id)
yield

# This is mostly for `tblogger`
Expand Down Expand Up @@ -608,6 +635,47 @@ def run(self) -> None: # noqa: C901, PLR0912, PLR0915
)


def _launch_ddp_runtime(
*,
evaluation_fn: Callable[..., float | Mapping[str, Any]],
optimization_dir: Path,
) -> None:
neps_state = NePSState.create_or_load(
path=optimization_dir,
load_only=True,
)
prev_trial = None
while True:
current_eval_trials = neps_state.lock_and_get_current_evaluating_trials()
# If the worker id on previous trial is the same as the current one, only then
# evaluate it.
if len(current_eval_trials) > 0:
current_trial = None
if prev_trial is None:
# In the beginning, we simply read the current trial from the
# environment variable
if _DDP_ENV_VAR_NAME in os.environ:
current_id = os.getenv(_DDP_ENV_VAR_NAME)
if current_id is None:
raise RuntimeError(
"In a pytorch-lightning DDP setup, the environment variable"
f" '{_DDP_ENV_VAR_NAME}' was not set. This is probably a bug in"
" NePS and should be reported."
)
current_trial = neps_state.lock_and_get_trial_by_id(current_id)
else:
for trial in current_eval_trials: # type: ignore[unreachable]
if (
trial.metadata.evaluating_worker_id
== prev_trial.metadata.evaluating_worker_id
) and (trial.id != prev_trial.id):
current_trial = trial
break
if current_trial:
evaluation_fn(**current_trial.config)
prev_trial = current_trial


# TODO: This should be done directly in `api.run` at some point to make it clearer at an
# entryy point how the woerer is set up to run if someone reads the entry point code.
def _launch_runtime( # noqa: PLR0913
Expand All @@ -626,6 +694,13 @@ def _launch_runtime( # noqa: PLR0913
max_evaluations_for_worker: int | None,
sample_batch_size: int | None,
) -> None:
if _is_ddp_and_not_rank_zero():
# Do not launch a new worker if we are in a DDP setup and not rank 0
_launch_ddp_runtime(
evaluation_fn=evaluation_fn, optimization_dir=optimization_dir
)
return

if overwrite_optimization_dir and optimization_dir.exists():
logger.info(
f"Overwriting optimization directory '{optimization_dir}' as"
Expand Down
10 changes: 10 additions & 0 deletions neps/state/neps_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,16 @@ def lock_and_get_next_pending_trial(
return pendings[0] if pendings else None
return pendings[:n]

def lock_and_get_current_evaluating_trials(self) -> list[Trial]:
"""Get the current evaluating trials."""
with self._trial_lock.lock():
trials = self._trial_repo.latest()
return [
trial
for trial in trials.values()
if trial.metadata.state == Trial.State.EVALUATING
]

@classmethod
def create_or_load(
cls,
Expand Down
1 change: 1 addition & 0 deletions neps_examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"multi_fidelity",
"multi_fidelity_and_expert_priors",
"pytorch_native_ddp",
"pytorch_lightning_ddp",
],
}

Expand Down
101 changes: 101 additions & 0 deletions neps_examples/efficiency/pytorch_lightning_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import neps
import logging

NUM_GPU = 8 # Number of GPUs to use for DDP

class ToyModel(nn.Module):
""" Taken from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html """
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))

class LightningModel(L.LightningModule):
def __init__(self, lr):
super().__init__()
self.lr = lr
self.model = ToyModel()

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.mse_loss(y_hat, y)
self.log("train_loss", loss, prog_bar=True, sync_dist=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.mse_loss(y_hat, y)
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
return loss

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.mse_loss(y_hat, y)
self.log("test_loss", loss, prog_bar=True, sync_dist=True)
return loss

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.lr)

def evaluate_pipeline(lr=0.1, epoch=20):
L.seed_everything(42)
# Model
model = LightningModel(lr=lr)

# Generate random tensors for data and labels
data = torch.rand((1000, 10))
labels = torch.rand((1000, 5))

dataset = list(zip(data, labels))

train_dataset, val_dataset, test_dataset = random_split(dataset, [600, 200, 200])

# Define simple data loaders using tensors and slicing
train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=20, shuffle=False)

# Trainer with DDP Strategy
trainer = L.Trainer(gradient_clip_val=0.25,
max_epochs=epoch,
fast_dev_run=False,
strategy='ddp',
devices=NUM_GPU
)
trainer.fit(model, train_dataloader, val_dataloader)
trainer.validate(model, test_dataloader)
return trainer.logged_metrics["val_loss"]

pipeline_space = dict(
lr=neps.Float(
lower=0.001,
upper=0.1,
log=True,
prior=0.01
),
epoch=neps.Integer(
lower=1,
upper=3,
is_fidelity=True
)
)

logging.basicConfig(level=logging.INFO)
neps.run(
evaluate_pipeline=evaluate_pipeline,
pipeline_space=pipeline_space,
root_directory="results/pytorch_lightning_ddp",
max_evaluations_total=5
)

0 comments on commit 031f151

Please sign in to comment.