From cc167198c5999e0fc5f8df7d627736f6c9665928 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Mon, 8 Jan 2024 16:50:44 +0000 Subject: [PATCH 1/4] make weighted loss general to device --- pvnet/models/utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pvnet/models/utils.py b/pvnet/models/utils.py index 30a107ba..a28e0dd6 100644 --- a/pvnet/models/utils.py +++ b/pvnet/models/utils.py @@ -148,26 +148,25 @@ def __init__(self, decay_rate: Optional[int] = None, forecast_length: int = 6): # normalized the weights, so there mean is 1. # To calculate the loss, we times the weights by the differences between truth # and predictions and then take the mean across all forecast horizons and the batch - self.weights = weights / weights.sum() * len(weights) + self.weights = weights / weights.mean() - # move weights to gpu is needed - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.weights = self.weights.to(device) def get_mse_exp(self, output, target): """Loss function weighted MSE""" + weights = self.weights.to(target.device) # get the differences weighted by the forecast horizon weights - diff_with_weights = self.weights * ((output - target) ** 2) + diff_with_weights = weights * ((output - target) ** 2) # average across batches return torch.mean(diff_with_weights) def get_mae_exp(self, output, target): """Loss function weighted MAE""" - + + weights = self.weights.to(target.device) # get the differences weighted by the forecast horizon weights - diff_with_weights = self.weights * torch.abs(output - target) + diff_with_weights = weights * torch.abs(output - target) # average across batches return torch.mean(diff_with_weights) From 245751fa6decee77550ce606b558d08a8af77dda Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 16:55:57 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pvnet/models/utils.py b/pvnet/models/utils.py index a28e0dd6..d8f645db 100644 --- a/pvnet/models/utils.py +++ b/pvnet/models/utils.py @@ -150,7 +150,6 @@ def __init__(self, decay_rate: Optional[int] = None, forecast_length: int = 6): # and predictions and then take the mean across all forecast horizons and the batch self.weights = weights / weights.mean() - def get_mse_exp(self, output, target): """Loss function weighted MSE""" @@ -163,7 +162,7 @@ def get_mse_exp(self, output, target): def get_mae_exp(self, output, target): """Loss function weighted MAE""" - + weights = self.weights.to(target.device) # get the differences weighted by the forecast horizon weights diff_with_weights = weights * torch.abs(output - target) From 7c28bf434ba183016427699e0cf8bed5ae762e88 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Mon, 8 Jan 2024 18:04:21 +0000 Subject: [PATCH 3/4] fix batches for datapipes update --- .../test_data/sample_batches/train/000000.pt | Bin 4221950 -> 4221950 bytes .../test_data/sample_batches/train/000001.pt | Bin 4193726 -> 4193726 bytes 2 files changed, 0 insertions(+), 0 deletions(-) diff --git a/tests/test_data/sample_batches/train/000000.pt b/tests/test_data/sample_batches/train/000000.pt index a291529fa339e6df350cdf64acc60fa8bc565b80..cecfa1877cb769c3a9f5b07bdec967604b216701 100644 GIT binary patch delta 427 zcmW;IJ4?e*7{+lNzDQQd zIVm`&w*QdWoBXFt_eh>^RcMxM)UPV^SRfJp_ek?nwB&phoNrsiFvhp5G)GeWrj;~* zuhLzTz+6uO`rdO(6+=!1S306z@E5DY^A zMj!~IFb3lgf(e*}FigQT%)l(n!8|O$B19kxF^EF~l8}Nl$e@4<8e~9+B{0B*Wmth# zSc7%Q)*EiF_8O86MK=vCBdexvDyFLFy1Zk1eIdjD%b$u?u>FMkThY4ri0xzX$a?q( DZ#u!D delta 427 zcmW;IOG^S#7{+mqjp?cEVmHn4+FiRi&b6Bkv}*{;c1BAfXf?DbS`{1#iI#C_J(A9@ zMP$%dsAcpS3fJwF9$ozS!^^XJJKr~+4qP+BJgkoi%**50%juD;#2TnYm)2tyEnVHkl?7=s{;!vsvi6ih=1 z!Y~675FrXNh(iLBkb+s5gLznhMOcDmNJ9ozU=`M29kL)n4rEY31r2mCU;{Q`3$|eg z^39f8XuOM(rptz8D2A-6nxZMXqN)d$&tFKD!u+{xhI&tEew57zuh2V|&&3~-9E79dg4?LLi!8%31G61Q856XcMzJ%oC$V8dp{2zkCd5*6VobaW zBZZL*urPWTDi=UujDAo&d6Pe1^>=>$NCz%C;Q>CxB_8C%JjBC~^j#-q@uAbF$X?Sw zulg5Wv_`xGabwV9QW2j9EtCCr!laD=Ns4=$RwAjU|5f$BJqg3K=-SjJIiWBlFU}Y} zB3bdm=m}ATTtYIU(T^qG`msgIp{G)&mg`-$910^a3S$s~ahQNfn1U!w!wk&A9L$3Z z3$O@F5QAk{fmK+8b=ZJS*n&80LjsbJf;1?Qfh^=84=QL-fFkH%fC(1Zz@P*U?7%LR y+bzd^eU7PS(NGP>REz1VS7oF2bmI;DvIdh_D{Ju*WQ`EQ#edik*qQiH#cyEiHzaprvM_ehnjq zk;K-*=m)U(sFlJPeW3W|pS*dhxBc@++HlMX_whb1aX;_p0Uo@e&l*w1i$-rlrG|G{ z^N!r`9&wpC*XcGf#G6hl#BAjZTKAEtxMp-C6l-|zHSfcfFpP^&MjetCvV~;CzD2i4 zO59s?ha^O#j3h<9n}T@iMizOS?n>3{*TYHErZ515Fa#kOh7lNrF$lvrOu!^e!8Al* z24*1&F^EF~l8}NlWMB^FAqxwTgGIoy&idft2 Date: Mon, 8 Jan 2024 18:07:07 +0000 Subject: [PATCH 4/4] update reqs --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7775997a..68d72982 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -ocf_datapipes>=3.0.0 +ocf_datapipes>=3.1.5 ocf_ml_metrics>=0.0.11 numpy pandas