From 0b853b17d106ffd0a5587a0d76aea17139d496bb Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Wed, 23 Oct 2024 12:10:08 +0100 Subject: [PATCH] Update backtest_sites.py --- scripts/backtest_sites.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/scripts/backtest_sites.py b/scripts/backtest_sites.py index 3572daa3..e4b63223 100644 --- a/scripts/backtest_sites.py +++ b/scripts/backtest_sites.py @@ -118,21 +118,31 @@ class PadForwardPVIterDataPipe(IterDataPipe): to run out of data to slice for the forecast part. """ - def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64): + def __init__(self, pv_dp: IterDataPipe, + forecast_duration: np.timedelta64, + history_duration: np.timedelta64, + time_resolution_minutes: np.timedelta64 + ): """Init""" super().__init__() self.pv_dp = pv_dp self.forecast_duration = forecast_duration + self.history_duration = history_duration + self.time_resolution_minutes = time_resolution_minutes + + self.min_seq_length = history_duration // time_resolution_minutes def __iter__(self): """Iter""" for xr_data in self.pv_dp: - t0 = xr_data.time_utc.data[int(xr_data.attrs["t0_idx"])] - pv_step = np.timedelta64(xr_data.attrs["sample_period_duration"]) - t_end = t0 + self.forecast_duration + pv_step - time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step) + t_end = xr_data.time_utc.data[0] + self.history_duration + self.forecast_duration + self.time_resolution_minutes + time_idx = np.arange(xr_data.time_utc.data[0], t_end, self.time_resolution_minutes) + + if len(xr_data.time_utc.data) < self.min_seq_length: + raise ValueError("Not enough PV data to predict") + yield xr_data.reindex(time_utc=time_idx, fill_value=-1) @@ -429,8 +439,10 @@ def get_datapipe(config_path: str) -> NumpyBatch: ) config = load_yaml_configuration(config_path) - data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv( - forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m") + data_pipeline['pv'] = data_pipeline['pv'].pad_forward_pv( + forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, 'm'), + history_duration=np.timedelta64(config.input_data.pv.history_minutes, 'm'), + time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, 'm') ) data_pipeline = DictDatasetIterDataPipe(