From ffe1809dacd89b940aab14551e647ba03dd7e0fb Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 5 Dec 2023 11:57:17 +0000 Subject: [PATCH] Debugging --- pvnet/models/multimodal/multimodal.py | 5 +++-- pvnet/models/multimodal/site_encoders/encoders.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index 58bc067f..cf44ac41 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -171,7 +171,7 @@ def __init__( sensor_history_minutes = history_minutes self.sensor_encoder = sensor_encoder( - sequence_length=sensor_history_minutes // 30 + 1, + sequence_length=self.history_len_30#sensor_history_minutes // 30 + 1, # Sensors are currently resampled to 30min ) @@ -249,7 +249,8 @@ def forward(self, x): # *********************** Sensor Data ************************************ # add sensor yield history if self.include_sensor: - modes["sensor"] = self.sensor_encoder(x) + sensor_history = x[BatchKey.sensor][:, : self.history_len_30].float() + modes["sensor"] = self.sensor_encoder(sensor_history) if self.include_sun: sun = torch.cat( diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 126086d9..4ab8bf99 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -334,14 +334,16 @@ def _encode_key(self, x): # Shape: [batch size, sequence length, PV site] sensor_site_seqs = x[BatchKey.sensor].float() batch_size = sensor_site_seqs.shape[0] + print(f"{sensor_site_seqs.shape=}") # Sensor ID embeddings are the same for each sample sensor_id_embed = torch.tile(self.pv_id_embedding(self._sensor_ids), (batch_size, 1, 1)) - + print(f"{sensor_id_embed.shape=}") # Each concated (Sensor sequence, Sensor ID embedding) is processed with encoder x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten( 0, 1 ) + print(f"{x_seq_in.shape=}") key = self._key_encoder(x_seq_in) # Reshape to [batch size, PV site, kdim]