diff --git a/experiments/analysis.py b/experiments/mae_analysis.py similarity index 74% rename from experiments/analysis.py rename to experiments/mae_analysis.py index bb119664..ac01aed2 100644 --- a/experiments/analysis.py +++ b/experiments/mae_analysis.py @@ -1,5 +1,8 @@ """ -Script to generate a table comparing two run for MAE values for 48 hour 15 minute forecast +Script to generate analysis of MAE values for multiple model forecasts + +Does this for 48 hour horizon forecasts with 15 minute granularity + """ import argparse @@ -10,15 +13,21 @@ import wandb -def main(runs: list[str], run_names: list[str]) -> None: +def main(project: str, runs: list[str], run_names: list[str]) -> None: """ - Compare two runs for MAE values for 48 hour 15 minute forecast + Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity + + Args: + project: name of W&B project + runs: W&B ids of runs + run_names: user specified names for runs + """ api = wandb.Api() dfs = [] epoch_num = [] for run in runs: - run = api.run(f"openclimatefix/PROJECT/{run}") + run = api.run(f"openclimatefix/{project}/{run}") df = run.history(samples=run.lastHistoryStep + 1) # Get the columns that are in the format 'MAE_horizon/step_/val` @@ -88,10 +97,12 @@ def main(runs: list[str], run_names: list[str]) -> None: for idx, df in enumerate(dfs): print(f"{run_names[idx]}: {df.mean()*100:0.3f}") - # Plot the error on per timestep, and all timesteps + # Plot the error per timestep plt.figure() for idx, df in enumerate(dfs): - plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}") + plt.plot( + column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-" + ) plt.legend() plt.xlabel("Timestep (minutes)") plt.ylabel("MAE %") @@ -99,25 +110,28 @@ def main(runs: list[str], run_names: list[str]) -> None: plt.savefig("mae_per_timestep.png") plt.show() - # Plot the error on per timestep, and grouped timesteps + # Plot the error per grouped timestep plt.figure() for idx, run_name in enumerate(run_names): - plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}") + plt.plot( + groups_df[run_name], + label=f"{run_name}, epoch: {epoch_num[idx]}", + marker="o", + linestyle="-", + ) plt.legend() plt.xlabel("Timestep (minutes)") plt.ylabel("MAE %") - plt.title("MAE % for each timestep") - plt.savefig("mae_per_timestep.png") + plt.title("MAE % for each grouped timestep") + plt.savefig("mae_per_grouped_timestep.png") plt.show() if __name__ == "__main__": parser = argparse.ArgumentParser() - "5llq8iw6" - parser.add_argument("--first_run", type=str, default="xdlew7ib") - parser.add_argument("--second_run", type=str, default="v3mja33d") + parser.add_argument("--project", type=str, default="") # Add arguments that is a list of strings parser.add_argument("--list_of_runs", nargs="+") parser.add_argument("--run_names", nargs="+") args = parser.parse_args() - main(args.list_of_runs, args.run_names) + main(args.project, args.list_of_runs, args.run_names)