Skip to content

Commit

Permalink
update backtest_sites.py (#266)
Browse files Browse the repository at this point in the history
* Update backtest_sites.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AUdaltsova and pre-commit-ci[bot] authored Oct 23, 2024
1 parent e5f34ca commit 05abb8b
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions scripts/backtest_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,38 @@ class PadForwardPVIterDataPipe(IterDataPipe):
to run out of data to slice for the forecast part.
"""

def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64):
def __init__(
self,
pv_dp: IterDataPipe,
forecast_duration: np.timedelta64,
history_duration: np.timedelta64,
time_resolution_minutes: np.timedelta64,
):
"""Init"""

super().__init__()
self.pv_dp = pv_dp
self.forecast_duration = forecast_duration
self.history_duration = history_duration
self.time_resolution_minutes = time_resolution_minutes

self.min_seq_length = history_duration // time_resolution_minutes

def __iter__(self):
"""Iter"""

for xr_data in self.pv_dp:
t0 = xr_data.time_utc.data[int(xr_data.attrs["t0_idx"])]
pv_step = np.timedelta64(xr_data.attrs["sample_period_duration"])
t_end = t0 + self.forecast_duration + pv_step
time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step)
t_end = (
xr_data.time_utc.data[0]
+ self.history_duration
+ self.forecast_duration
+ self.time_resolution_minutes
)
time_idx = np.arange(xr_data.time_utc.data[0], t_end, self.time_resolution_minutes)

if len(xr_data.time_utc.data) < self.min_seq_length:
raise ValueError("Not enough PV data to predict")

yield xr_data.reindex(time_utc=time_idx, fill_value=-1)


Expand Down Expand Up @@ -430,7 +447,9 @@ def get_datapipe(config_path: str) -> NumpyBatch:

config = load_yaml_configuration(config_path)
data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv(
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m")
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m"),
history_duration=np.timedelta64(config.input_data.pv.history_minutes, "m"),
time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, "m"),
)

data_pipeline = DictDatasetIterDataPipe(
Expand Down

0 comments on commit 05abb8b

Please sign in to comment.