Skip to content

Commit

Permalink
modified trainer and training script to use pytorch lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanakumars committed Nov 21, 2023
1 parent 5d7866b commit e31cab6
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 305 deletions.
95 changes: 52 additions & 43 deletions patchgan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from patchgan.unet import UNet
from patchgan.disc import Discriminator
from patchgan.io import COCOStuffDataset
from patchgan.trainer import Trainer
from patchgan.trainer import PatchGAN
from torch.utils.data import DataLoader, random_split
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
import yaml
import importlib.machinery
import argparse
Expand Down Expand Up @@ -80,48 +82,55 @@ def patchgan_train():
dloader_kwargs['persistent_workers'] = True

train_data = DataLoader(train_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs)
val_data = DataLoader(val_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs)

model_params = config['model_params']
generator_config = model_params['generator']
discriminator_config = model_params['discriminator']

# create the generator
gen_filts = generator_config['filters']
activation = generator_config['activation']
use_dropout = generator_config.get('use_dropout', True)
final_activation = generator_config.get('final_activation', 'sigmoid')
generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device)

# create the discriminator
disc_filts = discriminator_config['filters']
disc_norm = discriminator_config.get('norm', False)
n_disc_layers = discriminator_config['n_layers']
discriminator = Discriminator(in_channels + out_channels, disc_filts, norm=disc_norm, n_layers=n_disc_layers).to(device)

if args.summary:
summary(generator, [1, in_channels, size, size], depth=4)
summary(discriminator, [1, in_channels + out_channels, size, size])
val_data = DataLoader(val_datagen, batch_size=args.batch_size, pin_memory=True, **dloader_kwargs)

checkpoint_path = config.get('checkpoint_path', './checkpoints/')

trainer = Trainer(generator, discriminator, savefolder=checkpoint_path)

model = None
if config.get('load_last_checkpoint', False):
trainer.load_last_checkpoint()
elif config.get('transfer_learn', {}).get('generator_checkpoint', None) is not None:
gen_checkpoint = config['transfer_learn']['generator_checkpoint']
dsc_checkpoint = config['transfer_learn']['discriminator_checkpoint']
generator.load_transfer_data(torch.load(gen_checkpoint, map_location=device))
discriminator.load_transfer_data(torch.load(dsc_checkpoint, map_location=device))

train_params = config['train_params']

trainer.loss_type = train_params['loss_type']
trainer.seg_alpha = train_params['seg_alpha']

trainer.train(train_data, val_data, args.n_epochs,
dsc_learning_rate=train_params['disc_learning_rate'],
gen_learning_rate=train_params['gen_learning_rate'],
lr_decay=train_params.get('decay_rate', None),
save_freq=train_params.get('save_freq', 10))
model = PatchGAN.load_last_checkpoint(checkpoint_path)

if model is None:
model_params = config['model_params']
generator_config = model_params['generator']
discriminator_config = model_params['discriminator']

# get the discriminator and generator configs
gen_filts = generator_config['filters']
activation = generator_config['activation']
use_dropout = generator_config.get('use_dropout', True)
final_activation = generator_config.get('final_activation', 'sigmoid')
disc_filts = discriminator_config['filters']
disc_norm = discriminator_config.get('norm', False)
n_disc_layers = discriminator_config['n_layers']

# and the training parameters
train_params = config['train_params']
loss_type = train_params['loss_type']
seg_alpha = train_params['seg_alpha']
dsc_learning_rate = train_params['disc_learning_rate']
gen_learning_rate = train_params['gen_learning_rate']
lr_decay = train_params.get('decay_rate', 0.98)
decay_freq = train_params.get('decay_freq', 5)
save_freq = train_params.get('save_freq', 10)
model = PatchGAN(in_channels, out_channels, gen_filts, disc_filts, final_activation, n_disc_layers, use_dropout,
activation, disc_norm, gen_learning_rate, dsc_learning_rate, lr_decay, decay_freq,
loss_type=loss_type, seg_alpha=seg_alpha)

if config.get('transfer_learn', {}).get('checkpoint', None) is not None:
checkpoint = torch.load(config['transfer_learn']['checkpoint'], map_location=device)
model.generator.load_transfer_data({key: value for key, value in checkpoint['state_dict'].items() if 'generator' in key})
model.discriminator.load_transfer_data({key: value for key, value in checkpoint['state_dict'].items() if 'discriminator' in key})

if args.summary:
summary(model.generator, [1, in_channels, size, size], depth=4)
summary(model.discriminator, [1, in_channels + out_channels, size, size])

checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_path,
filename='patchgan_{epoch:03d}',
save_top_k=-1,
every_n_epochs=save_freq,
verbose=True)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
trainer = Trainer(accelerator=device, max_epochs=args.n_epochs, callbacks=[checkpoint_callback, lr_monitor])

trainer.fit(model, train_data, val_data)
Loading

0 comments on commit e31cab6

Please sign in to comment.