Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 17, 2024
1 parent 31581ac commit 3acf267
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions scripts/backtest_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
except RuntimeError:
pass

import json
import logging
import os
import sys
Expand All @@ -32,6 +33,8 @@
import pandas as pd
import torch
import xarray as xr
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from ocf_datapipes.batch import (
BatchKey,
NumpyBatch,
Expand All @@ -57,10 +60,6 @@
from pvnet.load_model import get_model_from_checkpoints
from pvnet.utils import SiteLocationLookup

import json
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME

# ------------------------------------------------------------------
# USER CONFIGURED VARIABLES TO RUN THE SCRIPT

Expand Down Expand Up @@ -109,7 +108,7 @@
# FUNCTIONS


@functional_datapipe('pad_forward_pv')
@functional_datapipe("pad_forward_pv")
class PadForwardPVIterDataPipe(IterDataPipe):
"""
Pads forecast pv. Sun position is calculated based off of pv time index
Expand All @@ -128,8 +127,8 @@ def __iter__(self):
"""Iter"""

for xr_data in self.pv_dp:
t0 = xr_data.time_utc.data[int(xr_data.attrs['t0_idx'])]
pv_step = np.timedelta64(xr_data.attrs['sample_period_duration'])
t0 = xr_data.time_utc.data[int(xr_data.attrs["t0_idx"])]
pv_step = np.timedelta64(xr_data.attrs["sample_period_duration"])
t_end = t0 + self.forecast_duration + pv_step
time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step)
yield xr_data.reindex(time_utc=time_idx, fill_value=-1)
Expand Down Expand Up @@ -424,8 +423,8 @@ def get_datapipe(config_path: str) -> NumpyBatch:
)

config = load_yaml_configuration(config_path)
data_pipeline['pv'] = data_pipeline['pv'].pad_forward_pv(
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, 'm')
data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv(
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m")
)

data_pipeline = DictDatasetIterDataPipe(
Expand Down

0 comments on commit 3acf267

Please sign in to comment.