From 9677877d086378ddedac29003bbcb0b0914937f3 Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:59:55 +0100 Subject: [PATCH] fix issue with time features missing in forward multimodal.py --- pvnet/models/multimodal/multimodal.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index 0eef7585..1541f69c 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -401,6 +401,19 @@ def forward(self, x): sun = self.sun_fc1(sun) modes["sun"] = sun + if self.include_time: + time = torch.cat( + ( + x[BatchKey[f"{self._target_key_name}_date_sin"]], + x[BatchKey[f"{self._target_key_name}_date_cos"]], + x[BatchKey[f"{self._target_key_name}_time_sin"]], + x[BatchKey[f"{self._target_key_name}_time_cos"]], + ), + dim=1, + ).float() + time = self.time_fc1(time) + modes["time"] = time + out = self.output_network(modes) if self.use_quantile_regression: