diff --git a/pvnet_app/app.py b/pvnet_app/app.py index 6b6a397..0dd96ea 100644 --- a/pvnet_app/app.py +++ b/pvnet_app/app.py @@ -8,31 +8,20 @@ import logging import os -import yaml import tempfile import warnings -from datetime import datetime, timedelta, timezone +from datetime import timedelta + -import fsspec import numpy as np import pandas as pd import torch import typer import xarray as xr -import xesmf as xe +import dask from nowcasting_datamodel.connection import DatabaseConnection -from nowcasting_datamodel.models import ( - ForecastSQL, - ForecastValue, -) -from nowcasting_datamodel.read.read import ( - get_latest_input_data_last_updated, - get_location, - get_model, -) from nowcasting_datamodel.save.save import save as save_sql_forecasts from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities -from nowcasting_datamodel.connection import DatabaseConnection from nowcasting_datamodel.models.base import Base_Forecast from ocf_datapipes.load import OpenGSPFromDatabase from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline @@ -40,9 +29,8 @@ from ocf_datapipes.utils.consts import BatchKey from ocf_datapipes.utils.utils import stack_np_examples_into_batch from pvnet_summation.models.base_model import BaseModel as SummationBaseModel -from sqlalchemy.orm import Session -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper +from torch.utils.data import DataLoader +from torch.utils.data.datapipes.iter import IterableWrapper import pvnet from pvnet.data.datamodule import batch_to_tensor, copy_batch_to_device @@ -50,13 +38,14 @@ from pvnet.utils import GSPLocationLookup import pvnet_app +from pvnet_app.utils import ( + worker_init_fn, populate_data_config_sources, convert_dataarray_to_forecasts, preds_to_dataarray +) +from pvnet_app.data import regrid_nwp_data, download_sat_data, download_nwp_data # --------------------------------------------------------------------------- # GLOBAL SETTINGS -# TODO: Host data config alongside model? -this_dir = os.path.dirname(os.path.abspath(__file__)) - # Model will use GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -76,7 +65,7 @@ # Huggingfacehub model repo and commit for PVNet summation (GSP sum to national model) # If summation_model_name is set to None, a simple sum is computed instead default_summation_model_name = "openclimatefix/pvnet_v2_summation" -default_summation_model_version = "01393d6e4a036103f9c7111cba6f03d5c19beb54" +default_summation_model_version = "6c5361101b461ae991662bdff05f7a0b77b4040b" model_name_ocf_db = "pvnet_v2" use_adjuster = os.getenv("USE_ADJUSTER", "True").lower() == "true" @@ -104,168 +93,7 @@ sql_logger.addHandler(logging.NullHandler()) # --------------------------------------------------------------------------- -# HELPER FUNCTIONS - -def regrid_nwp_data(nwp_path): - """This function loads the NWP data, then regrids and saves it back out if the data is not on - the same grid as expected. The data is resaved in-place. - """ - ds_raw = xr.open_zarr(nwp_path) - - # These are the coords we are aiming for - ds_target_coords = xr.load_dataset(f"{this_dir}/../data/nwp_target_coords.nc") - - # Check if regridding step needs to be done - needs_regridding = not ( - ds_raw.latitude.equals(ds_target_coords.latitude) and - ds_raw.longitude.equals(ds_target_coords.longitude) - - ) - - if not needs_regridding: - logger.info("No NWP regridding required - skipping this step") - return - - logger.info("Regridding NWP to expected grid") - - # Pull the raw data into RAM - ds_raw = ds_raw.compute() - - # Regrid in RAM efficient way by chunking first. Each step is regridded separately - regridder = xe.Regridder(ds_raw, ds_target_coords, method="bilinear") - ds_regridded = regridder( - ds_raw.chunk(dict(x=-1, y=-1, step=1)) - ).compute(scheduler="single-threaded") - - # Re-save - including rechunking - os.system(f"rm -fr {nwp_path}") - ds_regridded["variable"] = ds_regridded["variable"].astype(str) - ds_regridded.chunk(dict(step=12, x=100, y=100)).to_zarr(nwp_path) - - return - - -def populate_data_config_sources(input_path, output_path): - """Resave the data config and replace the source filepaths - - Args: - input_path: Path to input datapipes configuration file - output_path: Location to save the output configuration file - """ - with open(input_path) as infile: - config = yaml.load(infile, Loader=yaml.FullLoader) - - production_paths = { - "gsp": os.environ["DB_URL"], - "nwp": "nwp.zarr", - "satellite": "sat.zarr.zip", - # TODO: include hrvsatellite - } - - # Replace data sources - for source in ["gsp", "nwp", "satellite", "hrvsatellite"]: - if source in config["input_data"]: - # If not empty - i.e. if used - if config["input_data"][source][f"{source}_zarr_path"]!="": - assert source in production_paths, f"Missing production path: {source}" - config["input_data"][source][f"{source}_zarr_path"] = production_paths[source] - - # We do not need to set PV path right now. This currently done through datapipes - # TODO - Move the PV path to here - - with open(output_path, 'w') as outfile: - yaml.dump(config, outfile, default_flow_style=False) - - -def convert_dataarray_to_forecasts( - forecast_values_dataarray: xr.DataArray, session: Session, model_name: str, version: str -) -> list[ForecastSQL]: - """ - Make a ForecastSQL object from a DataArray. - - Args: - forecast_values_dataarray: Dataarray of forecasted values. Must have `target_datetime_utc` - `gsp_id`, and `output_label` coords. The `output_label` coords must have `"forecast_mw"` - as an element. - session: database session - model_name: the name of the model - version: the version of the model - Return: - List of ForecastSQL objects - """ - logger.debug("Converting DataArray to list of ForecastSQL") - - assert "target_datetime_utc" in forecast_values_dataarray.coords - assert "gsp_id" in forecast_values_dataarray.coords - assert "forecast_mw" in forecast_values_dataarray.output_label - - # get last input data - input_data_last_updated = get_latest_input_data_last_updated(session=session) - - # get model name - model = get_model(name=model_name, version=version, session=session) - - forecasts = [] - - for gsp_id in forecast_values_dataarray.gsp_id.values: - gsp_id = int(gsp_id) - # make forecast values - forecast_values = [] - - # get location - location = get_location(session=session, gsp_id=gsp_id) - - gsp_forecast_values_da = forecast_values_dataarray.sel(gsp_id=gsp_id) - - for target_time in pd.to_datetime(gsp_forecast_values_da.target_datetime_utc.values): - # add timezone - target_time_utc = target_time.replace(tzinfo=timezone.utc) - this_da = gsp_forecast_values_da.sel(target_datetime_utc=target_time) - - forecast_value_sql = ForecastValue( - target_time=target_time_utc, - expected_power_generation_megawatts=( - this_da.sel(output_label="forecast_mw").item() - ), - ).to_orm() - - forecast_value_sql.adjust_mw = 0.0 - - properties = {} - - if "forecast_mw_plevel_10" in gsp_forecast_values_da.output_label: - val = this_da.sel(output_label="forecast_mw_plevel_10").item() - # `val` can be NaN if PVNet has probabilistic outputs and PVNet_summation doesn't, - # or if PVNet_summation has probabilistic outputs and PVNet doesn't. - # Do not log the value if NaN - if not np.isnan(val): - properties["10"] = val - - if "forecast_mw_plevel_90" in gsp_forecast_values_da.output_label: - val = this_da.sel(output_label="forecast_mw_plevel_90").item() - - if not np.isnan(val): - properties["90"] = val - - if len(properties)>0: - forecast_value_sql.properties = properties - - forecast_values.append(forecast_value_sql) - - # make forecast object - forecast = ForecastSQL( - model=model, - forecast_creation_time=datetime.now(tz=timezone.utc), - location=location, - input_data_last_updated=input_data_last_updated, - forecast_values=forecast_values, - historic=False, - ) - - forecasts.append(forecast) - - return forecasts - +# APP MAIN def app( t0=None, @@ -293,6 +121,9 @@ def app( if num_workers == -1: num_workers = os.cpu_count() - 1 + if num_workers>0: + # Without this line the dataloader will hang if multiple workers are used + dask.config.set(scheduler='single-threaded') logger.info(f"Using `pvnet` library version: {pvnet.__version__}") logger.info(f"Using {num_workers} workers") @@ -343,23 +174,15 @@ def app( gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) # Download satellite data - logger.info("Downloading zipped satellite data") - fs = fsspec.open(os.environ["SATELLITE_ZARR_PATH"]).fs - fs.get(os.environ["SATELLITE_ZARR_PATH"], "sat.zarr.zip") - - # Also download 15-minute satellite if it exists - sat_latest_15 = os.environ["SATELLITE_ZARR_PATH"].replace(".zarr.zip", "_15.zarr.zip") - if fs.exists(sat_latest_15): - logger.info("Downloading 15-minute satellite data") - fs.get(sat_latest_15, "sat_15.zarr.zip") - - # Download nwp data - logger.info("Downloading nwp data") - fs = fsspec.open(os.environ["NWP_ZARR_PATH"]).fs - fs.get(os.environ["NWP_ZARR_PATH"], "nwp.zarr", recursive=True) + logger.info("Downloading satellite data") + download_sat_data() + + # Download NWP data + logger.info("Downloading NWP data") + download_nwp_data() - # Regrid the nwp data if needed - regrid_nwp_data("nwp.zarr") + # Regrid the NWP data if needed + regrid_nwp_data() # --------------------------------------------------------------------------- # 2. Set up data loader @@ -373,6 +196,7 @@ def app( # Populate the data config with production data paths temp_dir = tempfile.TemporaryDirectory() populated_data_config_filename = f"{temp_dir.name}/data_config.yaml" + populate_data_config_sources(data_config_filename, populated_data_config_filename) # Location and time datapipes @@ -396,12 +220,22 @@ def app( ) # Set up dataloader for parallel loading - rs = MultiProcessingReadingService( + dataloader_kwargs = dict( + shuffle=False, + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, num_workers=num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=0 if num_workers == 0 else 2, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=worker_init_fn, + prefetch_factor=None if num_workers == 0 else 2, + persistent_workers=False, ) - dataloader = DataLoader2(batch_datapipe, reading_service=rs) + + dataloader = DataLoader(batch_datapipe, **dataloader_kwargs) # --------------------------------------------------------------------------- # 3. set up model @@ -469,6 +303,10 @@ def app( sun_down_masks = np.concatenate(sun_down_masks) gsp_ids_all_batches = np.concatenate(gsp_ids_each_batch).squeeze() + + n_times = normed_preds.shape[1] + + valid_times = pd.to_datetime([t0 + timedelta(minutes=30 * (i + 1)) for i in range(n_times)]) # Reorder GSP order which ends up shuffled if multiprocessing is used inds = gsp_ids_all_batches.argsort() @@ -483,36 +321,14 @@ def app( # 5. Merge batch results to xarray DataArray logger.info("Processing raw predictions to DataArray") - n_times = normed_preds.shape[1] - - if model.use_quantile_regression: - output_labels = model.output_quantiles - output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles] - output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" - else: - output_labels = ["forecast_mw"] - normed_preds = normed_preds[..., np.newaxis] - - da_normed = xr.DataArray( - data=normed_preds, - dims=["gsp_id", "target_datetime_utc", "output_label"], - coords=dict( - gsp_id=gsp_ids_all_batches, - target_datetime_utc=pd.to_datetime( - [t0 + timedelta(minutes=30 * (i + 1)) for i in range(n_times)], - ), - output_label=output_labels, - ), - ) + da_normed = preds_to_dataarray(normed_preds, model, valid_times, gsp_ids_all_batches) da_sundown_mask = xr.DataArray( data=sun_down_masks, dims=["gsp_id", "target_datetime_utc"], coords=dict( gsp_id=gsp_ids_all_batches, - target_datetime_utc=pd.to_datetime( - [t0 + timedelta(minutes=30 * (i + 1)) for i in range(n_times)], - ), + target_datetime_utc=valid_times, ), ) @@ -545,23 +361,11 @@ def app( normed_national = summation_model(inputs).detach().squeeze().cpu().numpy() # Convert national predictions to DataArray - if summation_model.use_quantile_regression: - sum_output_labels = summation_model.output_quantiles - sum_output_labels = [ - f"forecast_mw_plevel_{int(q*100):02}" for q in summation_model.output_quantiles - ] - sum_output_labels[sum_output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" - else: - sum_output_labels = ["forecast_mw"] - - da_normed_national = xr.DataArray( - data=normed_national[np.newaxis], - dims=["gsp_id", "target_datetime_utc", "output_label"], - coords=dict( - gsp_id=[0], - target_datetime_utc=da_abs.target_datetime_utc, - output_label=sum_output_labels, - ), + da_normed_national = preds_to_dataarray( + normed_national[np.newaxis], + summation_model, + valid_times, + gsp_ids=[0] ) # Multiply normalised forecasts by capacities and clip negatives diff --git a/pvnet_app/consts.py b/pvnet_app/consts.py new file mode 100644 index 0000000..a4b52fb --- /dev/null +++ b/pvnet_app/consts.py @@ -0,0 +1,3 @@ +sat_path = "sat.zarr" +sat_15_path = "sat_15.zarr" +nwp_path = "nwp.zarr" \ No newline at end of file diff --git a/pvnet_app/data.py b/pvnet_app/data.py new file mode 100644 index 0000000..fdc8f3f --- /dev/null +++ b/pvnet_app/data.py @@ -0,0 +1,71 @@ +import xarray as xr +import xesmf as xe +import logging +import os +import fsspec + +from pvnet_app.consts import sat_path, sat_15_path, nwp_path + +logger = logging.getLogger(__name__) + +this_dir = os.path.dirname(os.path.abspath(__file__)) + +def download_sat_data(): + """Download the sat data""" + fs = fsspec.open(os.environ["SATELLITE_ZARR_PATH"]).fs + fs.get(os.environ["SATELLITE_ZARR_PATH"], "sat.zarr.zip") + os.system(f"rm -r {sat_path}") + os.system(f"unzip sat.zarr.zip -d {sat_path}") + + # Also download 15-minute satellite if it exists + sat_latest_15 = os.environ["SATELLITE_ZARR_PATH"].replace("sat.zarr", "sat_15.zarr") + if fs.exists(sat_latest_15): + logger.info("Downloading 15-minute satellite data") + fs.get(sat_latest_15, "sat_15.zarr") + os.system(f"unzip sat_15.zarr.zip -d {sat_15_path}") + + +def download_nwp_data(): + """Download the NWP data""" + fs = fsspec.open(os.environ["NWP_ZARR_PATH"]).fs + fs.get(os.environ["NWP_ZARR_PATH"], nwp_path, recursive=True) + + +def regrid_nwp_data(): + """This function loads the NWP data, then regrids and saves it back out if the data is not on + the same grid as expected. The data is resaved in-place. + """ + + ds_raw = xr.open_zarr(nwp_path) + + # These are the coords we are aiming for + ds_target_coords = xr.load_dataset(f"{this_dir}/../data/nwp_target_coords.nc") + + # Check if regridding step needs to be done + needs_regridding = not ( + ds_raw.latitude.equals(ds_target_coords.latitude) and + ds_raw.longitude.equals(ds_target_coords.longitude) + + ) + + if not needs_regridding: + logger.info("No NWP regridding required - skipping this step") + return + + logger.info("Regridding NWP to expected grid") + + # Pull the raw data into RAM + ds_raw = ds_raw.compute() + + # Regrid in RAM efficient way by chunking first. Each step is regridded separately + regridder = xe.Regridder(ds_raw, ds_target_coords, method="bilinear") + ds_regridded = regridder( + ds_raw.chunk(dict(x=-1, y=-1, step=1)) + ).compute(scheduler="single-threaded") + + # Re-save - including rechunking + os.system(f"rm -fr {nwp_path}") + ds_regridded["variable"] = ds_regridded["variable"].astype(str) + ds_regridded.chunk(dict(step=12, x=100, y=100)).to_zarr(nwp_path) + + return \ No newline at end of file diff --git a/pvnet_app/utils.py b/pvnet_app/utils.py new file mode 100644 index 0000000..2d5e42b --- /dev/null +++ b/pvnet_app/utils.py @@ -0,0 +1,185 @@ +import fsspec.asyn +import yaml +import os +import xarray as xr +import numpy as np +import pandas as pd +from sqlalchemy.orm import Session +import logging +from nowcasting_datamodel.models import ( + ForecastSQL, + ForecastValue, +) +from nowcasting_datamodel.read.read import ( + get_latest_input_data_last_updated, + get_location, + get_model, +) + +from datetime import timezone, datetime + +from pvnet_app.consts import sat_path, nwp_path + + +logger = logging.getLogger(__name__) + + + +def worker_init_fn(worker_id): + """ + Clear reference to the loop and thread. + This is a nasty hack that was suggested but NOT recommended by the lead fsspec developer! + This appears necessary otherwise gcsfs hangs when used after forking multiple worker processes. + Only required for fsspec >= 0.9.0 + See: + - https://github.com/fsspec/gcsfs/issues/379#issuecomment-839929801 + - https://github.com/fsspec/filesystem_spec/pull/963#issuecomment-1131709948 + TODO: Try deleting this two lines to make sure this is still relevant. + """ + fsspec.asyn.iothread[0] = None + fsspec.asyn.loop[0] = None + + +def populate_data_config_sources(input_path, output_path): + """Resave the data config and replace the source filepaths + + Args: + input_path: Path to input datapipes configuration file + output_path: Location to save the output configuration file + """ + with open(input_path) as infile: + config = yaml.load(infile, Loader=yaml.FullLoader) + + production_paths = { + "gsp": os.environ["DB_URL"], + "nwp": nwp_path, + "satellite": sat_path, + # TODO: include hrvsatellite + } + + # Replace data sources + for source in ["gsp", "nwp", "satellite", "hrvsatellite"]: + if source in config["input_data"]: + # If not empty - i.e. if used + if config["input_data"][source][f"{source}_zarr_path"]!="": + assert source in production_paths, f"Missing production path: {source}" + config["input_data"][source][f"{source}_zarr_path"] = production_paths[source] + + # We do not need to set PV path right now. This currently done through datapipes + # TODO - Move the PV path to here + + with open(output_path, 'w') as outfile: + yaml.dump(config, outfile, default_flow_style=False) + + +def preds_to_dataarray(preds, model, valid_times, gsp_ids): + """Put numpy array of predictions into a dataarray""" + + if model.use_quantile_regression: + output_labels = model.output_quantiles + output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles] + output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" + else: + output_labels = ["forecast_mw"] + normed_preds = normed_preds[..., np.newaxis] + + da = xr.DataArray( + data=preds, + dims=["gsp_id", "target_datetime_utc", "output_label"], + coords=dict( + gsp_id=gsp_ids, + target_datetime_utc=valid_times, + output_label=output_labels, + ), + ) + return da + + +def convert_dataarray_to_forecasts( + forecast_values_dataarray: xr.DataArray, session: Session, model_name: str, version: str +) -> list[ForecastSQL]: + """ + Make a ForecastSQL object from a DataArray. + + Args: + forecast_values_dataarray: Dataarray of forecasted values. Must have `target_datetime_utc` + `gsp_id`, and `output_label` coords. The `output_label` coords must have `"forecast_mw"` + as an element. + session: database session + model_name: the name of the model + version: the version of the model + Return: + List of ForecastSQL objects + """ + logger.debug("Converting DataArray to list of ForecastSQL") + + assert "target_datetime_utc" in forecast_values_dataarray.coords + assert "gsp_id" in forecast_values_dataarray.coords + assert "forecast_mw" in forecast_values_dataarray.output_label + + # get last input data + input_data_last_updated = get_latest_input_data_last_updated(session=session) + + # get model name + model = get_model(name=model_name, version=version, session=session) + + forecasts = [] + + for gsp_id in forecast_values_dataarray.gsp_id.values: + gsp_id = int(gsp_id) + # make forecast values + forecast_values = [] + + # get location + location = get_location(session=session, gsp_id=gsp_id) + + gsp_forecast_values_da = forecast_values_dataarray.sel(gsp_id=gsp_id) + + for target_time in pd.to_datetime(gsp_forecast_values_da.target_datetime_utc.values): + # add timezone + target_time_utc = target_time.replace(tzinfo=timezone.utc) + this_da = gsp_forecast_values_da.sel(target_datetime_utc=target_time) + + forecast_value_sql = ForecastValue( + target_time=target_time_utc, + expected_power_generation_megawatts=( + this_da.sel(output_label="forecast_mw").item() + ), + ).to_orm() + + forecast_value_sql.adjust_mw = 0.0 + + properties = {} + + if "forecast_mw_plevel_10" in gsp_forecast_values_da.output_label: + val = this_da.sel(output_label="forecast_mw_plevel_10").item() + # `val` can be NaN if PVNet has probabilistic outputs and PVNet_summation doesn't, + # or if PVNet_summation has probabilistic outputs and PVNet doesn't. + # Do not log the value if NaN + if not np.isnan(val): + properties["10"] = val + + if "forecast_mw_plevel_90" in gsp_forecast_values_da.output_label: + val = this_da.sel(output_label="forecast_mw_plevel_90").item() + + if not np.isnan(val): + properties["90"] = val + + if len(properties)>0: + forecast_value_sql.properties = properties + + forecast_values.append(forecast_value_sql) + + # make forecast object + forecast = ForecastSQL( + model=model, + forecast_creation_time=datetime.now(tz=timezone.utc), + location=location, + input_data_last_updated=input_data_last_updated, + forecast_values=forecast_values, + historic=False, + ) + + forecasts.append(forecast) + + return forecasts \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f95ae07..0a153e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -torch[cpu]>=2.0 -PVNet-summation>=0.0.8 -pvnet>=2.3.0 -ocf_datapipes>=2.0.6 +torch[cpu]>=2.1.1 +PVNet-summation>=0.0.9 +pvnet>=2.4.0 +ocf_datapipes>=2.2.2 nowcasting_datamodel>=1.5.22 fsspec[s3] xarray @@ -9,7 +9,6 @@ zarr numpy pandas sqlalchemy -torchdata pytest pytest-cov typer diff --git a/tests/test_app.py b/tests/test_app.py index 6cd6a24..41b735b 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,6 +1,7 @@ import tempfile import zarr import os +import logging from nowcasting_datamodel.models.forecast import ( ForecastSQL, @@ -38,7 +39,7 @@ def test_app(db_session, nwp_data, sat_data, gsp_yields_and_systems, me_latest): # Run prediction # This import needs to come after the environ vars have been set from pvnet_app.app import app - app(gsp_ids=list(range(1, 318))) + app(gsp_ids=list(range(1, 318)), num_workers=2) # Check forecasts have been made # (317 GSPs + 1 National + GSP-sum) = 319 forecasts