From fe2e2e5c339a1896d6768739729267087eb58a2d Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 5 Dec 2023 13:21:47 +0000 Subject: [PATCH] Debugging --- pvnet/models/base_model.py | 6 +++--- pvnet/models/multimodal/site_encoders/encoders.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 412f04c4..aac061f2 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -450,9 +450,9 @@ def validation_step(self, batch: dict, batch_idx): """Run validation step""" y_hat = self(batch) print(f"y_hat.shape: {y_hat.shape}") - print(f"{batch[self._target_key].shape}") - print(f"{batch[self._target_key][:, -self.forecast_len_30 :, 0].shape}") - print(f"{self.forecast_len_30}") + print(f"{batch[self._target_key].shape=}") + print(f"{batch[self._target_key][:, -self.forecast_len_30 :, 0].shape=}") + print(f"{self.forecast_len_30=}") y = batch[self._target_key][:, -self.forecast_len_30 :, 0] losses = self._calculate_common_losses(y, y_hat) diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index ae607e78..9f9f9a95 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -335,7 +335,7 @@ def _encode_query(self, x): def _encode_key(self, x): # Shape: [batch size, sequence length, PV site] - sensor_site_seqs = x[BatchKey.sensor].float() + sensor_site_seqs = x[BatchKey.sensor][:, : self.sequence_length].float() batch_size = sensor_site_seqs.shape[0] print(f"{sensor_site_seqs.shape=}") @@ -355,7 +355,7 @@ def _encode_key(self, x): def _encode_value(self, x): # Shape: [batch size, sequence length, PV site] - sensor_site_seqs = x[BatchKey.sensor].float() + sensor_site_seqs = x[BatchKey.sensor][:, : self.sequence_length].float() batch_size = sensor_site_seqs.shape[0] if self.use_sensor_id_in_value: @@ -394,7 +394,6 @@ def _attention_forward(self, x, average_attn_weights=True): def forward(self, x): """Run model forward""" # Do slicing here to only get history - x[BatchKey.sensor] = x[BatchKey.sensor][:, : self.sequence_length].float() attn_output, attn_output_weights = self._attention_forward(x) # Reshape from [batch_size, 1, vdim] to [batch_size, vdim]