From 31581ac901dda0c66ee91b0bdba94003c66386cd Mon Sep 17 00:00:00 2001 From: AUdaltsova Date: Thu, 17 Oct 2024 15:48:46 +0100 Subject: [PATCH] added pv padding and hf model handling to backtest_sites.py --- scripts/backtest_sites.py | 84 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 78 insertions(+), 6 deletions(-) diff --git a/scripts/backtest_sites.py b/scripts/backtest_sites.py index e764abf8..b2482466 100644 --- a/scripts/backtest_sites.py +++ b/scripts/backtest_sites.py @@ -50,13 +50,17 @@ ) from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD from omegaconf import DictConfig -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, IterDataPipe, functional_datapipe from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm from pvnet.load_model import get_model_from_checkpoints from pvnet.utils import SiteLocationLookup +import json +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME + # ------------------------------------------------------------------ # USER CONFIGURED VARIABLES TO RUN THE SCRIPT @@ -67,6 +71,10 @@ # checkpoint on the val set model_chckpoint_dir = "PLACEHOLDER" +revision = None +token = None +model_id = None + # Forecasts will be made for all available init times between these start_datetime = "2022-05-08 00:00" end_datetime = "2022-05-08 00:30" @@ -101,11 +109,64 @@ # FUNCTIONS +@functional_datapipe('pad_forward_pv') +class PadForwardPVIterDataPipe(IterDataPipe): + """ + Pads forecast pv. Sun position is calculated based off of pv time index + and for t0's close to end of pv data can have wrong shape as pv starts + to run out of data to slice for the forecast part. + """ + + def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64): + """Init""" + + super().__init__() + self.pv_dp = pv_dp + self.forecast_duration = forecast_duration + + def __iter__(self): + """Iter""" + + for xr_data in self.pv_dp: + t0 = xr_data.time_utc.data[int(xr_data.attrs['t0_idx'])] + pv_step = np.timedelta64(xr_data.attrs['sample_period_duration']) + t_end = t0 + self.forecast_duration + pv_step + time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step) + yield xr_data.reindex(time_utc=time_idx, fill_value=-1) + + +def load_model_from_hf(model_id: str, revision: str, token: str): + model_file = hf_hub_download( + repo_id=model_id, + filename=PYTORCH_WEIGHTS_NAME, + revision=revision, + token=token, + ) + + # load config file + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + token=token, + ) + + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + + model = hydra.utils.instantiate(config) + + state_dict = torch.load(model_file, map_location=torch.device("cuda")) + model.load_state_dict(state_dict) # type: ignore + model.eval() # type: ignore + + return model + + def preds_to_dataarray(preds, model, valid_times, site_ids): """Put numpy array of predictions into a dataarray""" if model.use_quantile_regression: - output_labels = model.output_quantiles output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles] output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" else: @@ -333,7 +394,7 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset: da_abs_site = da_abs_site.where(~da_sundown_mask).fillna(0.0) da_abs_site = da_abs_site.expand_dims(dim="init_time_utc", axis=0).assign_coords( - init_time_utc=[t0] + init_time_utc=np.array([t0], dtype="datetime64[ns]") ) return da_abs_site @@ -362,6 +423,11 @@ def get_datapipe(config_path: str) -> NumpyBatch: t0_datapipe, ) + config = load_yaml_configuration(config_path) + data_pipeline['pv'] = data_pipeline['pv'].pad_forward_pv( + forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, 'm') + ) + data_pipeline = DictDatasetIterDataPipe( {k: v for k, v in data_pipeline.items() if k != "config"}, ).map(split_dataset_dict_dp) @@ -412,7 +478,13 @@ def main(config: DictConfig): # Create a dataloader for the concurrent batches and use multiprocessing dataloader = DataLoader(batch_pipe, **dataloader_kwargs) # Load the PVNet model - model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True) + if model_chckpoint_dir is not None: + model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True) + elif model_id is not None: + model = load_model_from_hf(model_id, revision, token) + else: + raise ValueError("Provide a model checkpoint or a HuggingFace model") + model = model.eval().to(device) # Create object to make predictions for each input batch @@ -426,13 +498,13 @@ def main(config: DictConfig): t0 = ds_abs_all.init_time_utc.values[0] - # Save the predictioons + # Save the predictions filename = f"{output_dir}/{t0}.nc" ds_abs_all.to_netcdf(filename) pbar.update() except Exception as e: - print(f"Exception {e} at {i}") + print(f"Exception {e} at batch {i}") pass # Close down