Skip to content

Commit

Permalink
fix encoder dim
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Feb 11, 2024
1 parent f354078 commit 164448c
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,15 @@ def dirichlet_loss(preds, labels, question_index_groups, sum_over_questions=Fals
# input_size doesn't matter as long as it's large enough to not be pooled to zero
# channels doesn't matter at all but has to match encoder channels or shape error
def get_encoder_dim(encoder, channels=3):
device = next(encoder.parameters()).device
try:
x = torch.randn(2, channels, 224, 224, device=encoder.device) # BCHW
x = torch.randn(2, channels, 224, 224, device=device) # BCHW
return encoder(x).shape[-1]
except RuntimeError as e:
if 'channels instead' in str(e):
logging.info('encoder dim search failed on channels, trying with channels=1')
channels = 1
x = torch.randn(2, channels, 224, 224, device=encoder.device) # BCHW
x = torch.randn(2, channels, 224, 224, device=device) # BCHW
return encoder(x).shape[-1]
else:
raise e
Expand Down Expand Up @@ -474,3 +475,9 @@ def get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_drop
def schema_to_campaigns(schema):
# e.g. [gz2, dr12, ...]
return [question.text.split('-')[-1] for question in schema.questions]


if __name__ == '__main__':
encoder = get_pytorch_encoder(channels=1)
dim = get_encoder_dim(encoder, channels=1)
print(dim)

0 comments on commit 164448c

Please sign in to comment.