Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 23, 2024
1 parent 0b853b1 commit fdda5e3
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions scripts/backtest_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,13 @@ class PadForwardPVIterDataPipe(IterDataPipe):
to run out of data to slice for the forecast part.
"""

def __init__(self, pv_dp: IterDataPipe,
forecast_duration: np.timedelta64,
history_duration: np.timedelta64,
time_resolution_minutes: np.timedelta64
):
def __init__(
self,
pv_dp: IterDataPipe,
forecast_duration: np.timedelta64,
history_duration: np.timedelta64,
time_resolution_minutes: np.timedelta64,
):
"""Init"""

super().__init__()
Expand All @@ -137,7 +139,12 @@ def __iter__(self):
"""Iter"""

for xr_data in self.pv_dp:
t_end = xr_data.time_utc.data[0] + self.history_duration + self.forecast_duration + self.time_resolution_minutes
t_end = (
xr_data.time_utc.data[0]
+ self.history_duration
+ self.forecast_duration
+ self.time_resolution_minutes
)
time_idx = np.arange(xr_data.time_utc.data[0], t_end, self.time_resolution_minutes)

if len(xr_data.time_utc.data) < self.min_seq_length:
Expand Down Expand Up @@ -439,10 +446,10 @@ def get_datapipe(config_path: str) -> NumpyBatch:
)

config = load_yaml_configuration(config_path)
data_pipeline['pv'] = data_pipeline['pv'].pad_forward_pv(
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, 'm'),
history_duration=np.timedelta64(config.input_data.pv.history_minutes, 'm'),
time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, 'm')
data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv(
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m"),
history_duration=np.timedelta64(config.input_data.pv.history_minutes, "m"),
time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, "m"),
)

data_pipeline = DictDatasetIterDataPipe(
Expand Down

0 comments on commit fdda5e3

Please sign in to comment.