From c570c9351a2ea343190a4805884055563b1c2434 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 15 Feb 2024 12:33:13 +0000 Subject: [PATCH] Add better unsqeezing --- pvnet/models/multimodal/site_encoders/encoders.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 2d8b3713..9ac0b839 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -200,7 +200,7 @@ def _encode_query(self, x): else: ids = x[BatchKey[f"{self.input_key_to_use}_id"]][:, 0] ids = ids.squeeze().int() - if len(ids.shape) == 2: # Batch was squeezed down to nothing + if len(ids.shape) == 0: # Batch was squeezed down to nothing ids = ids.unsqueeze(0) query = self.target_id_embedding(ids).unsqueeze(1) return query @@ -244,7 +244,6 @@ def _attention_forward(self, x, average_attn_weights=True): query = self._encode_query(x) key = self._encode_key(x) value = self._encode_value(x) - attn_output, attn_weights = self.multihead_attn( query, key, value, average_attn_weights=average_attn_weights ) @@ -258,5 +257,7 @@ def forward(self, x): # Reshape from [batch_size, 1, vdim] to [batch_size, vdim] x_out = attn_output.squeeze() + if len(x_out.shape) == 1: + x_out = x_out.unsqueeze(0) return x_out