Skip to content

Commit

Permalink
typo
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Feb 5, 2024
1 parent efb39ef commit adcce7d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def __init__(
self.encoder = torch.compile(self.encoder)

# bit lazy assuming 224 input size
self.encoder_dim = get_encoder_dim(self.encoder)
self.encoder_dim = get_encoder_dim(self.encoder, channels)
# typically encoder_dim=1280 for effnetb0
logging.info('encoder dim: {}'.format(self.encoder_dim))

Expand Down Expand Up @@ -379,12 +379,12 @@ def dirichlet_loss(preds, labels, question_index_groups, sum_over_questions=Fals

# input_size doesn't matter as long as it's large enough to not be pooled to zero
# channels doesn't matter at all
def get_encoder_dim(encoder):
def get_encoder_dim(encoder, channels=3):
try:
x = torch.randn(1, 3, 224, 224) # BCHW
x = torch.randn(1, channels, 224, 224) # BCHW
return encoder(x).shape[-1]
except RuntimeError: # tensor might not be on same device as model, just try the only other option
x = torch.randn(1, 3, 224, 224).to('cuda')
x = torch.randn(1, channels, 224, 224).to('cuda')
return encoder(x).shape[-1]


Expand Down

0 comments on commit adcce7d

Please sign in to comment.