Skip to content

Commit

Permalink
Debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Dec 5, 2023
1 parent d7111db commit ffe1809
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 3 additions & 2 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(
sensor_history_minutes = history_minutes

self.sensor_encoder = sensor_encoder(
sequence_length=sensor_history_minutes // 30 + 1,
sequence_length=self.history_len_30#sensor_history_minutes // 30 + 1,
# Sensors are currently resampled to 30min
)

Expand Down Expand Up @@ -249,7 +249,8 @@ def forward(self, x):
# *********************** Sensor Data ************************************
# add sensor yield history
if self.include_sensor:
modes["sensor"] = self.sensor_encoder(x)
sensor_history = x[BatchKey.sensor][:, : self.history_len_30].float()
modes["sensor"] = self.sensor_encoder(sensor_history)

if self.include_sun:
sun = torch.cat(
Expand Down
4 changes: 3 additions & 1 deletion pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,14 +334,16 @@ def _encode_key(self, x):
# Shape: [batch size, sequence length, PV site]
sensor_site_seqs = x[BatchKey.sensor].float()
batch_size = sensor_site_seqs.shape[0]
print(f"{sensor_site_seqs.shape=}")

# Sensor ID embeddings are the same for each sample
sensor_id_embed = torch.tile(self.pv_id_embedding(self._sensor_ids), (batch_size, 1, 1))

print(f"{sensor_id_embed.shape=}")
# Each concated (Sensor sequence, Sensor ID embedding) is processed with encoder
x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten(
0, 1
)
print(f"{x_seq_in.shape=}")
key = self._key_encoder(x_seq_in)

# Reshape to [batch size, PV site, kdim]
Expand Down

0 comments on commit ffe1809

Please sign in to comment.