Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jacob/windnet' into jacob/windnet
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Nov 27, 2023
2 parents 03d5b78 + b860dc7 commit 6dccd2f
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ def _encode_key(self, x):
sensor_id_embed = torch.tile(self.pv_id_embedding(self._sensor_ids), (batch_size, 1, 1))

# Each concated (Sensor sequence, Sensor ID embedding) is processed with encoder
x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten(0, 1)
x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten(
0, 1
)
key = self._key_encoder(x_seq_in)

# Reshape to [batch size, PV site, kdim]
Expand All @@ -353,9 +355,13 @@ def _encode_value(self, x):

if self.use_sensor_id_in_value:
# Sensor ID embeddings are the same for each sample
sensor_id_embed = torch.tile(self.value_sensor_id_embedding(self._sensor_ids), (batch_size, 1, 1))
sensor_id_embed = torch.tile(
self.value_sensor_id_embedding(self._sensor_ids), (batch_size, 1, 1)
)
# Each concated (Sensor sequence, Sensor ID embedding) is processed with encoder
x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten(0, 1)
x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten(
0, 1
)
else:
# Encode each PV sequence independently
x_seq_in = sensor_site_seqs.swapaxes(1, 2).flatten(0, 1)
Expand Down

0 comments on commit 6dccd2f

Please sign in to comment.