Skip to content

Commit

Permalink
Merge pull request #148 from openclimatefix/ensemble
Browse files Browse the repository at this point in the history
Model ensemble
  • Loading branch information
dfulu authored Feb 29, 2024
2 parents 9194555 + 6471e87 commit 17fef47
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 37 deletions.
15 changes: 12 additions & 3 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def save_pretrained(
data_config: Union[str, Path],
repo_id: Optional[str] = None,
push_to_hub: bool = False,
wandb_model_code: Optional[str] = None,
wandb_ids: Optional[Union[list[str], str]] = None,
card_template_path=None,
**kwargs,
) -> Optional[str]:
Expand All @@ -187,7 +187,7 @@ def save_pretrained(
the folder name if not provided.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Huggingface Hub after saving it.
wandb_model_code: Identifier of the model on wandb.
wandb_ids: Identifier(s) of the model on wandb.
card_template_path: Path to the huggingface model card template. Defaults to card in
PVNet library if set to None.
kwargs:
Expand Down Expand Up @@ -215,10 +215,18 @@ def save_pretrained(
f"{os.path.dirname(os.path.abspath(__file__))}/model_card_template.md"
)

if isinstance(wandb_ids, str):
wandb_ids = [wandb_ids]

wandb_links = ""
for wandb_id in wandb_ids:
link = f"https://wandb.ai/openclimatefix/pvnet2.1/runs/{wandb_id}"
wandb_links += f" - [{link}]({link})\n"

card = ModelCard.from_template(
card_data,
template_path=card_template_path,
wandb_model_code=wandb_model_code,
wandb_links=wandb_links,
)

(save_directory / "README.md").write_text(str(card))
Expand Down Expand Up @@ -280,6 +288,7 @@ def __init__(
self.history_minutes = history_minutes
self.forecast_minutes = forecast_minutes
self.output_quantiles = output_quantiles
self.interval_minutes = interval_minutes

# Number of timestemps for 30 minutely data
self.history_len = history_minutes // interval_minutes
Expand Down
74 changes: 74 additions & 0 deletions pvnet/models/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Model which uses mutliple prediction heads"""
from typing import Optional

import torch
from torch import nn

from pvnet.models.base_model import BaseModel


class Ensemble(BaseModel):
"""Ensemble of PVNet models"""

def __init__(
self,
model_list: list[BaseModel],
weights: Optional[list[float]] = None,
):
"""Ensemble of PVNet models
Args:
model_list: A list of PVNet models to ensemble
weights: A list of weighting to apply to each model. If None, the models are weighted
equally.
"""

# Surface check all the models are compatible
output_quantiles = []
history_minutes = []
forecast_minutes = []
target_key = []
interval_minutes = []

# Get some model properties from each model
for model in model_list:
output_quantiles.append(model.output_quantiles)
history_minutes.append(model.history_minutes)
forecast_minutes.append(model.forecast_minutes)
target_key.append(model._target_key_name)
interval_minutes.append(model.interval_minutes)

# Check these properties are all the same
for param_list in [
output_quantiles,
history_minutes,
forecast_minutes,
target_key,
interval_minutes,
]:
assert all([p == param_list[0] for p in param_list]), param_list

super().__init__(
history_minutes=history_minutes[0],
forecast_minutes=forecast_minutes[0],
optimizer=None,
output_quantiles=output_quantiles[0],
target_key=target_key[0],
interval_minutes=interval_minutes[0],
)

self.model_list = nn.ModuleList(model_list)

if weights is None:
weights = torch.ones(len(model_list)) / len(model_list)
else:
assert len(weights) == len(model_list)
weights = torch.Tensor(weights) / sum(weights)
self.weights = nn.Parameter(weights, requires_grad=False)

def forward(self, batch):
"""Run the model forward"""
y_hat = 0
for weight, model in zip(self.weights, self.model_list):
y_hat = model(batch) * weight + y_hat
return y_hat
5 changes: 3 additions & 2 deletions pvnet/models/model_card_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ This model class uses satellite data, numericl weather predictions, and recent G

<!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->

The model is trained on data from 2017-2020 and validated on data from 2021. See experimental notes in the [the google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing) for more details.
The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes in the [the google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing) for more details.


### Preprocessing
Expand All @@ -36,7 +36,8 @@ Data is prepared with the `ocf_datapipes.training.pvnet` datapipe [2].

## Results

The training logs for the current model can be found [here on wandb](https://wandb.ai/openclimatefix/pvnet2.1/runs/{{ wandb_model_code }}).
The training logs for the current model can be found here:
{{ wandb_links }}

The training logs for all model runs of PVNet2 can be found [here](https://wandb.ai/openclimatefix/pvnet2.1).

Expand Down
94 changes: 62 additions & 32 deletions scripts/checkpoint_to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,87 @@
import wandb
from pyaml_env import parse_config

from pvnet.models.ensemble import Ensemble
from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel


def push_to_huggingface(
checkpoint_dir_path: str,
checkpoint_dir_paths: list[str],
val_best: bool = True,
wandb_id: Optional[str] = None,
wandb_ids: Optional[list[str]] = None,
local_path: Optional[str] = None,
push_to_hub: bool = True,
):
"""Push a local model to pvnet_v2 huggingface model repo
checkpoint_dir_path (str): Path of the chekpoint directory
val_best (bool): Use best model according to val loss, else last saved model
wandb_id (str): The wandb ID code
local_path (str): Where to save the local copy of the model
push_to_hub (bool): Whether to push the model to the hub or just create local version.
Args:
checkpoint_dir_paths: Path(s) of the checkpoint directory(ies)
val_best: Use best model according to val loss, else last saved model
wandb_ids: The wandb ID code(s)
local_path: Where to save the local copy of the model
push_to_hub: Whether to push the model to the hub or just create local version.
"""

assert push_to_hub or local_path is not None

os.path.dirname(os.path.abspath(__file__))

is_ensemble = len(checkpoint_dir_paths)

# Check if checkpoint dir name is wandb run ID
if wandb_id is None:
if wandb_ids == []:
all_wandb_ids = [run.id for run in wandb.Api().runs(path="openclimatefix/pvnet2.1")]
dirname = checkpoint_dir_path.split("/")[-1]
if dirname in all_wandb_ids:
wandb_id = dirname

# Load the model
model_config = parse_config(f"{checkpoint_dir_path}/model_config.yaml")

model = hydra.utils.instantiate(model_config)
for path in checkpoint_dir_paths:
dirname = path.split("/")[-1]
if dirname in all_wandb_ids:
wandb_ids.append(dirname)
else:
wandb_ids.append(None)

model_configs = []
models = []
data_configs = []

for path in checkpoint_dir_paths:
# Load the model
model_config = parse_config(f"{path}/model_config.yaml")

model = hydra.utils.instantiate(model_config)

if val_best:
# Only one epoch (best) saved per model
files = glob.glob(f"{path}/epoch*.ckpt")
assert len(files) == 1
checkpoint = torch.load(files[0], map_location="cpu")
else:
checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu")

model.load_state_dict(state_dict=checkpoint["state_dict"])

if isinstance(model, UMTModel):
model, model_config = model.convert_to_multimodal_model(model_config)

# Check for data config
data_config = f"{path}/data_config.yaml"
assert os.path.isfile(data_config)

model_configs.append(model_config)
models.append(model)
data_configs.append(data_config)

if is_ensemble:
model_config = {
"_target_": "pvnet.models.ensemble.Ensemble",
"model_list": model_configs,
}
model = Ensemble(model_list=models)
data_config = data_configs[0]

if val_best:
# Only one epoch (best) saved per model
files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt")
assert len(files) == 1
checkpoint = torch.load(files[0], map_location="cpu")
else:
checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt", map_location="cpu")

model.load_state_dict(state_dict=checkpoint["state_dict"])

if isinstance(model, UMTModel):
model, model_config = model.convert_to_multimodal_model(model_config)

# Check for data config
data_config = f"{checkpoint_dir_path}/data_config.yaml"
assert os.path.isfile(data_config)
model_config = model_configs[0]
model = models[0]
data_config = data_configs[0]
wandb_ids = wandb_ids[0]

# Push to hub
if local_path is None:
Expand All @@ -79,7 +109,7 @@ def push_to_huggingface(
model_output_dir,
config=model_config,
data_config=data_config,
wandb_model_code=wandb_id,
wandb_ids=wandb_ids,
push_to_hub=push_to_hub,
repo_id="openclimatefix/pvnet_v2" if push_to_hub else None,
)
Expand Down
37 changes: 37 additions & 0 deletions tests/models/test_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pvnet.models.ensemble import Ensemble


def test_model_init(multimodal_model):
ensemble_model = Ensemble(
model_list=[multimodal_model] * 3,
weights=None,
)

ensemble_model = Ensemble(
model_list=[multimodal_model] * 3,
weights=[1, 2, 3],
)


def test_model_forward(multimodal_model, sample_batch):
ensemble_model = Ensemble(
model_list=[multimodal_model] * 3,
)

y = ensemble_model(sample_batch)

# check output is the correct shape
# batch size=2, forecast_len=15
assert tuple(y.shape) == (2, 16), y.shape


def test_quantile_model_forward(multimodal_quantile_model, sample_batch):
ensemble_model = Ensemble(
model_list=[multimodal_quantile_model] * 3,
)

y_quantiles = ensemble_model(sample_batch)

# check output is the correct shape
# batch size=2, forecast_len=15, num_quantiles=3
assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape

0 comments on commit 17fef47

Please sign in to comment.