Skip to content

Commit

Permalink
Fix history bug and add bells in analysis.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AUdaltsova authored Aug 5, 2024
1 parent fcca76c commit 7daa19b
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions experiments/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def main(runs: list[str], run_names: list[str]) -> None:
"""
api = wandb.Api()
dfs = []
epoch_num = []
for run in runs:
run = api.run(f"openclimatefix/india/{run}")
run = api.run(f"openclimatefix/PROJECT/{run}")

df = run.history()
df = run.history(samples=run.lastHistoryStep+1)
# Get the columns that are in the format 'MAE_horizon/step_<number>/val`
mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col]
# Sort them
Expand All @@ -40,6 +41,7 @@ def main(runs: list[str], run_names: list[str]) -> None:
# Get the step from the column name
column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols]
dfs.append(df)
epoch_num.append(min_row_idx)
# Get the timedelta for each group
groupings = [
[0, 0],
Expand Down Expand Up @@ -89,22 +91,22 @@ def main(runs: list[str], run_names: list[str]) -> None:
# Plot the error on per timestep, and all timesteps
plt.figure()
for idx, df in enumerate(dfs):
plt.plot(column_timesteps, df, label=run_names[idx])
plt.legend()
plt.xlabel("Timestep (minutes)")
plt.ylabel("MAE %")
plt.title("MAE % for each timestep")
plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}")
plt.legend(fontsize=18)
plt.xlabel("Timestep (minutes)", fontsize=18)
plt.ylabel("MAE %", fontsize=18)
plt.title("MAE % for each timestep", fontsize=24)
plt.savefig("mae_per_timestep.png")
plt.show()

# Plot the error on per timestep, and grouped timesteps
plt.figure()
for run_name in run_names:
plt.plot(groups_df[run_name], label=run_name)
plt.legend()
plt.xlabel("Timestep (minutes)")
plt.ylabel("MAE %")
plt.title("MAE % for each timestep")
for idx, run_name in enumerate(run_names):
plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}")
plt.legend(fontsize=18)
plt.xlabel("Timestep (minutes)", fontsize=18)
plt.ylabel("MAE %", fontsize=18)
plt.title("MAE % for each timestep", fontsize=24)
plt.savefig("mae_per_timestep.png")
plt.show()

Expand All @@ -119,3 +121,4 @@ def main(runs: list[str], run_names: list[str]) -> None:
parser.add_argument("--run_names", nargs="+")
args = parser.parse_args()
main(args.list_of_runs, args.run_names)

0 comments on commit 7daa19b

Please sign in to comment.