diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 7ac7b686..c7951b86 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -140,7 +140,7 @@ def __init__( n_kv_res_blocks: Number of residual blocks to use in the key and value encoders. kv_res_block_layers: Number of fully-connected layers used in each residual block within the key and value encoders. - use_pv_id_in_value: Whether to use a PV ID embedding in network used to produce the + use_pv_id_in_value: Whether to use a PV ID embedding in network used to produce the value for the attention layer. """ @@ -232,8 +232,8 @@ def _attention_forward(self, x, average_attn_weights=True): attn_output, attn_weights = self.multihead_attn( query, key, value, average_attn_weights=average_attn_weights - ) - + ) + return attn_output, attn_weights def forward(self, x): diff --git a/requirements.txt b/requirements.txt index e6b5c6ea..3323c1d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,4 @@ omegaconf hydra-core python-dotenv hydra-optuna-sweeper -rich \ No newline at end of file +rich diff --git a/tests/conftest.py b/tests/conftest.py index 721b1bd3..7f7aa142 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import numpy as np import xarray as xr import torch -import hydra +import hydra from ocf_datapipes.utils.consts import BatchKey from datetime import timedelta @@ -144,62 +144,57 @@ def encoder_model_kwargs(): ) return kwargs + @pytest.fixture() def site_encoder_model_kwargs(): # Used to test site encoder model on PV data kwargs = dict( - sequence_length=180 // 5 +1, + sequence_length=180 // 5 + 1, num_sites=349, out_features=128, ) return kwargs + @pytest.fixture() def multimodal_model_kwargs(model_minutes_kwargs): - kwargs = dict( - sat_encoder=dict( - _target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet, - _partial_=True, - in_channels=11, - out_features=128, - number_of_conv3d_layers=6, - conv3d_channels=32, - image_size_pixels=24, + _target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet, + _partial_=True, + in_channels=11, + out_features=128, + number_of_conv3d_layers=6, + conv3d_channels=32, + image_size_pixels=24, ), - nwp_encoder=dict( - _target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet, - _partial_=True, - in_channels=2, - out_features=128, - number_of_conv3d_layers=6, - conv3d_channels=32, - image_size_pixels=24, + _target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet, + _partial_=True, + in_channels=2, + out_features=128, + number_of_conv3d_layers=6, + conv3d_channels=32, + image_size_pixels=24, ), - add_image_embedding_channel=True, - pv_encoder=dict( - _target_=pvnet.models.multimodal.site_encoders.encoders.SingleAttentionNetwork, - _partial_=True, - num_sites=349, - out_features=40, - num_heads=4, - kdim=40, - pv_id_embed_dim=20, + _target_=pvnet.models.multimodal.site_encoders.encoders.SingleAttentionNetwork, + _partial_=True, + num_sites=349, + out_features=40, + num_heads=4, + kdim=40, + pv_id_embed_dim=20, ), - output_network=dict( - _target_=pvnet.models.multimodal.linear_networks.networks.ResFCNet2, - _partial_=True, - fc_hidden_features=128, - n_res_blocks=6, - res_block_layers=2, - dropout_frac=0.0, + _target_=pvnet.models.multimodal.linear_networks.networks.ResFCNet2, + _partial_=True, + fc_hidden_features=128, + n_res_blocks=6, + res_block_layers=2, + dropout_frac=0.0, ), - embedding_dim=16, include_sun=True, include_gsp_yield_history=True, @@ -208,10 +203,9 @@ def multimodal_model_kwargs(model_minutes_kwargs): nwp_forecast_minutes=480, pv_history_minutes=180, min_sat_delay_minutes=30, - ) - + kwargs = hydra.utils.instantiate(kwargs) - + kwargs.update(model_minutes_kwargs) return kwargs diff --git a/tests/models/multimodal/site_encoders/test_encoders.py b/tests/models/multimodal/site_encoders/test_encoders.py index fa8e38fe..c3040173 100644 --- a/tests/models/multimodal/site_encoders/test_encoders.py +++ b/tests/models/multimodal/site_encoders/test_encoders.py @@ -26,7 +26,7 @@ def _test_model_backward(batch, model_class, kwargs): # Test model forward on all models def test_simplelearnedaggregator_forward(sample_batch, site_encoder_model_kwargs): _test_model_forward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs) - + def test_singleattentionnetwork_forward(sample_batch, site_encoder_model_kwargs): _test_model_forward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs) @@ -35,7 +35,7 @@ def test_singleattentionnetwork_forward(sample_batch, site_encoder_model_kwargs) # Test model backward on all models def test_simplelearnedaggregator_backward(sample_batch, site_encoder_model_kwargs): _test_model_backward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs) - - + + def test_singleattentionnetwork_backward(sample_batch, site_encoder_model_kwargs): _test_model_backward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs) diff --git a/tests/models/multimodal/test_deep_supervision.py b/tests/models/multimodal/test_deep_supervision.py index 87c36a34..e56c01b3 100644 --- a/tests/models/multimodal/test_deep_supervision.py +++ b/tests/models/multimodal/test_deep_supervision.py @@ -8,6 +8,7 @@ def deepsupervision_model(multimodal_model_kwargs): model = Model(**multimodal_model_kwargs) return model + @pytest.mark.skip(reason="This model is no longer maintained") def test_model_forward(deepsupervision_model, sample_batch): y = deepsupervision_model(sample_batch) @@ -16,6 +17,7 @@ def test_model_forward(deepsupervision_model, sample_batch): # batch size=2, forecast_len=15 assert tuple(y.shape) == (2, 16), y.shape + @pytest.mark.skip(reason="This model is no longer maintained") def test_model_backwards(deepsupervision_model, sample_batch): opt = SGD(deepsupervision_model.parameters(), lr=0.001) diff --git a/tests/models/multimodal/test_nwp_weighting.py b/tests/models/multimodal/test_nwp_weighting.py index 22d1d08f..4353ba46 100644 --- a/tests/models/multimodal/test_nwp_weighting.py +++ b/tests/models/multimodal/test_nwp_weighting.py @@ -12,6 +12,7 @@ def nwp_weighting_model(model_minutes_kwargs): ) return model + @pytest.mark.skip(reason="This model is no longer maintained") def test_model_forward(nwp_weighting_model, sample_batch): y = nwp_weighting_model(sample_batch) @@ -20,6 +21,7 @@ def test_model_forward(nwp_weighting_model, sample_batch): # batch size=2, forecast_len=15 assert tuple(y.shape) == (2, 16), y.shape + @pytest.mark.skip(reason="This model is no longer maintained") def test_model_backwards(nwp_weighting_model, sample_batch): opt = SGD(nwp_weighting_model.parameters(), lr=0.001) diff --git a/tests/models/multimodal/test_weather_residual.py b/tests/models/multimodal/test_weather_residual.py index cc36386e..6573a40d 100644 --- a/tests/models/multimodal/test_weather_residual.py +++ b/tests/models/multimodal/test_weather_residual.py @@ -8,6 +8,7 @@ def weather_residual_model(multimodal_model_kwargs): model = Model(**multimodal_model_kwargs) return model + @pytest.mark.skip(reason="This model is no longer maintained") def test_model_forward(weather_residual_model, sample_batch): y = weather_residual_model(sample_batch) @@ -16,6 +17,7 @@ def test_model_forward(weather_residual_model, sample_batch): # batch size=2, forecast_len=15 assert tuple(y.shape) == (2, 16), y.shape + @pytest.mark.skip(reason="This model is no longer maintained") def test_model_backwards(weather_residual_model, sample_batch): opt = SGD(weather_residual_model.parameters(), lr=0.001)