From 183fc4df5bc0d678b6919d4051ff30c9aad8bd5d Mon Sep 17 00:00:00 2001 From: James Fulton Date: Wed, 28 Feb 2024 11:46:34 +0000 Subject: [PATCH 1/3] add ensemble --- pvnet/models/ensemble.py | 69 +++++++++++++++++++++++++++++++++++ tests/models/test_ensemble.py | 37 +++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 pvnet/models/ensemble.py create mode 100644 tests/models/test_ensemble.py diff --git a/pvnet/models/ensemble.py b/pvnet/models/ensemble.py new file mode 100644 index 00000000..c820cd2b --- /dev/null +++ b/pvnet/models/ensemble.py @@ -0,0 +1,69 @@ +"""Model which uses mutliple prediction heads""" +from typing import Optional +import torch +from torch import nn +from pvnet.models.base_model import BaseModel + + +class Ensemble(BaseModel): + """Ensemble of PVNet models""" + def __init__( + self, + model_list: list[BaseModel], + weights: Optional[list[float]] = None, + ): + """Ensemble of PVNet models + + Args: + model_list: A list of PVNet models to ensemble + weights: A list of weighting to apply to each model. If None, the models are weighted + equally. + """ + + # Surface check all the models are compatible + output_quantiles = [] + history_minutes = [] + forecast_minutes = [] + target_key = [] + interval_minutes = [] + + # Get some model properties from each model + for model in model_list: + output_quantiles.append(model.output_quantiles) + history_minutes.append(model.history_minutes) + forecast_minutes.append(model.forecast_minutes) + target_key.append(model._target_key_name) + interval_minutes.append(model.interval_minutes) + + # Check these properties are all the same + for param_list in [ + output_quantiles, history_minutes, forecast_minutes, target_key, interval_minutes + ]: + assert all([p==param_list[0] for p in param_list]), param_list + + + super().__init__( + history_minutes=history_minutes[0], + forecast_minutes=forecast_minutes[0], + optimizer=None, + output_quantiles=output_quantiles[0], + target_key=target_key[0], + interval_minutes=interval_minutes[0], + ) + + self.model_list = nn.ModuleList(model_list) + + if weights is None: + weights = torch.ones(len(model_list))/len(model_list) + else: + assert len(weights)==len(model_list) + weights = torch.Tensor(weights)/sum(weights) + self.weights = nn.Parameter(weights, requires_grad=False) + + def forward(self, batch): + """Run the model forward""" + y_hat = 0 + for weight, model in zip(self.weights, self.model_list): + y_hat = model(batch)*weight + y_hat + return y_hat + \ No newline at end of file diff --git a/tests/models/test_ensemble.py b/tests/models/test_ensemble.py new file mode 100644 index 00000000..b75a0e19 --- /dev/null +++ b/tests/models/test_ensemble.py @@ -0,0 +1,37 @@ +from pvnet.models.ensemble import Ensemble + + +def test_model_init(multimodal_model): + ensemble_model = Ensemble( + model_list=[multimodal_model]*3, + weights=None, + ) + + ensemble_model = Ensemble( + model_list=[multimodal_model]*3, + weights=[1,2,3], + ) + +def test_model_forward(multimodal_model, sample_batch): + ensemble_model = Ensemble( + model_list=[multimodal_model]*3, + ) + + y = ensemble_model(sample_batch) + + # check output is the correct shape + # batch size=2, forecast_len=15 + assert tuple(y.shape) == (2, 16), y.shape + + +def test_quantile_model_forward(multimodal_quantile_model, sample_batch): + ensemble_model = Ensemble( + model_list=[multimodal_quantile_model]*3, + ) + + y_quantiles = ensemble_model(sample_batch) + + # check output is the correct shape + # batch size=2, forecast_len=15, num_quantiles=3 + assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape + From 27a68337499dcd3d76053f6e12f6ea37f96a24e9 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Wed, 28 Feb 2024 11:47:12 +0000 Subject: [PATCH 2/3] add support for pushing ensemble to huggingface --- pvnet/models/base_model.py | 15 +++- pvnet/models/model_card_template.md | 5 +- scripts/checkpoint_to_huggingface.py | 106 +++++++++++++++++---------- 3 files changed, 84 insertions(+), 42 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index f903c05d..5941971b 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -168,7 +168,7 @@ def save_pretrained( data_config: Union[str, Path], repo_id: Optional[str] = None, push_to_hub: bool = False, - wandb_model_code: Optional[str] = None, + wandb_ids: Optional[Union[list[str], str]] = None, card_template_path=None, **kwargs, ) -> Optional[str]: @@ -187,7 +187,7 @@ def save_pretrained( the folder name if not provided. push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Huggingface Hub after saving it. - wandb_model_code: Identifier of the model on wandb. + wandb_ids: Identifier(s) of the model on wandb. card_template_path: Path to the huggingface model card template. Defaults to card in PVNet library if set to None. kwargs: @@ -214,11 +214,19 @@ def save_pretrained( card_template_path = ( f"{os.path.dirname(os.path.abspath(__file__))}/model_card_template.md" ) + + if isinstance(wandb_ids, str): + wandb_ids = [wandb_ids] + + wandb_links = "" + for wandb_id in wandb_ids: + link = f"https://wandb.ai/openclimatefix/pvnet2.1/runs/{wandb_id}" + wandb_links += f" - [{link}]({link})\n" card = ModelCard.from_template( card_data, template_path=card_template_path, - wandb_model_code=wandb_model_code, + wandb_links=wandb_links, ) (save_directory / "README.md").write_text(str(card)) @@ -280,6 +288,7 @@ def __init__( self.history_minutes = history_minutes self.forecast_minutes = forecast_minutes self.output_quantiles = output_quantiles + self.interval_minutes = interval_minutes # Number of timestemps for 30 minutely data self.history_len = history_minutes // interval_minutes diff --git a/pvnet/models/model_card_template.md b/pvnet/models/model_card_template.md index ae4ed480..6ff5c42a 100644 --- a/pvnet/models/model_card_template.md +++ b/pvnet/models/model_card_template.md @@ -26,7 +26,7 @@ This model class uses satellite data, numericl weather predictions, and recent G -The model is trained on data from 2017-2020 and validated on data from 2021. See experimental notes in the [the google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing) for more details. +The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes in the [the google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing) for more details. ### Preprocessing @@ -36,7 +36,8 @@ Data is prepared with the `ocf_datapipes.training.pvnet` datapipe [2]. ## Results -The training logs for the current model can be found [here on wandb](https://wandb.ai/openclimatefix/pvnet2.1/runs/{{ wandb_model_code }}). +The training logs for the current model can be found here: +{{ wandb_links }} The training logs for all model runs of PVNet2 can be found [here](https://wandb.ai/openclimatefix/pvnet2.1). diff --git a/scripts/checkpoint_to_huggingface.py b/scripts/checkpoint_to_huggingface.py index 0fe51434..c1478594 100644 --- a/scripts/checkpoint_to_huggingface.py +++ b/scripts/checkpoint_to_huggingface.py @@ -8,65 +8,97 @@ import glob import os import tempfile -from typing import Optional +from typing import Optional, Union import hydra import torch import typer import wandb from pyaml_env import parse_config +import yaml from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel - +from pvnet.models.ensemble import Ensemble def push_to_huggingface( - checkpoint_dir_path: str, + checkpoint_dir_paths: list[str], val_best: bool = True, - wandb_id: Optional[str] = None, + wandb_ids: Optional[list[str]] = None, local_path: Optional[str] = None, push_to_hub: bool = True, ): """Push a local model to pvnet_v2 huggingface model repo - - checkpoint_dir_path (str): Path of the chekpoint directory - val_best (bool): Use best model according to val loss, else last saved model - wandb_id (str): The wandb ID code - local_path (str): Where to save the local copy of the model - push_to_hub (bool): Whether to push the model to the hub or just create local version. + + Args: + checkpoint_dir_paths: Path(s) of the checkpoint directory(ies) + val_best: Use best model according to val loss, else last saved model + wandb_ids: The wandb ID code(s) + local_path: Where to save the local copy of the model + push_to_hub: Whether to push the model to the hub or just create local version. """ assert push_to_hub or local_path is not None os.path.dirname(os.path.abspath(__file__)) - + + is_ensemble = len(checkpoint_dir_paths) + # Check if checkpoint dir name is wandb run ID - if wandb_id is None: + if wandb_ids==[]: all_wandb_ids = [run.id for run in wandb.Api().runs(path="openclimatefix/pvnet2.1")] - dirname = checkpoint_dir_path.split("/")[-1] - if dirname in all_wandb_ids: - wandb_id = dirname - - # Load the model - model_config = parse_config(f"{checkpoint_dir_path}/model_config.yaml") - - model = hydra.utils.instantiate(model_config) - - if val_best: - # Only one epoch (best) saved per model - files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt") - assert len(files) == 1 - checkpoint = torch.load(files[0], map_location="cpu") + for path in checkpoint_dir_paths: + dirname = path.split("/")[-1] + if dirname in all_wandb_ids: + wandb_ids.append(dirname) + else: + wandb_ids.append(None) + + model_configs = [] + models = [] + data_configs = [] + + for path in checkpoint_dir_paths: + # Load the model + model_config = parse_config(f"{path}/model_config.yaml") + + model = hydra.utils.instantiate(model_config) + + if val_best: + # Only one epoch (best) saved per model + files = glob.glob(f"{path}/epoch*.ckpt") + assert len(files) == 1 + checkpoint = torch.load(files[0], map_location="cpu") + else: + checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu") + + model.load_state_dict(state_dict=checkpoint["state_dict"]) + + if isinstance(model, UMTModel): + model, model_config = model.convert_to_multimodal_model(model_config) + + # Check for data config + data_config = f"{path}/data_config.yaml" + assert os.path.isfile(data_config) + + model_configs.append(model_config) + models.append(model) + data_configs.append(data_config) + + if is_ensemble: + model_config = { + "_target_": "pvnet.models.ensemble.Ensemble", + "model_list": model_configs, + } + model = Ensemble( + model_list = models + ) + data_config = data_configs[0] + else: - checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt", map_location="cpu") - - model.load_state_dict(state_dict=checkpoint["state_dict"]) - - if isinstance(model, UMTModel): - model, model_config = model.convert_to_multimodal_model(model_config) - - # Check for data config - data_config = f"{checkpoint_dir_path}/data_config.yaml" - assert os.path.isfile(data_config) + model_config = model_configs[0] + model = models[0] + data_config = data_configs[0] + wandb_ids = wandb_ids[0] # Push to hub if local_path is None: @@ -79,7 +111,7 @@ def push_to_huggingface( model_output_dir, config=model_config, data_config=data_config, - wandb_model_code=wandb_id, + wandb_ids=wandb_ids, push_to_hub=push_to_hub, repo_id="openclimatefix/pvnet_v2" if push_to_hub else None, ) From 6471e8702ccfc1e6bc4c614f92fb910bc2021318 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Feb 2024 11:49:20 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 4 +-- pvnet/models/ensemble.py | 39 ++++++++++++++++------------ pvnet/models/model_card_template.md | 2 +- scripts/checkpoint_to_huggingface.py | 30 ++++++++++----------- tests/models/test_ensemble.py | 18 ++++++------- 5 files changed, 48 insertions(+), 45 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 5941971b..dc18b5b1 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -214,10 +214,10 @@ def save_pretrained( card_template_path = ( f"{os.path.dirname(os.path.abspath(__file__))}/model_card_template.md" ) - + if isinstance(wandb_ids, str): wandb_ids = [wandb_ids] - + wandb_links = "" for wandb_id in wandb_ids: link = f"https://wandb.ai/openclimatefix/pvnet2.1/runs/{wandb_id}" diff --git a/pvnet/models/ensemble.py b/pvnet/models/ensemble.py index c820cd2b..ef90d8fb 100644 --- a/pvnet/models/ensemble.py +++ b/pvnet/models/ensemble.py @@ -1,32 +1,35 @@ """Model which uses mutliple prediction heads""" from typing import Optional + import torch from torch import nn + from pvnet.models.base_model import BaseModel class Ensemble(BaseModel): """Ensemble of PVNet models""" + def __init__( self, model_list: list[BaseModel], weights: Optional[list[float]] = None, - ): + ): """Ensemble of PVNet models - + Args: model_list: A list of PVNet models to ensemble weights: A list of weighting to apply to each model. If None, the models are weighted equally. """ - + # Surface check all the models are compatible output_quantiles = [] history_minutes = [] forecast_minutes = [] target_key = [] interval_minutes = [] - + # Get some model properties from each model for model in model_list: output_quantiles.append(model.output_quantiles) @@ -34,14 +37,17 @@ def __init__( forecast_minutes.append(model.forecast_minutes) target_key.append(model._target_key_name) interval_minutes.append(model.interval_minutes) - + # Check these properties are all the same for param_list in [ - output_quantiles, history_minutes, forecast_minutes, target_key, interval_minutes + output_quantiles, + history_minutes, + forecast_minutes, + target_key, + interval_minutes, ]: - assert all([p==param_list[0] for p in param_list]), param_list - - + assert all([p == param_list[0] for p in param_list]), param_list + super().__init__( history_minutes=history_minutes[0], forecast_minutes=forecast_minutes[0], @@ -50,20 +56,19 @@ def __init__( target_key=target_key[0], interval_minutes=interval_minutes[0], ) - + self.model_list = nn.ModuleList(model_list) - + if weights is None: - weights = torch.ones(len(model_list))/len(model_list) + weights = torch.ones(len(model_list)) / len(model_list) else: - assert len(weights)==len(model_list) - weights = torch.Tensor(weights)/sum(weights) + assert len(weights) == len(model_list) + weights = torch.Tensor(weights) / sum(weights) self.weights = nn.Parameter(weights, requires_grad=False) - + def forward(self, batch): """Run the model forward""" y_hat = 0 for weight, model in zip(self.weights, self.model_list): - y_hat = model(batch)*weight + y_hat + y_hat = model(batch) * weight + y_hat return y_hat - \ No newline at end of file diff --git a/pvnet/models/model_card_template.md b/pvnet/models/model_card_template.md index 6ff5c42a..8ae10af6 100644 --- a/pvnet/models/model_card_template.md +++ b/pvnet/models/model_card_template.md @@ -36,7 +36,7 @@ Data is prepared with the `ocf_datapipes.training.pvnet` datapipe [2]. ## Results -The training logs for the current model can be found here: +The training logs for the current model can be found here: {{ wandb_links }} The training logs for all model runs of PVNet2 can be found [here](https://wandb.ai/openclimatefix/pvnet2.1). diff --git a/scripts/checkpoint_to_huggingface.py b/scripts/checkpoint_to_huggingface.py index c1478594..4b0c4695 100644 --- a/scripts/checkpoint_to_huggingface.py +++ b/scripts/checkpoint_to_huggingface.py @@ -8,17 +8,17 @@ import glob import os import tempfile -from typing import Optional, Union +from typing import Optional import hydra import torch import typer import wandb from pyaml_env import parse_config -import yaml -from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel from pvnet.models.ensemble import Ensemble +from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel + def push_to_huggingface( checkpoint_dir_paths: list[str], @@ -28,7 +28,7 @@ def push_to_huggingface( push_to_hub: bool = True, ): """Push a local model to pvnet_v2 huggingface model repo - + Args: checkpoint_dir_paths: Path(s) of the checkpoint directory(ies) val_best: Use best model according to val loss, else last saved model @@ -40,11 +40,11 @@ def push_to_huggingface( assert push_to_hub or local_path is not None os.path.dirname(os.path.abspath(__file__)) - + is_ensemble = len(checkpoint_dir_paths) - + # Check if checkpoint dir name is wandb run ID - if wandb_ids==[]: + if wandb_ids == []: all_wandb_ids = [run.id for run in wandb.Api().runs(path="openclimatefix/pvnet2.1")] for path in checkpoint_dir_paths: dirname = path.split("/")[-1] @@ -52,11 +52,11 @@ def push_to_huggingface( wandb_ids.append(dirname) else: wandb_ids.append(None) - + model_configs = [] models = [] data_configs = [] - + for path in checkpoint_dir_paths: # Load the model model_config = parse_config(f"{path}/model_config.yaml") @@ -75,25 +75,23 @@ def push_to_huggingface( if isinstance(model, UMTModel): model, model_config = model.convert_to_multimodal_model(model_config) - + # Check for data config data_config = f"{path}/data_config.yaml" assert os.path.isfile(data_config) - + model_configs.append(model_config) models.append(model) data_configs.append(data_config) - + if is_ensemble: model_config = { "_target_": "pvnet.models.ensemble.Ensemble", "model_list": model_configs, } - model = Ensemble( - model_list = models - ) + model = Ensemble(model_list=models) data_config = data_configs[0] - + else: model_config = model_configs[0] model = models[0] diff --git a/tests/models/test_ensemble.py b/tests/models/test_ensemble.py index b75a0e19..048b4479 100644 --- a/tests/models/test_ensemble.py +++ b/tests/models/test_ensemble.py @@ -3,20 +3,21 @@ def test_model_init(multimodal_model): ensemble_model = Ensemble( - model_list=[multimodal_model]*3, + model_list=[multimodal_model] * 3, weights=None, ) ensemble_model = Ensemble( - model_list=[multimodal_model]*3, - weights=[1,2,3], + model_list=[multimodal_model] * 3, + weights=[1, 2, 3], ) + def test_model_forward(multimodal_model, sample_batch): ensemble_model = Ensemble( - model_list=[multimodal_model]*3, + model_list=[multimodal_model] * 3, ) - + y = ensemble_model(sample_batch) # check output is the correct shape @@ -26,12 +27,11 @@ def test_model_forward(multimodal_model, sample_batch): def test_quantile_model_forward(multimodal_quantile_model, sample_batch): ensemble_model = Ensemble( - model_list=[multimodal_quantile_model]*3, + model_list=[multimodal_quantile_model] * 3, ) - + y_quantiles = ensemble_model(sample_batch) - + # check output is the correct shape # batch size=2, forecast_len=15, num_quantiles=3 assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape -