Skip to content

Commit

Permalink
added pv padding and hf model handling to backtest_sites.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AUdaltsova committed Oct 17, 2024
1 parent aa76d4a commit 31581ac
Showing 1 changed file with 78 additions and 6 deletions.
84 changes: 78 additions & 6 deletions scripts/backtest_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,17 @@
)
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterDataPipe, functional_datapipe
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm

from pvnet.load_model import get_model_from_checkpoints
from pvnet.utils import SiteLocationLookup

import json
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME

# ------------------------------------------------------------------
# USER CONFIGURED VARIABLES TO RUN THE SCRIPT

Expand All @@ -67,6 +71,10 @@
# checkpoint on the val set
model_chckpoint_dir = "PLACEHOLDER"

revision = None
token = None
model_id = None

# Forecasts will be made for all available init times between these
start_datetime = "2022-05-08 00:00"
end_datetime = "2022-05-08 00:30"
Expand Down Expand Up @@ -101,11 +109,64 @@
# FUNCTIONS


@functional_datapipe('pad_forward_pv')
class PadForwardPVIterDataPipe(IterDataPipe):
"""
Pads forecast pv. Sun position is calculated based off of pv time index
and for t0's close to end of pv data can have wrong shape as pv starts
to run out of data to slice for the forecast part.
"""

def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64):
"""Init"""

super().__init__()
self.pv_dp = pv_dp
self.forecast_duration = forecast_duration

def __iter__(self):
"""Iter"""

for xr_data in self.pv_dp:
t0 = xr_data.time_utc.data[int(xr_data.attrs['t0_idx'])]
pv_step = np.timedelta64(xr_data.attrs['sample_period_duration'])
t_end = t0 + self.forecast_duration + pv_step
time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step)
yield xr_data.reindex(time_utc=time_idx, fill_value=-1)


def load_model_from_hf(model_id: str, revision: str, token: str):
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
revision=revision,
token=token,
)

# load config file
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
token=token,
)

with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)

model = hydra.utils.instantiate(config)

state_dict = torch.load(model_file, map_location=torch.device("cuda"))
model.load_state_dict(state_dict) # type: ignore
model.eval() # type: ignore

return model


def preds_to_dataarray(preds, model, valid_times, site_ids):
"""Put numpy array of predictions into a dataarray"""

if model.use_quantile_regression:
output_labels = model.output_quantiles
output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles]
output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw"
else:
Expand Down Expand Up @@ -333,7 +394,7 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset:
da_abs_site = da_abs_site.where(~da_sundown_mask).fillna(0.0)

da_abs_site = da_abs_site.expand_dims(dim="init_time_utc", axis=0).assign_coords(
init_time_utc=[t0]
init_time_utc=np.array([t0], dtype="datetime64[ns]")
)

return da_abs_site
Expand Down Expand Up @@ -362,6 +423,11 @@ def get_datapipe(config_path: str) -> NumpyBatch:
t0_datapipe,
)

config = load_yaml_configuration(config_path)
data_pipeline['pv'] = data_pipeline['pv'].pad_forward_pv(
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, 'm')
)

data_pipeline = DictDatasetIterDataPipe(
{k: v for k, v in data_pipeline.items() if k != "config"},
).map(split_dataset_dict_dp)
Expand Down Expand Up @@ -412,7 +478,13 @@ def main(config: DictConfig):
# Create a dataloader for the concurrent batches and use multiprocessing
dataloader = DataLoader(batch_pipe, **dataloader_kwargs)
# Load the PVNet model
model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True)
if model_chckpoint_dir is not None:
model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True)
elif model_id is not None:
model = load_model_from_hf(model_id, revision, token)
else:
raise ValueError("Provide a model checkpoint or a HuggingFace model")

model = model.eval().to(device)

# Create object to make predictions for each input batch
Expand All @@ -426,13 +498,13 @@ def main(config: DictConfig):

t0 = ds_abs_all.init_time_utc.values[0]

# Save the predictioons
# Save the predictions
filename = f"{output_dir}/{t0}.nc"
ds_abs_all.to_netcdf(filename)

pbar.update()
except Exception as e:
print(f"Exception {e} at {i}")
print(f"Exception {e} at batch {i}")
pass

# Close down
Expand Down

0 comments on commit 31581ac

Please sign in to comment.