Skip to content

Commit

Permalink
clean up and sort imports and add early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw92 committed Feb 6, 2024
1 parent 7a6a707 commit f4334ff
Showing 1 changed file with 57 additions and 29 deletions.
86 changes: 57 additions & 29 deletions tutorials/torch_example.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
#!/usr/bin/env python3

import logging
import os

Check failure on line 2 in tutorials/torch_example.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

tutorials/torch_example.py:2:8: F401 `os` imported but unused
import random
from typing import Union, Dict, Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader

from pytorch_lightning import LightningModule, Trainer
from lightning.pytorch import loggers
from lightning.pytorch import LightningModule, Trainer, loggers
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from torchmetrics import Accuracy

from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

from mpi4py import MPI

from propulate import Islands
from propulate.utils import get_default_propagator
from propulate import Propulator
from propulate.utils import get_default_propagator, set_logger_config


GPUS_PER_NODE: int = 4
Expand Down Expand Up @@ -185,29 +182,43 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]:
validation dataloader
"""
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
num_workers = 2
# num_workers = int(int(os.getenv("SLURM_CPUS_ON_NODE")) / MPI.COMM_WORLD.size)
# Alternatively, use "SLURM_CPUS_PER_GPU".
# Only set if the --cpus-per-gpu option is specified.
print(f"Use {num_workers} workers in dataloader.")

if MPI.COMM_WORLD.Get_rank() == 0: # Only root downloads data.
if MPI.COMM_WORLD.rank == 0: # Only root downloads data.
train_loader = DataLoader(
dataset=MNIST(
download=True, root=".", transform=data_transform, train=True
), # Use MNIST training dataset.
batch_size=batch_size, # Batch size
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
shuffle=True, # Shuffle data.
)

MPI.COMM_WORLD.Barrier()
if MPI.COMM_WORLD.Get_rank() != 0:
MPI.COMM_WORLD.barrier()
if MPI.COMM_WORLD.rank != 0:
train_loader = DataLoader(
dataset=MNIST(
download=False, root=".", transform=data_transform, train=True
), # Use MNIST training dataset.
batch_size=batch_size, # Batch size
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
shuffle=True, # Shuffle data.
)
val_loader = DataLoader(
dataset=MNIST(
download=False, root=".", transform=data_transform, train=False
), # Use MNIST testing dataset.
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_size=1, # Batch size
shuffle=False, # Do not shuffle data.
)
Expand All @@ -220,19 +231,20 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float:
Parameters
----------
params: dict[str, int | float | str]]
param s: dict[str, int | float | str]]
The hyperparameters to be optimized evolutionarily.
Returns
-------
float
The trained model's negative validation accuracy
The trained model's negative validation accuracy.
"""
# Extract hyperparameter combination to test from input dictionary.
conv_layers = params["conv_layers"] # Number of convolutional layers
activation = params["activation"] # Activation function
lr = params["lr"] # Learning rate

epochs = 2 # Number of epochs to train
epochs = 100

activations = {
"relu": nn.ReLU,
Expand Down Expand Up @@ -261,8 +273,9 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float:
trainer = Trainer(
max_epochs=epochs, # Stop training once this number of epochs is reached.
accelerator="gpu", # Pass accelerator type.
devices=[MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE], # Devices to train on
enable_progress_bar=True, # Disable progress bar.
devices=[MPI.COMM_WORLD.rank % GPUS_PER_NODE], # Devices to train on
callbacks=[EarlyStopping(monitor="val_loss", mode="min")],
enable_progress_bar=False, # Disable progress bar.
logger=tb_logger, # Logger
)
trainer.fit( # Run full model training optimization routine.
Expand All @@ -275,15 +288,16 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float:


if __name__ == "__main__":
num_generations = 3 # Number of generations
pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size
comm = MPI.COMM_WORLD
num_generations = 10 # Number of generations
pop_size = 2 * comm.size # Breeding population size
limits = {
"conv_layers": (2, 10),
"activation": ("relu", "sigmoid", "tanh"),
"lr": (0.01, 0.0001),
} # Define search space.
rng = random.Random(
MPI.COMM_WORLD.rank
comm.rank
) # Set up separate random number generator for evolutionary optimizer.
propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=pop_size, # Breeding population size
Expand All @@ -293,16 +307,30 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float:
random_prob=0.1, # Random-initialization probability
rng=rng, # Random number generator for evolutionary optimizer
)
islands = Islands( # Set up island model.

# Set up separate logger for Propulate optimization.
set_logger_config(
level=logging.INFO, # logging level
log_file="./propulator.log", # logging path
log_to_stdout=True, # Print log on stdout.
log_rank=False, # Do not prepend MPI rank to logging messages.
colors=True, # Use colors.
)

# Set up propulator performing actual optimization.
propulator = Propulator(
loss_fn=ind_loss, # Loss function to optimize
propagator=propagator, # Evolutionary operator
rng=rng, # Random number generator
comm=comm, # Communicator
generations=num_generations, # Number of generations per worker
num_islands=1, # Number of islands
checkpoint_path=log_path,
checkpoint_path=log_path, # Path to save checkpoints to
rng=rng, # Random number generator
)

# Run optimization and print summary of results.
propulator.propulate(
logging_interval=1, debug=2 # Logging interval and verbosity level
)
islands.evolve( # Run evolutionary optimization.
top_n=1, # Print top-n best individuals on each island in summary.
logging_interval=1, # Logging interval
debug=2, # Verbosity level
propulator.summarize(
top_n=1, debug=2 # Print top-n best individuals on each island in summary.
)

0 comments on commit f4334ff

Please sign in to comment.