diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 97a96c5e..ae607e78 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -327,7 +327,7 @@ def __init__( def _encode_query(self, x): # Select the first one - gsp_ids = x[BatchKey.sensor_id][:,0].squeeze().int() + gsp_ids = x[BatchKey.sensor_id][:, 0].squeeze().int() print(f"{gsp_ids.shape=}") query = self.sensor_id_embedding(gsp_ids).unsqueeze(1) print(f"{query.shape=}")