diff --git a/experiments/analysis.py b/experiments/analysis.py index 75b00cd9..bb119664 100644 --- a/experiments/analysis.py +++ b/experiments/analysis.py @@ -16,10 +16,11 @@ def main(runs: list[str], run_names: list[str]) -> None: """ api = wandb.Api() dfs = [] + epoch_num = [] for run in runs: - run = api.run(f"openclimatefix/india/{run}") + run = api.run(f"openclimatefix/PROJECT/{run}") - df = run.history() + df = run.history(samples=run.lastHistoryStep + 1) # Get the columns that are in the format 'MAE_horizon/step_/val` mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col] # Sort them @@ -40,6 +41,7 @@ def main(runs: list[str], run_names: list[str]) -> None: # Get the step from the column name column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols] dfs.append(df) + epoch_num.append(min_row_idx) # Get the timedelta for each group groupings = [ [0, 0], @@ -89,7 +91,7 @@ def main(runs: list[str], run_names: list[str]) -> None: # Plot the error on per timestep, and all timesteps plt.figure() for idx, df in enumerate(dfs): - plt.plot(column_timesteps, df, label=run_names[idx]) + plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}") plt.legend() plt.xlabel("Timestep (minutes)") plt.ylabel("MAE %") @@ -99,8 +101,8 @@ def main(runs: list[str], run_names: list[str]) -> None: # Plot the error on per timestep, and grouped timesteps plt.figure() - for run_name in run_names: - plt.plot(groups_df[run_name], label=run_name) + for idx, run_name in enumerate(run_names): + plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}") plt.legend() plt.xlabel("Timestep (minutes)") plt.ylabel("MAE %")