diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index 04a7d2e4..d35dc151 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -381,14 +381,15 @@ def dirichlet_loss(preds, labels, question_index_groups, sum_over_questions=Fals # input_size doesn't matter as long as it's large enough to not be pooled to zero # channels doesn't matter at all but has to match encoder channels or shape error def get_encoder_dim(encoder, channels=3): + device = next(encoder.parameters()).device try: - x = torch.randn(2, channels, 224, 224, device=encoder.device) # BCHW + x = torch.randn(2, channels, 224, 224, device=device) # BCHW return encoder(x).shape[-1] except RuntimeError as e: if 'channels instead' in str(e): logging.info('encoder dim search failed on channels, trying with channels=1') channels = 1 - x = torch.randn(2, channels, 224, 224, device=encoder.device) # BCHW + x = torch.randn(2, channels, 224, 224, device=device) # BCHW return encoder(x).shape[-1] else: raise e @@ -474,3 +475,9 @@ def get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_drop def schema_to_campaigns(schema): # e.g. [gz2, dr12, ...] return [question.text.split('-')[-1] for question in schema.questions] + + +if __name__ == '__main__': + encoder = get_pytorch_encoder(channels=1) + dim = get_encoder_dim(encoder, channels=1) + print(dim) \ No newline at end of file