Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/openclimatefix/PVNet into main
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Sep 16, 2024
2 parents 6e046b9 + 76651e9 commit 1ec1561
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[bumpversion]
commit = True
tag = True
current_version = 3.0.53
current_version = 3.0.56
message = Bump version: {current_version} → {new_version} [skip ci]

[bumpversion:file:pvnet/__init__.py]
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# PVNet 2.1

[![Python Bump Version & release](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml/badge.svg)](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml)
[![Python Bump Version & release](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml/badge.svg)](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml) [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)


This project is used for training PVNet and running PVNet on live data.

Expand Down Expand Up @@ -85,6 +86,8 @@ OCF maintains a Zarr formatted version of the German Weather Service's (DWD)
ICON-EU NWP model here:
https://huggingface.co/datasets/openclimatefix/dwd-icon-eu which includes the UK

Please note that the current version of [ICON loader]([url](https://github.com/openclimatefix/ocf_datapipes/blob/9ec252eeee44937c12ab52699579bdcace76e72f/ocf_datapipes/load/nwp/providers/icon.py#L9-L30)) supports a different format. If you want to use our ICON-EU dataset or your own NWP source, you can create a loader for it using [the instructions here]([url](https://github.com/openclimatefix/ocf_datapipes/tree/main/ocf_datapipes/load#nwp)).

**PV**\
OCF maintains a dataset of PV generation from 1311 private PV installations
here: https://huggingface.co/datasets/openclimatefix/uk_pv
Expand Down
12 changes: 7 additions & 5 deletions experiments/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def main(runs: list[str], run_names: list[str]) -> None:
"""
api = wandb.Api()
dfs = []
epoch_num = []
for run in runs:
run = api.run(f"openclimatefix/india/{run}")
run = api.run(f"openclimatefix/PROJECT/{run}")

df = run.history()
df = run.history(samples=run.lastHistoryStep + 1)
# Get the columns that are in the format 'MAE_horizon/step_<number>/val`
mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col]
# Sort them
Expand All @@ -40,6 +41,7 @@ def main(runs: list[str], run_names: list[str]) -> None:
# Get the step from the column name
column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols]
dfs.append(df)
epoch_num.append(min_row_idx)
# Get the timedelta for each group
groupings = [
[0, 0],
Expand Down Expand Up @@ -89,7 +91,7 @@ def main(runs: list[str], run_names: list[str]) -> None:
# Plot the error on per timestep, and all timesteps
plt.figure()
for idx, df in enumerate(dfs):
plt.plot(column_timesteps, df, label=run_names[idx])
plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}")
plt.legend()
plt.xlabel("Timestep (minutes)")
plt.ylabel("MAE %")
Expand All @@ -99,8 +101,8 @@ def main(runs: list[str], run_names: list[str]) -> None:

# Plot the error on per timestep, and grouped timesteps
plt.figure()
for run_name in run_names:
plt.plot(groups_df[run_name], label=run_name)
for idx, run_name in enumerate(run_names):
plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}")
plt.legend()
plt.xlabel("Timestep (minutes)")
plt.ylabel("MAE %")
Expand Down
2 changes: 1 addition & 1 deletion pvnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""PVNet"""
__version__ = "3.0.53"
__version__ = "3.0.56"
76 changes: 72 additions & 4 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Dict, Optional, Union

Expand All @@ -13,7 +14,7 @@
import torch.nn.functional as F
import wandb
import yaml
from huggingface_hub import ModelCard, ModelCardData
from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi
Expand Down Expand Up @@ -140,7 +141,7 @@ def minimize_data_config(input_path, output_path, model):
yaml.dump(config, outfile, default_flow_style=False)


class PVNetModelHubMixin:
class PVNetModelHubMixin(PyTorchModelHubMixin):
"""
Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities.
"""
Expand Down Expand Up @@ -410,6 +411,9 @@ def __init__(
else:
self.num_output_features = self.forecast_len

# save all validation results to array, so we can save these to weights n biases
self.validation_epoch_results = []

def _quantiles_to_prediction(self, y_quantiles):
"""
Convert network prediction into a point prediction.
Expand Down Expand Up @@ -609,12 +613,61 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p
print(e)
plt.close(fig)

def _log_validation_results(self, batch, y_hat, accum_batch_num):
"""Append validation results to self.validation_epoch_results"""

# get truth values, shape (b, forecast_len)
y = batch[self._target_key][:, -self.forecast_len :, 0]
y = y.detach().cpu().numpy()
batch_size = y.shape[0]

# get prediction values, shape (b, forecast_len, quantiles?)
y_hat = y_hat.detach().cpu().numpy()

# get time_utc, shape (b, forecast_len)
time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"]
time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy()

# get target id and change from (b,1) to (b,)
id_key = BatchKey[f"{self._target_key_name}_id"]
target_id = batch[id_key].detach().cpu().numpy()
target_id = target_id.squeeze()

for i in range(batch_size):
y_i = y[i]
y_hat_i = y_hat[i]
time_utc_i = time_utc[i]
target_id_i = target_id[i]

results_dict = {
"y": y_i,
"time_utc": time_utc_i,
}
if self.use_quantile_regression:
results_dict.update(
{f"y_quantile_{q}": y_hat_i[:, i] for i, q in enumerate(self.output_quantiles)}
)
else:
results_dict["y_hat"] = y_hat_i

results_df = pd.DataFrame(results_dict)
results_df["id"] = target_id_i
results_df["batch_idx"] = accum_batch_num
results_df["example_idx"] = i

self.validation_epoch_results.append(results_df)

def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches

y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len :, 0]

self._log_validation_results(batch, y_hat, accum_batch_num)

# Expand persistence to be the same shape as y
losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
Expand All @@ -632,8 +685,6 @@ def validation_step(self, batch: dict, batch_idx):
on_epoch=True,
)

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches

# Make plots only if using wandb logger
if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]:
# Store these temporarily under self
Expand Down Expand Up @@ -675,6 +726,23 @@ def validation_step(self, batch: dict, batch_idx):
def on_validation_epoch_end(self):
"""Run on epoch end"""

try:
# join together validation results, and save to wandb
validation_results_df = pd.concat(self.validation_epoch_results)
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, f"validation_results_{self.current_epoch}.csv")
validation_results_df.to_csv(filename, index=False)

# make and log wand artifact
validation_artifact = wandb.Artifact(
f"validation_results_epoch={self.current_epoch}", type="dataset"
)
validation_artifact.add_file(filename)
wandb.log_artifact(validation_artifact)
except Exception as e:
print("Failed to log validation results to wandb")
print(e)

horizon_maes_dict = self._horizon_maes.flush()

# Create the horizon accuracy curve
Expand Down
3 changes: 0 additions & 3 deletions tests/test_data/sample_wind_batches/data_configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ general:
input_data:
default_forecast_minutes: 2880
default_history_minutes: 60
data_source_which_defines_geospatial_locations: "wind"
nwp:
ecmwf:
# Path to ECMWF NWP data in zarr format
Expand Down Expand Up @@ -64,15 +63,13 @@ input_data:
- label: india
wind_filename: /mnt/storage_ssd_4tb/india_wind_data.nc
wind_metadata_filename: /mnt/storage_ssd_4tb/india_wind_metadata.csv
get_center: true
n_wind_systems_per_example: 1
#start_datetime: "2021-01-01 00:00:00"
#end_datetime: "2024-01-01 00:00:00"
sensor:
#sensor_files_groups:
# - label: meteomatics
sensor_filename: "/mnt/storage_b/nwp/meteomatics/nw_india/wind*.zarr.zip"
get_center: false
history_minutes: 60
forecast_minutes: 2880
#n_sensor_systems_per_example: 26
Expand Down

0 comments on commit 1ec1561

Please sign in to comment.