Skip to content

Commit

Permalink
Merge pull request #63 from h-0-0/h-0-0-dev
Browse files Browse the repository at this point in the history
Fixes #62 #59
  • Loading branch information
h-0-0 authored Dec 22, 2024
2 parents 0f7dd63 + 4ea2b15 commit 6753381
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 76 deletions.
46 changes: 25 additions & 21 deletions src/slune/savers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ def save_collated(self):

self.save_collated_from_results(self.logger.results)

def read(self, params: dict, metric_name: str, select_by: str ='max', avg: bool =True) -> Tuple[dict, float]:
def read(self, params: dict, metric_name: str, select_by: str ='max', collate_by: str ='mean') -> Tuple[dict, float]:
""" Finds the min/max value of a metric from all csv files in the root directory that match the parameters given.
Args:
- params (dict): Contains (parameter,value) pairs we would like in the run.
If None or empty dict, we will search through all csv files in the root directory.
- metric_name (string): Name of the metric to be read.
- select_by (string, optional): How to select the 'best' value for the metric from a log file, currently can select by 'min' or 'max'.
- avg (bool, optional): Whether to average the metric over all runs, default is True.
- collate_by (bool, optional): What to do with the metrics selected over all runs (with same parameters), default is 'mean'.
Returns:
- best_params (dict): Contains the arguments used to get the 'best' value of the metric (determined by select_by).
Expand All @@ -99,10 +99,13 @@ def read(self, params: dict, metric_name: str, select_by: str ='max', avg: bool

# Get all paths that match the parameters given
paths = get_all_paths('.csv', dict_to_strings(params), root_directory=self.root_dir)
# If no paths found, return None
if paths == []:
return None, None
# Read the metric from each path
values = {}
# Do averaging for different runs of same params if avg is True, otherwise just read the metric from each path
if avg:
if collate_by == 'mean':
paths_same_params = set([os.path.join(*p.split(os.path.sep)[:-1]) for p in paths])
for path in paths_same_params:
runs = get_all_paths('.csv', path.split(os.path.sep), root_directory=self.root_dir)
Expand All @@ -112,26 +115,27 @@ def read(self, params: dict, metric_name: str, select_by: str ='max', avg: bool
cumsum += self.read_log(df, metric_name, select_by)
avg_of_runs = cumsum / len(runs)
values[path] = avg_of_runs
else:
elif collate_by == 'all':
for path in paths:
df = pd.read_csv(path)
# values[os.path.join(*path.split(os.path.sep)[:-1])] = self.read_log(df, metric_name, select_by)
values[path] = self.read_log(df, metric_name, select_by)

# Get the key of the min/max value
if select_by == 'min':
best_params = min(values, key=values.get)
elif select_by == 'max':
best_params = max(values, key=values.get)
else:
raise ValueError(f"select_by must be 'min' or 'max', got {select_by}")
# Find the best value of the metric from the key
best_value = values[best_params]
# Format the path into a list of arguments
best_params = best_params.replace(self.root_dir, '')
if best_params.startswith(os.path.sep):
best_params = best_params[1:]
best_params = best_params.split(os.path.sep)
if best_params[-1].startswith('results_'):
best_params = best_params[:-1]
return best_params, best_value
raise ValueError(f"collate_by must be 'mean' or 'all', got {collate_by}")

# Format the path into a list of arguments
out_params, out_values = [], []
for key in values.keys():
value = values[key]
key = key.replace(self.root_dir, '')
if key.startswith(os.path.sep):
key = key[1:]
key = key.split(os.path.sep)
if key[-1].startswith('results_'):
# key = key[:-1]
# if has .csv, remove it
if key[-1].endswith('.csv'):
key[-1] = key[-1][:-4]
out_params.append(key)
out_values.append(value)
return out_params, out_values
195 changes: 140 additions & 55 deletions tests/test_savers_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd
from slune.savers.csv import SaverCsv
from slune.loggers.default import LoggerDefault
import numpy as np

class TestSaverCsvGetMatch(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -414,23 +415,30 @@ def tearDown(self):
os.rmdir(os.path.join(root, name))
os.rmdir(self.test_dir)

def test_max_metric(self):
def test_select_by_max(self):
# Create some params to use for testing
params = {'param1':1, 'param2':True, 'param3':3}
# Create an instance of SaverCsv
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get min and max values
max_param_a, max_value_a = saver.read(params, 'a', select_by='max')
max_param_b, max_value_b = saver.read(params, 'b', select_by='max')

max_param_a, max_value_a = saver.read(params, 'a', select_by='max', collate_by='mean')
max_param_b, max_value_b = saver.read(params, 'b', select_by='max', collate_by='mean')
self.assertEqual(1, len(max_param_a))
self.assertEqual(1, len(max_param_b))
self.assertEqual(1, len(max_value_a))
self.assertEqual(1, len(max_value_b))
max_param_a = max_param_a[0]
max_param_b = max_param_b[0]
max_value_a = max_value_a[0]
max_value_b = max_value_b[0]
# Perform assertions based on your expectations
self.assertEqual(max_param_a, ['param1=1', 'param2=True', 'param3=3'])
self.assertEqual(max_param_b, ['param1=1', 'param2=True', 'param3=3'])
self.assertEqual(max_value_a, 3.5)
self.assertEqual(max_value_b, 6.5)

def test_min_metric(self):
def test_select_by_min(self):
# Create some params to use for testing
params = {'param1':1, 'param2':True, 'param3':3}
# Create an instance of SaverCsv
Expand All @@ -439,13 +447,102 @@ def test_min_metric(self):
# Call the read method to get min and max values
min_param_a, min_value_a = saver.read(params, 'a', select_by='min')
min_param_b, min_value_b = saver.read(params, 'b', select_by='min')
self.assertEqual(1, len(min_param_a))
self.assertEqual(1, len(min_param_b))
self.assertEqual(1, len(min_value_a))
self.assertEqual(1, len(min_value_b))
min_param_a = min_param_a[0]
min_param_b = min_param_b[0]
min_value_a = min_value_a[0]
min_value_b = min_value_b[0]

# Perform assertions based on your expectations
self.assertEqual(min_param_a, ['param1=1', 'param2=True', 'param3=3'])
self.assertEqual(min_param_b, ['param1=1', 'param2=True', 'param3=3'])
self.assertEqual(min_value_a, 1.5)
self.assertEqual(min_value_b, 4.5)

def test_select_by_all(self):
# Create some params to use for testing
params = {'param1':1, 'param2':True, 'param3':3}
# Create an instance of SaverCsv
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get all values matching the params
params, values = saver.read(params, 'a', select_by='all', collate_by='mean')

# Perform assertions based on your expectations
self.assertEqual(params, [['param1=1', 'param2=True', 'param3=3']])
eq = np.array_equal(values, [[1.5, 2.5, 3.5]])
self.assertEqual(True, eq)

# Now we check that we get an error if there is a mismatch in the number of metrics returned by logger
file_path = os.path.join(self.test_dir, 'param1=1', 'param2=True', 'param3=3', 'results_2.csv')
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# Create a data frame with different values for each CSV file
results = pd.DataFrame({'a': [1,2,3,4,5], 'b': [4,5,6,7,8]})
# Save the results
results.to_csv(file_path, mode='w', index=False)

# Create some params to use for testing
params = {'param1':1, 'param2':True, 'param3':3}
with self.assertRaises(ValueError):
saver.read(params, 'a', select_by='all', collate_by='mean')
# Remove the results file
os.remove(file_path)

def test_select_by_last(self):
# Create some params to use for testing
params = {'param1':1, 'param2':True, 'param3':3}
# Create an instance of SaverCsv
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get last values
param, value = saver.read(params, 'a', select_by='last', collate_by='mean')

# Perform assertions based on your expectations
self.assertEqual(param, [['param1=1', 'param2=True', 'param3=3']])
self.assertEqual(value, [3.5])

def test_select_by_first(self):
# Create some params to use for testing
params = {'param1':1, 'param2':True, 'param3':3}
# Create an instance of SaverCsv
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get first values
param, value = saver.read(params, 'a', select_by='first', collate_by='mean')

# Perform assertions based on your expectations
self.assertEqual(param, [['param1=1', 'param2=True', 'param3=3']])
self.assertEqual(value, [1.5])

def test_select_by_mean(self):
# Create some params to use for testing
params = {'param1':1, 'param2':True, 'param3':3}
# Create an instance of SaverCsv
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get mean values
param, value = saver.read(params, 'a', select_by='mean', collate_by='mean')

# Perform assertions based on your expectations
self.assertEqual(param, [['param1=1', 'param2=True', 'param3=3']])
self.assertEqual(value, [2.5])

def test_select_by_median(self):
# Create some params to use for testing
params = {'param1':1, 'param2':True, 'param3':3}
# Create an instance of SaverCsv
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get median values
param, value = saver.read(params, 'a', select_by='median', collate_by='mean')

# Perform assertions based on your expectations
self.assertEqual(param, [['param1=1', 'param2=True', 'param3=3']])
self.assertEqual(value, [2.5])

def test_nonexistent_metric(self):
# Create some params to use for testing
params = {'param1':1, 'param2':True, 'param3':3}
Expand All @@ -466,26 +563,42 @@ def test_value_exists_both_str_float(self):
param, value = saver.read(params, 'a', select_by='min')

# Perform assertions based on your expectations
self.assertEqual(param, ['param1=string', 'param2=1', 'param3=3'])
self.assertEqual(value, 4)
self.assertEqual(param, [['param1=string', 'param2=1', 'param3=3']])
self.assertEqual(value, [4])

def test_not_given_params(self):
# Create an instance of SaverCsv
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get min values, averaged over each run
params, values = saver.read({}, 'a', select_by='min', avg=True)
params, values = saver.read({}, 'a', select_by='min', collate_by='mean')

# Perform assertions based on your expectations
self.assertEqual(params, ['param1=1', 'param2=True', 'param3=3'])
self.assertEqual(values, 1.5)
self.assertEqual(3, len(params))
self.assertEqual(True, ['param1=string', 'param2=1', 'param3=3'] in params)
self.assertEqual(True, ['param1=1', 'param2=True', 'param3=3'] in params)
self.assertEqual(True, ['param1=1', 'param2=False', 'param3=3'] in params)
self.assertEqual(3, len(values))
self.assertEqual(True, 4 in values)
self.assertEqual(True, 1.5 in values)
self.assertEqual(True, 3 in values)

# Now not averaging and None instead of empty list
params, values = saver.read(None, 'a', select_by='min', avg=False)
params, values = saver.read(None, 'a', select_by='all', collate_by='all')

# Perform assertions based on your expectations
self.assertEqual(params, ['param1=1', 'param2=True', 'param3=3'])
self.assertEqual(values, 1)
self.assertEqual(4, len(params))
self.assertEqual(True, ['param1=string', 'param2=1', 'param3=3', 'results_0'] in params)
self.assertEqual(True, ['param1=1', 'param2=True', 'param3=3', 'results_0'] in params)
self.assertEqual(True, ['param1=1', 'param2=True', 'param3=3', 'results_1'] in params)
self.assertEqual(True, ['param1=1', 'param2=False', 'param3=3', 'results_0'] in params)
# Turn values into a list of lists (from a list of numpy arrays)
values = [list(v) for v in values]
self.assertEqual(4, len(values))
self.assertEqual(True, [1, 2, 3] in values)
self.assertEqual(True, [2, 3, 4] in values)
self.assertEqual(True, [3, 4, 5] in values)
self.assertEqual(True, [4, 5, 6] in values)

def test_multi_matching_paths(self):
# Create some params to use for testing
Expand All @@ -494,11 +607,15 @@ def test_multi_matching_paths(self):
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get max value and params
param, value = saver.read(params, 'a', select_by='max')
param, value = saver.read(params, 'a', select_by='max', collate_by='mean')

# Check results are as expected
self.assertEqual(param, ['param1=1', 'param2=False', 'param3=3'])
self.assertEqual(value, 5)
self.assertEqual(2, len(param))
self.assertEqual(True, ['param1=1', 'param2=False', 'param3=3'] in param)
self.assertEqual(True, ['param1=1', 'param2=True', 'param3=3'] in param)
self.assertEqual(2, len(value))
self.assertEqual(True, 5 in value)
self.assertEqual(True, 3.5 in value)

def test_no_matching_paths(self):
# Create some params to use for testing
Expand All @@ -507,8 +624,9 @@ def test_no_matching_paths(self):
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get max value and params
with self.assertRaises(ValueError):
saver.read(params, 'a', select_by='max')
param, value = saver.read(params, 'a', select_by='max')
self.assertEqual(param, None)
self.assertEqual(value, None)

def test_multi_matching_paths_and_missing_metrics(self):
# Create some params to use for testing
Expand All @@ -520,51 +638,18 @@ def test_multi_matching_paths_and_missing_metrics(self):
with self.assertRaises(KeyError):
saver.read(params, 'c', select_by='max')

def test_results_avg(self):
def test_collate_by_mean_single(self):
# Create some params to use for testing
params = {'param1':1, 'param2':False, 'param3':3}
# Create an instance of SaverCsv
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get max value and params
param, value = saver.read(params, 'a', avg=True, select_by='max')
param, value = saver.read(params, 'a', select_by='max', collate_by='mean')

# Check results are as expected
self.assertEqual(param, ['param1=1', 'param2=False', 'param3=3'])
self.assertEqual(value, 5)

def test_multi_results_avg(self):
# Create another results file with different values
results = pd.DataFrame({'a': [7,8,9], 'd': [10,11,12]})
results.to_csv(os.path.join(self.test_dir, 'param1=1','param2=True','param3=3','more_results.csv'), mode='w', index=False)
# Create some params to use for testing
params = {'param1':1, 'param2':True, 'param3':3}
# Create an instance of SaverCsv
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get max value and params
param, value = saver.read(params, 'a', avg=True, select_by='max')

# Check results are as expected
self.assertEqual(param, ['param1=1', 'param2=True', 'param3=3'])
self.assertEqual(value, (9+4+3)/3)

# Remove the results file
os.remove(os.path.join(self.test_dir, 'param1=1','param2=True','param3=3','more_results.csv'))

def test_no_matching_results(self):
# Create some params to use for testing
params = {'param1':2, 'param2':True, 'param3':3}
# Create an instance of SaverCsv
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Call the read method to get max value and params
with self.assertRaises(ValueError):
saver.read(params, 'a', avg=True, select_by='max')

# Now the same but with avg=False
with self.assertRaises(ValueError):
saver.read(params, 'a', avg=False, select_by='max')
self.assertEqual(param,[['param1=1', 'param2=False', 'param3=3']])
self.assertEqual(value, [5])


class TestSaverCsvGetSetCurrentPath(unittest.TestCase):
Expand Down

0 comments on commit 6753381

Please sign in to comment.