diff --git a/tutorials/torch_example.py b/tutorials/torch_example.py index bdc0cd4f..59148778 100755 --- a/tutorials/torch_example.py +++ b/tutorials/torch_example.py @@ -1,3 +1,10 @@ +""" +Toy example for hyperparameter optimization / NAS in Propulate, using a simple convolutional network trained on the +MNIST dataset. + +This script was tested on a single compute node with 4 GPUs. Note that you need to adapt ``GPUS_PER_NODE`` (see ll. 25). +and +""" import logging import random from typing import Union, Dict, Tuple @@ -15,8 +22,12 @@ from propulate.utils import get_default_propagator, set_logger_config -GPUS_PER_NODE: int = 4 +GPUS_PER_NODE: int = 4 # This example script was tested on a single node with 4 GPUs. +NUM_WORKERS: int = ( + 2 # Set this to the recommended number of workers in the PyTorch dataloader. +) log_path = "torch_ckpts" +log = logging.getLogger(__name__) # Get logger instance. class Net(LightningModule): @@ -39,11 +50,11 @@ def __init__( activation: torch.nn.modules.activation The activation function to use. lr: float - learning rate + The learning rate. loss_fn: torch.nn.modules.loss The loss function. """ - super(Net, self).__init__() + super().__init__() self.lr = lr # Set learning rate. self.loss_fn = loss_fn # Set the loss function used for training the model. @@ -180,11 +191,11 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]: The validation dataloader. """ data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) - num_workers = 2 + num_workers = NUM_WORKERS # num_workers = int(int(os.getenv("SLURM_CPUS_ON_NODE")) / MPI.COMM_WORLD.size) # Alternatively, use "SLURM_CPUS_PER_GPU". # Only set if the --cpus-per-gpu option is specified. - print(f"Use {num_workers} workers in dataloader.") + log.info(f"Use {num_workers} workers in dataloader.") if MPI.COMM_WORLD.rank == 0: # Only root downloads data. train_loader = DataLoader( @@ -229,7 +240,7 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float: Parameters ---------- - param s: dict[str, int | float | str]] + params : Dict[str, int | float | str] The hyperparameters to be optimized evolutionarily. Returns