diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 2d8b3713..9ac0b839 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -200,7 +200,7 @@ def _encode_query(self, x): else: ids = x[BatchKey[f"{self.input_key_to_use}_id"]][:, 0] ids = ids.squeeze().int() - if len(ids.shape) == 2: # Batch was squeezed down to nothing + if len(ids.shape) == 0: # Batch was squeezed down to nothing ids = ids.unsqueeze(0) query = self.target_id_embedding(ids).unsqueeze(1) return query @@ -244,7 +244,6 @@ def _attention_forward(self, x, average_attn_weights=True): query = self._encode_query(x) key = self._encode_key(x) value = self._encode_value(x) - attn_output, attn_weights = self.multihead_attn( query, key, value, average_attn_weights=average_attn_weights ) @@ -258,5 +257,7 @@ def forward(self, x): # Reshape from [batch_size, 1, vdim] to [batch_size, vdim] x_out = attn_output.squeeze() + if len(x_out.shape) == 1: + x_out = x_out.unsqueeze(0) return x_out