Skip to content

Commit

Permalink
fix 60 min hardcode NWP resolution in multimodal.py
Browse files Browse the repository at this point in the history
NWP resolution fixed to pull form config; if config values not available, default 60 min resolution supplied for backward compatibility
  • Loading branch information
AUdaltsova authored Jun 3, 2024
1 parent b9be06c commit dbc9931
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import OrderedDict
from typing import Optional
from omegaconf import DictConfig

import torch
from ocf_datapipes.batch import BatchKey, NWPBatchKey
Expand Down Expand Up @@ -52,15 +53,16 @@ def __init__(
history_minutes: int = 60,
sat_history_minutes: Optional[int] = None,
min_sat_delay_minutes: Optional[int] = 30,
nwp_forecast_minutes: Optional[int] = None,
nwp_history_minutes: Optional[int] = None,
nwp_forecast_minutes: Optional[DictConfig] = None,
nwp_history_minutes: Optional[DictConfig] = None,
pv_history_minutes: Optional[int] = None,
wind_history_minutes: Optional[int] = None,
sensor_history_minutes: Optional[int] = None,
sensor_forecast_minutes: Optional[int] = None,
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
target_key: str = "gsp",
interval_minutes: int = 30,
nwp_interval_minutes: Optional[DictConfig] = None,
pv_interval_minutes: int = 5,
sat_interval_minutes: int = 5,
sensor_interval_minutes: int = 30,
Expand All @@ -80,16 +82,16 @@ def __init__(
- for example if `m` is a regular function.
Args:
output_network: A partially instatiated pytorch Module class used to combine the 1D
output_network: A partially instantiated pytorch Module class used to combine the 1D
features to produce the forecast.
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
None the output is a single value.
nwp_encoders_dict: A dictionary of partially instatiated pytorch Module class used to
encode the NWP data from 4D into an 1D feature vector from different sources.
sat_encoder: A partially instatiated pytorch Module class used to encode the satellite
data from 4D into an 1D feature vector.
pv_encoder: A partially instatiated pytorch Module class used to encode the site-level
PV data from 2D into an 1D feature vector.
nwp_encoders_dict: A dictionary of partially instantiated pytorch Module class used to
encode the NWP data from 4D into a 1D feature vector from different sources.
sat_encoder: A partially instantiated pytorch Module class used to encode the satellite
data from 4D into a 1D feature vector.
pv_encoder: A partially instantiated pytorch Module class used to encode the site-level
PV data from 2D into a 1D feature vector.
add_image_embedding_channel: Add a channel to the NWP and satellite data with the
embedding of the GSP ID.
include_gsp_yield_history: Include GSP yield data.
Expand All @@ -106,14 +108,15 @@ def __init__(
`forecast_minutes` if not provided.
nwp_history_minutes: Period of historical NWP forecast used as input. Defaults to
`history_minutes` if not provided.
pv_history_minutes: Length of recent site-level PV data data used as
pv_history_minutes: Length of recent site-level PV data used as
input. Defaults to `history_minutes` if not provided.
optimizer: Optimizer factory function used for network.
target_key: The key of the target variable in the batch.
interval_minutes: The interval between each sample of the target data
wind_interval_minutes: The interval between each sample of the wind data
wind_encoder: Encoder for wind data
wind_history_minutes: Length of recent wind data used as input.
nwp_interval_minutes: DIctionary of the intervals between each sample of the NWP data for each source
pv_interval_minutes: The interval between each sample of the PV data
sat_interval_minutes: The interval between each sample of the satellite data
sensor_interval_minutes: The interval between each sample of the sensor data
Expand Down Expand Up @@ -153,7 +156,7 @@ def __init__(
)

# Number of features expected by the output_network
# Add to this as network pices are constructed
# Add to this as network pieces are constructed
fusion_input_features = 0

if self.include_sat:
Expand Down Expand Up @@ -186,14 +189,17 @@ def __init__(
assert set(nwp_encoders_dict.keys()) == set(nwp_forecast_minutes.keys())
assert set(nwp_encoders_dict.keys()) == set(nwp_history_minutes.keys())

if nwp_interval_minutes is None:
nwp_interval_minutes = dict.fromkeys(nwp_encoders_dict.keys(), 60)

self.nwp_encoders_dict = torch.nn.ModuleDict()
if add_image_embedding_channel:
self.nwp_embed_dict = torch.nn.ModuleDict()

for nwp_source in nwp_encoders_dict.keys():
nwp_sequence_len = (
nwp_history_minutes[nwp_source] // 60
+ nwp_forecast_minutes[nwp_source] // 60
nwp_history_minutes[nwp_source] // nwp_interval_minutes[nwp_source]
+ nwp_forecast_minutes[nwp_source] // nwp_interval_minutes[nwp_source]
+ 1
)

Expand Down

0 comments on commit dbc9931

Please sign in to comment.