diff --git a/scripts/backtest_sites.py b/scripts/backtest_sites.py index e4b63223..65d63be2 100644 --- a/scripts/backtest_sites.py +++ b/scripts/backtest_sites.py @@ -118,11 +118,13 @@ class PadForwardPVIterDataPipe(IterDataPipe): to run out of data to slice for the forecast part. """ - def __init__(self, pv_dp: IterDataPipe, - forecast_duration: np.timedelta64, - history_duration: np.timedelta64, - time_resolution_minutes: np.timedelta64 - ): + def __init__( + self, + pv_dp: IterDataPipe, + forecast_duration: np.timedelta64, + history_duration: np.timedelta64, + time_resolution_minutes: np.timedelta64, + ): """Init""" super().__init__() @@ -137,7 +139,12 @@ def __iter__(self): """Iter""" for xr_data in self.pv_dp: - t_end = xr_data.time_utc.data[0] + self.history_duration + self.forecast_duration + self.time_resolution_minutes + t_end = ( + xr_data.time_utc.data[0] + + self.history_duration + + self.forecast_duration + + self.time_resolution_minutes + ) time_idx = np.arange(xr_data.time_utc.data[0], t_end, self.time_resolution_minutes) if len(xr_data.time_utc.data) < self.min_seq_length: @@ -439,10 +446,10 @@ def get_datapipe(config_path: str) -> NumpyBatch: ) config = load_yaml_configuration(config_path) - data_pipeline['pv'] = data_pipeline['pv'].pad_forward_pv( - forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, 'm'), - history_duration=np.timedelta64(config.input_data.pv.history_minutes, 'm'), - time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, 'm') + data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv( + forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m"), + history_duration=np.timedelta64(config.input_data.pv.history_minutes, "m"), + time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, "m"), ) data_pipeline = DictDatasetIterDataPipe(