-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #148 from openclimatefix/ensemble
Model ensemble
- Loading branch information
Showing
5 changed files
with
188 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Model which uses mutliple prediction heads""" | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from pvnet.models.base_model import BaseModel | ||
|
||
|
||
class Ensemble(BaseModel): | ||
"""Ensemble of PVNet models""" | ||
|
||
def __init__( | ||
self, | ||
model_list: list[BaseModel], | ||
weights: Optional[list[float]] = None, | ||
): | ||
"""Ensemble of PVNet models | ||
Args: | ||
model_list: A list of PVNet models to ensemble | ||
weights: A list of weighting to apply to each model. If None, the models are weighted | ||
equally. | ||
""" | ||
|
||
# Surface check all the models are compatible | ||
output_quantiles = [] | ||
history_minutes = [] | ||
forecast_minutes = [] | ||
target_key = [] | ||
interval_minutes = [] | ||
|
||
# Get some model properties from each model | ||
for model in model_list: | ||
output_quantiles.append(model.output_quantiles) | ||
history_minutes.append(model.history_minutes) | ||
forecast_minutes.append(model.forecast_minutes) | ||
target_key.append(model._target_key_name) | ||
interval_minutes.append(model.interval_minutes) | ||
|
||
# Check these properties are all the same | ||
for param_list in [ | ||
output_quantiles, | ||
history_minutes, | ||
forecast_minutes, | ||
target_key, | ||
interval_minutes, | ||
]: | ||
assert all([p == param_list[0] for p in param_list]), param_list | ||
|
||
super().__init__( | ||
history_minutes=history_minutes[0], | ||
forecast_minutes=forecast_minutes[0], | ||
optimizer=None, | ||
output_quantiles=output_quantiles[0], | ||
target_key=target_key[0], | ||
interval_minutes=interval_minutes[0], | ||
) | ||
|
||
self.model_list = nn.ModuleList(model_list) | ||
|
||
if weights is None: | ||
weights = torch.ones(len(model_list)) / len(model_list) | ||
else: | ||
assert len(weights) == len(model_list) | ||
weights = torch.Tensor(weights) / sum(weights) | ||
self.weights = nn.Parameter(weights, requires_grad=False) | ||
|
||
def forward(self, batch): | ||
"""Run the model forward""" | ||
y_hat = 0 | ||
for weight, model in zip(self.weights, self.model_list): | ||
y_hat = model(batch) * weight + y_hat | ||
return y_hat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from pvnet.models.ensemble import Ensemble | ||
|
||
|
||
def test_model_init(multimodal_model): | ||
ensemble_model = Ensemble( | ||
model_list=[multimodal_model] * 3, | ||
weights=None, | ||
) | ||
|
||
ensemble_model = Ensemble( | ||
model_list=[multimodal_model] * 3, | ||
weights=[1, 2, 3], | ||
) | ||
|
||
|
||
def test_model_forward(multimodal_model, sample_batch): | ||
ensemble_model = Ensemble( | ||
model_list=[multimodal_model] * 3, | ||
) | ||
|
||
y = ensemble_model(sample_batch) | ||
|
||
# check output is the correct shape | ||
# batch size=2, forecast_len=15 | ||
assert tuple(y.shape) == (2, 16), y.shape | ||
|
||
|
||
def test_quantile_model_forward(multimodal_quantile_model, sample_batch): | ||
ensemble_model = Ensemble( | ||
model_list=[multimodal_quantile_model] * 3, | ||
) | ||
|
||
y_quantiles = ensemble_model(sample_batch) | ||
|
||
# check output is the correct shape | ||
# batch size=2, forecast_len=15, num_quantiles=3 | ||
assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape |