Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 24, 2023
1 parent 0497e8b commit c88505b
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 52 deletions.
2 changes: 1 addition & 1 deletion pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def save_pretrained(
# saving model and data config
if isinstance(config, dict):
(save_directory / CONFIG_NAME).write_text(json.dumps(config, indent=4))

if data_config is not None:
shutil.copyfile(data_config, save_directory / "data_config.yaml")

Expand Down
6 changes: 3 additions & 3 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
n_kv_res_blocks: Number of residual blocks to use in the key and value encoders.
kv_res_block_layers: Number of fully-connected layers used in each residual block within
the key and value encoders.
use_pv_id_in_value: Whether to use a PV ID embedding in network used to produce the
use_pv_id_in_value: Whether to use a PV ID embedding in network used to produce the
value for the attention layer.
"""
Expand Down Expand Up @@ -232,8 +232,8 @@ def _attention_forward(self, x, average_attn_weights=True):

attn_output, attn_weights = self.multihead_attn(
query, key, value, average_attn_weights=average_attn_weights
)
)

return attn_output, attn_weights

def forward(self, x):
Expand Down
6 changes: 3 additions & 3 deletions pvnet/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ def train(config: DictConfig) -> Optional[float]:
# Also save model config here - this makes for easy model push to huggingface
os.makedirs(callback.dirpath, exist_ok=True)
OmegaConf.save(config.model, f"{callback.dirpath}/model_config.yaml")

# Similarly save the data config
data_config = config.datamodule.configuration
if data_config is None:
# Data config can be none if using presaved batches. We go to the presaved
# batches to get the data config
# batches to get the data config
data_config = f"{config.datamodule.batch_dir}/data_configuration.yaml"

assert os.path.isfile(data_config), f"Data config file not found: {data_config}"
shutil.copyfile(data_config, f"{callback.dirpath}/data_config.yaml")
break
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ omegaconf
hydra-core
python-dotenv
hydra-optuna-sweeper
rich
rich
4 changes: 2 additions & 2 deletions scripts/checkpoint_to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def push_to_huggingface(
checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt")

model.load_state_dict(state_dict=checkpoint["state_dict"])
# Check for optional data config

# Check for optional data config
data_config = f"{checkpoint_dir_path}/data_config.yaml"
data_config = data_config if os.path.isfile(data_config) else None

Expand Down
72 changes: 33 additions & 39 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import xarray as xr
import torch
import hydra
import hydra

from ocf_datapipes.utils.consts import BatchKey
from datetime import timedelta
Expand Down Expand Up @@ -144,62 +144,57 @@ def encoder_model_kwargs():
)
return kwargs


@pytest.fixture()
def site_encoder_model_kwargs():
# Used to test site encoder model on PV data
kwargs = dict(
sequence_length=180 // 5 +1,
sequence_length=180 // 5 + 1,
num_sites=349,
out_features=128,
)
return kwargs


@pytest.fixture()
def multimodal_model_kwargs(model_minutes_kwargs):

kwargs = dict(

sat_encoder=dict(
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
_partial_=True,
in_channels=11,
out_features=128,
number_of_conv3d_layers=6,
conv3d_channels=32,
image_size_pixels=24,
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
_partial_=True,
in_channels=11,
out_features=128,
number_of_conv3d_layers=6,
conv3d_channels=32,
image_size_pixels=24,
),

nwp_encoder=dict(
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
_partial_=True,
in_channels=2,
out_features=128,
number_of_conv3d_layers=6,
conv3d_channels=32,
image_size_pixels=24,
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
_partial_=True,
in_channels=2,
out_features=128,
number_of_conv3d_layers=6,
conv3d_channels=32,
image_size_pixels=24,
),

add_image_embedding_channel=True,

pv_encoder=dict(
_target_=pvnet.models.multimodal.site_encoders.encoders.SingleAttentionNetwork,
_partial_=True,
num_sites=349,
out_features=40,
num_heads=4,
kdim=40,
pv_id_embed_dim=20,
_target_=pvnet.models.multimodal.site_encoders.encoders.SingleAttentionNetwork,
_partial_=True,
num_sites=349,
out_features=40,
num_heads=4,
kdim=40,
pv_id_embed_dim=20,
),

output_network=dict(
_target_=pvnet.models.multimodal.linear_networks.networks.ResFCNet2,
_partial_=True,
fc_hidden_features=128,
n_res_blocks=6,
res_block_layers=2,
dropout_frac=0.0,
_target_=pvnet.models.multimodal.linear_networks.networks.ResFCNet2,
_partial_=True,
fc_hidden_features=128,
n_res_blocks=6,
res_block_layers=2,
dropout_frac=0.0,
),

embedding_dim=16,
include_sun=True,
include_gsp_yield_history=True,
Expand All @@ -208,10 +203,9 @@ def multimodal_model_kwargs(model_minutes_kwargs):
nwp_forecast_minutes=480,
pv_history_minutes=180,
min_sat_delay_minutes=30,

)

kwargs = hydra.utils.instantiate(kwargs)

kwargs.update(model_minutes_kwargs)
return kwargs
6 changes: 3 additions & 3 deletions tests/models/multimodal/site_encoders/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _test_model_backward(batch, model_class, kwargs):
# Test model forward on all models
def test_simplelearnedaggregator_forward(sample_batch, site_encoder_model_kwargs):
_test_model_forward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs)


def test_singleattentionnetwork_forward(sample_batch, site_encoder_model_kwargs):
_test_model_forward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs)
Expand All @@ -35,7 +35,7 @@ def test_singleattentionnetwork_forward(sample_batch, site_encoder_model_kwargs)
# Test model backward on all models
def test_simplelearnedaggregator_backward(sample_batch, site_encoder_model_kwargs):
_test_model_backward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs)


def test_singleattentionnetwork_backward(sample_batch, site_encoder_model_kwargs):
_test_model_backward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs)
2 changes: 2 additions & 0 deletions tests/models/multimodal/test_deep_supervision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def deepsupervision_model(multimodal_model_kwargs):
model = Model(**multimodal_model_kwargs)
return model


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_forward(deepsupervision_model, sample_batch):
y = deepsupervision_model(sample_batch)
Expand All @@ -16,6 +17,7 @@ def test_model_forward(deepsupervision_model, sample_batch):
# batch size=2, forecast_len=15
assert tuple(y.shape) == (2, 16), y.shape


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_backwards(deepsupervision_model, sample_batch):
opt = SGD(deepsupervision_model.parameters(), lr=0.001)
Expand Down
2 changes: 2 additions & 0 deletions tests/models/multimodal/test_nwp_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def nwp_weighting_model(model_minutes_kwargs):
)
return model


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_forward(nwp_weighting_model, sample_batch):
y = nwp_weighting_model(sample_batch)
Expand All @@ -20,6 +21,7 @@ def test_model_forward(nwp_weighting_model, sample_batch):
# batch size=2, forecast_len=15
assert tuple(y.shape) == (2, 16), y.shape


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_backwards(nwp_weighting_model, sample_batch):
opt = SGD(nwp_weighting_model.parameters(), lr=0.001)
Expand Down
2 changes: 2 additions & 0 deletions tests/models/multimodal/test_weather_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def weather_residual_model(multimodal_model_kwargs):
model = Model(**multimodal_model_kwargs)
return model


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_forward(weather_residual_model, sample_batch):
y = weather_residual_model(sample_batch)
Expand All @@ -16,6 +17,7 @@ def test_model_forward(weather_residual_model, sample_batch):
# batch size=2, forecast_len=15
assert tuple(y.shape) == (2, 16), y.shape


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_backwards(weather_residual_model, sample_batch):
opt = SGD(weather_residual_model.parameters(), lr=0.001)
Expand Down

0 comments on commit c88505b

Please sign in to comment.