Skip to content

Commit

Permalink
Upd df loading
Browse files Browse the repository at this point in the history
  • Loading branch information
krasheninnikov committed Apr 4, 2024
1 parent 64755ac commit aabb306
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion utils/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def prettify_label(label):

def make_experiment_plot(exp_name, stage_paths, thruncate_stages_after_epoch=None, eval_each_epochs_per_stage=None,
tags=['eval/d1consis_EM', 'eval/d2consis_EM'], os_list=None, ylabel='Value', title='',
figsize=(5.7,4), legend_loc='best', colors=None):
figsize=(5.7,4), legend_loc='best', colors=None, return_df_only=False):
"""
exp_name - name of the experiment (top level folder name)
stage_paths - list of strings that are the starts to paths to stages,
Expand Down Expand Up @@ -178,6 +178,17 @@ def make_experiment_plot(exp_name, stage_paths, thruncate_stages_after_epoch=Non
for stage_path, thruncate_after_epoch, eval_each_epochs in zip(stage_paths, thruncate_stages_after_epoch, eval_each_epochs_per_stage):
curr_stage_exp_names = [x for x in os_list if x.startswith(stage_path)]

# try alternate paths where stage2 logs might be saved if no logs were found (cpt is only added when save_each_epochs!=0)
if not curr_stage_exp_names and 'cpt20' in stage_path:
# remove cpt20_ from the path
alt_stage_path = stage_path.replace('cpt20_', '')
curr_stage_exp_names = [x for x in os_list if x.startswith(alt_stage_path)]
print(f'No experiments found for {stage_path}, trying {alt_stage_path}:\nLoaded{curr_stage_exp_names}')
if not curr_stage_exp_names:
raise ValueError(f'No experiments found for {stage_path} or {alt_stage_path}')
# remove any found "experiments" that are actually files
curr_stage_exp_names = [x for x in curr_stage_exp_names if os.path.isdir(os.path.join(exp_folder, x))]

# take only seed_stage2 = 0 experiments
# if 's2stage' in curr_stage_exp_names[0]:
# curr_stage_exp_names = [x for x in curr_stage_exp_names if 's2stage0' in x]
Expand Down Expand Up @@ -223,6 +234,10 @@ def make_experiment_plot(exp_name, stage_paths, thruncate_stages_after_epoch=Non
# add a column with log of value
# df['log_value'] = np.log(df['value'])
df['tag'] = df['tag'].apply(lambda x: x.replace('eval/', '').replace('train_', '').replace('_EM', '').replace('_loss', ''))

if return_df_only:
return df

tags = [x.replace('eval/', '').replace('train_', '').replace('_EM', '').replace('_loss', '') for x in tags]
linestyles = ["--" if "defs_" in tag else "-" for tag in tags]
markers = ["*" if "defs_" in tag else "o" for tag in tags]
Expand Down

0 comments on commit aabb306

Please sign in to comment.