From f5fbd0b58ba4f42eac6492d5404e665b31f52f09 Mon Sep 17 00:00:00 2001 From: Oskar Taubert Date: Fri, 19 Apr 2024 13:58:51 +0200 Subject: [PATCH] fixed mnist downloading to temporary directory for surrogate testing --- tests/test_surrogate.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/test_surrogate.py b/tests/test_surrogate.py index da883a48..9f5b2fa1 100644 --- a/tests/test_surrogate.py +++ b/tests/test_surrogate.py @@ -1,6 +1,8 @@ import logging import os import random +from functools import partial +from pathlib import Path from typing import Dict, Generator, Tuple, Union import pytest @@ -134,7 +136,7 @@ def configure_optimizers(self) -> torch.optim.SGD: return torch.optim.SGD(self.parameters(), lr=self.lr) -def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]: +def get_data_loaders(batch_size: int, root=Path) -> Tuple[DataLoader, DataLoader]: """ Get MNIST train and validation dataloaders. @@ -162,7 +164,7 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]: if MPI.COMM_WORLD.Get_rank() == 0: # Only root downloads data. train_loader = DataLoader( dataset=MNIST( - download=True, root=".", transform=data_transform, train=True + download=True, root=root, transform=data_transform, train=True ), # Use MNIST training dataset. batch_size=batch_size, # Batch size shuffle=True, # Shuffle data. @@ -176,14 +178,14 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]: if MPI.COMM_WORLD.Get_rank() != 0: train_loader = DataLoader( dataset=MNIST( - download=False, root=".", transform=data_transform, train=True + download=False, root=root, transform=data_transform, train=True ), # Use MNIST training dataset. batch_size=batch_size, # Batch size shuffle=True, # Shuffle data. ) val_loader = DataLoader( dataset=MNIST( - download=False, root=".", transform=data_transform, train=False + download=False, root=root, transform=data_transform, train=False ), # Use MNIST testing dataset. batch_size=1, # Batch size shuffle=False, # Do not shuffle data. @@ -192,7 +194,7 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]: def ind_loss( - params: Dict[str, Union[int, float, str]], + params: Dict[str, Union[int, float, str]], root: Path ) -> Generator[float, None, None]: """ Loss function for evolutionary optimization with Propulate. Minimize the model's negative validation accuracy. @@ -243,7 +245,7 @@ def ind_loss( model.best_accuracy = 0.0 # Initialize the model's best validation accuracy. train_loader, val_loader = get_data_loaders( - batch_size=8 + batch_size=8, root=root ) # Get training and validation data loaders. # Configure optimizer. @@ -327,7 +329,7 @@ def test_mnist_static(mpi_tmp_path): rng=rng, # Random number generator for evolutionary optimizer ) islands = Islands( # Set up island model. - loss_fn=ind_loss, # Loss function to optimize + loss_fn=partial(ind_loss, root=mpi_tmp_path), # Loss function to optimize propagator=propagator, # Evolutionary operator rng=rng, # Random number generator generations=num_generations, # Number of generations per worker @@ -365,7 +367,7 @@ def test_mnist_dynamic(mpi_tmp_path): rng=rng, # Random number generator for evolutionary optimizer ) islands = Islands( # Set up island model. - loss_fn=ind_loss, # Loss function to optimize + loss_fn=partial(ind_loss, root=mpi_tmp_path), # Loss function to optimize propagator=propagator, # Evolutionary operator rng=rng, # Random number generator generations=num_generations, # Number of generations per worker