-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added basic functionality for submitting hyperparamter grid search as…
… sbatch jobs and for logging metrics to csv files stored in a hierarchichal folder structure based on argument names
- Loading branch information
Showing
16 changed files
with
746 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,16 @@ | ||
# slune | ||
A super simplistic way to perform hyperparameter tuning on SLURM. Will submit a seperate job script for each run in the tuning and store metrics in a csv and then allow you to easily query for the best hyperparamters based on metric. | ||
# slune (slurm + tune!) | ||
A super simplistic way to perform hyperparameter tuning on a cluster using SLURM. Will submit a seperate job script for each run in the tuning and store metrics in a csv and then allow you to easily query for the best hyperparamters based on metric. | ||
|
||
Currently very much in early stages, first things still to do: | ||
- Add ability to read results, currently can only submit jobs and log metrics during tuning. | ||
- Refine class structure, ie. subclassing, making sure classes have essential methods, what are the essential methods and attributes? etc. | ||
- Refine package structure and sort out github actions like test coverage, running tests etc. | ||
- Add interfacing with SLURM to check for and re-submit failed jobs etc. | ||
- Add more tests and documentation. | ||
- Add some more subclasses for saving job results in different ways and for different tuning methods. | ||
Although the idea for this package is to keep it ultra bare-bones and make it easy for the user to mod and add things themselves to their liking. | ||
|
||
To run tests use: | ||
```bash | ||
python -m unittest discover -s . -p 'test_*.py' | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from setuptools import setup, find_packages | ||
|
||
setup( | ||
name='slune', | ||
version='0.1', | ||
packages=find_packages(), | ||
install_requires=[ | ||
"argparse", | ||
"subprocess", | ||
"os", | ||
"pandas", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__all__ = ['slune', 'base', 'utils', 'loggers', 'savers', 'searchers' ] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
|
||
class Searcher(): | ||
""" | ||
Class that creates search space and returns arguments to pass to sbatch script | ||
""" | ||
def __init__(self): | ||
pass | ||
|
||
def __len__(self): | ||
""" | ||
Returns the number of hyperparameter configurations to try. | ||
""" | ||
return len(self.searcher) | ||
|
||
def next_tune(self, args, kwargs): | ||
""" | ||
Returns the next hyperparameter configuration to try. | ||
""" | ||
return self.searcher.next_tune(args, kwargs) | ||
|
||
class Slog(): | ||
""" | ||
Class used to log metrics during tuning run and to save the results. | ||
Args: | ||
- Logger (object): Class that handles logging of metrics, including the formatting that will be used to save and read the results. | ||
- Saver (object): Class that handles saving logs from Logger to storage and fetching correct logs from storage to give to Logger to read. | ||
""" | ||
def __init__(self, params, Logger, Saver): | ||
self.logger = Logger | ||
self.saver = Saver(params) | ||
|
||
def log(self, args, kwargs): | ||
""" | ||
Logs the metric/s for the current hyperparameter configuration, | ||
stores them in a data frame that we can later save in storage. | ||
""" | ||
self.logger.log(args, kwargs) | ||
|
||
def save_collated(self, args, kwargs): | ||
""" | ||
Saves the current results in logger to storage. | ||
""" | ||
self.saver.save_collated(self.logger.results, args, kwargs) | ||
|
||
def read(self, args, kwargs): | ||
""" | ||
Reads results from storage. | ||
""" | ||
return self.saver.read(args, kwargs) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import pandas as pd | ||
|
||
class LoggerDefault(): | ||
""" | ||
Logs the metric/s for the current hyperparameter configuration, | ||
stores them in a data frame that we can later save in storage. | ||
""" | ||
def __init__(self): | ||
self.results = pd.DataFrame() | ||
|
||
def log(self, metrics): | ||
""" | ||
Logs the metric/s for the current hyperparameter configuration, | ||
stores them in a data frame that we can later save in storage. | ||
All metrics provided will be saved as a row in the results data frame, | ||
the first column is always the time stamp at which log is called. | ||
Args: | ||
- metrics (dict): Dictionary of metrics to be logged, keys are metric names and values are metric values. | ||
Each metric should only have one value! So please log as soon as you get a metric | ||
""" | ||
# Get current time stamp | ||
time_stamp = pd.Timestamp.now() | ||
# Add time stamp to metrics dictionary | ||
metrics['time_stamp'] = time_stamp | ||
# Convert metrics dictionary to a dataframe | ||
metrics_df = pd.DataFrame(metrics, index=[0]) | ||
# Append metrics dataframe to results dataframe | ||
self.results = pd.concat([self.results, metrics_df], ignore_index=True) | ||
|
||
def read_log(self, args, kwargs): | ||
# TODO: implement this function | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os | ||
import pandas as pd | ||
from slune.utils import find_directory_path | ||
|
||
class SaverCsv(): | ||
""" | ||
Saves the results of each run in a CSV file in a hierarchical directory structure based on argument names. | ||
""" | ||
def __init__(self, params, root_dir='./tuning_results'): | ||
self.root_dir = root_dir | ||
self.current_path = self.get_path(params) | ||
|
||
def strip_params(self, params): | ||
""" | ||
Strips the argument names from the arguments given by args. | ||
eg. ["--argument_name=argument_value", ...] -> ["--argument_name=", ...] | ||
Also gets rid of blank spaces | ||
""" | ||
return [p.split('=')[0].strip() for p in params] | ||
|
||
def get_match(self, params): | ||
""" | ||
Searches the root directory for a directory tree that matches the parameters given. | ||
If only partial matches are found, returns the deepest matching directory with the missing parameters appended. | ||
If no matches are found creates a path using the parameters. | ||
""" | ||
# First check if there is a directory with path matching some subset of the arguments | ||
stripped_params = [p.split('=')[0].strip() +'=' for p in params] # Strip the params of whitespace and everything after the '=' | ||
match = find_directory_path(stripped_params, root_directory=self.root_dir) | ||
# Check which arguments are missing from the path | ||
missing_params = [[p for p in params if sp in p][0] for sp in stripped_params if sp not in match] | ||
# Now we add back in the values we stripped out | ||
match = match.split('/') | ||
match = [match[0]] + [[p for p in params if m in p][0] for m in match[1:]] | ||
match = '/'.join(match) | ||
# If there are missing arguments, add them to the path | ||
if len(missing_params) > 0: | ||
match = os.path.join(match, *missing_params) | ||
return match | ||
|
||
def get_path(self, params): | ||
""" | ||
Creates a path using the parameters by checking existing directories in the root directory. | ||
Check get_match for how we create the path, we then check if results files for this path already exist, | ||
if they do we increment the number of the results file name that we will use. | ||
TODO: Add option to dictate order of parameters in directory structure. | ||
TODO: Return warnings if there exist multiple paths that match the parameters but in a different order, or paths that don't go as deep as others. | ||
Args: | ||
- params (list): List of strings containing the arguments used, in form ["--argument_name=argument_value", ...]. | ||
""" | ||
# Check if root directory exists, if not create it | ||
if not os.path.exists(self.root_dir): | ||
os.makedirs(self.root_dir) | ||
# Get path of directory where we should store our csv of results | ||
dir_path = self.get_match(params) | ||
# Check if directory exists, if not create it | ||
if not os.path.exists(dir_path): | ||
csv_file_number = 0 | ||
# If it does exist, check if there is already a csv file with results, | ||
# if there is find the name of the last csv file and increment the number | ||
else: | ||
csv_files = [f for f in os.listdir(dir_path) if f.endswith('.csv')] | ||
if len(csv_files) > 0: | ||
last_csv_file = max(csv_files) | ||
# Check that the last csv file starts with "results_" | ||
if not last_csv_file.startswith('results_'): | ||
raise ValueError('Found csv file in directory that doesn\'t start with "results_"') | ||
csv_file_number = int(last_csv_file.split('_')[1][:-4]) + 1 | ||
else: | ||
csv_file_number = 0 | ||
# Create path name for a new csv file where we can later store results | ||
csv_file_path = os.path.join(dir_path, f'results_{csv_file_number}.csv') | ||
return csv_file_path | ||
|
||
def save_collated(self, results): | ||
# We add results onto the end of the current results in the csv file if it already exists | ||
# if not then we create a new csv file and save the results there | ||
if os.path.exists(self.current_path): | ||
results = pd.concat([pd.read_csv(self.current_path), results]) | ||
results.to_csv(self.current_path, mode='w', index=False) | ||
else: | ||
results.to_csv(self.current_path, index=False) | ||
|
||
def read(self, args, kwargs): | ||
# TODO: implement this function | ||
raise NotImplementedError | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from slune.utils import dict_to_strings | ||
|
||
class SearcherGrid(): | ||
""" | ||
Given dictionary of hyperparameters and values to try, creates grid of all possible hyperparameter configurations, | ||
and returns them one by one for each call to next_tune. | ||
Args: | ||
- hyperparameters (dict): Dictionary of hyperparameters and values to try. | ||
Structure of dictionary should be: { "--argument_name" : [Value_1, Value_2, ...], ... } | ||
TODO: Add extra functionality by using nested dictionaries to specify which hyperparameters to try together. | ||
""" | ||
def __init__(self, hyperparameters): | ||
self.hyperparameters = hyperparameters | ||
self.grid = self.get_grid(hyperparameters) | ||
self.grid_index = None | ||
|
||
def __len__(self): | ||
""" | ||
Returns the number of hyperparameter configurations to try. | ||
""" | ||
return len(self.grid) | ||
|
||
def get_grid(self, param_dict): | ||
""" | ||
Generate all possible combinations of values for each argument in the given dictionary using recursion. | ||
Args: | ||
param_dict (dict): A dictionary where keys are argument names and values are lists of values. | ||
Returns: | ||
list: A list of dictionaries, each containing one combination of argument values. | ||
""" | ||
# Helper function to recursively generate combinations | ||
def generate_combinations(param_names, current_combination, all_combinations): | ||
if not param_names: | ||
# If there are no more parameters to combine, add the current combination to the result | ||
all_combinations.append(dict(current_combination)) | ||
return | ||
|
||
param_name = param_names[0] | ||
param_values = param_dict[param_name] | ||
|
||
for value in param_values: | ||
current_combination[param_name] = value | ||
# Recursively generate combinations for the remaining parameters | ||
generate_combinations(param_names[1:], current_combination, all_combinations) | ||
|
||
# Start with an empty combination and generate all combinations | ||
all_combinations = [] | ||
generate_combinations(list(param_dict.keys()), {}, all_combinations) | ||
|
||
return all_combinations | ||
|
||
def next_tune(self): | ||
""" | ||
Returns the next hyperparameter configuration to try. | ||
""" | ||
# If this is the first call to next_tune, set grid_index to 0 | ||
if self.grid_index is None: | ||
self.grid_index = 0 | ||
else: | ||
self.grid_index += 1 | ||
# If we have reached the end of the grid, raise an error | ||
if self.grid_index == len(self.grid): | ||
raise IndexError('Reached end of grid, no more hyperparameter configurations to try.') | ||
# Return the next hyperparameter configuration to try | ||
return dict_to_strings(self.grid[self.grid_index]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from argparse import ArgumentParser | ||
import subprocess | ||
|
||
def submit_job(sh_path, args): | ||
""" | ||
Submits a job using the Bash script at sh_path, | ||
args is a list of strings containing the arguments to be passed to the Bash script. | ||
""" | ||
try: | ||
# Run the Bash script using subprocess | ||
command = ['bash', sh_path] + args | ||
subprocess.run(['sbatch', command], check=True) | ||
except subprocess.CalledProcessError as e: | ||
print(f"Error running sbatch: {e}") | ||
|
||
def sbatchit(script_path, template_path, tuning, cargs=[]): | ||
""" | ||
Carries out hyper-parameter tuning by submitting a job for each set of hyper-parameters given by tune_control, | ||
for each job runs the script stored at script_path with selected hyper-parameter values and the arguments given by cargs. | ||
Uses the template file with path template_path to guide the creation of the sbatch script for each job. | ||
Args: | ||
- script_path (string): Path to the script (of the model) to be run for each job. | ||
- template_path (string): Path to the template file used to create the sbatch script for each job. | ||
- cargs (list): List of strings containing the arguments to be passed to the script for each job. | ||
Must be a list even if there is just one argument, default is empty list. | ||
- tuning (Tuning): Tuning object used to select hyper-parameter values for each job. | ||
""" | ||
# Create sbatch script for each job | ||
for i in range(len(tuning)): | ||
# Get argument for this job | ||
args = tuning.next_tune() | ||
# Submit job | ||
submit_job(template_path, [script_path] + cargs + args) | ||
print("Submitted all jobs!") | ||
|
||
# TODO: add functions for reading results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import os | ||
|
||
def find_directory_path(strings, root_directory='.'): | ||
""" | ||
Searches the root directory for a path of directories that matches the strings given in any order. | ||
If only a partial match is found, returns the deepest matching directory with the missing strings appended. | ||
If no matches are found returns the strings as a path. | ||
Args: | ||
- strings (list): List of strings to be matched in any order. Each string in list must be in the form '--string='. | ||
- root_directory (string): Path to the root directory to be searched, default is current working directory. | ||
# TODO: could probably optimize this function | ||
""" | ||
# Get list of directories in root directory | ||
dir_list = os.listdir(root_directory) | ||
# Get substring up to and including '=' for each directory name in dir_list, and strip whitespace | ||
stripped_dir_list = [d.split('=')[0].strip() +"=" for d in dir_list] | ||
# Get rid of duplicates | ||
stripped_dir_list = list(set(stripped_dir_list)) | ||
# Check if any of the strings are in the list of directories | ||
for string in strings: | ||
if string in stripped_dir_list: | ||
# If a string is found it means that at the current root there is a directory starting "--string=" | ||
# we now want to find all directories in the root directory that start with "--string=" and search them recursively | ||
# then we return the path to the deepest directory found | ||
dir_list = [d for d in dir_list if d.startswith(string)] | ||
# Recursively search each directory starting with string | ||
paths = [] | ||
for d in dir_list: | ||
paths.append(find_directory_path(strings, os.path.join(root_directory, d))) | ||
# Return the deepest directory found, ie. most /'s in path | ||
return max(paths, key=lambda x: x.count('/')) | ||
# If no strings are found, return the root directory | ||
dirs = root_directory.split('/') | ||
dirs = [dirs[0]] + [d.split('=')[0].strip() +"=" for d in dirs[1:]] | ||
root_directory = '/'.join(dirs) | ||
return root_directory | ||
|
||
|
||
def dict_to_strings(d): | ||
""" | ||
Converts a dictionary into a list of strings in the form of '--key=value'. | ||
""" | ||
out = [] | ||
for key, value in d.items(): | ||
if key.startswith('--'): | ||
out.append('{}={}'.format(key, value)) | ||
else: | ||
out.append('--{}={}'.format(key, value)) | ||
return out |
Oops, something went wrong.