diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index 8167ac9f..5af9ae2c 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -2,6 +2,7 @@ from collections import OrderedDict from typing import Optional +from omegaconf import DictConfig import torch from ocf_datapipes.batch import BatchKey, NWPBatchKey @@ -52,8 +53,8 @@ def __init__( history_minutes: int = 60, sat_history_minutes: Optional[int] = None, min_sat_delay_minutes: Optional[int] = 30, - nwp_forecast_minutes: Optional[int] = None, - nwp_history_minutes: Optional[int] = None, + nwp_forecast_minutes: Optional[DictConfig] = None, + nwp_history_minutes: Optional[DictConfig] = None, pv_history_minutes: Optional[int] = None, wind_history_minutes: Optional[int] = None, sensor_history_minutes: Optional[int] = None, @@ -61,6 +62,7 @@ def __init__( optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), target_key: str = "gsp", interval_minutes: int = 30, + nwp_interval_minutes: Optional[DictConfig] = None, pv_interval_minutes: int = 5, sat_interval_minutes: int = 5, sensor_interval_minutes: int = 30, @@ -80,16 +82,16 @@ def __init__( - for example if `m` is a regular function. Args: - output_network: A partially instatiated pytorch Module class used to combine the 1D + output_network: A partially instantiated pytorch Module class used to combine the 1D features to produce the forecast. output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to None the output is a single value. - nwp_encoders_dict: A dictionary of partially instatiated pytorch Module class used to - encode the NWP data from 4D into an 1D feature vector from different sources. - sat_encoder: A partially instatiated pytorch Module class used to encode the satellite - data from 4D into an 1D feature vector. - pv_encoder: A partially instatiated pytorch Module class used to encode the site-level - PV data from 2D into an 1D feature vector. + nwp_encoders_dict: A dictionary of partially instantiated pytorch Module class used to + encode the NWP data from 4D into a 1D feature vector from different sources. + sat_encoder: A partially instantiated pytorch Module class used to encode the satellite + data from 4D into a 1D feature vector. + pv_encoder: A partially instantiated pytorch Module class used to encode the site-level + PV data from 2D into a 1D feature vector. add_image_embedding_channel: Add a channel to the NWP and satellite data with the embedding of the GSP ID. include_gsp_yield_history: Include GSP yield data. @@ -106,7 +108,7 @@ def __init__( `forecast_minutes` if not provided. nwp_history_minutes: Period of historical NWP forecast used as input. Defaults to `history_minutes` if not provided. - pv_history_minutes: Length of recent site-level PV data data used as + pv_history_minutes: Length of recent site-level PV data used as input. Defaults to `history_minutes` if not provided. optimizer: Optimizer factory function used for network. target_key: The key of the target variable in the batch. @@ -114,6 +116,7 @@ def __init__( wind_interval_minutes: The interval between each sample of the wind data wind_encoder: Encoder for wind data wind_history_minutes: Length of recent wind data used as input. + nwp_interval_minutes: DIctionary of the intervals between each sample of the NWP data for each source pv_interval_minutes: The interval between each sample of the PV data sat_interval_minutes: The interval between each sample of the satellite data sensor_interval_minutes: The interval between each sample of the sensor data @@ -153,7 +156,7 @@ def __init__( ) # Number of features expected by the output_network - # Add to this as network pices are constructed + # Add to this as network pieces are constructed fusion_input_features = 0 if self.include_sat: @@ -186,14 +189,17 @@ def __init__( assert set(nwp_encoders_dict.keys()) == set(nwp_forecast_minutes.keys()) assert set(nwp_encoders_dict.keys()) == set(nwp_history_minutes.keys()) + if nwp_interval_minutes is None: + nwp_interval_minutes = dict.fromkeys(nwp_encoders_dict.keys(), 60) + self.nwp_encoders_dict = torch.nn.ModuleDict() if add_image_embedding_channel: self.nwp_embed_dict = torch.nn.ModuleDict() for nwp_source in nwp_encoders_dict.keys(): nwp_sequence_len = ( - nwp_history_minutes[nwp_source] // 60 - + nwp_forecast_minutes[nwp_source] // 60 + nwp_history_minutes[nwp_source] // nwp_interval_minutes[nwp_source] + + nwp_forecast_minutes[nwp_source] // nwp_interval_minutes[nwp_source] + 1 )