Skip to content

Commit

Permalink
Debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Dec 5, 2023
1 parent 2d11c8b commit fe2e2e5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
6 changes: 3 additions & 3 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,9 @@ def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""
y_hat = self(batch)
print(f"y_hat.shape: {y_hat.shape}")
print(f"{batch[self._target_key].shape}")
print(f"{batch[self._target_key][:, -self.forecast_len_30 :, 0].shape}")
print(f"{self.forecast_len_30}")
print(f"{batch[self._target_key].shape=}")
print(f"{batch[self._target_key][:, -self.forecast_len_30 :, 0].shape=}")
print(f"{self.forecast_len_30=}")
y = batch[self._target_key][:, -self.forecast_len_30 :, 0]

losses = self._calculate_common_losses(y, y_hat)
Expand Down
5 changes: 2 additions & 3 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _encode_query(self, x):

def _encode_key(self, x):
# Shape: [batch size, sequence length, PV site]
sensor_site_seqs = x[BatchKey.sensor].float()
sensor_site_seqs = x[BatchKey.sensor][:, : self.sequence_length].float()
batch_size = sensor_site_seqs.shape[0]
print(f"{sensor_site_seqs.shape=}")

Expand All @@ -355,7 +355,7 @@ def _encode_key(self, x):

def _encode_value(self, x):
# Shape: [batch size, sequence length, PV site]
sensor_site_seqs = x[BatchKey.sensor].float()
sensor_site_seqs = x[BatchKey.sensor][:, : self.sequence_length].float()
batch_size = sensor_site_seqs.shape[0]

if self.use_sensor_id_in_value:
Expand Down Expand Up @@ -394,7 +394,6 @@ def _attention_forward(self, x, average_attn_weights=True):
def forward(self, x):
"""Run model forward"""
# Do slicing here to only get history
x[BatchKey.sensor] = x[BatchKey.sensor][:, : self.sequence_length].float()
attn_output, attn_output_weights = self._attention_forward(x)

# Reshape from [batch_size, 1, vdim] to [batch_size, vdim]
Expand Down

0 comments on commit fe2e2e5

Please sign in to comment.