Skip to content

Commit

Permalink
Merge pull request #117 from openclimatefix/fix_tests
Browse files Browse the repository at this point in the history
Fix test batches and make weighted loss general to device
  • Loading branch information
dfulu authored Jan 9, 2024
2 parents 3b74b1d + b2d0e30 commit 0fc19bc
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 8 deletions.
12 changes: 5 additions & 7 deletions pvnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,26 +148,24 @@ def __init__(self, decay_rate: Optional[int] = None, forecast_length: int = 6):
# normalized the weights, so there mean is 1.
# To calculate the loss, we times the weights by the differences between truth
# and predictions and then take the mean across all forecast horizons and the batch
self.weights = weights / weights.sum() * len(weights)

# move weights to gpu is needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.weights = self.weights.to(device)
self.weights = weights / weights.mean()

def get_mse_exp(self, output, target):
"""Loss function weighted MSE"""

weights = self.weights.to(target.device)
# get the differences weighted by the forecast horizon weights
diff_with_weights = self.weights * ((output - target) ** 2)
diff_with_weights = weights * ((output - target) ** 2)

# average across batches
return torch.mean(diff_with_weights)

def get_mae_exp(self, output, target):
"""Loss function weighted MAE"""

weights = self.weights.to(target.device)
# get the differences weighted by the forecast horizon weights
diff_with_weights = self.weights * torch.abs(output - target)
diff_with_weights = weights * torch.abs(output - target)

# average across batches
return torch.mean(diff_with_weights)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ocf_datapipes>=3.0.0
ocf_datapipes>=3.1.5
ocf_ml_metrics>=0.0.11
numpy
pandas
Expand Down
Binary file modified tests/test_data/sample_batches/train/000000.pt
Binary file not shown.
Binary file modified tests/test_data/sample_batches/train/000001.pt
Binary file not shown.

0 comments on commit 0fc19bc

Please sign in to comment.