diff --git a/pvnet/models/utils.py b/pvnet/models/utils.py index 30a107ba..d8f645db 100644 --- a/pvnet/models/utils.py +++ b/pvnet/models/utils.py @@ -148,17 +148,14 @@ def __init__(self, decay_rate: Optional[int] = None, forecast_length: int = 6): # normalized the weights, so there mean is 1. # To calculate the loss, we times the weights by the differences between truth # and predictions and then take the mean across all forecast horizons and the batch - self.weights = weights / weights.sum() * len(weights) - - # move weights to gpu is needed - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.weights = self.weights.to(device) + self.weights = weights / weights.mean() def get_mse_exp(self, output, target): """Loss function weighted MSE""" + weights = self.weights.to(target.device) # get the differences weighted by the forecast horizon weights - diff_with_weights = self.weights * ((output - target) ** 2) + diff_with_weights = weights * ((output - target) ** 2) # average across batches return torch.mean(diff_with_weights) @@ -166,8 +163,9 @@ def get_mse_exp(self, output, target): def get_mae_exp(self, output, target): """Loss function weighted MAE""" + weights = self.weights.to(target.device) # get the differences weighted by the forecast horizon weights - diff_with_weights = self.weights * torch.abs(output - target) + diff_with_weights = weights * torch.abs(output - target) # average across batches return torch.mean(diff_with_weights) diff --git a/requirements.txt b/requirements.txt index 7775997a..68d72982 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -ocf_datapipes>=3.0.0 +ocf_datapipes>=3.1.5 ocf_ml_metrics>=0.0.11 numpy pandas diff --git a/tests/test_data/sample_batches/train/000000.pt b/tests/test_data/sample_batches/train/000000.pt index a291529f..cecfa187 100644 Binary files a/tests/test_data/sample_batches/train/000000.pt and b/tests/test_data/sample_batches/train/000000.pt differ diff --git a/tests/test_data/sample_batches/train/000001.pt b/tests/test_data/sample_batches/train/000001.pt index 5031a53b..83fb056c 100644 Binary files a/tests/test_data/sample_batches/train/000001.pt and b/tests/test_data/sample_batches/train/000001.pt differ