diff --git a/experiments/analysis.py b/experiments/analysis.py index 75b00cd9..9d7d35f9 100644 --- a/experiments/analysis.py +++ b/experiments/analysis.py @@ -16,10 +16,11 @@ def main(runs: list[str], run_names: list[str]) -> None: """ api = wandb.Api() dfs = [] + epoch_num = [] for run in runs: - run = api.run(f"openclimatefix/india/{run}") + run = api.run(f"openclimatefix/PROJECT/{run}") - df = run.history() + df = run.history(samples=run.lastHistoryStep+1) # Get the columns that are in the format 'MAE_horizon/step_/val` mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col] # Sort them @@ -40,6 +41,7 @@ def main(runs: list[str], run_names: list[str]) -> None: # Get the step from the column name column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols] dfs.append(df) + epoch_num.append(min_row_idx) # Get the timedelta for each group groupings = [ [0, 0], @@ -89,22 +91,22 @@ def main(runs: list[str], run_names: list[str]) -> None: # Plot the error on per timestep, and all timesteps plt.figure() for idx, df in enumerate(dfs): - plt.plot(column_timesteps, df, label=run_names[idx]) - plt.legend() - plt.xlabel("Timestep (minutes)") - plt.ylabel("MAE %") - plt.title("MAE % for each timestep") + plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}") + plt.legend(fontsize=18) + plt.xlabel("Timestep (minutes)", fontsize=18) + plt.ylabel("MAE %", fontsize=18) + plt.title("MAE % for each timestep", fontsize=24) plt.savefig("mae_per_timestep.png") plt.show() # Plot the error on per timestep, and grouped timesteps plt.figure() - for run_name in run_names: - plt.plot(groups_df[run_name], label=run_name) - plt.legend() - plt.xlabel("Timestep (minutes)") - plt.ylabel("MAE %") - plt.title("MAE % for each timestep") + for idx, run_name in enumerate(run_names): + plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}") + plt.legend(fontsize=18) + plt.xlabel("Timestep (minutes)", fontsize=18) + plt.ylabel("MAE %", fontsize=18) + plt.title("MAE % for each timestep", fontsize=24) plt.savefig("mae_per_timestep.png") plt.show() @@ -119,3 +121,4 @@ def main(runs: list[str], run_names: list[str]) -> None: parser.add_argument("--run_names", nargs="+") args = parser.parse_args() main(args.list_of_runs, args.run_names) +