Skip to content

Commit

Permalink
Added basic functionality for submitting hyperparamter grid search as…
Browse files Browse the repository at this point in the history
… sbatch jobs and for logging metrics to csv files stored in a hierarchichal folder structure based on argument names
  • Loading branch information
h-0-0 committed Sep 26, 2023
1 parent 42d49f2 commit cd5e554
Show file tree
Hide file tree
Showing 16 changed files with 746 additions and 2 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,16 @@
# slune
A super simplistic way to perform hyperparameter tuning on SLURM. Will submit a seperate job script for each run in the tuning and store metrics in a csv and then allow you to easily query for the best hyperparamters based on metric.
# slune (slurm + tune!)
A super simplistic way to perform hyperparameter tuning on a cluster using SLURM. Will submit a seperate job script for each run in the tuning and store metrics in a csv and then allow you to easily query for the best hyperparamters based on metric.

Currently very much in early stages, first things still to do:
- Add ability to read results, currently can only submit jobs and log metrics during tuning.
- Refine class structure, ie. subclassing, making sure classes have essential methods, what are the essential methods and attributes? etc.
- Refine package structure and sort out github actions like test coverage, running tests etc.
- Add interfacing with SLURM to check for and re-submit failed jobs etc.
- Add more tests and documentation.
- Add some more subclasses for saving job results in different ways and for different tuning methods.
Although the idea for this package is to keep it ultra bare-bones and make it easy for the user to mod and add things themselves to their liking.

To run tests use:
```bash
python -m unittest discover -s . -p 'test_*.py'
```
13 changes: 13 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from setuptools import setup, find_packages

setup(
name='slune',
version='0.1',
packages=find_packages(),
install_requires=[
"argparse",
"subprocess",
"os",
"pandas",
],
)
1 change: 1 addition & 0 deletions slune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = ['slune', 'base', 'utils', 'loggers', 'savers', 'searchers' ]
51 changes: 51 additions & 0 deletions slune/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@

class Searcher():
"""
Class that creates search space and returns arguments to pass to sbatch script
"""
def __init__(self):
pass

def __len__(self):
"""
Returns the number of hyperparameter configurations to try.
"""
return len(self.searcher)

def next_tune(self, args, kwargs):
"""
Returns the next hyperparameter configuration to try.
"""
return self.searcher.next_tune(args, kwargs)

class Slog():
"""
Class used to log metrics during tuning run and to save the results.
Args:
- Logger (object): Class that handles logging of metrics, including the formatting that will be used to save and read the results.
- Saver (object): Class that handles saving logs from Logger to storage and fetching correct logs from storage to give to Logger to read.
"""
def __init__(self, params, Logger, Saver):
self.logger = Logger
self.saver = Saver(params)

def log(self, args, kwargs):
"""
Logs the metric/s for the current hyperparameter configuration,
stores them in a data frame that we can later save in storage.
"""
self.logger.log(args, kwargs)

def save_collated(self, args, kwargs):
"""
Saves the current results in logger to storage.
"""
self.saver.save_collated(self.logger.results, args, kwargs)

def read(self, args, kwargs):
"""
Reads results from storage.
"""
return self.saver.read(args, kwargs)


32 changes: 32 additions & 0 deletions slune/loggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pandas as pd

class LoggerDefault():
"""
Logs the metric/s for the current hyperparameter configuration,
stores them in a data frame that we can later save in storage.
"""
def __init__(self):
self.results = pd.DataFrame()

def log(self, metrics):
"""
Logs the metric/s for the current hyperparameter configuration,
stores them in a data frame that we can later save in storage.
All metrics provided will be saved as a row in the results data frame,
the first column is always the time stamp at which log is called.
Args:
- metrics (dict): Dictionary of metrics to be logged, keys are metric names and values are metric values.
Each metric should only have one value! So please log as soon as you get a metric
"""
# Get current time stamp
time_stamp = pd.Timestamp.now()
# Add time stamp to metrics dictionary
metrics['time_stamp'] = time_stamp
# Convert metrics dictionary to a dataframe
metrics_df = pd.DataFrame(metrics, index=[0])
# Append metrics dataframe to results dataframe
self.results = pd.concat([self.results, metrics_df], ignore_index=True)

def read_log(self, args, kwargs):
# TODO: implement this function
raise NotImplementedError
87 changes: 87 additions & 0 deletions slune/savers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import pandas as pd
from slune.utils import find_directory_path

class SaverCsv():
"""
Saves the results of each run in a CSV file in a hierarchical directory structure based on argument names.
"""
def __init__(self, params, root_dir='./tuning_results'):
self.root_dir = root_dir
self.current_path = self.get_path(params)

def strip_params(self, params):
"""
Strips the argument names from the arguments given by args.
eg. ["--argument_name=argument_value", ...] -> ["--argument_name=", ...]
Also gets rid of blank spaces
"""
return [p.split('=')[0].strip() for p in params]

def get_match(self, params):
"""
Searches the root directory for a directory tree that matches the parameters given.
If only partial matches are found, returns the deepest matching directory with the missing parameters appended.
If no matches are found creates a path using the parameters.
"""
# First check if there is a directory with path matching some subset of the arguments
stripped_params = [p.split('=')[0].strip() +'=' for p in params] # Strip the params of whitespace and everything after the '='
match = find_directory_path(stripped_params, root_directory=self.root_dir)
# Check which arguments are missing from the path
missing_params = [[p for p in params if sp in p][0] for sp in stripped_params if sp not in match]
# Now we add back in the values we stripped out
match = match.split('/')
match = [match[0]] + [[p for p in params if m in p][0] for m in match[1:]]
match = '/'.join(match)
# If there are missing arguments, add them to the path
if len(missing_params) > 0:
match = os.path.join(match, *missing_params)
return match

def get_path(self, params):
"""
Creates a path using the parameters by checking existing directories in the root directory.
Check get_match for how we create the path, we then check if results files for this path already exist,
if they do we increment the number of the results file name that we will use.
TODO: Add option to dictate order of parameters in directory structure.
TODO: Return warnings if there exist multiple paths that match the parameters but in a different order, or paths that don't go as deep as others.
Args:
- params (list): List of strings containing the arguments used, in form ["--argument_name=argument_value", ...].
"""
# Check if root directory exists, if not create it
if not os.path.exists(self.root_dir):
os.makedirs(self.root_dir)
# Get path of directory where we should store our csv of results
dir_path = self.get_match(params)
# Check if directory exists, if not create it
if not os.path.exists(dir_path):
csv_file_number = 0
# If it does exist, check if there is already a csv file with results,
# if there is find the name of the last csv file and increment the number
else:
csv_files = [f for f in os.listdir(dir_path) if f.endswith('.csv')]
if len(csv_files) > 0:
last_csv_file = max(csv_files)
# Check that the last csv file starts with "results_"
if not last_csv_file.startswith('results_'):
raise ValueError('Found csv file in directory that doesn\'t start with "results_"')
csv_file_number = int(last_csv_file.split('_')[1][:-4]) + 1
else:
csv_file_number = 0
# Create path name for a new csv file where we can later store results
csv_file_path = os.path.join(dir_path, f'results_{csv_file_number}.csv')
return csv_file_path

def save_collated(self, results):
# We add results onto the end of the current results in the csv file if it already exists
# if not then we create a new csv file and save the results there
if os.path.exists(self.current_path):
results = pd.concat([pd.read_csv(self.current_path), results])
results.to_csv(self.current_path, mode='w', index=False)
else:
results.to_csv(self.current_path, index=False)

def read(self, args, kwargs):
# TODO: implement this function
raise NotImplementedError

67 changes: 67 additions & 0 deletions slune/searchers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from slune.utils import dict_to_strings

class SearcherGrid():
"""
Given dictionary of hyperparameters and values to try, creates grid of all possible hyperparameter configurations,
and returns them one by one for each call to next_tune.
Args:
- hyperparameters (dict): Dictionary of hyperparameters and values to try.
Structure of dictionary should be: { "--argument_name" : [Value_1, Value_2, ...], ... }
TODO: Add extra functionality by using nested dictionaries to specify which hyperparameters to try together.
"""
def __init__(self, hyperparameters):
self.hyperparameters = hyperparameters
self.grid = self.get_grid(hyperparameters)
self.grid_index = None

def __len__(self):
"""
Returns the number of hyperparameter configurations to try.
"""
return len(self.grid)

def get_grid(self, param_dict):
"""
Generate all possible combinations of values for each argument in the given dictionary using recursion.
Args:
param_dict (dict): A dictionary where keys are argument names and values are lists of values.
Returns:
list: A list of dictionaries, each containing one combination of argument values.
"""
# Helper function to recursively generate combinations
def generate_combinations(param_names, current_combination, all_combinations):
if not param_names:
# If there are no more parameters to combine, add the current combination to the result
all_combinations.append(dict(current_combination))
return

param_name = param_names[0]
param_values = param_dict[param_name]

for value in param_values:
current_combination[param_name] = value
# Recursively generate combinations for the remaining parameters
generate_combinations(param_names[1:], current_combination, all_combinations)

# Start with an empty combination and generate all combinations
all_combinations = []
generate_combinations(list(param_dict.keys()), {}, all_combinations)

return all_combinations

def next_tune(self):
"""
Returns the next hyperparameter configuration to try.
"""
# If this is the first call to next_tune, set grid_index to 0
if self.grid_index is None:
self.grid_index = 0
else:
self.grid_index += 1
# If we have reached the end of the grid, raise an error
if self.grid_index == len(self.grid):
raise IndexError('Reached end of grid, no more hyperparameter configurations to try.')
# Return the next hyperparameter configuration to try
return dict_to_strings(self.grid[self.grid_index])
39 changes: 39 additions & 0 deletions slune/slune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from argparse import ArgumentParser
import subprocess

def submit_job(sh_path, args):
"""
Submits a job using the Bash script at sh_path,
args is a list of strings containing the arguments to be passed to the Bash script.
"""
try:
# Run the Bash script using subprocess
command = ['bash', sh_path] + args
subprocess.run(['sbatch', command], check=True)
except subprocess.CalledProcessError as e:
print(f"Error running sbatch: {e}")

def sbatchit(script_path, template_path, tuning, cargs=[]):
"""
Carries out hyper-parameter tuning by submitting a job for each set of hyper-parameters given by tune_control,
for each job runs the script stored at script_path with selected hyper-parameter values and the arguments given by cargs.
Uses the template file with path template_path to guide the creation of the sbatch script for each job.
Args:
- script_path (string): Path to the script (of the model) to be run for each job.
- template_path (string): Path to the template file used to create the sbatch script for each job.
- cargs (list): List of strings containing the arguments to be passed to the script for each job.
Must be a list even if there is just one argument, default is empty list.
- tuning (Tuning): Tuning object used to select hyper-parameter values for each job.
"""
# Create sbatch script for each job
for i in range(len(tuning)):
# Get argument for this job
args = tuning.next_tune()
# Submit job
submit_job(template_path, [script_path] + cargs + args)
print("Submitted all jobs!")

# TODO: add functions for reading results
49 changes: 49 additions & 0 deletions slune/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os

def find_directory_path(strings, root_directory='.'):
"""
Searches the root directory for a path of directories that matches the strings given in any order.
If only a partial match is found, returns the deepest matching directory with the missing strings appended.
If no matches are found returns the strings as a path.
Args:
- strings (list): List of strings to be matched in any order. Each string in list must be in the form '--string='.
- root_directory (string): Path to the root directory to be searched, default is current working directory.
# TODO: could probably optimize this function
"""
# Get list of directories in root directory
dir_list = os.listdir(root_directory)
# Get substring up to and including '=' for each directory name in dir_list, and strip whitespace
stripped_dir_list = [d.split('=')[0].strip() +"=" for d in dir_list]
# Get rid of duplicates
stripped_dir_list = list(set(stripped_dir_list))
# Check if any of the strings are in the list of directories
for string in strings:
if string in stripped_dir_list:
# If a string is found it means that at the current root there is a directory starting "--string="
# we now want to find all directories in the root directory that start with "--string=" and search them recursively
# then we return the path to the deepest directory found
dir_list = [d for d in dir_list if d.startswith(string)]
# Recursively search each directory starting with string
paths = []
for d in dir_list:
paths.append(find_directory_path(strings, os.path.join(root_directory, d)))
# Return the deepest directory found, ie. most /'s in path
return max(paths, key=lambda x: x.count('/'))
# If no strings are found, return the root directory
dirs = root_directory.split('/')
dirs = [dirs[0]] + [d.split('=')[0].strip() +"=" for d in dirs[1:]]
root_directory = '/'.join(dirs)
return root_directory


def dict_to_strings(d):
"""
Converts a dictionary into a list of strings in the form of '--key=value'.
"""
out = []
for key, value in d.items():
if key.startswith('--'):
out.append('{}={}'.format(key, value))
else:
out.append('--{}={}'.format(key, value))
return out
Loading

0 comments on commit cd5e554

Please sign in to comment.