diff --git a/README.md b/README.md index 5ef10ab..813a461 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,16 @@ -# slune -A super simplistic way to perform hyperparameter tuning on SLURM. Will submit a seperate job script for each run in the tuning and store metrics in a csv and then allow you to easily query for the best hyperparamters based on metric. +# slune (slurm + tune!) +A super simplistic way to perform hyperparameter tuning on a cluster using SLURM. Will submit a seperate job script for each run in the tuning and store metrics in a csv and then allow you to easily query for the best hyperparamters based on metric. + +Currently very much in early stages, first things still to do: +- Add ability to read results, currently can only submit jobs and log metrics during tuning. +- Refine class structure, ie. subclassing, making sure classes have essential methods, what are the essential methods and attributes? etc. +- Refine package structure and sort out github actions like test coverage, running tests etc. +- Add interfacing with SLURM to check for and re-submit failed jobs etc. +- Add more tests and documentation. +- Add some more subclasses for saving job results in different ways and for different tuning methods. +Although the idea for this package is to keep it ultra bare-bones and make it easy for the user to mod and add things themselves to their liking. + +To run tests use: +```bash +python -m unittest discover -s . -p 'test_*.py' +``` \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..043dbe6 --- /dev/null +++ b/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='slune', + version='0.1', + packages=find_packages(), + install_requires=[ + "argparse", + "subprocess", + "os", + "pandas", + ], +) diff --git a/slune/__init__.py b/slune/__init__.py new file mode 100644 index 0000000..a2397b5 --- /dev/null +++ b/slune/__init__.py @@ -0,0 +1 @@ +__all__ = ['slune', 'base', 'utils', 'loggers', 'savers', 'searchers' ] \ No newline at end of file diff --git a/slune/base.py b/slune/base.py new file mode 100644 index 0000000..27d6fc5 --- /dev/null +++ b/slune/base.py @@ -0,0 +1,51 @@ + +class Searcher(): + """ + Class that creates search space and returns arguments to pass to sbatch script + """ + def __init__(self): + pass + + def __len__(self): + """ + Returns the number of hyperparameter configurations to try. + """ + return len(self.searcher) + + def next_tune(self, args, kwargs): + """ + Returns the next hyperparameter configuration to try. + """ + return self.searcher.next_tune(args, kwargs) + +class Slog(): + """ + Class used to log metrics during tuning run and to save the results. + Args: + - Logger (object): Class that handles logging of metrics, including the formatting that will be used to save and read the results. + - Saver (object): Class that handles saving logs from Logger to storage and fetching correct logs from storage to give to Logger to read. + """ + def __init__(self, params, Logger, Saver): + self.logger = Logger + self.saver = Saver(params) + + def log(self, args, kwargs): + """ + Logs the metric/s for the current hyperparameter configuration, + stores them in a data frame that we can later save in storage. + """ + self.logger.log(args, kwargs) + + def save_collated(self, args, kwargs): + """ + Saves the current results in logger to storage. + """ + self.saver.save_collated(self.logger.results, args, kwargs) + + def read(self, args, kwargs): + """ + Reads results from storage. + """ + return self.saver.read(args, kwargs) + + diff --git a/slune/loggers.py b/slune/loggers.py new file mode 100644 index 0000000..4629e1d --- /dev/null +++ b/slune/loggers.py @@ -0,0 +1,32 @@ +import pandas as pd + +class LoggerDefault(): + """ + Logs the metric/s for the current hyperparameter configuration, + stores them in a data frame that we can later save in storage. + """ + def __init__(self): + self.results = pd.DataFrame() + + def log(self, metrics): + """ + Logs the metric/s for the current hyperparameter configuration, + stores them in a data frame that we can later save in storage. + All metrics provided will be saved as a row in the results data frame, + the first column is always the time stamp at which log is called. + Args: + - metrics (dict): Dictionary of metrics to be logged, keys are metric names and values are metric values. + Each metric should only have one value! So please log as soon as you get a metric + """ + # Get current time stamp + time_stamp = pd.Timestamp.now() + # Add time stamp to metrics dictionary + metrics['time_stamp'] = time_stamp + # Convert metrics dictionary to a dataframe + metrics_df = pd.DataFrame(metrics, index=[0]) + # Append metrics dataframe to results dataframe + self.results = pd.concat([self.results, metrics_df], ignore_index=True) + + def read_log(self, args, kwargs): + # TODO: implement this function + raise NotImplementedError \ No newline at end of file diff --git a/slune/savers.py b/slune/savers.py new file mode 100644 index 0000000..0b8c9f9 --- /dev/null +++ b/slune/savers.py @@ -0,0 +1,87 @@ +import os +import pandas as pd +from slune.utils import find_directory_path + +class SaverCsv(): + """ + Saves the results of each run in a CSV file in a hierarchical directory structure based on argument names. + """ + def __init__(self, params, root_dir='./tuning_results'): + self.root_dir = root_dir + self.current_path = self.get_path(params) + + def strip_params(self, params): + """ + Strips the argument names from the arguments given by args. + eg. ["--argument_name=argument_value", ...] -> ["--argument_name=", ...] + Also gets rid of blank spaces + """ + return [p.split('=')[0].strip() for p in params] + + def get_match(self, params): + """ + Searches the root directory for a directory tree that matches the parameters given. + If only partial matches are found, returns the deepest matching directory with the missing parameters appended. + If no matches are found creates a path using the parameters. + """ + # First check if there is a directory with path matching some subset of the arguments + stripped_params = [p.split('=')[0].strip() +'=' for p in params] # Strip the params of whitespace and everything after the '=' + match = find_directory_path(stripped_params, root_directory=self.root_dir) + # Check which arguments are missing from the path + missing_params = [[p for p in params if sp in p][0] for sp in stripped_params if sp not in match] + # Now we add back in the values we stripped out + match = match.split('/') + match = [match[0]] + [[p for p in params if m in p][0] for m in match[1:]] + match = '/'.join(match) + # If there are missing arguments, add them to the path + if len(missing_params) > 0: + match = os.path.join(match, *missing_params) + return match + + def get_path(self, params): + """ + Creates a path using the parameters by checking existing directories in the root directory. + Check get_match for how we create the path, we then check if results files for this path already exist, + if they do we increment the number of the results file name that we will use. + TODO: Add option to dictate order of parameters in directory structure. + TODO: Return warnings if there exist multiple paths that match the parameters but in a different order, or paths that don't go as deep as others. + Args: + - params (list): List of strings containing the arguments used, in form ["--argument_name=argument_value", ...]. + """ + # Check if root directory exists, if not create it + if not os.path.exists(self.root_dir): + os.makedirs(self.root_dir) + # Get path of directory where we should store our csv of results + dir_path = self.get_match(params) + # Check if directory exists, if not create it + if not os.path.exists(dir_path): + csv_file_number = 0 + # If it does exist, check if there is already a csv file with results, + # if there is find the name of the last csv file and increment the number + else: + csv_files = [f for f in os.listdir(dir_path) if f.endswith('.csv')] + if len(csv_files) > 0: + last_csv_file = max(csv_files) + # Check that the last csv file starts with "results_" + if not last_csv_file.startswith('results_'): + raise ValueError('Found csv file in directory that doesn\'t start with "results_"') + csv_file_number = int(last_csv_file.split('_')[1][:-4]) + 1 + else: + csv_file_number = 0 + # Create path name for a new csv file where we can later store results + csv_file_path = os.path.join(dir_path, f'results_{csv_file_number}.csv') + return csv_file_path + + def save_collated(self, results): + # We add results onto the end of the current results in the csv file if it already exists + # if not then we create a new csv file and save the results there + if os.path.exists(self.current_path): + results = pd.concat([pd.read_csv(self.current_path), results]) + results.to_csv(self.current_path, mode='w', index=False) + else: + results.to_csv(self.current_path, index=False) + + def read(self, args, kwargs): + # TODO: implement this function + raise NotImplementedError + diff --git a/slune/searchers.py b/slune/searchers.py new file mode 100644 index 0000000..a8ded88 --- /dev/null +++ b/slune/searchers.py @@ -0,0 +1,67 @@ +from slune.utils import dict_to_strings + +class SearcherGrid(): + """ + Given dictionary of hyperparameters and values to try, creates grid of all possible hyperparameter configurations, + and returns them one by one for each call to next_tune. + Args: + - hyperparameters (dict): Dictionary of hyperparameters and values to try. + Structure of dictionary should be: { "--argument_name" : [Value_1, Value_2, ...], ... } + TODO: Add extra functionality by using nested dictionaries to specify which hyperparameters to try together. + """ + def __init__(self, hyperparameters): + self.hyperparameters = hyperparameters + self.grid = self.get_grid(hyperparameters) + self.grid_index = None + + def __len__(self): + """ + Returns the number of hyperparameter configurations to try. + """ + return len(self.grid) + + def get_grid(self, param_dict): + """ + Generate all possible combinations of values for each argument in the given dictionary using recursion. + + Args: + param_dict (dict): A dictionary where keys are argument names and values are lists of values. + + Returns: + list: A list of dictionaries, each containing one combination of argument values. + """ + # Helper function to recursively generate combinations + def generate_combinations(param_names, current_combination, all_combinations): + if not param_names: + # If there are no more parameters to combine, add the current combination to the result + all_combinations.append(dict(current_combination)) + return + + param_name = param_names[0] + param_values = param_dict[param_name] + + for value in param_values: + current_combination[param_name] = value + # Recursively generate combinations for the remaining parameters + generate_combinations(param_names[1:], current_combination, all_combinations) + + # Start with an empty combination and generate all combinations + all_combinations = [] + generate_combinations(list(param_dict.keys()), {}, all_combinations) + + return all_combinations + + def next_tune(self): + """ + Returns the next hyperparameter configuration to try. + """ + # If this is the first call to next_tune, set grid_index to 0 + if self.grid_index is None: + self.grid_index = 0 + else: + self.grid_index += 1 + # If we have reached the end of the grid, raise an error + if self.grid_index == len(self.grid): + raise IndexError('Reached end of grid, no more hyperparameter configurations to try.') + # Return the next hyperparameter configuration to try + return dict_to_strings(self.grid[self.grid_index]) diff --git a/slune/slune.py b/slune/slune.py new file mode 100644 index 0000000..9f23bd1 --- /dev/null +++ b/slune/slune.py @@ -0,0 +1,39 @@ +from argparse import ArgumentParser +import subprocess + +def submit_job(sh_path, args): + """ + Submits a job using the Bash script at sh_path, + args is a list of strings containing the arguments to be passed to the Bash script. + """ + try: + # Run the Bash script using subprocess + command = ['bash', sh_path] + args + subprocess.run(['sbatch', command], check=True) + except subprocess.CalledProcessError as e: + print(f"Error running sbatch: {e}") + +def sbatchit(script_path, template_path, tuning, cargs=[]): + """ + Carries out hyper-parameter tuning by submitting a job for each set of hyper-parameters given by tune_control, + for each job runs the script stored at script_path with selected hyper-parameter values and the arguments given by cargs. + Uses the template file with path template_path to guide the creation of the sbatch script for each job. + Args: + - script_path (string): Path to the script (of the model) to be run for each job. + + - template_path (string): Path to the template file used to create the sbatch script for each job. + + - cargs (list): List of strings containing the arguments to be passed to the script for each job. + Must be a list even if there is just one argument, default is empty list. + + - tuning (Tuning): Tuning object used to select hyper-parameter values for each job. + """ + # Create sbatch script for each job + for i in range(len(tuning)): + # Get argument for this job + args = tuning.next_tune() + # Submit job + submit_job(template_path, [script_path] + cargs + args) + print("Submitted all jobs!") + +# TODO: add functions for reading results diff --git a/slune/utils.py b/slune/utils.py new file mode 100644 index 0000000..6a16fb1 --- /dev/null +++ b/slune/utils.py @@ -0,0 +1,49 @@ +import os + +def find_directory_path(strings, root_directory='.'): + """ + Searches the root directory for a path of directories that matches the strings given in any order. + If only a partial match is found, returns the deepest matching directory with the missing strings appended. + If no matches are found returns the strings as a path. + Args: + - strings (list): List of strings to be matched in any order. Each string in list must be in the form '--string='. + - root_directory (string): Path to the root directory to be searched, default is current working directory. + # TODO: could probably optimize this function + """ + # Get list of directories in root directory + dir_list = os.listdir(root_directory) + # Get substring up to and including '=' for each directory name in dir_list, and strip whitespace + stripped_dir_list = [d.split('=')[0].strip() +"=" for d in dir_list] + # Get rid of duplicates + stripped_dir_list = list(set(stripped_dir_list)) + # Check if any of the strings are in the list of directories + for string in strings: + if string in stripped_dir_list: + # If a string is found it means that at the current root there is a directory starting "--string=" + # we now want to find all directories in the root directory that start with "--string=" and search them recursively + # then we return the path to the deepest directory found + dir_list = [d for d in dir_list if d.startswith(string)] + # Recursively search each directory starting with string + paths = [] + for d in dir_list: + paths.append(find_directory_path(strings, os.path.join(root_directory, d))) + # Return the deepest directory found, ie. most /'s in path + return max(paths, key=lambda x: x.count('/')) + # If no strings are found, return the root directory + dirs = root_directory.split('/') + dirs = [dirs[0]] + [d.split('=')[0].strip() +"=" for d in dirs[1:]] + root_directory = '/'.join(dirs) + return root_directory + + +def dict_to_strings(d): + """ + Converts a dictionary into a list of strings in the form of '--key=value'. + """ + out = [] + for key, value in d.items(): + if key.startswith('--'): + out.append('{}={}'.format(key, value)) + else: + out.append('--{}={}'.format(key, value)) + return out \ No newline at end of file diff --git a/templates/cpu_template.sh b/templates/cpu_template.sh new file mode 100644 index 0000000..df9f649 --- /dev/null +++ b/templates/cpu_template.sh @@ -0,0 +1,39 @@ +#!/bin/bash +#SBATCH --job-name=my_job_name # Job name +#SBATCH --output=my_job_output.log # Output file (stdout) +#SBATCH --error=my_job_error.log # Error file (stderr) +#SBATCH --partition=cpu # Specify the partition/queue name +#SBATCH --nodes=1 # Number of nodes +#SBATCH --ntasks=1 # Number of tasks (cores) +#SBATCH --cpus-per-task=1 # Number of CPU cores per task +#SBATCH --mem=1G # Memory per node (in GB) +#SBATCH --time=01:00:00 # Wall time (hh:mm:ss) +#SBATCH --mail-user=your@email.com # Email address for job notifications +#SBATCH --mail-type=ALL # Email notifications (BEGIN, END, FAIL) + +# Define executable +export EXE=/bin/hostname + +# Optional: Load necessary modules or set environment variables +# module load your_module +# export YOUR_VARIABLE=value + +# Change to your working directory +cd "${SLURM_SUBMIT_DIR}" + +# Execute code +${EXE} + +# Print some usefull stuff! +echo JOB ID: ${SLURM_JOBID} +echo Working Directory: $(pwd) +echo Start Time: $(date) + +# Activate virtual environment (if you have one), change the path to match the location of your virtual environment +source ../pyvenv/bin/activate + +# Where we run the script to perform training run with model +python $1 ${@:2} + +# End of job script, let's print the time at which we finished +echo End Time: $(date) \ No newline at end of file diff --git a/templates/gpu_template.sh b/templates/gpu_template.sh new file mode 100644 index 0000000..be54506 --- /dev/null +++ b/templates/gpu_template.sh @@ -0,0 +1,45 @@ +#!/bin/bash +#SBATCH --job-name=my_job_name # Job name +#SBATCH --output=my_job_output.log # Output file (stdout) +#SBATCH --error=my_job_error.log # Error file (stderr) +#SBATCH --partition=gpu # Specify the partition/queue name +#SBATCH --nodes=1 # Number of nodes +#SBATCH --ntasks=1 # Number of tasks (cores) +#SBATCH --cpus-per-task=1 # Number of CPU cores per task +#SBATCH --gres=gpu:1 # Define number of GPUs per node, can also define type of GPU eg. gpu:tesla, gpu:k80, gpu:p100, gpu:v100 +#SBATCH --mem-per-gpu=1G # Define memory per GPU +#SBATCH --time=01:00:00 # Wall time (hh:mm:ss) +#SBATCH --mail-user=your@email.com # Email address for job notifications +#SBATCH --mail-type=ALL # Email notifications (BEGIN, END, FAIL) + +# Define executable +export EXE=/bin/hostname + +# Optional: Load necessary modules or set environment variables +# module load your_module +# export YOUR_VARIABLE=value + +# Change to your working directory +cd "${SLURM_SUBMIT_DIR}" + +# Execute code +${EXE} + +# Print some usefull stuff! +echo JOB ID: ${SLURM_JOBID} +echo Working Directory: $(pwd) +echo Start Time: $(date) + +# Print GPU information +nvidia-smi --query-gpu=name --format=csv,noheader + +# Activate virtual environment (if you have one), change the path to match the location of your virtual environment +source ../pyvenv/bin/activate + +# Where we run the script to perform training run with model, +# first argument to this job script will be the python script to run, +# the rest of the arguments passed to the job script will be passed as arguments to the python script +python $1 ${@:2} + +# End of job script, let's print the time at which we finished +echo End Time: $(date) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..4eb03b6 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Empty! \ No newline at end of file diff --git a/tests/test_loggers.py b/tests/test_loggers.py new file mode 100644 index 0000000..9163006 --- /dev/null +++ b/tests/test_loggers.py @@ -0,0 +1,54 @@ +import unittest +from unittest.mock import patch +import pandas as pd +from slune.loggers import LoggerDefault +from datetime import datetime +import time + +class TestLoggerDefault(unittest.TestCase): + def setUp(self): + self.logger = LoggerDefault() + + def tearDown(self): + # Clean up any resources if needed + pass + + def test_initial_dataframe_empty(self): + self.assertIsInstance(self.logger.results, pd.DataFrame) + self.assertTrue(self.logger.results.empty) + + def test_log_method_adds_metrics(self): + metrics = {'metric1': 42, 'metric2': 99} + self.logger.log(metrics) + + self.assertEqual(len(self.logger.results), 1) + self.assertSetEqual(set(self.logger.results.columns), set(metrics.keys())) + + def test_log_method_adds_timestamp(self): + metrics = {'metric1': 42} + self.logger.log(metrics) + + self.assertEqual(len(self.logger.results), 1) + self.assertTrue('time_stamp' in self.logger.results.columns) + + def test_log_method_adds_correct_values(self): + from datetime import datetime + import numpy as np + + timestamp = datetime.now() + metrics = {'metric1': 42, 'metric2': 99} + + # Create a Pandas Timestamp object with the same precision + rounded_timestamp = pd.Timestamp(timestamp).round('s') + + with patch('time.time', return_value=rounded_timestamp.timestamp()): + self.logger.log(metrics) + + row = self.logger.results.iloc[0] + self.assertEqual(row['metric1'], 42) + self.assertEqual(row['metric2'], 99) + self.assertEqual(row['time_stamp'].round('s'), rounded_timestamp) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_savers.py b/tests/test_savers.py new file mode 100644 index 0000000..d4653c3 --- /dev/null +++ b/tests/test_savers.py @@ -0,0 +1,120 @@ +import unittest +import os +import pandas as pd +from slune.savers import SaverCsv + +class TestSaverCsv(unittest.TestCase): + def setUp(self): + # Create a temporary directory for testing + self.test_dir = 'test_directory' + os.makedirs(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', '--folder3=0.3')) + os.makedirs(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.22', '--folder3=0.33', '--folder4=0.4')) + os.makedirs(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', 'another_folder')) + os.makedirs(os.path.join(self.test_dir, '--folder1=0.1', '--folder5=0.5', '--folder6=0.6')) + # Add a results file at --folder1=0.1/--folder2=0.2/--folder3=0.3 + with open(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', '--folder3=0.3', 'results_0.csv'), 'w') as f: + f.write('') + + + def tearDown(self): + # Remove the temporary directory and its contents after testing + os.remove(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', '--folder3=0.3', 'results_0.csv')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', '--folder3=0.3')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.22', '--folder3=0.33', '--folder4=0.4')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.22', '--folder3=0.33')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', 'another_folder')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.22')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder5=0.5', '--folder6=0.6')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder5=0.5')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1')) + os.rmdir(self.test_dir) + + def test_get_match_full_match(self): + # Create a SaverCsv instance + saver = SaverCsv(["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) + + # Test if get_match finds correct match and builds correct directory path using the parameters + matching_dir = saver.get_match(["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"]) + self.assertEqual(matching_dir, os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2/--folder3=0.3")) + + def test_get_match_partial_match(self): + # Create a SaverCsv instance + saver = SaverCsv(["--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) + + # Test if get_match finds correct match and builds correct directory path using the parameters + matching_dir = saver.get_match(["--folder2=0.2", "--folder1=0.1"]) + self.assertEqual(matching_dir, os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2")) + + def test_get_match_different_values(self): + # Create a SaverCsv instance + saver = SaverCsv(["--folder2=2.2", "--folder1=1.1"], root_dir=self.test_dir) + + # Test if get_match finds correct match and builds correct directory path using the parameters + matching_dir = saver.get_match(["--folder2=2.2", "--folder1=1.1"]) + self.assertEqual(matching_dir, os.path.join(self.test_dir, "--folder1=1.1/--folder2=2.2")) + + def test_get_match_too_deep(self): + # Create a SaverCsv instance + saver = SaverCsv(["--folder2=0.2", "--folder3=0.3"], root_dir=self.test_dir) + + # Test if get_match finds correct match and builds correct directory path using the parameters + matching_dir = saver.get_match(["--folder2=0.2", "--folder3=0.3"]) + self.assertEqual(matching_dir, os.path.join(self.test_dir, "--folder2=0.2/--folder3=0.3")) + + def test_get_match_no_match(self): + # Create a SaverCsv instance + saver = SaverCsv(["--folder_not_there=0", "--folder_also_not_there=0.1"], root_dir=self.test_dir) + + # Test if get_match finds correct match and builds correct directory path using the parameters + matching_dir = saver.get_match(["--folder_not_there=0", "--folder_also_not_there=0.1"]) + self.assertEqual(matching_dir, os.path.join(self.test_dir, "--folder_not_there=0/--folder_also_not_there=0.1")) + + def test_get_path_no_results(self): + # Create a SaverCsv instance + saver = SaverCsv(["--folder5=0.5","--folder1=0.1", "--folder6=0.6"], root_dir=self.test_dir) + + # Test if get_path gets the correct path + path = saver.current_path + self.assertEqual(path, os.path.join(self.test_dir, "--folder1=0.1/--folder5=0.5/--folder6=0.6/results_0.csv")) + + def test_get_path_already_results(self): + # Create a SaverCsv instance + saver = SaverCsv(["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) + + # Test if get_path gets the correct path + path = saver.current_path + self.assertEqual(path, os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2/--folder3=0.3/results_1.csv")) + + def test_save_collated(self): + # Create a SaverCsv instance + saver = SaverCsv(["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) + # Create a data frame with some results + results = pd.DataFrame({'a': [1,2,3], 'b': [4,5,6]}) + # Save the results + saver.save_collated(results) + # Check if the results were saved correctly + read_results = pd.read_csv(os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2/--folder3=0.3/results_1.csv")) + self.assertEqual(read_results.shape, (3,2)) + self.assertEqual(results.columns.tolist(), read_results.columns.tolist()) + read_values = [x for x in read_results.values.tolist() if str(x) != 'nan'] + values = [x for x in results.values.tolist() if str(x) != 'nan'] + self.assertEqual(values, read_values) + # Create another data frame with more results + results = pd.DataFrame({'a': [7,8,9], 'd': [10,11,12]}) + # Save the results + saver.save_collated(results) + # Check if the results were saved correctly + read_results = pd.read_csv(os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2/--folder3=0.3/results_1.csv")) + results = pd.concat([pd.DataFrame({'a': [1,2,3], 'b': [4,5,6]}), results], ignore_index=True) + self.assertEqual(read_results.shape, (6,3)) + self.assertEqual(results.columns.tolist(), read_results.columns.tolist()) + read_values = [[j for j in i if str(j) != 'nan'] for i in read_results.values.tolist()] + values = [[j for j in i if str(j) != 'nan'] for i in results.values.tolist()] + self.assertEqual(read_values, values) + # Remove the results file + os.remove(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', '--folder3=0.3', 'results_1.csv')) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_searchers.py b/tests/test_searchers.py new file mode 100644 index 0000000..485af35 --- /dev/null +++ b/tests/test_searchers.py @@ -0,0 +1,67 @@ +import unittest + +# Import your SearcherGrid class here +from slune.searchers import SearcherGrid + +class TestSearcherGrid(unittest.TestCase): + + def test_get_grid(self): + # Test that get_grid returns the expected list of dictionaries + + # Create an instance of SearcherGrid with sample hyperparameters + hyperparameters = { + "--param1": [1, 2], + "--param2": ["a", "b"] + } + searcher = SearcherGrid(hyperparameters) + + # Get the grid of hyperparameters + grid = searcher.grid + + # Check if the length of the grid is as expected + self.assertEqual(len(grid), 4) # 2 values for param1 x 2 values for param2 + + # Check if the grid contains the expected dictionaries + expected_grid = [ + {"--param1": 1, "--param2": "a"}, + {"--param1": 1, "--param2": "b"}, + {"--param1": 2, "--param2": "a"}, + {"--param1": 2, "--param2": "b"} + ] + self.assertEqual(grid, expected_grid) + + def test_next_tune(self): + # Test that next_tune returns the expected combinations of hyperparameters + + # Create an instance of SearcherGrid with sample hyperparameters + hyperparameters = { + "--param1": [1, 2], + "--param2": ["a", "b"] + } + searcher = SearcherGrid(hyperparameters) + + # Test the first few calls to next_tune + self.assertEqual(searcher.next_tune(), ["--param1=1", "--param2=a"]) + self.assertEqual(searcher.next_tune(), ["--param1=1", "--param2=b"]) + self.assertEqual(searcher.next_tune(), ["--param1=2", "--param2=a"]) + self.assertEqual(searcher.next_tune(), ["--param1=2", "--param2=b"]) + + # Test that it raises IndexError when all combinations are exhausted + with self.assertRaises(IndexError): + searcher.next_tune() + + def test__len__(self): + # Test that __len__ returns the expected number of hyperparameter combinations + + # Create an instance of SearcherGrid with sample hyperparameters + hyperparameters = { + "--param1": [1, 2], + "--param2": ["a", "b"] + } + searcher = SearcherGrid(hyperparameters) + + # Check that the length is as expected + self.assertEqual(len(searcher), 4) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..5896569 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,65 @@ +import unittest +import os +from slune.utils import find_directory_path, dict_to_strings + +class TestFindDirectoryPath(unittest.TestCase): + + def setUp(self): + # Create a temporary directory structure for testing + self.test_dir = 'test_directory' + os.makedirs(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', '--folder3=0.3')) + os.makedirs(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.22', '--folder3=0.33', '--folder4=0.4')) + os.makedirs(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', 'another_folder')) + os.makedirs(os.path.join(self.test_dir, '--folder1=0.1', '--folder5=0.5', '--folder6=0.6')) + + def tearDown(self): + # Clean up the temporary directory structure + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', '--folder3=0.3')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.22', '--folder3=0.33', '--folder4=0.4')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.22', '--folder3=0.33')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', 'another_folder')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.22')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder5=0.5', '--folder6=0.6')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1', '--folder5=0.5')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.1')) + os.rmdir(self.test_dir) + + def test_matching_path(self): + search_strings = ['--folder1=', '--folder2=', '--folder3='] + result = find_directory_path(search_strings, root_directory=self.test_dir) + self.assertEqual(result, os.path.join(self.test_dir, '--folder1=', '--folder2=', '--folder3=')) + + def test_matching_path_diff_order(self): + search_strings = ['--folder2=', '--folder3=', '--folder1='] + result = find_directory_path(search_strings, root_directory=self.test_dir) + self.assertEqual(result, os.path.join(self.test_dir, '--folder1=', '--folder2=', '--folder3=')) + + def test_partial_match(self): + search_strings = ['--folder1=', '--folder2=', '--missing_folder='] + result = find_directory_path(search_strings, root_directory=self.test_dir) + self.assertEqual(result, os.path.join(self.test_dir, '--folder1=', '--folder2=')) + + def test_partial_match_diff_order(self): + search_strings = ['--folder2=', '--missing_folder=', '--folder1='] + result = find_directory_path(search_strings, root_directory=self.test_dir) + self.assertEqual(result, os.path.join(self.test_dir, '--folder1=', '--folder2=')) + + def test_no_match(self): + search_strings = ['--nonexistent_folder1=', '--nonexistent_folder2='] + result = find_directory_path(search_strings, root_directory=self.test_dir) + self.assertEqual(result, self.test_dir) + + def test_deepest(self): + search_strings = ['--folder1=', '--folder2=', '--folder3=', '--folder4='] + result = find_directory_path(search_strings, root_directory=self.test_dir) + self.assertEqual(result, os.path.join(self.test_dir, '--folder1=', '--folder2=', '--folder3=', '--folder4=')) + + +class TestDictToStrings(unittest.TestCase): + + def test_dict_to_strings(self): + d = {'arg1': 1, 'arg2': 2} + result = dict_to_strings(d) + self.assertEqual(result, ['--arg1=1', '--arg2=2']) + \ No newline at end of file