From b860dc7896aac136d40f755ac628bbd2e8fb0a8b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Nov 2023 13:59:22 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/multimodal/site_encoders/encoders.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 22139941..126086d9 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -339,7 +339,9 @@ def _encode_key(self, x): sensor_id_embed = torch.tile(self.pv_id_embedding(self._sensor_ids), (batch_size, 1, 1)) # Each concated (Sensor sequence, Sensor ID embedding) is processed with encoder - x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten(0, 1) + x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten( + 0, 1 + ) key = self._key_encoder(x_seq_in) # Reshape to [batch size, PV site, kdim] @@ -353,9 +355,13 @@ def _encode_value(self, x): if self.use_sensor_id_in_value: # Sensor ID embeddings are the same for each sample - sensor_id_embed = torch.tile(self.value_sensor_id_embedding(self._sensor_ids), (batch_size, 1, 1)) + sensor_id_embed = torch.tile( + self.value_sensor_id_embedding(self._sensor_ids), (batch_size, 1, 1) + ) # Each concated (Sensor sequence, Sensor ID embedding) is processed with encoder - x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten(0, 1) + x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten( + 0, 1 + ) else: # Encode each PV sequence independently x_seq_in = sensor_site_seqs.swapaxes(1, 2).flatten(0, 1)