From 164448cbf9ef24a9d8127a5c4fe869d01d1c3093 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Sun, 11 Feb 2024 11:00:50 +0000 Subject: [PATCH] fix encoder dim --- zoobot/pytorch/estimators/define_model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index 04a7d2e4..d35dc151 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -381,14 +381,15 @@ def dirichlet_loss(preds, labels, question_index_groups, sum_over_questions=Fals # input_size doesn't matter as long as it's large enough to not be pooled to zero # channels doesn't matter at all but has to match encoder channels or shape error def get_encoder_dim(encoder, channels=3): + device = next(encoder.parameters()).device try: - x = torch.randn(2, channels, 224, 224, device=encoder.device) # BCHW + x = torch.randn(2, channels, 224, 224, device=device) # BCHW return encoder(x).shape[-1] except RuntimeError as e: if 'channels instead' in str(e): logging.info('encoder dim search failed on channels, trying with channels=1') channels = 1 - x = torch.randn(2, channels, 224, 224, device=encoder.device) # BCHW + x = torch.randn(2, channels, 224, 224, device=device) # BCHW return encoder(x).shape[-1] else: raise e @@ -474,3 +475,9 @@ def get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_drop def schema_to_campaigns(schema): # e.g. [gz2, dr12, ...] return [question.text.split('-')[-1] for question in schema.questions] + + +if __name__ == '__main__': + encoder = get_pytorch_encoder(channels=1) + dim = get_encoder_dim(encoder, channels=1) + print(dim) \ No newline at end of file