Skip to content

Commit

Permalink
Merge pull request #241 from openclimatefix/analysis-bug
Browse files Browse the repository at this point in the history
Fix history bug and add bells in analysis.py
  • Loading branch information
AUdaltsova authored Aug 12, 2024
2 parents c838156 + f928e80 commit eaefc7b
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions experiments/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def main(runs: list[str], run_names: list[str]) -> None:
"""
api = wandb.Api()
dfs = []
epoch_num = []
for run in runs:
run = api.run(f"openclimatefix/india/{run}")
run = api.run(f"openclimatefix/PROJECT/{run}")

df = run.history()
df = run.history(samples=run.lastHistoryStep + 1)
# Get the columns that are in the format 'MAE_horizon/step_<number>/val`
mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col]
# Sort them
Expand All @@ -40,6 +41,7 @@ def main(runs: list[str], run_names: list[str]) -> None:
# Get the step from the column name
column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols]
dfs.append(df)
epoch_num.append(min_row_idx)
# Get the timedelta for each group
groupings = [
[0, 0],
Expand Down Expand Up @@ -89,7 +91,7 @@ def main(runs: list[str], run_names: list[str]) -> None:
# Plot the error on per timestep, and all timesteps
plt.figure()
for idx, df in enumerate(dfs):
plt.plot(column_timesteps, df, label=run_names[idx])
plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}")
plt.legend()
plt.xlabel("Timestep (minutes)")
plt.ylabel("MAE %")
Expand All @@ -99,8 +101,8 @@ def main(runs: list[str], run_names: list[str]) -> None:

# Plot the error on per timestep, and grouped timesteps
plt.figure()
for run_name in run_names:
plt.plot(groups_df[run_name], label=run_name)
for idx, run_name in enumerate(run_names):
plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}")
plt.legend()
plt.xlabel("Timestep (minutes)")
plt.ylabel("MAE %")
Expand Down

0 comments on commit eaefc7b

Please sign in to comment.