Skip to content

Commit

Permalink
fixed mnist downloading to temporary directory for surrogate testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Oskar Taubert committed Apr 19, 2024
1 parent e226908 commit f5fbd0b
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions tests/test_surrogate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os
import random
from functools import partial
from pathlib import Path
from typing import Dict, Generator, Tuple, Union

import pytest
Expand Down Expand Up @@ -134,7 +136,7 @@ def configure_optimizers(self) -> torch.optim.SGD:
return torch.optim.SGD(self.parameters(), lr=self.lr)


def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(batch_size: int, root=Path) -> Tuple[DataLoader, DataLoader]:
"""
Get MNIST train and validation dataloaders.
Expand Down Expand Up @@ -162,7 +164,7 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]:
if MPI.COMM_WORLD.Get_rank() == 0: # Only root downloads data.
train_loader = DataLoader(
dataset=MNIST(
download=True, root=".", transform=data_transform, train=True
download=True, root=root, transform=data_transform, train=True
), # Use MNIST training dataset.
batch_size=batch_size, # Batch size
shuffle=True, # Shuffle data.
Expand All @@ -176,14 +178,14 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]:
if MPI.COMM_WORLD.Get_rank() != 0:
train_loader = DataLoader(
dataset=MNIST(
download=False, root=".", transform=data_transform, train=True
download=False, root=root, transform=data_transform, train=True
), # Use MNIST training dataset.
batch_size=batch_size, # Batch size
shuffle=True, # Shuffle data.
)
val_loader = DataLoader(
dataset=MNIST(
download=False, root=".", transform=data_transform, train=False
download=False, root=root, transform=data_transform, train=False
), # Use MNIST testing dataset.
batch_size=1, # Batch size
shuffle=False, # Do not shuffle data.
Expand All @@ -192,7 +194,7 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]:


def ind_loss(
params: Dict[str, Union[int, float, str]],
params: Dict[str, Union[int, float, str]], root: Path
) -> Generator[float, None, None]:
"""
Loss function for evolutionary optimization with Propulate. Minimize the model's negative validation accuracy.
Expand Down Expand Up @@ -243,7 +245,7 @@ def ind_loss(
model.best_accuracy = 0.0 # Initialize the model's best validation accuracy.

train_loader, val_loader = get_data_loaders(
batch_size=8
batch_size=8, root=root
) # Get training and validation data loaders.

# Configure optimizer.
Expand Down Expand Up @@ -327,7 +329,7 @@ def test_mnist_static(mpi_tmp_path):
rng=rng, # Random number generator for evolutionary optimizer
)
islands = Islands( # Set up island model.
loss_fn=ind_loss, # Loss function to optimize
loss_fn=partial(ind_loss, root=mpi_tmp_path), # Loss function to optimize
propagator=propagator, # Evolutionary operator
rng=rng, # Random number generator
generations=num_generations, # Number of generations per worker
Expand Down Expand Up @@ -365,7 +367,7 @@ def test_mnist_dynamic(mpi_tmp_path):
rng=rng, # Random number generator for evolutionary optimizer
)
islands = Islands( # Set up island model.
loss_fn=ind_loss, # Loss function to optimize
loss_fn=partial(ind_loss, root=mpi_tmp_path), # Loss function to optimize
propagator=propagator, # Evolutionary operator
rng=rng, # Random number generator
generations=num_generations, # Number of generations per worker
Expand Down

0 comments on commit f5fbd0b

Please sign in to comment.