Skip to content

Commit

Permalink
Merge pull request #25 from slacgismo/runner-refactor
Browse files Browse the repository at this point in the history
Refactored the runner and added class for DB insertion
  • Loading branch information
Thistleman authored Aug 7, 2023
2 parents 3d39943 + 92a06fb commit a2395e3
Show file tree
Hide file tree
Showing 7 changed files with 1,091 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"category_name": "time_shifts",
"function_name": "detect_time_shifts",
"comparison_type": "time_series",
"performance_metrics": [ "runtime", "mean_absolute_error" ],
"allowable_kwargs": [ "latitude", "longitude", "data_sampling_frequency" ],
"ground_truth_compare": [ "time_series" ],
"public_results_table": "time-shift-public-metrics.json",
"private_results_columns": [
"system_id",
"file_name",
"run_time",
"data_requirements",
"mean_absolute_error_time_series",
"data_sampling_frequency",
"issue"
],
"plots": [
{
"type": "histogram",
"x_val": "mean_absolute_error_time_series",
"color_code": "issue",
"title": "Time Series MAE Distribution by Issue",
"save_file_path": "mean_absolute_error_time_series_dist.png"
},
{
"type": "histogram",
"x_val": "mean_absolute_error_time_series",
"color_code": "data_sampling_frequency",
"title": "Time Series MAE Distribution by Sampling Frequency",
"save_file_path": "mean_absolute_error_time_series_dist.png"
},
{
"type": "histogram",
"x_val": "run_time",
"title": "Run Time Distribution",
"save_file_path": "run_time_dist.png"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,36 @@ def get_module_name(module_dir):
return get_module_file_name(module_dir)[:-3]


def run(module_to_import_s3_path, optional_result_data_dir=None):
def generate_histogram(dataframe, x_axis, title, color_code = None,
number_bins = 30):
"""
Generate a histogram for a distribution. Option to color code the
histogram by the color_code column parameter.
"""
sns.displot(dataframe,
x=x_axis,
hue=color_code,
multiple="stack",
bins=number_bins)
plt.title(title)
plt.tight_layout()
return plt


def generate_scatter_plot(dataframe, x_axis, y_axis, title):
"""
Generate a scatterplot between an x- and a y-variable.
"""
sns.scatterplot(data=dataframe,
x=x_axis,
y=y_axis)
plt.title(title)
plt.tight_layout()
return plt


def run(module_to_import_s3_path, config_file_path,
optional_result_data_dir=None):
# If a path is provided, set the directories to that path, otherwise use default
if optional_result_data_dir is not None:
results_dir = optional_result_data_dir + "/results" if not optional_result_data_dir.endswith('/') else optional_result_data_dir + "results"
Expand Down Expand Up @@ -216,11 +245,16 @@ def run(module_to_import_s3_path, optional_result_data_dir=None):

# Convert the list of file metadata to a DataFrame
file_metadata = pd.DataFrame(file_metadata_list)

# Read in the configuration JSON for the particular run
with open(config_file_path) as f:
config_data = json.load(f)

# Link the above tables together to get all of the files associated
# with the time_shift category in the validation_tests table.
time_shift_test_information = dict(validation_tests[
validation_tests['category_name'] == 'time_shifts'].iloc[0])
validation_tests['category_name'] ==
config_data['category_name']].iloc[0])
# Get the associated metrics we're supposed to calculate
performance_metrics = ast.literal_eval(time_shift_test_information[
'performance_metrics'])
Expand Down Expand Up @@ -251,49 +285,72 @@ def run(module_to_import_s3_path, optional_result_data_dir=None):
# any necessary arguments
associated_metadata = dict(system_metadata[
system_metadata['system_id'] == system_id].iloc[0])
# Get the ground truth scalars that we will compare to
ground_truth_dict = dict()
if config_data['comparison_type'] == 'scalar':
for val in config_data['ground_truth_compare']:
ground_truth_dict[val] = associated_metadata[val]
if config_data['comparison_type'] == 'time_series':
ground_truth_series = pd.read_csv(
os.path.join(data_dir + "/validation_data/", file_name),
index_col=0,
parse_dates=True).squeeze()
ground_truth_dict["time_series"] = ground_truth_series
# Create master dictionary of all possible function kwargs
kwargs_dict = dict(ChainMap(dict(row), associated_metadata))
# Filter out to only allowable args for the function
kwargs_dict = {key:kwargs_dict[key] for key in
config_data['allowable_kwargs']}
# Now that we've collected all of the information associated with the
# test, let's read in the file as a pandas dataframe (this data
# would most likely be stored in an S3 bucket)
time_series = pd.read_csv(os.path.join(data_dir + "/file_data/", file_name),
index_col=0,
parse_dates=True).squeeze()

time_series = time_series.asfreq(
str(row['data_sampling_frequency']) + "T")
# Read in the associated validation time series (this would act as a
# fixture or similar, and validation data would be stored in an
# associated folder on S3 or similar)
ground_truth_series = pd.read_csv(
os.path.join(data_dir + "/validation_data/", file_name),
index_col=0,
parse_dates=True).squeeze()

# Filter the kwargs dictionary based on required function params
kwargs = dict((k, kwargs_dict[k]) for k in function_parameters
if k in kwargs_dict)
# Time function execution if 'run_time' is in performance metrics
# list
if 'run_time' in performance_metrics:
start_time = time.time()
time_shift_series = function(time_series, **kwargs)
end_time = time.time()
function_run_time = (end_time - start_time)
else:
time_shift_series = function(time_series, **kwargs)
# Get the performance metrics that we want to quantify
performance_metrics = config_data['performance_metrics']
# Run the routine (timed)
start_time = time.time()
data_outputs = function(time_series, **kwargs)
end_time = time.time()
function_run_time = (end_time - start_time)
# Convert the data outputs to a dictionary identical to the
# ground truth dictionary
output_dictionary = dict()
if config_data['comparison_type'] == 'scalar':
for idx in range(len(config_data['ground_truth_compare'])):
output_dictionary[config_data['ground_truth_compare'
][idx]] = data_outputs[idx]
if config_data['comparison_type'] == 'time_series':
output_dictionary['time_series'] = data_outputs
# Run routine for all of the performance metrics and append
# results to the dictionary
results_dictionary = dict()
results_dictionary['file_name'] = file_name
# Set the runtime in the results dictionary
results_dictionary['run_time'] = function_run_time
# Set the data requirements in the dictionary
results_dictionary['data_requirements'] = function_parameters
# Loop through the rest of the performance metrics and calculate them
# (this predominantly applies to error metrics)
for metric in performance_metrics:
if metric == 'run_time':
results_dictionary[metric] = function_run_time
if metric == 'mean_absolute_error':
mae = np.mean(np.abs(ground_truth_series - time_shift_series))
results_dictionary[metric] = mae
if metric == 'data_requirements':
results_dictionary[metric] = function_parameters
if metric == 'absolute_error':
# Loop through the input and the output dictionaries,
# and calculate the absolute error
for val in config_data['ground_truth_compare']:
error = np.abs(output_dictionary[val] -
ground_truth_dict[val])
results_dictionary[metric + "_" + val] = error
elif metric == 'mean_absolute_error':
for val in config_data['ground_truth_compare']:
error = np.mean(np.abs(output_dictionary[val] -
ground_truth_dict[val]))
results_dictionary[metric + "_" + val] = error
results_list.append(results_dictionary)
# Convert the results to a pandas dataframe and perform all of the
# post-processing in the script
Expand All @@ -305,58 +362,71 @@ def run(module_to_import_s3_path, optional_result_data_dir=None):
# be saved to a public metrics dictionary)
public_metrics_dict = dict()
public_metrics_dict['module'] = module_name
# Get the mean and median run times
public_metrics_dict['mean_run_time'] = results_df['run_time'].mean()
public_metrics_dict['median_run_time'] = results_df['run_time'].median()
public_metrics_dict['function_parameters'] = function_parameters
for metric in performance_metrics:
if metric != 'data_requirements':
mean_value = results_df[metric].mean()
public_metrics_dict['mean_' + metric] = mean_value
else:
public_metrics_dict[metric] = function_parameters
# TODO: Write public metric information to a public results table. here we
# just write a json to illustrate that final outputs.
with open(results_dir + '/time-shift-public-metrics.json', 'w') as fp:
if 'absolute_error' in metric:
for val in config_data['ground_truth_compare']:
public_metrics_dict['mean_' + metric + '_' + val] = \
results_df[metric + "_" + val].mean()
public_metrics_dict['median_' + metric + '_' + val] = \
results_df[metric + "_" + val].median()
# Write public metric information to a public results table.
with open(os.path.join(results_dir, config_data['public_results_table']),
'w') as fp:
json.dump(public_metrics_dict, fp)
# Now generate private results. These will be more specific to the
# type of analysis being run as results will be color-coded by certain
# parameters. These params will be available as columns in the
# 'associated_files' dataframe
color_code_params = ['data_sampling_frequency', 'issue']
results_df_private = pd.merge(results_df,
associated_files[['file_name'] +
color_code_params],
associated_files,
on='file_name')
for param in color_code_params:
# Mean absolute error histogram
sns.displot(results_df_private,
x='mean_absolute_error', hue=param,
multiple="stack", bins=30)
plt.gca().set_yscale('log')
plt.title('MAE by ' + str(param))
# Save to a folder
plt.savefig(os.path.join(results_dir, str(param) + '_mean_absolute_error.png'))
plt.close()
plt.clf()
# Generate stratified table for private reports
stratified_mae_table = pd.DataFrame(results_df_private.groupby(param)[
'mean_absolute_error'].mean())
stratified_mae_table.to_csv(
os.path.join(results_dir, str(param) + '_mean_absolute_error_results.csv'))
# Run time histogram
sns.displot(results_df_private,
x='run_time', hue=param,
multiple="stack", bins=30)
plt.title('Run time (s) by ' + str(param))
# Save to a folder
plt.savefig(os.path.join(results_dir, str(param) + '_run_time.png'))


plt.close()
plt.clf()
# Generate stratified table for private reports
stratified_mae_table = pd.DataFrame(results_df_private.groupby(param)[
'run_time'].mean())
stratified_mae_table.to_csv(
os.path.join(results_dir, str(param) + '_run_time_results.csv'))

# Filter to only the necessary columns (available via the config)
results_df_private = results_df_private[config_data
["private_results_columns"]]
results_df_private.to_csv(
os.path.join(results_dir,
module_name + "_full_results.csv"))
# Loop through all of the plot dictionaries and generate plots and
# associated tables for reporting
for plot in config_data['plots']:
if plot['type'] == 'histogram':
if 'color_code' in plot:
color_code = plot['color_code']
else:
color_code = None
gen_plot = generate_histogram(results_df_private,
plot['x_val'],
plot['title'],
color_code)
# Save the plot
gen_plot.savefig(os.path.join(results_dir,
plot['save_file_path']))
plt.close()
plt.clf()
# Write the stratified results to a table for private reporting
# (if color_code param is not None)
if color_code:
stratified_results_tbl = pd.DataFrame(
results_df_private.groupby(color_code)[
plot['x_val']].mean())
stratified_results_tbl.to_csv(
os.path.join(results_dir,
module_name + '_' + str(color_code) +
'_' + plot['x_val'] + '.csv'))
if plot['type'] == 'scatter_plot':
gen_plot = generate_scatter_plot(results_df_private,
plot['x_val'],
plot['y_val'],
plot['title'])
# Save the plot
gen_plot.savefig(os.path.join(results_dir,
plot['save_file_path']))
plt.close()
plt.clf()
return public_metrics_dict


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{
"category_name": "az_tilt_estimation",
"function_name": "estimate_az_tilt",
"comparison_type": "scalar",
"performance_metrics": [ "runtime", "absolute_error" ],
"allowable_kwargs": [ "latitude", "longitude" ],
"ground_truth_compare": [ "azimuth", "tilt" ],
"public_results_table": "./results/az-tilt-public-metrics.json",
"private_results_columns": [
"system_id",
"file_name",
"run_time",
"data_requirements",
"absolute_error_azimuth",
"absolute_error_tilt",
"issue",
"number_days"
],
"plots": [
{
"type": "histogram",
"x_val": "absolute_error_azimuth",
"title": "Azimuth Absolute Error Distribution",
"save_file_path": "./results/absolute_error_az_dist.png"
},
{
"type": "histogram",
"x_val": "absolute_error_tilt",
"title": "Tilt Absolute Error Distribution",
"save_file_path": "./results/absolute_error_tilt_dist.png"
},
{
"type": "histogram",
"x_val": "run_time",
"title": "Run Time Distribution",
"save_file_path": "./results/run_time_dist.png"
},
{
"type": "scatter_plot",
"x_val": "run_time",
"y_val": "number_days",
"title": "Run Time vs. Number Days in Data Set",
"save_file_path": "./results/run_time_number_days.png"
}
]
}
Loading

0 comments on commit a2395e3

Please sign in to comment.