Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added support for PyTorch Lightning in the DDP backend. #162

Merged
merged 7 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ def _default_worker_name() -> str:
return f"{os.getpid()}-{isoformat}"


_DDP_ENV_VAR_NAME = "NEPS_DDP_TRIAL_ID"


def _is_ddp_and_not_rank_zero() -> bool:
import torch.distributed as dist

# Check for environment variables typically set by DDP
ddp_env_vars = ["WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]
rank_env_vars = ["RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK"]

# Check if PyTorch distributed is initialized
if (dist.is_available() and dist.is_initialized()) or all(
var in os.environ for var in ddp_env_vars
):
for var in rank_env_vars:
rank = os.environ.get(var)
if rank is not None:
return int(rank) != 0
return False


def _set_ddp_env_var(trial_id: str) -> None:
"""Sets an environment variable with current trial_id in a DDP setup."""
os.environ[_DDP_ENV_VAR_NAME] = trial_id


Loc = TypeVar("Loc")

# NOTE: As each NEPS process is only ever evaluating a single trial, this global can
Expand Down Expand Up @@ -119,6 +145,7 @@ def _set_global_trial(trial: Trial) -> Iterator[None]:
"\n\nThis is most likely a bug and should be reported to NePS!"
)
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS = trial
_set_ddp_env_var(trial.id)
yield

# This is mostly for `tblogger`
Expand Down Expand Up @@ -608,6 +635,47 @@ def run(self) -> None: # noqa: C901, PLR0912, PLR0915
)


def _launch_ddp_runtime(
*,
evaluation_fn: Callable[..., float | Mapping[str, Any]],
optimization_dir: Path,
) -> None:
neps_state = NePSState.create_or_load(
path=optimization_dir,
load_only=True,
)
prev_trial = None
while True:
current_eval_trials = neps_state.lock_and_get_current_evaluating_trials()
# If the worker id on previous trial is the same as the current one, only then
# evaluate it.
if len(current_eval_trials) > 0:
current_trial = None
if prev_trial is None:
# In the beginning, we simply read the current trial from the
# environment variable
if _DDP_ENV_VAR_NAME in os.environ:
current_id = os.getenv(_DDP_ENV_VAR_NAME)
if current_id is None:
raise RuntimeError(
"In a pytorch-lightning DDP setup, the environment variable"
f" '{_DDP_ENV_VAR_NAME}' was not set. This is probably a bug in"
" NePS and should be reported."
)
current_trial = neps_state.lock_and_get_trial_by_id(current_id)
else:
for trial in current_eval_trials: # type: ignore[unreachable]
if (
trial.metadata.evaluating_worker_id
== prev_trial.metadata.evaluating_worker_id
) and (trial.id != prev_trial.id):
current_trial = trial
break
if current_trial:
evaluation_fn(**current_trial.config)
prev_trial = current_trial


# TODO: This should be done directly in `api.run` at some point to make it clearer at an
# entryy point how the woerer is set up to run if someone reads the entry point code.
def _launch_runtime( # noqa: PLR0913
Expand All @@ -626,6 +694,13 @@ def _launch_runtime( # noqa: PLR0913
max_evaluations_for_worker: int | None,
sample_batch_size: int | None,
) -> None:
if _is_ddp_and_not_rank_zero():
# Do not launch a new worker if we are in a DDP setup and not rank 0
_launch_ddp_runtime(
evaluation_fn=evaluation_fn, optimization_dir=optimization_dir
)
return

if overwrite_optimization_dir and optimization_dir.exists():
logger.info(
f"Overwriting optimization directory '{optimization_dir}' as"
Expand Down
10 changes: 10 additions & 0 deletions neps/state/neps_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,16 @@ def lock_and_get_next_pending_trial(
return pendings[0] if pendings else None
return pendings[:n]

def lock_and_get_current_evaluating_trials(self) -> list[Trial]:
"""Get the current evaluating trials."""
with self._trial_lock.lock():
trials = self._trial_repo.latest()
return [
trial
for trial in trials.values()
if trial.metadata.state == Trial.State.EVALUATING
]

@classmethod
def create_or_load(
cls,
Expand Down
1 change: 1 addition & 0 deletions neps_examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"multi_fidelity",
"multi_fidelity_and_expert_priors",
"pytorch_native_ddp",
"pytorch_lightning_ddp",
],
}

Expand Down
101 changes: 101 additions & 0 deletions neps_examples/efficiency/pytorch_lightning_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import neps
import logging

NUM_GPU = 8 # Number of GPUs to use for DDP

class ToyModel(nn.Module):
""" Taken from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html """
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))

class LightningModel(L.LightningModule):
def __init__(self, lr):
super().__init__()
self.lr = lr
self.model = ToyModel()

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.mse_loss(y_hat, y)
self.log("train_loss", loss, prog_bar=True, sync_dist=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.mse_loss(y_hat, y)
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
return loss

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.mse_loss(y_hat, y)
self.log("test_loss", loss, prog_bar=True, sync_dist=True)
return loss

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.lr)

def evaluate_pipeline(lr=0.1, epoch=20):
L.seed_everything(42)
# Model
model = LightningModel(lr=lr)

# Generate random tensors for data and labels
data = torch.rand((1000, 10))
labels = torch.rand((1000, 5))

dataset = list(zip(data, labels))

train_dataset, val_dataset, test_dataset = random_split(dataset, [600, 200, 200])

# Define simple data loaders using tensors and slicing
train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=20, shuffle=False)

# Trainer with DDP Strategy
trainer = L.Trainer(gradient_clip_val=0.25,
max_epochs=epoch,
fast_dev_run=False,
strategy='ddp',
devices=NUM_GPU
)
trainer.fit(model, train_dataloader, val_dataloader)
trainer.validate(model, test_dataloader)
return trainer.logged_metrics["val_loss"]

pipeline_space = dict(
lr=neps.Float(
lower=0.001,
upper=0.1,
log=True,
prior=0.01
),
epoch=neps.Integer(
lower=1,
upper=3,
is_fidelity=True
)
)

logging.basicConfig(level=logging.INFO)
neps.run(
evaluate_pipeline=evaluate_pipeline,
pipeline_space=pipeline_space,
root_directory="results/pytorch_lightning_ddp",
max_evaluations_total=5
)
Loading