Skip to content

Commit

Permalink
Update backtest_sites.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AUdaltsova authored Oct 23, 2024
1 parent e5f34ca commit 0b853b1
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions scripts/backtest_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,31 @@ class PadForwardPVIterDataPipe(IterDataPipe):
to run out of data to slice for the forecast part.
"""

def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64):
def __init__(self, pv_dp: IterDataPipe,
forecast_duration: np.timedelta64,
history_duration: np.timedelta64,
time_resolution_minutes: np.timedelta64
):
"""Init"""

super().__init__()
self.pv_dp = pv_dp
self.forecast_duration = forecast_duration
self.history_duration = history_duration
self.time_resolution_minutes = time_resolution_minutes

self.min_seq_length = history_duration // time_resolution_minutes

def __iter__(self):
"""Iter"""

for xr_data in self.pv_dp:
t0 = xr_data.time_utc.data[int(xr_data.attrs["t0_idx"])]
pv_step = np.timedelta64(xr_data.attrs["sample_period_duration"])
t_end = t0 + self.forecast_duration + pv_step
time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step)
t_end = xr_data.time_utc.data[0] + self.history_duration + self.forecast_duration + self.time_resolution_minutes
time_idx = np.arange(xr_data.time_utc.data[0], t_end, self.time_resolution_minutes)

if len(xr_data.time_utc.data) < self.min_seq_length:
raise ValueError("Not enough PV data to predict")

yield xr_data.reindex(time_utc=time_idx, fill_value=-1)


Expand Down Expand Up @@ -429,8 +439,10 @@ def get_datapipe(config_path: str) -> NumpyBatch:
)

config = load_yaml_configuration(config_path)
data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv(
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m")
data_pipeline['pv'] = data_pipeline['pv'].pad_forward_pv(
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, 'm'),
history_duration=np.timedelta64(config.input_data.pv.history_minutes, 'm'),
time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, 'm')
)

data_pipeline = DictDatasetIterDataPipe(
Expand Down

0 comments on commit 0b853b1

Please sign in to comment.