Skip to content

Commit

Permalink
add comments and docstring about system specifications
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw92 committed Mar 12, 2024
1 parent 464ed73 commit cb87967
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions tutorials/torch_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
Toy example for hyperparameter optimization / NAS in Propulate, using a simple convolutional network trained on the
MNIST dataset.
This script was tested on a single compute node with 4 GPUs. Note that you need to adapt ``GPUS_PER_NODE`` (see ll. 25).
and
"""
import logging
import random
from typing import Union, Dict, Tuple
Expand All @@ -15,8 +22,12 @@
from propulate.utils import get_default_propagator, set_logger_config


GPUS_PER_NODE: int = 4
GPUS_PER_NODE: int = 4 # This example script was tested on a single node with 4 GPUs.
NUM_WORKERS: int = (
2 # Set this to the recommended number of workers in the PyTorch dataloader.
)
log_path = "torch_ckpts"
log = logging.getLogger(__name__) # Get logger instance.


class Net(LightningModule):
Expand All @@ -39,11 +50,11 @@ def __init__(
activation: torch.nn.modules.activation
The activation function to use.
lr: float
learning rate
The learning rate.
loss_fn: torch.nn.modules.loss
The loss function.
"""
super(Net, self).__init__()
super().__init__()

self.lr = lr # Set learning rate.
self.loss_fn = loss_fn # Set the loss function used for training the model.
Expand Down Expand Up @@ -180,11 +191,11 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]:
The validation dataloader.
"""
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
num_workers = 2
num_workers = NUM_WORKERS
# num_workers = int(int(os.getenv("SLURM_CPUS_ON_NODE")) / MPI.COMM_WORLD.size)
# Alternatively, use "SLURM_CPUS_PER_GPU".
# Only set if the --cpus-per-gpu option is specified.
print(f"Use {num_workers} workers in dataloader.")
log.info(f"Use {num_workers} workers in dataloader.")

if MPI.COMM_WORLD.rank == 0: # Only root downloads data.
train_loader = DataLoader(
Expand Down Expand Up @@ -229,7 +240,7 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float:
Parameters
----------
param s: dict[str, int | float | str]]
params : Dict[str, int | float | str]
The hyperparameters to be optimized evolutionarily.
Returns
Expand Down

0 comments on commit cb87967

Please sign in to comment.