Skip to content

Commit

Permalink
Merge pull request #118 from openclimatefix/ocf-data-sampler
Browse files Browse the repository at this point in the history
Ocf data sampler
  • Loading branch information
peterdudfield authored Sep 16, 2024
2 parents 6ff0f78 + 5f62f1f commit 175615f
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 115 deletions.
147 changes: 57 additions & 90 deletions pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,51 +11,42 @@
"""

import logging
import warnings
import os
import tempfile
import warnings
from datetime import timedelta


import numpy as np
import dask
import pandas as pd
import pvnet
import torch
import typer
import dask
from nowcasting_datamodel.connection import DatabaseConnection
from nowcasting_datamodel.save.save import save as save_sql_forecasts
from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities
from nowcasting_datamodel.models.base import Base_Forecast
from ocf_datapipes.load import OpenGSPFromDatabase
from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline
from ocf_datapipes.batch import stack_np_examples_into_batch, batch_to_tensor, copy_batch_to_device

from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper

import pvnet
from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities
from nowcasting_datamodel.save.save import save as save_sql_forecasts
from ocf_datapipes.batch import batch_to_tensor, copy_batch_to_device
from pvnet.models.base_model import BaseModel as PVNetBaseModel
from pvnet.utils import GSPLocationLookup
import sentry_sdk


import pvnet_app
from pvnet_app.utils import (
worker_init_fn,
populate_data_config_sources,
convert_dataarray_to_forecasts,
find_min_satellite_delay_config,
save_yaml_config,
)
from pvnet_app.data.nwp import download_all_nwp_data, preprocess_nwp_data
from pvnet_app.data.satellite import (
download_all_sat_data,
preprocess_sat_data,
check_model_inputs_available,
)
from pvnet_app.data.nwp import (
download_all_nwp_data,
preprocess_nwp_data,
)
from pvnet_app.forecast_compiler import ForecastCompiler
from pvnet_app.utils import (
populate_data_config_sources,
convert_dataarray_to_forecasts,
find_min_satellite_delay_config,
save_yaml_config,
)

from pvnet_app.dataloader import get_legacy_dataloader, get_dataloader


# sentry
sentry_sdk.init(
Expand Down Expand Up @@ -84,11 +75,12 @@
# - Batches are prepared only once, so the extra models must be able to run on the batches created
# to run the pvnet_v2 model
models_dict = {

"pvnet_v2": {
# Huggingfacehub model repo and commit for PVNet (GSP-level model)
"pvnet": {
"name": "openclimatefix/pvnet_uk_region",
"version": os.getenv('PVNET_V2_VERSION', "62e5e20ab793cee7cf94eadac870d2199501a730"),
"version": os.getenv('PVNET_V2_VERSION', "ae0b8006841ac6227db873a1fc7f7331dc7dadb5"),
# We should only set PVNET_V2_VERSION in a short term solution,
# as its difficult to track which model is being used
},
Expand All @@ -107,11 +99,12 @@
"verbose": True,
"save_gsp_to_forecast_value_last_seven_days": True,
},

# Extra models which will be run on dev only
"pvnet_v2-sat0min-v12-batches": {
"pvnet_v2-sat0-samples-v1": {
"pvnet": {
"name": "openclimatefix/pvnet_uk_region",
"version": "dce387462ee08401355f33f53e86461dd59663e2",
"version": "8a7cc21b64d25ce1add7a8547674be3143b2e650",
},
"summation": {
"name": "openclimatefix/pvnet_v2_summation",
Expand All @@ -122,11 +115,12 @@
"verbose": False,
"save_gsp_to_forecast_value_last_seven_days": False,
},

# single source models
"pvnet_v2-sat_delay0_only-v12-batches": {
"pvnet_v2-sat0-only-samples-v1": {
"pvnet": {
"name": "openclimatefix/pvnet_uk_region",
"version": "ea6ad2cf84152969c768788586df227976890f31",
"version": "d7ab648942c85b6788adcdbed44c91c4e1c5604a",
},
"summation": {
"name": "openclimatefix/pvnet_v2_summation",
Expand All @@ -138,10 +132,10 @@
"save_gsp_to_forecast_value_last_seven_days": False,
},

"pvnet_v2-ukv_only-v12-batches": {
"pvnet_v2-ukv-only-samples-v1": {
"pvnet": {
"name": "openclimatefix/pvnet_uk_region",
"version": "35d55181a82440bdd087f380d650bfd0b64bd322",
"version": "eb73bf9a176a108f2e33b809f1f6993f893a4df9",
},
"summation": {
"name": "openclimatefix/pvnet_v2_summation",
Expand All @@ -153,10 +147,10 @@
"save_gsp_to_forecast_value_last_seven_days": False,
},

"pvnet_v2-ecmwf_only-v12-batches": {
"pvnet_v2-ecmwf-only-samples-v1": {
"pvnet": {
"name": "openclimatefix/pvnet_uk_region",
"version": "c14f7427d9854d63430aa936ce45f55d3818d033",
"version": "0bc344fafb2232fb0b6bb0bf419f0449fe11c643",
},
"summation": {
"name": "openclimatefix/pvnet_v2_summation",
Expand All @@ -169,6 +163,8 @@
},
}

# The day ahead model has not yet been re-trained with data-sampler.
# It will be run with the legacy dataloader using ocf_datapipes
day_ahead_model_dict = {
"pvnet_day_ahead": {
# Huggingfacehub model repo and commit for PVNet day ahead models
Expand Down Expand Up @@ -257,19 +253,19 @@ def app(
# Without this line the dataloader will hang if multiple workers are used
dask.config.set(scheduler="single-threaded")

day_ahead_model_used = os.getenv("DAY_AHEAD_MODEL", "false").lower() == "true"
use_day_ahead_model = os.getenv("DAY_AHEAD_MODEL", "false").lower() == "true"
use_satellite = os.getenv("USE_SATELLITE", "true").lower() == "true"
logger.info(f"Using satellite data: {use_satellite}")
logger.info(f"Using day ahead model: {day_ahead_model_used}")
logger.info(f"Using day ahead model: {use_day_ahead_model}")

if day_ahead_model_used:
if use_day_ahead_model:
logger.info(f"Using day ahead PVNet model")

logger.info(f"Using `pvnet` library version: {pvnet.__version__}")
logger.info(f"Using `pvnet_app` library version: {pvnet_app.__version__}")
logger.info(f"Using {num_workers} workers")

if day_ahead_model_used:
if use_day_ahead_model:
logger.info(f"Using adjduster: {day_ahead_model_dict['pvnet_day_ahead']['use_adjuster']}")
logger.info(f"Saving GSP sum: {day_ahead_model_dict['pvnet_day_ahead']['save_gsp_sum']}")

Expand All @@ -296,26 +292,19 @@ def app(
# ---------------------------------------------------------------------------
# 1. Prepare data sources

# Make pands Series of most recent GSP effective capacities
logger.info("Loading GSP metadata")

ds_gsp = next(iter(OpenGSPFromDatabase()))

# Get capacities from the database
db_connection = DatabaseConnection(url=os.getenv("DB_URL"), base=Base_Forecast, echo=False)
with db_connection.get_session() as session:
#  Pandas series of most recent GSP capacities
now_minis_two_days = pd.Timestamp.now(tz="UTC") - timedelta(days=2)
gsp_capacities = get_latest_gsp_capacities(
session=session, gsp_ids=gsp_ids, datetime_utc=now_minis_two_days
session=session, gsp_ids=gsp_ids, datetime_utc=t0-timedelta(days=2)
)

# National capacity is needed if using summation model
national_capacity = get_latest_gsp_capacities(session, [0])[0]

# Set up ID location query object
gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)

# Download satellite data
if use_satellite:
logger.info("Downloading satellite data")
Expand All @@ -325,7 +314,7 @@ def app(
all_satellite_datetimes, data_freq_minutes = preprocess_sat_data(t0)
else:
all_satellite_datetimes = []
data_freq_minutes = 5
data_freq_minutes = None

# Download NWP data
logger.info("Downloading NWP data")
Expand All @@ -337,7 +326,7 @@ def app(
# ---------------------------------------------------------------------------
# 2. Set up models

if day_ahead_model_used:
if use_day_ahead_model:
model_to_run_dict = {"pvnet_day_ahead": day_ahead_model_dict["pvnet_day_ahead"]}
# Remove extra models if not configured to run them
elif os.getenv("RUN_EXTRA_MODELS", "false").lower() == "false":
Expand Down Expand Up @@ -393,47 +382,25 @@ def app(
# Set up data loader
logger.info("Creating DataLoader")

# Populate the data config with production data paths
populated_data_config_filename = f"{temp_dir.name}/data_config.yaml"

populate_data_config_sources(common_config_path, populated_data_config_filename)

# Location and time datapipes
location_pipe = IterableWrapper([gsp_id_to_loc(gsp_id) for gsp_id in gsp_ids])
t0_datapipe = IterableWrapper([t0]).repeat(len(location_pipe))

location_pipe = location_pipe.sharding_filter()
t0_datapipe = t0_datapipe.sharding_filter()

# Batch datapipe
batch_datapipe = (
construct_sliced_data_pipeline(
config_filename=populated_data_config_filename,
location_pipe=location_pipe,
t0_datapipe=t0_datapipe,
production=True,
if use_day_ahead_model:
# The current day ahead model uses the legacy dataloader
dataloader = get_legacy_dataloader(
config_filename=common_config_path,
t0=t0,
gsp_ids=gsp_ids,
batch_size=batch_size,
num_workers=num_workers,
)
.batch(batch_size)
.map(stack_np_examples_into_batch)
)

# Set up dataloader for parallel loading
dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=worker_init_fn,
prefetch_factor=None if num_workers == 0 else 2,
persistent_workers=False,
)

dataloader = DataLoader(batch_datapipe, **dataloader_kwargs)

else:
dataloader = get_dataloader(
config_filename=common_config_path,
t0=t0,
gsp_ids=gsp_ids,
batch_size=batch_size,
num_workers=num_workers,
)


# ---------------------------------------------------------------------------
# Make predictions
Expand All @@ -459,7 +426,7 @@ def app(
# Escape clause for making predictions locally
if not write_predictions:
temp_dir.cleanup()
if not day_ahead_model_used:
if not use_day_ahead_model:
return forecast_compilers["pvnet_v2"].da_abs_all
return forecast_compilers["pvnet_day_ahead"].da_abs_all

Expand Down
Loading

0 comments on commit 175615f

Please sign in to comment.