From f4334ff99206949a3832dd6b1fdbe27b60a79e5a Mon Sep 17 00:00:00 2001 From: Marie Weiel Date: Tue, 6 Feb 2024 11:44:47 +0100 Subject: [PATCH] clean up and sort imports and add early stopping --- tutorials/torch_example.py | 86 +++++++++++++++++++++++++------------- 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/tutorials/torch_example.py b/tutorials/torch_example.py index 6343d614..fb80f1ac 100755 --- a/tutorials/torch_example.py +++ b/tutorials/torch_example.py @@ -1,22 +1,19 @@ -#!/usr/bin/env python3 - +import logging +import os import random from typing import Union, Dict, Tuple + import torch from torch import nn from torch.utils.data import DataLoader - -from pytorch_lightning import LightningModule, Trainer -from lightning.pytorch import loggers +from lightning.pytorch import LightningModule, Trainer, loggers +from lightning.pytorch.callbacks.early_stopping import EarlyStopping from torchmetrics import Accuracy - from torchvision.datasets import MNIST from torchvision.transforms import Compose, ToTensor, Normalize - from mpi4py import MPI - -from propulate import Islands -from propulate.utils import get_default_propagator +from propulate import Propulator +from propulate.utils import get_default_propagator, set_logger_config GPUS_PER_NODE: int = 4 @@ -185,29 +182,43 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]: validation dataloader """ data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) + num_workers = 2 + # num_workers = int(int(os.getenv("SLURM_CPUS_ON_NODE")) / MPI.COMM_WORLD.size) + # Alternatively, use "SLURM_CPUS_PER_GPU". + # Only set if the --cpus-per-gpu option is specified. + print(f"Use {num_workers} workers in dataloader.") - if MPI.COMM_WORLD.Get_rank() == 0: # Only root downloads data. + if MPI.COMM_WORLD.rank == 0: # Only root downloads data. train_loader = DataLoader( dataset=MNIST( download=True, root=".", transform=data_transform, train=True ), # Use MNIST training dataset. batch_size=batch_size, # Batch size + num_workers=num_workers, + pin_memory=True, + persistent_workers=True, shuffle=True, # Shuffle data. ) - MPI.COMM_WORLD.Barrier() - if MPI.COMM_WORLD.Get_rank() != 0: + MPI.COMM_WORLD.barrier() + if MPI.COMM_WORLD.rank != 0: train_loader = DataLoader( dataset=MNIST( download=False, root=".", transform=data_transform, train=True ), # Use MNIST training dataset. batch_size=batch_size, # Batch size + num_workers=num_workers, + pin_memory=True, + persistent_workers=True, shuffle=True, # Shuffle data. ) val_loader = DataLoader( dataset=MNIST( download=False, root=".", transform=data_transform, train=False ), # Use MNIST testing dataset. + num_workers=num_workers, + pin_memory=True, + persistent_workers=True, batch_size=1, # Batch size shuffle=False, # Do not shuffle data. ) @@ -220,19 +231,20 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float: Parameters ---------- - params: dict[str, int | float | str]] + param s: dict[str, int | float | str]] + The hyperparameters to be optimized evolutionarily. Returns ------- float - The trained model's negative validation accuracy + The trained model's negative validation accuracy. """ # Extract hyperparameter combination to test from input dictionary. conv_layers = params["conv_layers"] # Number of convolutional layers activation = params["activation"] # Activation function lr = params["lr"] # Learning rate - epochs = 2 # Number of epochs to train + epochs = 100 activations = { "relu": nn.ReLU, @@ -261,8 +273,9 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float: trainer = Trainer( max_epochs=epochs, # Stop training once this number of epochs is reached. accelerator="gpu", # Pass accelerator type. - devices=[MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE], # Devices to train on - enable_progress_bar=True, # Disable progress bar. + devices=[MPI.COMM_WORLD.rank % GPUS_PER_NODE], # Devices to train on + callbacks=[EarlyStopping(monitor="val_loss", mode="min")], + enable_progress_bar=False, # Disable progress bar. logger=tb_logger, # Logger ) trainer.fit( # Run full model training optimization routine. @@ -275,15 +288,16 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float: if __name__ == "__main__": - num_generations = 3 # Number of generations - pop_size = 2 * MPI.COMM_WORLD.size # Breeding population size + comm = MPI.COMM_WORLD + num_generations = 10 # Number of generations + pop_size = 2 * comm.size # Breeding population size limits = { "conv_layers": (2, 10), "activation": ("relu", "sigmoid", "tanh"), "lr": (0.01, 0.0001), } # Define search space. rng = random.Random( - MPI.COMM_WORLD.rank + comm.rank ) # Set up separate random number generator for evolutionary optimizer. propagator = get_default_propagator( # Get default evolutionary operator. pop_size=pop_size, # Breeding population size @@ -293,16 +307,30 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float: random_prob=0.1, # Random-initialization probability rng=rng, # Random number generator for evolutionary optimizer ) - islands = Islands( # Set up island model. + + # Set up separate logger for Propulate optimization. + set_logger_config( + level=logging.INFO, # logging level + log_file="./propulator.log", # logging path + log_to_stdout=True, # Print log on stdout. + log_rank=False, # Do not prepend MPI rank to logging messages. + colors=True, # Use colors. + ) + + # Set up propulator performing actual optimization. + propulator = Propulator( loss_fn=ind_loss, # Loss function to optimize propagator=propagator, # Evolutionary operator - rng=rng, # Random number generator + comm=comm, # Communicator generations=num_generations, # Number of generations per worker - num_islands=1, # Number of islands - checkpoint_path=log_path, + checkpoint_path=log_path, # Path to save checkpoints to + rng=rng, # Random number generator + ) + + # Run optimization and print summary of results. + propulator.propulate( + logging_interval=1, debug=2 # Logging interval and verbosity level ) - islands.evolve( # Run evolutionary optimization. - top_n=1, # Print top-n best individuals on each island in summary. - logging_interval=1, # Logging interval - debug=2, # Verbosity level + propulator.summarize( + top_n=1, debug=2 # Print top-n best individuals on each island in summary. )