Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pv padding and hf model handling to backtest_sites.py #262

Merged
merged 11 commits into from
Oct 18, 2024
89 changes: 83 additions & 6 deletions scripts/backtest_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
except RuntimeError:
pass

import json
import logging
import os
import sys
Expand All @@ -32,6 +33,8 @@
import pandas as pd
import torch
import xarray as xr
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from ocf_datapipes.batch import (
BatchKey,
NumpyBatch,
Expand All @@ -50,7 +53,7 @@
)
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterDataPipe, functional_datapipe
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm

Expand All @@ -67,6 +70,10 @@
# checkpoint on the val set
model_chckpoint_dir = "PLACEHOLDER"

hf_revision = None
hf_token = None
hf_model_id = None

# Forecasts will be made for all available init times between these
start_datetime = "2022-05-08 00:00"
end_datetime = "2022-05-08 00:30"
Expand Down Expand Up @@ -101,11 +108,70 @@
# FUNCTIONS


@functional_datapipe("pad_forward_pv")
class PadForwardPVIterDataPipe(IterDataPipe):
"""
Pads forecast pv.

Sun position is calculated based off of pv time index
and for t0's close to end of pv data can have wrong shape as pv starts
to run out of data to slice for the forecast part.
"""

def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64):
"""Init"""

super().__init__()
self.pv_dp = pv_dp
self.forecast_duration = forecast_duration

def __iter__(self):
"""Iter"""

for xr_data in self.pv_dp:
t0 = xr_data.time_utc.data[int(xr_data.attrs["t0_idx"])]
pv_step = np.timedelta64(xr_data.attrs["sample_period_duration"])
t_end = t0 + self.forecast_duration + pv_step
time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step)
yield xr_data.reindex(time_utc=time_idx, fill_value=-1)


def load_model_from_hf(model_id: str, revision: str, token: str):
AUdaltsova marked this conversation as resolved.
Show resolved Hide resolved
"""
Loads model from HuggingFace
"""

model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
revision=revision,
token=token,
)

# load config file
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
token=token,
)

with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)

model = hydra.utils.instantiate(config)

state_dict = torch.load(model_file, map_location=torch.device("cuda"))
model.load_state_dict(state_dict) # type: ignore
model.eval() # type: ignore

return model


def preds_to_dataarray(preds, model, valid_times, site_ids):
"""Put numpy array of predictions into a dataarray"""

if model.use_quantile_regression:
output_labels = model.output_quantiles
output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles]
output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw"
else:
Expand Down Expand Up @@ -333,7 +399,7 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset:
da_abs_site = da_abs_site.where(~da_sundown_mask).fillna(0.0)

da_abs_site = da_abs_site.expand_dims(dim="init_time_utc", axis=0).assign_coords(
init_time_utc=[t0]
init_time_utc=np.array([t0], dtype="datetime64[ns]")
)

return da_abs_site
Expand Down Expand Up @@ -362,6 +428,11 @@ def get_datapipe(config_path: str) -> NumpyBatch:
t0_datapipe,
)

config = load_yaml_configuration(config_path)
data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv(
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m")
)

data_pipeline = DictDatasetIterDataPipe(
{k: v for k, v in data_pipeline.items() if k != "config"},
).map(split_dataset_dict_dp)
Expand Down Expand Up @@ -412,7 +483,13 @@ def main(config: DictConfig):
# Create a dataloader for the concurrent batches and use multiprocessing
dataloader = DataLoader(batch_pipe, **dataloader_kwargs)
# Load the PVNet model
model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True)
if model_chckpoint_dir:
model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True)
elif hf_model_id:
model = load_model_from_hf(hf_model_id, hf_revision, hf_token)
else:
raise ValueError("Provide a model checkpoint or a HuggingFace model")

model = model.eval().to(device)

# Create object to make predictions for each input batch
Expand All @@ -426,13 +503,13 @@ def main(config: DictConfig):

t0 = ds_abs_all.init_time_utc.values[0]

# Save the predictioons
# Save the predictions
filename = f"{output_dir}/{t0}.nc"
ds_abs_all.to_netcdf(filename)

pbar.update()
except Exception as e:
print(f"Exception {e} at {i}")
print(f"Exception {e} at batch {i}")
pass

# Close down
Expand Down
Loading