Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch datapipes #29

Merged
merged 5 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 48 additions & 244 deletions pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,44 @@

import logging
import os
import yaml
import tempfile
import warnings
from datetime import datetime, timedelta, timezone
from datetime import timedelta


import fsspec
import numpy as np
import pandas as pd
import torch
import typer
import xarray as xr
import xesmf as xe
import dask
from nowcasting_datamodel.connection import DatabaseConnection
from nowcasting_datamodel.models import (
ForecastSQL,
ForecastValue,
)
from nowcasting_datamodel.read.read import (
get_latest_input_data_last_updated,
get_location,
get_model,
)
from nowcasting_datamodel.save.save import save as save_sql_forecasts
from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities
from nowcasting_datamodel.connection import DatabaseConnection
from nowcasting_datamodel.models.base import Base_Forecast
from ocf_datapipes.load import OpenGSPFromDatabase
from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline
from ocf_datapipes.transform.numpy.batch.sun_position import ELEVATION_MEAN, ELEVATION_STD
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from pvnet_summation.models.base_model import BaseModel as SummationBaseModel
from sqlalchemy.orm import Session
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper

import pvnet
from pvnet.data.datamodule import batch_to_tensor, copy_batch_to_device
from pvnet.models.base_model import BaseModel as PVNetBaseModel
from pvnet.utils import GSPLocationLookup

import pvnet_app
from pvnet_app.utils import (
worker_init_fn, populate_data_config_sources, convert_dataarray_to_forecasts, preds_to_dataarray
)
from pvnet_app.data import regrid_nwp_data, download_sat_data, download_nwp_data

# ---------------------------------------------------------------------------
# GLOBAL SETTINGS

# TODO: Host data config alongside model?
this_dir = os.path.dirname(os.path.abspath(__file__))

# Model will use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -76,7 +65,7 @@
# Huggingfacehub model repo and commit for PVNet summation (GSP sum to national model)
# If summation_model_name is set to None, a simple sum is computed instead
default_summation_model_name = "openclimatefix/pvnet_v2_summation"
default_summation_model_version = "01393d6e4a036103f9c7111cba6f03d5c19beb54"
default_summation_model_version = "6c5361101b461ae991662bdff05f7a0b77b4040b"

model_name_ocf_db = "pvnet_v2"
use_adjuster = os.getenv("USE_ADJUSTER", "True").lower() == "true"
Expand Down Expand Up @@ -104,168 +93,7 @@
sql_logger.addHandler(logging.NullHandler())

# ---------------------------------------------------------------------------
# HELPER FUNCTIONS

def regrid_nwp_data(nwp_path):
"""This function loads the NWP data, then regrids and saves it back out if the data is not on
the same grid as expected. The data is resaved in-place.
"""
ds_raw = xr.open_zarr(nwp_path)

# These are the coords we are aiming for
ds_target_coords = xr.load_dataset(f"{this_dir}/../data/nwp_target_coords.nc")

# Check if regridding step needs to be done
needs_regridding = not (
ds_raw.latitude.equals(ds_target_coords.latitude) and
ds_raw.longitude.equals(ds_target_coords.longitude)

)

if not needs_regridding:
logger.info("No NWP regridding required - skipping this step")
return

logger.info("Regridding NWP to expected grid")

# Pull the raw data into RAM
ds_raw = ds_raw.compute()

# Regrid in RAM efficient way by chunking first. Each step is regridded separately
regridder = xe.Regridder(ds_raw, ds_target_coords, method="bilinear")
ds_regridded = regridder(
ds_raw.chunk(dict(x=-1, y=-1, step=1))
).compute(scheduler="single-threaded")

# Re-save - including rechunking
os.system(f"rm -fr {nwp_path}")
ds_regridded["variable"] = ds_regridded["variable"].astype(str)
ds_regridded.chunk(dict(step=12, x=100, y=100)).to_zarr(nwp_path)

return


def populate_data_config_sources(input_path, output_path):
"""Resave the data config and replace the source filepaths

Args:
input_path: Path to input datapipes configuration file
output_path: Location to save the output configuration file
"""
with open(input_path) as infile:
config = yaml.load(infile, Loader=yaml.FullLoader)

production_paths = {
"gsp": os.environ["DB_URL"],
"nwp": "nwp.zarr",
"satellite": "sat.zarr.zip",
# TODO: include hrvsatellite
}

# Replace data sources
for source in ["gsp", "nwp", "satellite", "hrvsatellite"]:
if source in config["input_data"]:
# If not empty - i.e. if used
if config["input_data"][source][f"{source}_zarr_path"]!="":
assert source in production_paths, f"Missing production path: {source}"
config["input_data"][source][f"{source}_zarr_path"] = production_paths[source]

# We do not need to set PV path right now. This currently done through datapipes
# TODO - Move the PV path to here

with open(output_path, 'w') as outfile:
yaml.dump(config, outfile, default_flow_style=False)


def convert_dataarray_to_forecasts(
forecast_values_dataarray: xr.DataArray, session: Session, model_name: str, version: str
) -> list[ForecastSQL]:
"""
Make a ForecastSQL object from a DataArray.

Args:
forecast_values_dataarray: Dataarray of forecasted values. Must have `target_datetime_utc`
`gsp_id`, and `output_label` coords. The `output_label` coords must have `"forecast_mw"`
as an element.
session: database session
model_name: the name of the model
version: the version of the model
Return:
List of ForecastSQL objects
"""
logger.debug("Converting DataArray to list of ForecastSQL")

assert "target_datetime_utc" in forecast_values_dataarray.coords
assert "gsp_id" in forecast_values_dataarray.coords
assert "forecast_mw" in forecast_values_dataarray.output_label

# get last input data
input_data_last_updated = get_latest_input_data_last_updated(session=session)

# get model name
model = get_model(name=model_name, version=version, session=session)

forecasts = []

for gsp_id in forecast_values_dataarray.gsp_id.values:
gsp_id = int(gsp_id)
# make forecast values
forecast_values = []

# get location
location = get_location(session=session, gsp_id=gsp_id)

gsp_forecast_values_da = forecast_values_dataarray.sel(gsp_id=gsp_id)

for target_time in pd.to_datetime(gsp_forecast_values_da.target_datetime_utc.values):
# add timezone
target_time_utc = target_time.replace(tzinfo=timezone.utc)
this_da = gsp_forecast_values_da.sel(target_datetime_utc=target_time)

forecast_value_sql = ForecastValue(
target_time=target_time_utc,
expected_power_generation_megawatts=(
this_da.sel(output_label="forecast_mw").item()
),
).to_orm()

forecast_value_sql.adjust_mw = 0.0

properties = {}

if "forecast_mw_plevel_10" in gsp_forecast_values_da.output_label:
val = this_da.sel(output_label="forecast_mw_plevel_10").item()
# `val` can be NaN if PVNet has probabilistic outputs and PVNet_summation doesn't,
# or if PVNet_summation has probabilistic outputs and PVNet doesn't.
# Do not log the value if NaN
if not np.isnan(val):
properties["10"] = val

if "forecast_mw_plevel_90" in gsp_forecast_values_da.output_label:
val = this_da.sel(output_label="forecast_mw_plevel_90").item()

if not np.isnan(val):
properties["90"] = val

if len(properties)>0:
forecast_value_sql.properties = properties

forecast_values.append(forecast_value_sql)

# make forecast object
forecast = ForecastSQL(
model=model,
forecast_creation_time=datetime.now(tz=timezone.utc),
location=location,
input_data_last_updated=input_data_last_updated,
forecast_values=forecast_values,
historic=False,
)

forecasts.append(forecast)

return forecasts

# APP MAIN

def app(
t0=None,
Expand Down Expand Up @@ -293,6 +121,9 @@ def app(

if num_workers == -1:
num_workers = os.cpu_count() - 1
if num_workers>0:
# Without this line the dataloader will hang if multiple workers are used
dask.config.set(scheduler='single-threaded')

logger.info(f"Using `pvnet` library version: {pvnet.__version__}")
logger.info(f"Using {num_workers} workers")
Expand Down Expand Up @@ -343,23 +174,15 @@ def app(
gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)

# Download satellite data
logger.info("Downloading zipped satellite data")
fs = fsspec.open(os.environ["SATELLITE_ZARR_PATH"]).fs
fs.get(os.environ["SATELLITE_ZARR_PATH"], "sat.zarr.zip")

# Also download 15-minute satellite if it exists
sat_latest_15 = os.environ["SATELLITE_ZARR_PATH"].replace(".zarr.zip", "_15.zarr.zip")
if fs.exists(sat_latest_15):
logger.info("Downloading 15-minute satellite data")
fs.get(sat_latest_15, "sat_15.zarr.zip")

# Download nwp data
logger.info("Downloading nwp data")
fs = fsspec.open(os.environ["NWP_ZARR_PATH"]).fs
fs.get(os.environ["NWP_ZARR_PATH"], "nwp.zarr", recursive=True)
logger.info("Downloading satellite data")
download_sat_data()

# Download NWP data
logger.info("Downloading NWP data")
download_nwp_data()

# Regrid the nwp data if needed
regrid_nwp_data("nwp.zarr")
# Regrid the NWP data if needed
regrid_nwp_data()

# ---------------------------------------------------------------------------
# 2. Set up data loader
Expand All @@ -373,6 +196,7 @@ def app(
# Populate the data config with production data paths
temp_dir = tempfile.TemporaryDirectory()
populated_data_config_filename = f"{temp_dir.name}/data_config.yaml"

populate_data_config_sources(data_config_filename, populated_data_config_filename)

# Location and time datapipes
Expand All @@ -396,12 +220,22 @@ def app(
)

# Set up dataloader for parallel loading
rs = MultiProcessingReadingService(
dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=0 if num_workers == 0 else 2,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=worker_init_fn,
prefetch_factor=None if num_workers == 0 else 2,
persistent_workers=False,
)
dataloader = DataLoader2(batch_datapipe, reading_service=rs)

dataloader = DataLoader(batch_datapipe, **dataloader_kwargs)

# ---------------------------------------------------------------------------
# 3. set up model
Expand Down Expand Up @@ -469,6 +303,10 @@ def app(
sun_down_masks = np.concatenate(sun_down_masks)

gsp_ids_all_batches = np.concatenate(gsp_ids_each_batch).squeeze()

n_times = normed_preds.shape[1]

valid_times = pd.to_datetime([t0 + timedelta(minutes=30 * (i + 1)) for i in range(n_times)])

# Reorder GSP order which ends up shuffled if multiprocessing is used
inds = gsp_ids_all_batches.argsort()
Expand All @@ -483,36 +321,14 @@ def app(
# 5. Merge batch results to xarray DataArray
logger.info("Processing raw predictions to DataArray")

n_times = normed_preds.shape[1]

if model.use_quantile_regression:
output_labels = model.output_quantiles
output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles]
output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw"
else:
output_labels = ["forecast_mw"]
normed_preds = normed_preds[..., np.newaxis]

da_normed = xr.DataArray(
data=normed_preds,
dims=["gsp_id", "target_datetime_utc", "output_label"],
coords=dict(
gsp_id=gsp_ids_all_batches,
target_datetime_utc=pd.to_datetime(
[t0 + timedelta(minutes=30 * (i + 1)) for i in range(n_times)],
),
output_label=output_labels,
),
)
da_normed = preds_to_dataarray(normed_preds, model, valid_times, gsp_ids_all_batches)

da_sundown_mask = xr.DataArray(
data=sun_down_masks,
dims=["gsp_id", "target_datetime_utc"],
coords=dict(
gsp_id=gsp_ids_all_batches,
target_datetime_utc=pd.to_datetime(
[t0 + timedelta(minutes=30 * (i + 1)) for i in range(n_times)],
),
target_datetime_utc=valid_times,
),
)

Expand Down Expand Up @@ -545,23 +361,11 @@ def app(
normed_national = summation_model(inputs).detach().squeeze().cpu().numpy()

# Convert national predictions to DataArray
if summation_model.use_quantile_regression:
sum_output_labels = summation_model.output_quantiles
sum_output_labels = [
f"forecast_mw_plevel_{int(q*100):02}" for q in summation_model.output_quantiles
]
sum_output_labels[sum_output_labels.index("forecast_mw_plevel_50")] = "forecast_mw"
else:
sum_output_labels = ["forecast_mw"]

da_normed_national = xr.DataArray(
data=normed_national[np.newaxis],
dims=["gsp_id", "target_datetime_utc", "output_label"],
coords=dict(
gsp_id=[0],
target_datetime_utc=da_abs.target_datetime_utc,
output_label=sum_output_labels,
),
da_normed_national = preds_to_dataarray(
normed_national[np.newaxis],
summation_model,
valid_times,
gsp_ids=[0]
)

# Multiply normalised forecasts by capacities and clip negatives
Expand Down
Loading
Loading