Skip to content

Commit

Permalink
add sync batchnorm option
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Feb 8, 2024
1 parent 3095b5a commit 8ae3e5a
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import pytorch_lightning as pl
from pytorch_lightning.plugins import TorchSyncBatchNorm
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
Expand Down Expand Up @@ -49,6 +50,7 @@ def train_default_zoobot_from_scratch(
# hardware parameters
nodes=1,
gpus=2,
sync_batchnorm=False,
num_workers=4,
prefetch_factor=4,
mixed_precision=False,
Expand Down Expand Up @@ -283,6 +285,11 @@ def train_default_zoobot_from_scratch(
weight_decay=weight_decay,
scheduler_params=scheduler_params
)

if sync_batchnorm:
logging.info('Using sync batchnorm')
lightning_model = TorchSyncBatchNorm.apply(lightning_model)


extra_callbacks = extra_callbacks if extra_callbacks else []

Expand Down

0 comments on commit 8ae3e5a

Please sign in to comment.