Skip to content

Commit

Permalink
Merge pull request #138 from openclimatefix/jacob/unsqeeze2
Browse files Browse the repository at this point in the history
Add better unsqeezing
  • Loading branch information
jacobbieker authored Feb 15, 2024
2 parents 78008fc + c570c93 commit d07cfbf
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _encode_query(self, x):
else:
ids = x[BatchKey[f"{self.input_key_to_use}_id"]][:, 0]
ids = ids.squeeze().int()
if len(ids.shape) == 2: # Batch was squeezed down to nothing
if len(ids.shape) == 0: # Batch was squeezed down to nothing
ids = ids.unsqueeze(0)
query = self.target_id_embedding(ids).unsqueeze(1)
return query
Expand Down Expand Up @@ -244,7 +244,6 @@ def _attention_forward(self, x, average_attn_weights=True):
query = self._encode_query(x)
key = self._encode_key(x)
value = self._encode_value(x)

attn_output, attn_weights = self.multihead_attn(
query, key, value, average_attn_weights=average_attn_weights
)
Expand All @@ -258,5 +257,7 @@ def forward(self, x):

# Reshape from [batch_size, 1, vdim] to [batch_size, vdim]
x_out = attn_output.squeeze()
if len(x_out.shape) == 1:
x_out = x_out.unsqueeze(0)

return x_out

0 comments on commit d07cfbf

Please sign in to comment.