diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 22139941..126086d9 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -339,7 +339,9 @@ def _encode_key(self, x): sensor_id_embed = torch.tile(self.pv_id_embedding(self._sensor_ids), (batch_size, 1, 1)) # Each concated (Sensor sequence, Sensor ID embedding) is processed with encoder - x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten(0, 1) + x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten( + 0, 1 + ) key = self._key_encoder(x_seq_in) # Reshape to [batch size, PV site, kdim] @@ -353,9 +355,13 @@ def _encode_value(self, x): if self.use_sensor_id_in_value: # Sensor ID embeddings are the same for each sample - sensor_id_embed = torch.tile(self.value_sensor_id_embedding(self._sensor_ids), (batch_size, 1, 1)) + sensor_id_embed = torch.tile( + self.value_sensor_id_embedding(self._sensor_ids), (batch_size, 1, 1) + ) # Each concated (Sensor sequence, Sensor ID embedding) is processed with encoder - x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten(0, 1) + x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten( + 0, 1 + ) else: # Encode each PV sequence independently x_seq_in = sensor_site_seqs.swapaxes(1, 2).flatten(0, 1)