Skip to content

Commit

Permalink
Merge pull request #68 from openclimatefix/less_sat_delay
Browse files Browse the repository at this point in the history
Update app to use model which uses less satellite delay
  • Loading branch information
peterdudfield authored Apr 18, 2024
2 parents 2d4d017 + 95f16a2 commit 1679e06
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 186 deletions.
93 changes: 46 additions & 47 deletions pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
from ocf_datapipes.load import OpenGSPFromDatabase
from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
from ocf_datapipes.batch import BatchKey, stack_np_examples_into_batch
from ocf_datapipes.batch import (
BatchKey, stack_np_examples_into_batch, batch_to_tensor, copy_batch_to_device
)
from pvnet_summation.models.base_model import BaseModel as SummationBaseModel
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper

import pvnet
from pvnet.data.utils import batch_to_tensor, copy_batch_to_device
from pvnet.models.base_model import BaseModel as PVNetBaseModel
from pvnet.utils import GSPLocationLookup

Expand Down Expand Up @@ -61,8 +62,8 @@
batch_size = 10

# Huggingfacehub model repo and commit for PVNet (GSP-level model)
default_model_name = "openclimatefix/pvnet_v2"
default_model_version = "5ed2b179974993d8804a1e60fdc850dc547e9025"
default_model_name = "openclimatefix/pvnet_uk_region"
default_model_version = "9cc2bf5859e129b3816041b657c8875d31ced0d6"

# Huggingfacehub model repo and commit for PVNet summation (GSP sum to national model)
# If summation_model_name is set to None, a simple sum is computed instead
Expand All @@ -72,12 +73,6 @@
model_name_ocf_db = "pvnet_v2"
use_adjuster = os.getenv("USE_ADJUSTER", "True").lower() == "true"

# If environmental variable is true, the sum-of-GSPs will be computed and saved under a different
# model name. This can be useful to compare against the summation model and therefore monitor its
# performance in production
save_gsp_sum = os.getenv("SAVE_GSP_SUM", "False").lower() == "true"
gsp_sum_model_name_ocf_db = "pvnet_gsp_sum"

# ---------------------------------------------------------------------------
# LOGGER

Expand Down Expand Up @@ -137,6 +132,12 @@ def app(
# Without this line the dataloader will hang if multiple workers are used
dask.config.set(scheduler='single-threaded')

# If environmental variable is true, the sum-of-GSPs will be computed and saved under a different
# model name. This can be useful to compare against the summation model and therefore monitor its
# performance in production
gsp_sum_model_name_ocf_db = "pvnet_gsp_sum"
save_gsp_sum = os.getenv("SAVE_GSP_SUM", "False").lower() == "true"

logger.info(f"Using `pvnet` library version: {pvnet.__version__}")
logger.info(f"Using {num_workers} workers")
logger.info(f"Using adjduster: {use_adjuster}")
Expand Down Expand Up @@ -164,10 +165,42 @@ def app(
logger.info(f"Making forecast for GSP IDs: {gsp_ids}")

# ---------------------------------------------------------------------------
# 1. Prepare data sources
# 1. set up model
logger.info(f"Loading model: {model_name} - {model_version}")

# Make pands Series of most recent GSP effective capacities
model = PVNetBaseModel.from_pretrained(
model_name,
revision=model_version,
).to(device)

if summation_model_name is not None:
summation_model = SummationBaseModel.from_pretrained(
summation_model_name,
revision=summation_model_version,
).to(device)

if (
summation_model.pvnet_model_name != model_name
or summation_model.pvnet_model_version != model_version
):
warnings.warn(
f"The PVNet version running in this app is {model_name}/{model_version}. "
"The summation model running in this app was trained on outputs from PVNet version "
f"{summation_model.pvnet_model_name}/{summation_model.pvnet_model_version}. "
"Combining these models may lead to an error if the shape of PVNet output doesn't "
"match the expected shape of the summation model. Combining may lead to unreliable "
"results even if the shapes match."
)
# ---------------------------------------------------------------------------
# 2. Prepare data sources

# Pull the data config from huggingface
data_config_filename = PVNetBaseModel.get_data_config(
model_name,
revision=model_version,
)

# Make pands Series of most recent GSP effective capacities
logger.info("Loading GSP metadata")

ds_gsp = next(iter(OpenGSPFromDatabase()))
Expand All @@ -190,7 +223,7 @@ def app(
download_all_sat_data()

# Process the 5/15 minutely satellite data
preprocess_sat_data(t0)
preprocess_sat_data(t0, data_config_filename)

# Download NWP data
logger.info("Downloading NWP data")
Expand All @@ -202,12 +235,6 @@ def app(
# ---------------------------------------------------------------------------
# 2. Set up data loader
logger.info("Creating DataLoader")

# Pull the data config from huggingface
data_config_filename = PVNetBaseModel.get_data_config(
model_name,
revision=model_version,
)

# Populate the data config with production data paths
temp_dir = tempfile.TemporaryDirectory()
Expand All @@ -229,7 +256,6 @@ def app(
location_pipe=location_pipe,
t0_datapipe=t0_datapipe,
production=True,
check_satellite_no_zeros=True,
)
.batch(batch_size)
.map(stack_np_examples_into_batch)
Expand All @@ -253,33 +279,7 @@ def app(

dataloader = DataLoader(batch_datapipe, **dataloader_kwargs)

# ---------------------------------------------------------------------------
# 3. set up model
logger.info(f"Loading model: {model_name} - {model_version}")

model = PVNetBaseModel.from_pretrained(
model_name,
revision=model_version,
).to(device)

if summation_model_name is not None:
summation_model = SummationBaseModel.from_pretrained(
summation_model_name,
revision=summation_model_version,
).to(device)

if (
summation_model.pvnet_model_name != model_name
or summation_model.pvnet_model_version != model_version
):
warnings.warn(
f"The PVNet version running in this app is {model_name}/{model_version}. "
"The summation model running in this app was trained on outputs from PVNet version "
f"{summation_model.pvnet_model_name}/{summation_model.pvnet_model_version}. "
"Combining these models may lead to an error if the shape of PVNet output doesn't "
"match the expected shape of the summation model. Combining may lead to unreliable "
"results even if the shapes match."
)

# 4. Make prediction
logger.info("Processing batches")
Expand Down Expand Up @@ -429,7 +429,6 @@ def app(
sql_forecasts = convert_dataarray_to_forecasts(
da_abs_all, session, model_name=model_name_ocf_db, version=pvnet_app.__version__
)

save_sql_forecasts(
forecasts=sql_forecasts,
session=session,
Expand Down
124 changes: 100 additions & 24 deletions pvnet_app/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import fsspec
from datetime import timedelta
import ocf_blosc2
from ocf_datapipes.config.load import load_yaml_configuration

from pvnet_app.consts import sat_path, nwp_ukv_path, nwp_ecmwf_path

Expand Down Expand Up @@ -38,47 +39,122 @@ def download_all_sat_data():
logger.info(f"Downloading 15-minute satellite data {sat_15_dl_path}")
fs.get(sat_15_dl_path, "sat_15_min.zarr.zip")
os.system(f"unzip sat_15_min.zarr.zip -d {sat_15_path}")


def _get_latest_time_and_mins_delay(sat_zarr_path, t0):
ds_sat = xr.open_zarr(sat_zarr_path)
latest_time = pd.to_datetime(ds_sat.time.max().item())
delay = t0 - latest_time
delay_mins = int(delay.total_seconds() / 60)
return latest_time, delay_mins


def preprocess_sat_data(t0):
def combine_5_and_15_sat_data(t0, max_sat_delay_allowed_mins):
"""Select and/or combine the 5 and 15-minutely satellite data"""

use_15_minute = False
if not os.path.exists(sat_5_path):
use_15_minute = True
logger.debug(f"5-minute satellite data not found at {sat_5_path}. "
f"Using 15-minute data.")
use_5_minute = os.path.exists(sat_5_path)
if not use_5_minute:
logger.info(f"5-minute satellite data not found at {sat_5_path}. Trying 15-minute data.")
else:
ds_sat_5 = xr.open_zarr(sat_5_path)
latest_time_5 = pd.to_datetime(ds_sat_5.time.max().values)
sat_delay_5 = t0 - latest_time_5 #Timedelta for the delay
sat_delay_minutes = int(sat_delay_5.total_seconds() / 60) #To see the timedelta in minutes
latest_time_5, delay_mins_5 = _get_latest_time_and_mins_delay(sat_5_path, t0)
logger.info(f"Latest 5-minute timestamp is {latest_time_5} for t0 time {t0}.")

if sat_delay_minutes < 60:
logger.info(f"5-min satellite delay is only {sat_delay_minutes} minutes - Using 5-minutely data.")

if delay_mins_5 <= max_sat_delay_allowed_mins:
logger.info(
f"5-min satellite delay is only {delay_mins_5} minutes. "
f"Maximum delay for this model is {max_sat_delay_allowed_mins} minutes - "
"Using 5-minutely data."
)
os.system(f"mv {sat_5_path} {sat_path}")
else:
use_15_minute = True
logger.info(f"5-min satellite delay is {sat_delay_minutes} minutes - "
f"Switching to 15-minutely data.")
logger.info(
f"5-min satellite delay is {delay_mins_5} minutes. "
f"Maximum delay for this model is {max_sat_delay_allowed_mins} minutes - "
"Trying 15-minutely data."
)
use_5_minute = False

if use_15_minute:
logger.info(f"Using 15-minute satellite data")
if not use_5_minute:
# Make sure the 15-minute data is actually there
if not os.path.exists(sat_15_path):
raise ValueError(f"5-minute satellite data not found at {sat_15_path}")

ds_sat_15 = xr.open_zarr(sat_15_path)
latest_time_15 = pd.to_datetime(ds_sat_15.time.max().values)
latest_time_15, delay_mins_15 = _get_latest_time_and_mins_delay(sat_15_path, t0)
logger.info(f"Latest 15-minute timestamp is {latest_time_15} for t0 time {t0}.")

logger.debug("Resampling 15 minute data to 5 mins")
# If the 15-minute satellite data is too delayed the run fails
if delay_mins_15 > max_sat_delay_allowed_mins:
raise ValueError(
f"15-min satellite delay is {delay_mins_15} minutes. "
f"Maximum delay for this model is {max_sat_delay_allowed_mins} minutes"
)

ds_sat_15 = xr.open_zarr(sat_15_path)

#logger.debug("Resampling 15 minute data to 5 mins")
#ds_sat_15.resample(time="5T").interpolate("linear").to_zarr(sat_path)
ds_sat_15.attrs["source"] = "15-minute"

logger.debug(f"Saving 15 minute data to {sat_path}")
ds_sat_15.to_zarr(sat_path)


def extend_satellite_data_with_nans(t0, min_sat_delay_used_mins):
"""Fill the satellite data with NaNs if needed by the model"""

# Check how the expected satellite delay compares with the satellite data available and fill
# if required
latest_time, delay_mins = _get_latest_time_and_mins_delay(sat_path, t0)

if min_sat_delay_used_mins < delay_mins:
fill_mins = delay_mins - min_sat_delay_used_mins
logger.info(f"Filling most recent {fill_mins} mins with NaNs")

# Load into memory so we can delete it on disk
ds_sat = xr.open_zarr(sat_path).compute()

# Pad with zeros
fill_times = pd.date_range(
latest_time+timedelta(minutes=5),
latest_time+timedelta(minutes=fill_mins),
freq="5min"
)


ds_sat = ds_sat.reindex(time=np.concatenate([ds_sat.time, fill_times]), fill_value=np.nan)

# Re-save inplace
os.system(f"rm -rf {sat_path}")
ds_sat.to_zarr(sat_path)


def preprocess_sat_data(t0, data_config_filename):

# Find the max delay w.r.t t0 that this model was trained with
data_config = load_yaml_configuration(data_config_filename)

# Take into account how recently the model tries to slice data from
max_sat_delay_allowed_mins = data_config.input_data.satellite.live_delay_minutes

# Take into account the dropout the model was trained with, if any
if data_config.input_data.satellite.dropout_fraction>0:
max_sat_delay_allowed_mins = max(
max_sat_delay_allowed_mins,
np.abs(data_config.input_data.satellite.dropout_timedeltas_minutes).max()
)

# The model will not ever try to use data more recent than this
min_sat_delay_used_mins = data_config.input_data.satellite.live_delay_minutes

# Deal with switching between the 5 and 15 minutely satellite data
combine_5_and_15_sat_data(t0, max_sat_delay_allowed_mins)

# Extend the satellite data with NaNs if needed by the model
extend_satellite_data_with_nans(t0, min_sat_delay_used_mins)

ds_sat = xr.open_zarr(sat_path)
ds_sat.data.isnull().mean().compute()
#assert False

return use_15_minute


def _download_nwp_data(source, destination):
fs = fsspec.open(source).fs
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
pydantic
pytorch-lightning==2.1.3
torch[cpu]==2.2.0
PVNet-summation==0.1.3
pvnet==3.0.11
ocf_datapipes==3.2.11
nowcasting_datamodel>=1.5.30
PVNet-summation==0.1.4
pvnet==3.0.25
ocf_datapipes==3.3.19
nowcasting_datamodel>=1.5.39
fsspec[s3]
xarray
zarr
Expand Down
Loading

0 comments on commit 1679e06

Please sign in to comment.