Skip to content

Commit

Permalink
try sync again
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Feb 9, 2024
1 parent 8ae3e5a commit f354078
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
19 changes: 12 additions & 7 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
self.encoder = torch.compile(self.encoder)

# bit lazy assuming 224 input size
# logging.warning(channels)
self.encoder_dim = get_encoder_dim(self.encoder, channels)
# typically encoder_dim=1280 for effnetb0
logging.info('encoder dim: {}'.format(self.encoder_dim))
Expand Down Expand Up @@ -378,17 +379,21 @@ def dirichlet_loss(preds, labels, question_index_groups, sum_over_questions=Fals


# input_size doesn't matter as long as it's large enough to not be pooled to zero
# channels doesn't matter at all
# channels doesn't matter at all but has to match encoder channels or shape error
def get_encoder_dim(encoder, channels=3):
try:
x = torch.randn(1, channels, 224, 224) # BCHW
x = torch.randn(2, channels, 224, 224, device=encoder.device) # BCHW
return encoder(x).shape[-1]
except RuntimeError: # tensor might not be on same device as model, just try the only other option
x = torch.randn(1, channels, 224, 224).to('cuda')
return encoder(x).shape[-1]


except RuntimeError as e:
if 'channels instead' in str(e):
logging.info('encoder dim search failed on channels, trying with channels=1')
channels = 1
x = torch.randn(2, channels, 224, 224, device=encoder.device) # BCHW
return encoder(x).shape[-1]
else:
raise e


def get_pytorch_encoder(
architecture_name='efficientnet_b0',
channels=1,
Expand Down
2 changes: 1 addition & 1 deletion zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def train_default_zoobot_from_scratch(

if sync_batchnorm:
logging.info('Using sync batchnorm')
lightning_model = TorchSyncBatchNorm.apply(lightning_model)
lightning_model = TorchSyncBatchNorm().apply(lightning_model)


extra_callbacks = extra_callbacks if extra_callbacks else []
Expand Down

0 comments on commit f354078

Please sign in to comment.