Skip to content

Commit

Permalink
Update MAE analysis script (#274)
Browse files Browse the repository at this point in the history
Update script
  • Loading branch information
Sukh-P authored Nov 8, 2024
1 parent 9497430 commit acb6a36
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions experiments/analysis.py → experiments/mae_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""
Script to generate a table comparing two run for MAE values for 48 hour 15 minute forecast
Script to generate analysis of MAE values for multiple model forecasts
Does this for 48 hour horizon forecasts with 15 minute granularity
"""

import argparse
Expand All @@ -10,15 +13,21 @@
import wandb


def main(runs: list[str], run_names: list[str]) -> None:
def main(project: str, runs: list[str], run_names: list[str]) -> None:
"""
Compare two runs for MAE values for 48 hour 15 minute forecast
Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity
Args:
project: name of W&B project
runs: W&B ids of runs
run_names: user specified names for runs
"""
api = wandb.Api()
dfs = []
epoch_num = []
for run in runs:
run = api.run(f"openclimatefix/PROJECT/{run}")
run = api.run(f"openclimatefix/{project}/{run}")

df = run.history(samples=run.lastHistoryStep + 1)
# Get the columns that are in the format 'MAE_horizon/step_<number>/val`
Expand Down Expand Up @@ -88,36 +97,41 @@ def main(runs: list[str], run_names: list[str]) -> None:
for idx, df in enumerate(dfs):
print(f"{run_names[idx]}: {df.mean()*100:0.3f}")

# Plot the error on per timestep, and all timesteps
# Plot the error per timestep
plt.figure()
for idx, df in enumerate(dfs):
plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}")
plt.plot(
column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-"
)
plt.legend()
plt.xlabel("Timestep (minutes)")
plt.ylabel("MAE %")
plt.title("MAE % for each timestep")
plt.savefig("mae_per_timestep.png")
plt.show()

# Plot the error on per timestep, and grouped timesteps
# Plot the error per grouped timestep
plt.figure()
for idx, run_name in enumerate(run_names):
plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}")
plt.plot(
groups_df[run_name],
label=f"{run_name}, epoch: {epoch_num[idx]}",
marker="o",
linestyle="-",
)
plt.legend()
plt.xlabel("Timestep (minutes)")
plt.ylabel("MAE %")
plt.title("MAE % for each timestep")
plt.savefig("mae_per_timestep.png")
plt.title("MAE % for each grouped timestep")
plt.savefig("mae_per_grouped_timestep.png")
plt.show()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
"5llq8iw6"
parser.add_argument("--first_run", type=str, default="xdlew7ib")
parser.add_argument("--second_run", type=str, default="v3mja33d")
parser.add_argument("--project", type=str, default="")
# Add arguments that is a list of strings
parser.add_argument("--list_of_runs", nargs="+")
parser.add_argument("--run_names", nargs="+")
args = parser.parse_args()
main(args.list_of_runs, args.run_names)
main(args.project, args.list_of_runs, args.run_names)

0 comments on commit acb6a36

Please sign in to comment.