Skip to content

Commit

Permalink
Merge pull request #155 from openclimatefix/fix_umt_conversion
Browse files Browse the repository at this point in the history
port all weights when converting to mutimodal model object
  • Loading branch information
dfulu authored Apr 10, 2024
2 parents f474265 + 226bf26 commit 3370db0
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 31 deletions.
5 changes: 3 additions & 2 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def save_pretrained(
self,
save_directory: Union[str, Path],
config: dict,
data_config: Union[str, Path],
data_config: Optional[Union[str, Path]],
repo_id: Optional[str] = None,
push_to_hub: bool = False,
wandb_ids: Optional[Union[list[str], str]] = None,
Expand Down Expand Up @@ -206,7 +206,8 @@ def save_pretrained(
(save_directory / CONFIG_NAME).write_text(json.dumps(config, indent=4))

# Save cleaned datapipes configuration file
make_clean_data_config(data_config, save_directory / DATA_CONFIG_NAME)
if data_config is not None:
make_clean_data_config(data_config, save_directory / DATA_CONFIG_NAME)

# Creating and saving model card.
card_data = ModelCardData(language="en", license="mit", library_name="pytorch")
Expand Down
11 changes: 10 additions & 1 deletion pvnet/models/multimodal/unimodal_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,10 @@ def training_step(self, batch, batch_idx):
def convert_to_multimodal_model(self, config):
"""Convert the model into a multimodal model class whilst preserving weights"""
config = config.copy()
del config["cold_start"]

if "cold_start" in config:
del config["cold_start"]

config["_target_"] = "pvnet.models.multimodal.multimodal.Model"

sources = []
Expand Down Expand Up @@ -416,4 +419,10 @@ def convert_to_multimodal_model(self, config):

multimodal_model.output_network.load_state_dict(self.output_network.state_dict())

if self.embedding_dim:
multimodal_model.embed.load_state_dict(self.embed.state_dict())

if self.include_sun:
multimodal_model.sun_fc1.load_state_dict(self.sun_fc1.state_dict())

return multimodal_model, config
1 change: 0 additions & 1 deletion pvnet/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def __call__(self, model):
if not isinstance(self._lr, float):
return self._call_multi(model)
else:
assert False
default_lr = self._lr if model.lr is None else model.lr
opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs)
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
Expand Down
64 changes: 37 additions & 27 deletions scripts/save_concurrent_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
```
"""
# This is needed to get multiprocessing/multiple workers to behave
try:
import torch.multiprocessing as mp

mp.set_start_method("spawn", force=True)
except RuntimeError:
pass

import logging
import os
Expand Down Expand Up @@ -157,9 +164,6 @@ def main(config: DictConfig):

shutil.copyfile(config_dm.configuration, f"{config.batch_output_dir}/data_configuration.yaml")

os.mkdir(f"{config.batch_output_dir}/train")
os.mkdir(f"{config.batch_output_dir}/val")

dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
Expand All @@ -175,35 +179,41 @@ def main(config: DictConfig):
persistent_workers=False,
)

print("----- Saving val batches -----")
if config.num_val_batches > 0:
print("----- Saving val batches -----")

val_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.val_period,
config.num_val_batches,
)
os.mkdir(f"{config.batch_output_dir}/val")

_save_batches_with_dataloader(
batch_pipe=val_batch_pipe,
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
dataloader_kwargs=dataloader_kwargs,
)
val_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.val_period,
config.num_val_batches,
)

_save_batches_with_dataloader(
batch_pipe=val_batch_pipe,
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
dataloader_kwargs=dataloader_kwargs,
)

print("----- Saving train batches -----")
if config.num_train_batches > 0:
print("----- Saving train batches -----")

train_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.train_period,
config.num_train_batches,
)
os.mkdir(f"{config.batch_output_dir}/train")

_save_batches_with_dataloader(
batch_pipe=train_batch_pipe,
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
dataloader_kwargs=dataloader_kwargs,
)
train_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.train_period,
config.num_train_batches,
)

_save_batches_with_dataloader(
batch_pipe=train_batch_pipe,
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
dataloader_kwargs=dataloader_kwargs,
)

print("done")

Expand Down
14 changes: 14 additions & 0 deletions tests/models/multimodal/test_unimodal_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def unimodal_model_kwargs(teacher_dir, model_minutes_kwargs):
res_block_layers=2,
dropout_frac=0.0,
),
cold_start=True,
)

# Get the teacher model save directories
Expand Down Expand Up @@ -89,3 +90,16 @@ def test_model_backward(unimodal_teacher_model, sample_batch):

# Backwards on sum drives sum to zero
y.sum().backward()


def test_model_conversion(unimodal_model_kwargs, sample_batch):
# Create the unimodal model
um_model = Model(**unimodal_model_kwargs)
# Convert to the equivalent multimodel model
mm_model, _ = um_model.convert_to_multimodal_model(unimodal_model_kwargs)

# If the model has been successfully converted the predictions should be identical
y_um = um_model(sample_batch, return_modes=False)
y_mm = mm_model(sample_batch)

assert (y_um == y_mm).all()

0 comments on commit 3370db0

Please sign in to comment.