From e31cab681534af1e39503946df3fa201a938e881 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Mon, 20 Nov 2023 21:09:25 -0600 Subject: [PATCH 01/17] modified trainer and training script to use pytorch lightning --- patchgan/train.py | 95 +++++++------ patchgan/trainer.py | 340 ++++++++++---------------------------------- 2 files changed, 130 insertions(+), 305 deletions(-) diff --git a/patchgan/train.py b/patchgan/train.py index 7b846a5..00e7d03 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -3,8 +3,10 @@ from patchgan.unet import UNet from patchgan.disc import Discriminator from patchgan.io import COCOStuffDataset -from patchgan.trainer import Trainer +from patchgan.trainer import PatchGAN from torch.utils.data import DataLoader, random_split +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor import yaml import importlib.machinery import argparse @@ -80,48 +82,55 @@ def patchgan_train(): dloader_kwargs['persistent_workers'] = True train_data = DataLoader(train_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs) - val_data = DataLoader(val_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs) - - model_params = config['model_params'] - generator_config = model_params['generator'] - discriminator_config = model_params['discriminator'] - - # create the generator - gen_filts = generator_config['filters'] - activation = generator_config['activation'] - use_dropout = generator_config.get('use_dropout', True) - final_activation = generator_config.get('final_activation', 'sigmoid') - generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device) - - # create the discriminator - disc_filts = discriminator_config['filters'] - disc_norm = discriminator_config.get('norm', False) - n_disc_layers = discriminator_config['n_layers'] - discriminator = Discriminator(in_channels + out_channels, disc_filts, norm=disc_norm, n_layers=n_disc_layers).to(device) - - if args.summary: - summary(generator, [1, in_channels, size, size], depth=4) - summary(discriminator, [1, in_channels + out_channels, size, size]) + val_data = DataLoader(val_datagen, batch_size=args.batch_size, pin_memory=True, **dloader_kwargs) checkpoint_path = config.get('checkpoint_path', './checkpoints/') - - trainer = Trainer(generator, discriminator, savefolder=checkpoint_path) - + model = None if config.get('load_last_checkpoint', False): - trainer.load_last_checkpoint() - elif config.get('transfer_learn', {}).get('generator_checkpoint', None) is not None: - gen_checkpoint = config['transfer_learn']['generator_checkpoint'] - dsc_checkpoint = config['transfer_learn']['discriminator_checkpoint'] - generator.load_transfer_data(torch.load(gen_checkpoint, map_location=device)) - discriminator.load_transfer_data(torch.load(dsc_checkpoint, map_location=device)) - - train_params = config['train_params'] - - trainer.loss_type = train_params['loss_type'] - trainer.seg_alpha = train_params['seg_alpha'] - - trainer.train(train_data, val_data, args.n_epochs, - dsc_learning_rate=train_params['disc_learning_rate'], - gen_learning_rate=train_params['gen_learning_rate'], - lr_decay=train_params.get('decay_rate', None), - save_freq=train_params.get('save_freq', 10)) + model = PatchGAN.load_last_checkpoint(checkpoint_path) + + if model is None: + model_params = config['model_params'] + generator_config = model_params['generator'] + discriminator_config = model_params['discriminator'] + + # get the discriminator and generator configs + gen_filts = generator_config['filters'] + activation = generator_config['activation'] + use_dropout = generator_config.get('use_dropout', True) + final_activation = generator_config.get('final_activation', 'sigmoid') + disc_filts = discriminator_config['filters'] + disc_norm = discriminator_config.get('norm', False) + n_disc_layers = discriminator_config['n_layers'] + + # and the training parameters + train_params = config['train_params'] + loss_type = train_params['loss_type'] + seg_alpha = train_params['seg_alpha'] + dsc_learning_rate = train_params['disc_learning_rate'] + gen_learning_rate = train_params['gen_learning_rate'] + lr_decay = train_params.get('decay_rate', 0.98) + decay_freq = train_params.get('decay_freq', 5) + save_freq = train_params.get('save_freq', 10) + model = PatchGAN(in_channels, out_channels, gen_filts, disc_filts, final_activation, n_disc_layers, use_dropout, + activation, disc_norm, gen_learning_rate, dsc_learning_rate, lr_decay, decay_freq, + loss_type=loss_type, seg_alpha=seg_alpha) + + if config.get('transfer_learn', {}).get('checkpoint', None) is not None: + checkpoint = torch.load(config['transfer_learn']['checkpoint'], map_location=device) + model.generator.load_transfer_data({key: value for key, value in checkpoint['state_dict'].items() if 'generator' in key}) + model.discriminator.load_transfer_data({key: value for key, value in checkpoint['state_dict'].items() if 'discriminator' in key}) + + if args.summary: + summary(model.generator, [1, in_channels, size, size], depth=4) + summary(model.discriminator, [1, in_channels + out_channels, size, size]) + + checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_path, + filename='patchgan_{epoch:03d}', + save_top_k=-1, + every_n_epochs=save_freq, + verbose=True) + lr_monitor = LearningRateMonitor(logging_interval='epoch') + trainer = Trainer(accelerator=device, max_epochs=args.n_epochs, callbacks=[checkpoint_callback, lr_monitor]) + + trainer.fit(model, train_data, val_data) diff --git a/patchgan/trainer.py b/patchgan/trainer.py index 394d44b..2a7d6fd 100644 --- a/patchgan/trainer.py +++ b/patchgan/trainer.py @@ -8,56 +8,58 @@ from .losses import fc_tversky, bce_loss, MAE_loss from torch.nn.functional import binary_cross_entropy from collections import defaultdict +from .unet import UNet +from .disc import Discriminator +import lightning as L device = 'cuda' if torch.cuda.is_available() else 'cpu' -class Trainer: - ''' - Trainer module which contains both the full training driver - which calls the train_batch method - ''' +class PatchGAN(L.LightningModule): + def __init__(self, input_channels: int, output_channels: int, gen_filts: int, disc_filts: int, final_activation: str, + n_disc_layers: int = 5, use_gen_dropout: bool = True, gen_activation: str = 'leakyrelu', + disc_norm: bool = False, gen_lr: float = 1.e-3, dsc_lr: float = 1.e-3, lr_decay: float = 0.98, + decay_freq: int = 5, adam_b1: float = 0.5, adam_b2: float = 0.999, seg_alpha: float = 200, + loss_type: str = 'tversky', tversky_beta: float = 0.75, tversky_gamma: float = 0.75): + super().__init__() + self.save_hyperparameters() + self.automatic_optimization = False - seg_alpha = 200 - loss_type = 'tversky' - tversky_beta = 0.75 - tversky_gamma = 0.75 + self.generator = UNet(input_channels, output_channels, gen_filts, use_dropout=use_gen_dropout, + activation=gen_activation, final_act=final_activation) + self.discriminator = Discriminator(input_channels + output_channels, disc_filts, + norm=disc_norm, n_layers=n_disc_layers) - neptune_config = None + def forward(self, img, return_hidden=False): + return self.generator(img, return_hidden) - def __init__(self, generator, discriminator, savefolder, device='cuda'): + def training_step(self, batch): ''' - Store the generator and discriminator info + Train the generator and discriminator on a single batch ''' + optimizer_g, optimizer_d = self.optimizers() - generator.apply(weights_init) - discriminator.apply(weights_init) - - self.generator = generator - self.discriminator = discriminator - self.device = device + mean_loss = self.batch_step(batch, True, optimizer_g, optimizer_d) - if savefolder[-1] != '/': - savefolder += '/' + sch_g, sch_d = self.lr_schedulers() + if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % self.hparams.decay_freq == 0: + sch_g.step() + sch_d.step() - self.savefolder = savefolder - if not os.path.exists(savefolder): - os.mkdir(savefolder) + for key, val in mean_loss.items(): + self.log(key, val, prog_bar=True, on_epoch=True, reduce_fx=torch.mean) - self.start = 1 + def validation_step(self, batch): + mean_loss = self.batch_step(batch, False) - def batch(self, x, y, train=False): - ''' - Train the generator and discriminator on a single batch - ''' + for key, val in mean_loss.items(): + self.log(key, val, prog_bar=True, on_epoch=True, reduce_fx=torch.mean) - if not isinstance(x, torch.Tensor): - input_tensor = torch.as_tensor(x, dtype=torch.float).to(self.device) - target_tensor = torch.as_tensor(y, dtype=torch.float).to(self.device) - else: - input_tensor = x.to(self.device, non_blocking=True) - target_tensor = y.to(self.device, non_blocking=True) + def batch_step(self, batch: torch.Tensor | tuple[torch.Tensor], train: bool, + optimizer_g: torch.optim.Optimizer | None = None, + optimizer_d: torch.optim.Optimizer | None = None): + input_tensor, target_tensor = batch # train the generator gen_img = self.generator(input_tensor) @@ -68,30 +70,32 @@ def batch(self, x, y, train=False): labels_real = torch.full(disc_fake.shape, 1, dtype=torch.float, device=device) labels_fake = torch.full(disc_fake.shape, 0, dtype=torch.float, device=device) - if self.loss_type == 'tversky': + if self.hparams.loss_type == 'tversky': gen_loss = fc_tversky(target_tensor, gen_img, - beta=self.tversky_beta, - gamma=self.tversky_gamma) * self.seg_alpha - elif self.loss_type == 'weighted_bce': + beta=self.hparams.tversky_beta, + gamma=self.hparams.tversky_gamma) * self.hparams.seg_alpha + elif self.hparams.loss_type == 'weighted_bce': if gen_img.shape[1] > 1: weight = 1 - torch.sum(target_tensor, dim=(2, 3), keepdim=True) / torch.sum(target_tensor) else: weight = torch.ones_like(target_tensor) - gen_loss = binary_cross_entropy(gen_img, target_tensor, weight=weight) * self.seg_alpha - elif self.loss_type == 'MAE': - gen_loss = MAE_loss(gen_img, target_tensor) * self.seg_alpha + gen_loss = binary_cross_entropy(gen_img, target_tensor, weight=weight) * self.hparams.seg_alpha + elif self.hparams.loss_type == 'MAE': + gen_loss = MAE_loss(gen_img, target_tensor) * self.hparams.seg_alpha gen_loss_disc = bce_loss(disc_fake, labels_real) gen_loss = gen_loss + gen_loss_disc if train: - self.generator.zero_grad() - gen_loss.backward() - self.gen_optimizer.step() + self.toggle_optimizer(optimizer_g) + optimizer_g.zero_grad() + self.manual_backward(gen_loss) + optimizer_g.step() + self.untoggle_optimizer(optimizer_g) # Train the discriminator if train: - self.discriminator.zero_grad() + self.toggle_optimizer(optimizer_d) disc_inp_real = torch.cat((input_tensor, target_tensor), 1) disc_real = self.discriminator(disc_inp_real) @@ -103,8 +107,10 @@ def batch(self, x, y, train=False): disc_loss = (loss_fake + loss_real) / 2. if train: - disc_loss.backward() - self.disc_optimizer.step() + optimizer_d.zero_grad() + self.manual_backward(disc_loss) + optimizer_d.step() + self.untoggle_optimizer(optimizer_d) keys = ['gen', 'gen_loss', 'gdisc', 'discr', 'discf', 'disc'] mean_loss_i = [gen_loss.item(), gen_loss.item(), gen_loss_disc.item(), @@ -114,230 +120,40 @@ def batch(self, x, y, train=False): return loss - def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-3, - gen_learning_rate=1.e-3, save_freq=10, lr_decay=None, decay_freq=5, - reduce_on_plateau=False): - ''' - Training driver which loads the optimizer and calls the - `train_batch` method. Also handles checkpoint saving - Inputs - ------ - train_data : DataLoader object - Training data that is mapped using the DataLoader or - MmapDataLoader object defined in io.py - val_data : DataLoader object - Validation data loaded in using the DataLoader or - MmapDataLoader object - epochs : int - Number of epochs to run the model - dsc_learning_rate : float [default: 1e-4] - Initial learning rate for the discriminator - gen_learning_rate : float [default: 1e-3] - Initial learning rate for the generator - save_freq : int [default: 10] - Frequency at which to save checkpoints to the save folder - lr_decay : float [default: None] - Learning rate decay rate (ratio of new learning rate - to previous). A value of 0.95, for example, would set the - new LR to 95% of the previous value - decay_freq : int [default: 5] - Frequency at which to decay the learning rate. For example, - a value of for decay_freq and 0.95 for lr_decay would decay - the learning to 95% of the current value every 5 epochs. - Outputs - ------- - G_loss_plot : numpy.ndarray - Generator loss history as a function of the epochs - D_loss_plot : numpy.ndarray - Discriminator loss history as a function of the epochs - ''' - - if (lr_decay is not None) and not reduce_on_plateau: - gen_lr = gen_learning_rate * (lr_decay)**((self.start - 1) / (decay_freq)) - dsc_lr = dsc_learning_rate * (lr_decay)**((self.start - 1) / (decay_freq)) - else: - gen_lr = gen_learning_rate - dsc_lr = dsc_learning_rate - - if self.neptune_config is not None: - self.neptune_config['model/parameters/gen_learning_rate'] = gen_lr - self.neptune_config['model/parameters/dsc_learning_rate'] = dsc_lr - self.neptune_config['model/parameters/start'] = self.start - self.neptune_config['model/parameters/n_epochs'] = epochs - - # create the Adam optimzers - self.gen_optimizer = optim.Adam( - self.generator.parameters(), lr=gen_lr, betas=(0.9, 0.999)) - self.disc_optimizer = optim.Adam( - self.discriminator.parameters(), lr=dsc_lr, betas=(0.9, 0.999)) - - # set up the learning rate scheduler with exponential lr decay - if reduce_on_plateau: - gen_scheduler = ReduceLROnPlateau(self.gen_optimizer, verbose=True) - dsc_scheduler = ReduceLROnPlateau(self.disc_optimizer, verbose=True) - self.neptune_config['model/parameters/scheduler'] = 'ReduceLROnPlateau' - elif lr_decay is not None: - gen_scheduler = ExponentialLR(self.gen_optimizer, gamma=lr_decay) - dsc_scheduler = ExponentialLR(self.disc_optimizer, gamma=lr_decay) - if self.neptune_config is not None: - self.neptune_config['model/parameters/scheduler'] = 'ExponentialLR' - self.neptune_config['model/parameters/decay_freq'] = decay_freq - self.neptune_config['model/parameters/lr_decay'] = lr_decay - else: - gen_scheduler = None - dsc_scheduler = None - - # empty lists for storing epoch loss data - D_loss_ep, G_loss_ep = [], [] - for epoch in range(self.start, epochs + 1): - if isinstance(gen_scheduler, ExponentialLR): - gen_lr = gen_scheduler.get_last_lr()[0] - dsc_lr = dsc_scheduler.get_last_lr()[0] - else: - gen_lr = gen_learning_rate - dsc_lr = dsc_learning_rate - - print(f"Epoch {epoch} -- lr: {gen_lr:5.3e}, {dsc_lr:5.3e}") - print("-------------------------------------------------------") - - # batch loss data - pbar = tqdm.tqdm(train_data, desc='Training: ', dynamic_ncols=True) - - if hasattr(train_data, 'shuffle'): - train_data.shuffle() - - # set to training mode - self.generator.train() - self.discriminator.train() - - losses = defaultdict(list) - # loop through the training data - for i, (input_img, target_mask) in enumerate(pbar): - - # train on this batch - batch_loss = self.batch(input_img, target_mask, train=True) + def configure_optimizers(self): + gen_lr = self.hparams.gen_lr + dsc_lr = self.hparams.dsc_lr - # append the current batch loss - loss_mean = {} - for key, value in batch_loss.items(): - losses[key].append(value) - loss_mean[key] = np.mean(losses[key], axis=0) + opt_g = torch.optim.Adam(self.generator.parameters(), lr=gen_lr, betas=(self.hparams.adam_b1, self.hparams.adam_b2)) + opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=dsc_lr, betas=(self.hparams.adam_b1, self.hparams.adam_b2)) - loss_str = " ".join([f"{key}: {value:.2e}" for key, value in loss_mean.items()]) + gen_lr_scheduler = ExponentialLR(opt_g, gamma=self.hparams.lr_decay) + dsc_lr_scheduler = ExponentialLR(opt_d, gamma=self.hparams.lr_decay) - pbar.set_postfix_str(loss_str) + gen_lr_scheduler_config = {"scheduler": gen_lr_scheduler, + "interval": "epoch", + "frequency": self.hparams.decay_freq} - # update the epoch loss - D_loss_ep.append(loss_mean['disc']) - G_loss_ep.append(loss_mean['gen']) + dsc_lr_scheduler_config = {"scheduler": dsc_lr_scheduler, + "interval": "epoch", + "frequency": self.hparams.decay_freq} - if self.neptune_config is not None: - self.neptune_config['train/gen_loss'].append(loss_mean['gen']) - self.neptune_config['train/disc_loss'].append(loss_mean['disc']) + return [{"optimizer": opt_g, "lr_scheduler": gen_lr_scheduler_config}, + {"optimizer": opt_d, "lr_scheduler": dsc_lr_scheduler_config}] - # validate every `validation_freq` epochs - self.discriminator.eval() - self.generator.eval() - pbar = tqdm.tqdm(val_data, desc='Validation: ') + def load_last_checkpoint(self, checkpoint_path): + checkpoints = sorted(glob.glob(os.path.join(checkpoint_path, "patchgan_*.pth"))) - if hasattr(val_data, 'shuffle'): - val_data.shuffle() - - losses = defaultdict(list) - # loop through the training data - for i, (input_img, target_mask) in enumerate(pbar): - # validate on this batch - batch_loss = self.batch(input_img, target_mask, train=False) - - loss_mean = {} - for key, value in batch_loss.items(): - losses[key].append(value) - loss_mean[key] = np.mean(losses[key], axis=0) - - loss_str = " ".join([f"{key}: {value:.2e}" for key, value in loss_mean.items()]) - - pbar.set_postfix_str(loss_str) - - if self.neptune_config is not None: - self.neptune_config['eval/gen_loss'].append(loss_mean['gen']) - self.neptune_config['eval/disc_loss'].append(loss_mean['disc']) - - # apply learning rate decay - if (gen_scheduler is not None) & (dsc_scheduler is not None): - if isinstance(gen_scheduler, ExponentialLR): - if epoch % decay_freq == 0: - gen_scheduler.step() - dsc_scheduler.step() - else: - gen_scheduler.step(loss_mean['gen']) - dsc_scheduler.step(loss_mean['disc']) - - # save checkpoints - if epoch % save_freq == 0: - self.save(epoch) - - return G_loss_ep, D_loss_ep - - def save(self, epoch): - gen_savefile = f'{self.savefolder}/generator_ep_{epoch:03d}.pth' - disc_savefile = f'{self.savefolder}/discriminator_ep_{epoch:03d}.pth' - - print(f"Saving to {gen_savefile} and {disc_savefile}") - torch.save(self.generator.state_dict(), gen_savefile) - torch.save(self.discriminator.state_dict(), disc_savefile) - - def load_last_checkpoint(self): - gen_checkpoints = sorted( - glob.glob(self.savefolder + "generator_ep*.pth")) - disc_checkpoints = sorted( - glob.glob(self.savefolder + "discriminator_ep*.pth")) - - gen_epochs = set([int(ch.split( - '/')[-1].replace('generator_ep_', '')[:-4]) for - ch in gen_checkpoints]) - dsc_epochs = set([int(ch.split( - '/')[-1].replace('discriminator_ep_', '')[:-4]) for - ch in disc_checkpoints]) + epochs = set([ + int(ch.split('/')[-1].replace('patchgan_', '')[:-4]) for ch in checkpoints + ]) try: - assert len(gen_epochs) > 0, "No checkpoints found!" - - start = max(gen_epochs.union(dsc_epochs)) - self.load(f"{self.savefolder}/generator_ep_{start:03d}.pth", - f"{self.savefolder}/discriminator_ep_{start:03d}.pth") - self.start = start + 1 + assert len(epochs) > 0, "No checkpoints found!" + start = max(epochs) + checkpoint_file = os.path.join(checkpoint_path, f"patchgan_{start:03d}.pth") + print(f"Loading from {checkpoint_file}") + return self.load_from_checkpoint(checkpoint_file) except Exception as e: print(e) print("Checkpoints not loaded") - - def load(self, generator_save, discriminator_save): - print(generator_save, discriminator_save) - self.generator.load_state_dict(torch.load(generator_save)) - self.discriminator.load_state_dict(torch.load(discriminator_save)) - - gfname = generator_save.split('/')[-1] - dfname = discriminator_save.split('/')[-1] - print( - f"Loaded checkpoints from {gfname} and {dfname}") - -# custom weights initialization called on generator and discriminator -# scaling here means std - - -def weights_init(net, init_type='normal', scaling=0.02): - """Initialize network weights. - Parameters: - net (network) -- network to be initialized - init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal - init_gain (float) -- scaling factor for normal, xavier and orthogonal. - We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might - work better for some applications. Feel free to try yourself. - """ - def init_func(m): # define the initialization function - classname = m.__class__.__name__ - if hasattr(m, 'weight') and (classname.find('Conv')) != -1: - torch.nn.init.xavier_uniform_(m.weight.data) - # BatchNorm Layer's weight is not a matrix; only normal distribution applies. - elif classname.find('InstanceNorm') != -1: - torch.nn.init.xavier_uniform_(m.weight.data, 1.0) - torch.nn.init.constant_(m.bias.data, 0.0) From 245f0b348d6a69486dd7b269b7b4eb1b79befc63 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Mon, 20 Nov 2023 21:09:57 -0600 Subject: [PATCH 02/17] cleaned up formatting. moving weights_init to unet.py --- patchgan/disc.py | 3 +++ patchgan/unet.py | 29 +++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/patchgan/disc.py b/patchgan/disc.py index 83f48ae..c00b1b9 100644 --- a/patchgan/disc.py +++ b/patchgan/disc.py @@ -1,5 +1,6 @@ from torch import nn from .transfer import Transferable +from .unet import weights_init class Discriminator(nn.Module, Transferable): @@ -46,6 +47,8 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.Insta stride=1, padding=padw), nn.Sigmoid()] self.model = nn.Sequential(*sequence) + self.apply(weights_init) + def forward(self, input): """Standard forward.""" return self.model(input) diff --git a/patchgan/unet.py b/patchgan/unet.py index 4805ce1..7df92c9 100755 --- a/patchgan/unet.py +++ b/patchgan/unet.py @@ -24,8 +24,7 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, use_d (f'DownAct{layer}', activation), ]) if use_dropout: - enc_sub = OrderedDict(chain(enc_sub.items(), - [(f'DownDropout{layer}', nn.Dropout(0.2))])) + enc_sub = OrderedDict(chain(enc_sub.items(), [(f'DownDropout{layer}', nn.Dropout(0.2))])) self.model = nn.Sequential(enc_sub) @@ -61,8 +60,7 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, batch dec_sub = OrderedDict([(f'UpConv{layer}', upconv), (f'UpAct{layer}', activation)]) if use_dropout: - dec_sub = OrderedDict(chain(dec_sub.items(), - [(f'UpDropout{layer}', nn.Dropout(0.2))])) + dec_sub = OrderedDict(chain(dec_sub.items(), [(f'UpDropout{layer}', nn.Dropout(0.2))])) self.model = nn.Sequential(dec_sub) @@ -109,6 +107,8 @@ def __init__(self, input_nc, output_nc, nf=64, self.encoder = nn.ModuleList(encoder_layers) self.decoder = nn.ModuleList(decoder_layers) + self.apply(weights_init) + def forward(self, x, return_hidden=False): xencs = [] @@ -132,3 +132,24 @@ def forward(self, x, return_hidden=False): return x, hidden else: return x + + +# custom weights initialization called on generator and discriminator +# scaling here means std +def weights_init(net, init_type='normal', scaling=0.02): + """Initialize network weights. + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv')) != -1: + torch.nn.init.xavier_uniform_(m.weight.data) + # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + elif classname.find('InstanceNorm') != -1: + torch.nn.init.xavier_uniform_(m.weight.data, 1.0) + torch.nn.init.constant_(m.bias.data, 0.0) From 23fc791ae8686093afb455b0fc397ec7d93e7800 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Mon, 20 Nov 2023 21:13:54 -0600 Subject: [PATCH 03/17] fixing flake issues --- patchgan/train.py | 2 -- patchgan/trainer.py | 6 +----- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/patchgan/train.py b/patchgan/train.py index 00e7d03..24b304e 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -1,7 +1,5 @@ import torch from torchinfo import summary -from patchgan.unet import UNet -from patchgan.disc import Discriminator from patchgan.io import COCOStuffDataset from patchgan.trainer import PatchGAN from torch.utils.data import DataLoader, random_split diff --git a/patchgan/trainer.py b/patchgan/trainer.py index 2a7d6fd..00b5ae0 100644 --- a/patchgan/trainer.py +++ b/patchgan/trainer.py @@ -1,13 +1,9 @@ import torch import os -import tqdm import glob -import numpy as np -from torch import optim -from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau +from torch.optim.lr_scheduler import ExponentialLR from .losses import fc_tversky, bce_loss, MAE_loss from torch.nn.functional import binary_cross_entropy -from collections import defaultdict from .unet import UNet from .disc import Discriminator import lightning as L From fe8f43d07310d432ba691197eeadf9e12d6fca0d Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Thu, 25 Jan 2024 22:49:05 -0800 Subject: [PATCH 04/17] fixing issues for old python version --- patchgan/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/patchgan/trainer.py b/patchgan/trainer.py index 00b5ae0..a7a13b6 100644 --- a/patchgan/trainer.py +++ b/patchgan/trainer.py @@ -6,6 +6,7 @@ from torch.nn.functional import binary_cross_entropy from .unet import UNet from .disc import Discriminator +from typing import Union, Optional import lightning as L @@ -52,9 +53,9 @@ def validation_step(self, batch): for key, val in mean_loss.items(): self.log(key, val, prog_bar=True, on_epoch=True, reduce_fx=torch.mean) - def batch_step(self, batch: torch.Tensor | tuple[torch.Tensor], train: bool, - optimizer_g: torch.optim.Optimizer | None = None, - optimizer_d: torch.optim.Optimizer | None = None): + def batch_step(self, batch: Union[torch.Tensor, tuple[torch.Tensor]], train: bool, + optimizer_g: Optional[torch.optim.Optimizer] = None, + optimizer_d: Optional[torch.optim.Optimizer] = None): input_tensor, target_tensor = batch # train the generator From b0fa99ffd7b0b9f89e976d359bd502019b03bb85 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Thu, 25 Jan 2024 22:49:15 -0800 Subject: [PATCH 05/17] converting to pyproject.toml --- patchgan/__init__.py | 4 ++-- pyproject.toml | 38 ++++++++++++++++++++++++++++++++++++++ setup.py | 44 -------------------------------------------- 3 files changed, 40 insertions(+), 46 deletions(-) create mode 100644 pyproject.toml delete mode 100644 setup.py diff --git a/patchgan/__init__.py b/patchgan/__init__.py index fb52b0a..3fbf9c8 100644 --- a/patchgan/__init__.py +++ b/patchgan/__init__.py @@ -1,8 +1,8 @@ from .unet import UNet from .disc import Discriminator -from .trainer import Trainer +from .trainer import PatchGAN from .version import __version__ __all__ = [ - 'UNet', 'Discriminator', 'Trainer', '__version__' + 'UNet', 'Discriminator', 'PatchGAN', '__version__' ] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f2501c6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ["setuptools >= 61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "patchgan" +dynamic = ["version"] +license = { file = 'LICENSE' } +description = 'patchGAN image segmentation model in PyTorch' +requires-python = ">=3.8" +dependencies = [ + 'numpy>=1.21.0,<1.25.2', + 'torch>=2.1', + 'matplotlib>3.5.0', + 'torchvision>=0.14.0', + 'tqdm>=4.62.3', + 'torchinfo>=1.5.0,', + 'pyyaml', + 'patchify', + 'einops', +] +keywords = ["generative model", "deep learning", "U-Net", "image segmentation"] +authors = [ + { name = 'Ramanakumar Sankar', email = 'ramanakumar.sankar@berkeley.edu' }, + { name = 'Kameswara Mantha', email = 'manth145@umn.edu' }, + { name = 'Lucy Fortson', email = 'lfortson@umn.edu' }, +] +readme = "README.md" + +[tool.setuptools.dynamic] +version = { attr = "patchgan.version.__version__" } + +[project.scripts] +patchgan_train = 'patchgan.train:patchgan_train' +patchgan_infer = 'patchgan.infer:patchgan_infer' + +[project.urls] +repository = "https://www.github.com/ramanakumars/patchGAN" \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 99f8a84..0000000 --- a/setup.py +++ /dev/null @@ -1,44 +0,0 @@ -from setuptools import setup, find_packages -import os - -here = os.path.abspath(os.path.dirname(__file__)) - -try: - with open(os.path.join(here, 'README.md'), 'r') as fh: - long_description = fh.read() -except FileNotFoundError: - long_description = '' - -version = {} -with open(os.path.join(here, 'patchgan/version.py')) as ver_file: - exec(ver_file.read(), version) - -setup( - name='patchGAN', - version=version['__version__'], - description='patchGAN image segmentation model in PyTorch', - long_description=long_description, - long_description_content_type='text/markdown', - license='GNU General Public License v3', - url='https://github.com/ramanakumars/patchGAN', - author='Kameswara Mantha, Ramanakumar Sankar, Lucy Fortson', - author_email='manth145@umn.edu, rsankar@umn.edu, lfortson@umn.edu', - packages=find_packages(), - entry_points={ - 'console_scripts': [ - 'patchgan_train = patchgan.train:patchgan_train', - 'patchgan_infer = patchgan.infer:patchgan_infer' - ] - }, - install_requires=[ - 'numpy>=1.21.0,<1.25.2', - 'torch>=1.13.0', - 'matplotlib>3.5.0', - 'torchvision>=0.14.0', - 'tqdm>=4.62.3', - 'torchinfo>=1.5.0,', - 'pyyaml', - 'patchify', - 'einops' - ] -) From 270df4a46c061ae49b688ff3c0127aa84bb2df67 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 30 Jan 2024 19:33:20 -0800 Subject: [PATCH 06/17] fixing inference script to use torch's patch generation functions --- patchgan/infer.py | 176 +++++++++++++++++++++------------------------- 1 file changed, 80 insertions(+), 96 deletions(-) diff --git a/patchgan/infer.py b/patchgan/infer.py index d8fd15d..aecb8fe 100644 --- a/patchgan/infer.py +++ b/patchgan/infer.py @@ -1,71 +1,80 @@ import torch from torchinfo import summary -from patchgan.unet import UNet -from patchgan.disc import Discriminator -from patchgan.io import COCOStuffDataset +from .io import COCOStuffDataset +from .patchgan import PatchGAN import yaml import tqdm import os -import numpy as np import importlib.machinery import argparse - - -def n_crop(image, size, overlap): - c, height, width = image.shape - - effective_size = int(overlap * size) - - ncropsy = int(np.ceil(height / effective_size)) - ncropsx = int(np.ceil(width / effective_size)) - - crops = torch.zeros((ncropsx * ncropsy, c, size, size), device=image.device) - - for j in range(ncropsy): - for i in range(ncropsx): - starty = j * effective_size - startx = i * effective_size - - starty -= max([starty + size - height, 0]) - startx -= max([startx + size - width, 0]) - - crops[j * ncropsy + i, :] = image[:, starty:starty + size, startx:startx + size] - - return crops - - -def build_mask(masks, crop_size, image_size, threshold, overlap): - n, c, height, width = masks.shape - image_height, image_width = image_size - mask = np.zeros((c, *image_size)) - count = np.zeros((c, *image_size)) - - effective_size = int(overlap * crop_size) - - ncropsy = int(np.ceil(image_height / effective_size)) - ncropsx = int(np.ceil(image_width / effective_size)) - - for j in range(ncropsy): - for i in range(ncropsx): - starty = j * effective_size - startx = i * effective_size - starty -= max([starty + crop_size - image_height, 0]) - startx -= max([startx + crop_size - image_width, 0]) - endy = starty + crop_size - endx = startx + crop_size - - mask[:, starty:endy, startx:endx] += masks[j * ncropsy + i, :] - count[:, starty:endy, startx:endx] += 1 - mask = mask / count - - if threshold > 0: - mask[mask >= threshold] = 1 - mask[mask < threshold] = 0 - - if c > 1: - return np.argmax(mask, axis=0) - else: - return mask[0] +from torch import nn +from einops import rearrange +import lightning as L + + +class InferenceModel(L.LightningModule): + ''' + Wrapper for running PatchGAN with a crop inference mode, + where the input images are cropped with overlap into (patch_size x patch_size) + ''' + + def __init__(self, model: PatchGAN, patch_size: int): + super().__init__() + self.save_hyperparameters() + self.model = model + + def forward(self, img: torch.Tensor) -> torch.Tensor: + ''' + run the image forward through the patch generation stage + and then through the patchGAN + ''' + C, H, W = img.shape + + assert H == W, "PatchGAN currently only supports square images" + image_size = H + + patch_size = self.hparams.patch_size + + # find the optimal stride + # this stride should cover the whole image with minimal overlap + # essentially solving (n - 1) * (kernel_size + 1) >= image_size + # for the n = number of overlapping patches in each dimension + for i in range(2, 10): + n = (image_size - patch_size) // patch_size + i + stride = (image_size - patch_size) // (n - 1) + if stride * (n - 1) + patch_size == image_size: + break + + # check to make sure we got convergence + if stride * (n - 1) + patch_size != image_size: + raise ValueError(f"Could fit {image_size} into window of size {patch_size}") + + # create the parameter dict for torch's fold and unfold functions + fold_params = {'kernel_size': patch_size, 'stride': stride, 'dilation': 1, 'padding': 0} + + # we will also apply this to a ones array to get the number of patches that cover each pixel + # we will divide the final mask by this count to normalize the number of predictions per pixel + count = self.fold(self.unfold(torch.ones_like(img), fold_params), fold_params, image_size) + masks = self.fold(self.model(self.unfold(img, fold_params)), fold_params, image_size) + + return masks / count + + def fold(self, x, fold_params, image_size): + ''' + Folding function. Given an input of (l, channels, patch_size, patch_size), returns the + reconstructed image of (channels, image_size, image_size) + ''' + x = rearrange(x, 'l c h w -> (c h w) l') + return nn.functional.fold(x, output_size=(image_size, image_size), **fold_params) + + def unfold(self, x, fold_params): + ''' + Unfolding function. Given an input of (channels, image_size, image_size) returns the set + of overlapping patches of size (l, channels, patch_size, patch_size) + ''' + x = nn.functional.unfold(x, **fold_params) + return rearrange(x, '(c h w) l -> l c h w', c=self.model.hparams.input_channels, + h=self.hparams.patch_size, w=self.hparams.patch_size) def patchgan_infer(): @@ -94,14 +103,12 @@ def patchgan_infer(): dataset_params = config['dataset'] dataset_path = dataset_params['dataset_path'] - size = dataset_params.get('size', 256) + patch_size = dataset_params.get('patch_size', 256) dataset_kwargs = {} if dataset_params['type'] == 'COCOStuff': Dataset = COCOStuffDataset - in_channels = 3 labels = dataset_params.get('labels', [1]) - out_channels = len(labels) dataset_kwargs['labels'] = labels else: try: @@ -113,8 +120,6 @@ def patchgan_infer(): except (ImportError, ModuleNotFoundError): print(f"io.py does not contain {dataset_params['type']}") raise - in_channels = dataset_params.get('in_channels', 3) - out_channels = dataset_params.get('out_channels', 1) assert hasattr(Dataset, 'get_filename') and callable(Dataset.get_filename), \ f"Dataset class {Dataset.__name__} must have the get_filename method which returns the image filename for a given index" @@ -124,26 +129,10 @@ def patchgan_infer(): datagen = Dataset(dataset_path, **dataset_kwargs) - model_params = config['model_params'] - gen_filts = model_params['gen_filts'] - disc_filts = model_params['disc_filts'] - n_disc_layers = model_params['n_disc_layers'] - activation = model_params['activation'] - final_activation = model_params.get('final_activation', 'sigmoid') + model_checkpoint = config['model_checkpoint'] - # create the generator - generator = UNet(in_channels, out_channels, gen_filts, activation=activation, final_act=final_activation).to(device) - - # create the discriminator - discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) - - if args.summary: - summary(generator, [1, in_channels, size, size], device=device) - summary(discriminator, [1, in_channels + out_channels, size, size], device=device) - - checkpoint_paths = config['checkpoint_paths'] - gen_checkpoint = checkpoint_paths['generator'] - dsc_checkpoint = checkpoint_paths['discriminator'] + # create the patchGAN + model = PatchGAN.load_from_checkpoint(model_checkpoint) infer_params = config.get('infer_params', {}) output_path = infer_params.get('output_path', 'predictions/') @@ -152,23 +141,18 @@ def patchgan_infer(): os.makedirs(output_path) print(f"Created folder {output_path}") - generator.eval() - discriminator.eval() + model.eval() - generator.load_state_dict(torch.load(gen_checkpoint, map_location=device)) - discriminator.load_state_dict(torch.load(dsc_checkpoint, map_location=device)) + inferencemodel = InferenceModel(model, patch_size) - threshold = infer_params.get('threshold', 0) - overlap = infer_params.get('overlap', 0.9) + if args.summary: + summary(inferencemodel, datagen[0].shape, device=device) for i, data in enumerate(tqdm.tqdm(datagen, desc='Predicting', dynamic_ncols=True, ascii=True)): - imgs = n_crop(data, size, overlap) out_fname, _ = os.path.splitext(datagen.get_filename(i)) with torch.no_grad(): - img_tensor = torch.Tensor(imgs).to(device) - masks = generator(img_tensor).cpu().numpy() - - mask = build_mask(masks, size, data.shape[1:], threshold, overlap) + img_tensor = torch.Tensor(data).to(device) + mask = inferencemodel(img_tensor).cpu().numpy() Dataset.save_mask(mask, output_path, out_fname) From 01312f83b5a7d2c6ab4330fa64e4e61833b4bd68 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 30 Jan 2024 19:33:54 -0800 Subject: [PATCH 07/17] moving patchGAN from trainer.py to patchgan.py --- patchgan/__init__.py | 2 +- patchgan/patchgan.py | 137 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 patchgan/patchgan.py diff --git a/patchgan/__init__.py b/patchgan/__init__.py index 3fbf9c8..f4e08d6 100644 --- a/patchgan/__init__.py +++ b/patchgan/__init__.py @@ -1,6 +1,6 @@ from .unet import UNet from .disc import Discriminator -from .trainer import PatchGAN +from .patchgan import PatchGAN from .version import __version__ __all__ = [ diff --git a/patchgan/patchgan.py b/patchgan/patchgan.py new file mode 100644 index 0000000..5e0c357 --- /dev/null +++ b/patchgan/patchgan.py @@ -0,0 +1,137 @@ +import torch +from torch.optim.lr_scheduler import ExponentialLR +from .losses import fc_tversky, bce_loss, MAE_loss +from torch.nn.functional import binary_cross_entropy +from .unet import UNet +from .disc import Discriminator +from typing import Union, Optional +import lightning as L + + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + + +class PatchGAN(L.LightningModule): + def __init__(self, input_channels: int, output_channels: int, gen_filts: int, disc_filts: int, final_activation: str, + n_disc_layers: int = 5, use_gen_dropout: bool = True, gen_activation: str = 'leakyrelu', + disc_norm: bool = False, gen_lr: float = 1.e-3, dsc_lr: float = 1.e-3, lr_decay: float = 0.98, + decay_freq: int = 5, adam_b1: float = 0.5, adam_b2: float = 0.999, seg_alpha: float = 200, + loss_type: str = 'tversky', tversky_beta: float = 0.75, tversky_gamma: float = 0.75): + super().__init__() + self.save_hyperparameters() + self.automatic_optimization = False + + self.generator = UNet(input_channels, output_channels, gen_filts, use_dropout=use_gen_dropout, + activation=gen_activation, final_act=final_activation) + self.discriminator = Discriminator(input_channels + output_channels, disc_filts, + norm=disc_norm, n_layers=n_disc_layers) + + def forward(self, img, return_hidden=False): + return self.generator(img, return_hidden) + + def training_step(self, batch): + ''' + Train the generator and discriminator on a single batch + ''' + optimizer_g, optimizer_d = self.optimizers() + + mean_loss = self.batch_step(batch, True, optimizer_g, optimizer_d) + + sch_g, sch_d = self.lr_schedulers() + if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % self.hparams.decay_freq == 0: + sch_g.step() + sch_d.step() + + for key, val in mean_loss.items(): + self.log(key, val, prog_bar=True, on_epoch=True, reduce_fx=torch.mean) + + def validation_step(self, batch): + mean_loss = self.batch_step(batch, False) + + for key, val in mean_loss.items(): + self.log(key, val, prog_bar=True, on_epoch=True, reduce_fx=torch.mean) + + def batch_step(self, batch: Union[torch.Tensor, tuple[torch.Tensor]], train: bool, + optimizer_g: Optional[torch.optim.Optimizer] = None, + optimizer_d: Optional[torch.optim.Optimizer] = None): + input_tensor, target_tensor = batch + + # train the generator + gen_img = self.generator(input_tensor) + + disc_inp_fake = torch.cat((input_tensor, gen_img), 1) + disc_fake = self.discriminator(disc_inp_fake) + + labels_real = torch.full(disc_fake.shape, 1, dtype=torch.float, device=device) + labels_fake = torch.full(disc_fake.shape, 0, dtype=torch.float, device=device) + + if self.hparams.loss_type == 'tversky': + gen_loss = fc_tversky(target_tensor, gen_img, + beta=self.hparams.tversky_beta, + gamma=self.hparams.tversky_gamma) * self.hparams.seg_alpha + elif self.hparams.loss_type == 'weighted_bce': + if gen_img.shape[1] > 1: + weight = 1 - torch.sum(target_tensor, dim=(2, 3), keepdim=True) / torch.sum(target_tensor) + else: + weight = torch.ones_like(target_tensor) + gen_loss = binary_cross_entropy(gen_img, target_tensor, weight=weight) * self.hparams.seg_alpha + elif self.hparams.loss_type == 'MAE': + gen_loss = MAE_loss(gen_img, target_tensor) * self.hparams.seg_alpha + + gen_loss_disc = bce_loss(disc_fake, labels_real) + gen_loss = gen_loss + gen_loss_disc + + if train: + self.toggle_optimizer(optimizer_g) + optimizer_g.zero_grad() + self.manual_backward(gen_loss) + optimizer_g.step() + self.untoggle_optimizer(optimizer_g) + + # Train the discriminator + if train: + self.toggle_optimizer(optimizer_d) + + disc_inp_real = torch.cat((input_tensor, target_tensor), 1) + disc_real = self.discriminator(disc_inp_real) + disc_inp_fake = torch.cat((input_tensor, gen_img.detach()), 1) + disc_fake = self.discriminator(disc_inp_fake) + + loss_real = bce_loss(disc_real, labels_real.detach()) + loss_fake = bce_loss(disc_fake, labels_fake) + disc_loss = (loss_fake + loss_real) / 2. + + if train: + optimizer_d.zero_grad() + self.manual_backward(disc_loss) + optimizer_d.step() + self.untoggle_optimizer(optimizer_d) + + keys = ['gen', 'gen_loss', 'gdisc', 'discr', 'discf', 'disc'] + mean_loss_i = [gen_loss.item(), gen_loss.item(), gen_loss_disc.item(), + loss_real.item(), loss_fake.item(), disc_loss.item()] + + loss = {key: val for key, val in zip(keys, mean_loss_i)} + + return loss + + def configure_optimizers(self): + gen_lr = self.hparams.gen_lr + dsc_lr = self.hparams.dsc_lr + + opt_g = torch.optim.Adam(self.generator.parameters(), lr=gen_lr, betas=(self.hparams.adam_b1, self.hparams.adam_b2)) + opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=dsc_lr, betas=(self.hparams.adam_b1, self.hparams.adam_b2)) + + gen_lr_scheduler = ExponentialLR(opt_g, gamma=self.hparams.lr_decay) + dsc_lr_scheduler = ExponentialLR(opt_d, gamma=self.hparams.lr_decay) + + gen_lr_scheduler_config = {"scheduler": gen_lr_scheduler, + "interval": "epoch", + "frequency": self.hparams.decay_freq} + + dsc_lr_scheduler_config = {"scheduler": dsc_lr_scheduler, + "interval": "epoch", + "frequency": self.hparams.decay_freq} + + return [{"optimizer": opt_g, "lr_scheduler": gen_lr_scheduler_config}, + {"optimizer": opt_d, "lr_scheduler": dsc_lr_scheduler_config}] From 4973efb5f33ea134395f6ed7e3dfde9228289112 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 30 Jan 2024 19:34:13 -0800 Subject: [PATCH 08/17] fixing training script to have better transfer learning capabilities --- patchgan/train.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/patchgan/train.py b/patchgan/train.py index 24b304e..92482ea 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -1,7 +1,8 @@ import torch from torchinfo import summary -from patchgan.io import COCOStuffDataset -from patchgan.trainer import PatchGAN +from .io import COCOStuffDataset +from .patchgan import PatchGAN +import os from torch.utils.data import DataLoader, random_split from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor @@ -84,8 +85,9 @@ def patchgan_train(): checkpoint_path = config.get('checkpoint_path', './checkpoints/') model = None - if config.get('load_last_checkpoint', False): - model = PatchGAN.load_last_checkpoint(checkpoint_path) + checkpoint_file = config.get('load_from_checkpoint', '') + if os.path.isfile(checkpoint_file): + model = PatchGAN.load_from_checkpoint(checkpoint_file) if model is None: model_params = config['model_params'] @@ -116,8 +118,8 @@ def patchgan_train(): if config.get('transfer_learn', {}).get('checkpoint', None) is not None: checkpoint = torch.load(config['transfer_learn']['checkpoint'], map_location=device) - model.generator.load_transfer_data({key: value for key, value in checkpoint['state_dict'].items() if 'generator' in key}) - model.discriminator.load_transfer_data({key: value for key, value in checkpoint['state_dict'].items() if 'discriminator' in key}) + model.generator.load_transfer_data({key.replace('PatchGAN.', ''): value for key, value in checkpoint['state_dict'].items() if 'generator' in key}) + model.discriminator.load_transfer_data({key.replace('PatchGAN.', ''): value for key, value in checkpoint['state_dict'].items() if 'discriminator' in key}) if args.summary: summary(model.generator, [1, in_channels, size, size], depth=4) @@ -131,4 +133,7 @@ def patchgan_train(): lr_monitor = LearningRateMonitor(logging_interval='epoch') trainer = Trainer(accelerator=device, max_epochs=args.n_epochs, callbacks=[checkpoint_callback, lr_monitor]) - trainer.fit(model, train_data, val_data) + if os.path.isfile(checkpoint_file): + trainer.fit(model, train_data, val_data, ckpt_path=checkpoint_file) + else: + trainer.fit(model, train_data, val_data) From ce4d6464d9052a663a9ccab08d3ada78cbdba2cd Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 30 Jan 2024 19:34:27 -0800 Subject: [PATCH 09/17] removing trainer.py --- patchgan/trainer.py | 156 -------------------------------------------- 1 file changed, 156 deletions(-) delete mode 100644 patchgan/trainer.py diff --git a/patchgan/trainer.py b/patchgan/trainer.py deleted file mode 100644 index a7a13b6..0000000 --- a/patchgan/trainer.py +++ /dev/null @@ -1,156 +0,0 @@ -import torch -import os -import glob -from torch.optim.lr_scheduler import ExponentialLR -from .losses import fc_tversky, bce_loss, MAE_loss -from torch.nn.functional import binary_cross_entropy -from .unet import UNet -from .disc import Discriminator -from typing import Union, Optional -import lightning as L - - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - - -class PatchGAN(L.LightningModule): - def __init__(self, input_channels: int, output_channels: int, gen_filts: int, disc_filts: int, final_activation: str, - n_disc_layers: int = 5, use_gen_dropout: bool = True, gen_activation: str = 'leakyrelu', - disc_norm: bool = False, gen_lr: float = 1.e-3, dsc_lr: float = 1.e-3, lr_decay: float = 0.98, - decay_freq: int = 5, adam_b1: float = 0.5, adam_b2: float = 0.999, seg_alpha: float = 200, - loss_type: str = 'tversky', tversky_beta: float = 0.75, tversky_gamma: float = 0.75): - super().__init__() - self.save_hyperparameters() - self.automatic_optimization = False - - self.generator = UNet(input_channels, output_channels, gen_filts, use_dropout=use_gen_dropout, - activation=gen_activation, final_act=final_activation) - self.discriminator = Discriminator(input_channels + output_channels, disc_filts, - norm=disc_norm, n_layers=n_disc_layers) - - def forward(self, img, return_hidden=False): - return self.generator(img, return_hidden) - - def training_step(self, batch): - ''' - Train the generator and discriminator on a single batch - ''' - optimizer_g, optimizer_d = self.optimizers() - - mean_loss = self.batch_step(batch, True, optimizer_g, optimizer_d) - - sch_g, sch_d = self.lr_schedulers() - if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % self.hparams.decay_freq == 0: - sch_g.step() - sch_d.step() - - for key, val in mean_loss.items(): - self.log(key, val, prog_bar=True, on_epoch=True, reduce_fx=torch.mean) - - def validation_step(self, batch): - mean_loss = self.batch_step(batch, False) - - for key, val in mean_loss.items(): - self.log(key, val, prog_bar=True, on_epoch=True, reduce_fx=torch.mean) - - def batch_step(self, batch: Union[torch.Tensor, tuple[torch.Tensor]], train: bool, - optimizer_g: Optional[torch.optim.Optimizer] = None, - optimizer_d: Optional[torch.optim.Optimizer] = None): - input_tensor, target_tensor = batch - - # train the generator - gen_img = self.generator(input_tensor) - - disc_inp_fake = torch.cat((input_tensor, gen_img), 1) - disc_fake = self.discriminator(disc_inp_fake) - - labels_real = torch.full(disc_fake.shape, 1, dtype=torch.float, device=device) - labels_fake = torch.full(disc_fake.shape, 0, dtype=torch.float, device=device) - - if self.hparams.loss_type == 'tversky': - gen_loss = fc_tversky(target_tensor, gen_img, - beta=self.hparams.tversky_beta, - gamma=self.hparams.tversky_gamma) * self.hparams.seg_alpha - elif self.hparams.loss_type == 'weighted_bce': - if gen_img.shape[1] > 1: - weight = 1 - torch.sum(target_tensor, dim=(2, 3), keepdim=True) / torch.sum(target_tensor) - else: - weight = torch.ones_like(target_tensor) - gen_loss = binary_cross_entropy(gen_img, target_tensor, weight=weight) * self.hparams.seg_alpha - elif self.hparams.loss_type == 'MAE': - gen_loss = MAE_loss(gen_img, target_tensor) * self.hparams.seg_alpha - - gen_loss_disc = bce_loss(disc_fake, labels_real) - gen_loss = gen_loss + gen_loss_disc - - if train: - self.toggle_optimizer(optimizer_g) - optimizer_g.zero_grad() - self.manual_backward(gen_loss) - optimizer_g.step() - self.untoggle_optimizer(optimizer_g) - - # Train the discriminator - if train: - self.toggle_optimizer(optimizer_d) - - disc_inp_real = torch.cat((input_tensor, target_tensor), 1) - disc_real = self.discriminator(disc_inp_real) - disc_inp_fake = torch.cat((input_tensor, gen_img.detach()), 1) - disc_fake = self.discriminator(disc_inp_fake) - - loss_real = bce_loss(disc_real, labels_real.detach()) - loss_fake = bce_loss(disc_fake, labels_fake) - disc_loss = (loss_fake + loss_real) / 2. - - if train: - optimizer_d.zero_grad() - self.manual_backward(disc_loss) - optimizer_d.step() - self.untoggle_optimizer(optimizer_d) - - keys = ['gen', 'gen_loss', 'gdisc', 'discr', 'discf', 'disc'] - mean_loss_i = [gen_loss.item(), gen_loss.item(), gen_loss_disc.item(), - loss_real.item(), loss_fake.item(), disc_loss.item()] - - loss = {key: val for key, val in zip(keys, mean_loss_i)} - - return loss - - def configure_optimizers(self): - gen_lr = self.hparams.gen_lr - dsc_lr = self.hparams.dsc_lr - - opt_g = torch.optim.Adam(self.generator.parameters(), lr=gen_lr, betas=(self.hparams.adam_b1, self.hparams.adam_b2)) - opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=dsc_lr, betas=(self.hparams.adam_b1, self.hparams.adam_b2)) - - gen_lr_scheduler = ExponentialLR(opt_g, gamma=self.hparams.lr_decay) - dsc_lr_scheduler = ExponentialLR(opt_d, gamma=self.hparams.lr_decay) - - gen_lr_scheduler_config = {"scheduler": gen_lr_scheduler, - "interval": "epoch", - "frequency": self.hparams.decay_freq} - - dsc_lr_scheduler_config = {"scheduler": dsc_lr_scheduler, - "interval": "epoch", - "frequency": self.hparams.decay_freq} - - return [{"optimizer": opt_g, "lr_scheduler": gen_lr_scheduler_config}, - {"optimizer": opt_d, "lr_scheduler": dsc_lr_scheduler_config}] - - def load_last_checkpoint(self, checkpoint_path): - checkpoints = sorted(glob.glob(os.path.join(checkpoint_path, "patchgan_*.pth"))) - - epochs = set([ - int(ch.split('/')[-1].replace('patchgan_', '')[:-4]) for ch in checkpoints - ]) - - try: - assert len(epochs) > 0, "No checkpoints found!" - start = max(epochs) - checkpoint_file = os.path.join(checkpoint_path, f"patchgan_{start:03d}.pth") - print(f"Loading from {checkpoint_file}") - return self.load_from_checkpoint(checkpoint_file) - except Exception as e: - print(e) - print("Checkpoints not loaded") From 1caecc4bbd36618ac265510c9dd3e65928610000 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Wed, 31 Jan 2024 22:51:14 -0800 Subject: [PATCH 10/17] refactored code so that patchgan architecture is mutable to conditioning --- patchgan/__init__.py | 4 +- patchgan/{unet.py => conv_layers.py} | 109 +++++++++++---------------- patchgan/disc.py | 2 +- patchgan/patchgan.py | 101 ++++++++++++++++++++++--- patchgan/point_encoder.py | 12 +++ 5 files changed, 150 insertions(+), 78 deletions(-) rename patchgan/{unet.py => conv_layers.py} (74%) mode change 100755 => 100644 create mode 100644 patchgan/point_encoder.py diff --git a/patchgan/__init__.py b/patchgan/__init__.py index f4e08d6..93e8d1a 100644 --- a/patchgan/__init__.py +++ b/patchgan/__init__.py @@ -1,8 +1,6 @@ -from .unet import UNet -from .disc import Discriminator from .patchgan import PatchGAN from .version import __version__ __all__ = [ - 'UNet', 'Discriminator', 'PatchGAN', '__version__' + 'PatchGAN', '__version__' ] diff --git a/patchgan/unet.py b/patchgan/conv_layers.py old mode 100755 new mode 100644 similarity index 74% rename from patchgan/unet.py rename to patchgan/conv_layers.py index 7df92c9..9acea18 --- a/patchgan/unet.py +++ b/patchgan/conv_layers.py @@ -2,7 +2,6 @@ from torch import nn from collections import OrderedDict from itertools import chain -from .transfer import Transferable class DownSampleBlock(nn.Module): @@ -70,86 +69,68 @@ def forward(self, x): return x -class UNet(nn.Module, Transferable): - def __init__(self, input_nc, output_nc, nf=64, - norm_layer=nn.InstanceNorm2d, use_dropout=False, - activation='tanh', final_act='softmax'): - super(UNet, self).__init__() +# custom weights initialization called on generator and discriminator +# scaling here means std +def weights_init(net, init_type='normal', scaling=0.02): + """Initialize network weights. + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv')) != -1: + torch.nn.init.xavier_uniform_(m.weight.data) + # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + elif classname.find('InstanceNorm') != -1: + torch.nn.init.xavier_uniform_(m.weight.data, 1.0) + torch.nn.init.constant_(m.bias.data, 0.0) + +class Encoder(nn.ModuleList): + def __init__(self, input_channels: int, gen_filts: int, gen_activation: str, use_gen_dropout: bool): kernel_size = 4 padding = 1 - conv_filts = [nf, nf * 2, nf * 4, nf * 8, nf * 8, nf * 8, nf * 8] + conv_filts = [gen_filts, gen_filts * 2, gen_filts * 4, gen_filts * 8, gen_filts * 8, gen_filts * 8, gen_filts * 8] encoder_layers = [] - prev_filt = input_nc + prev_filt = input_channels for i, filt in enumerate(conv_filts): - encoder_layers.append(DownSampleBlock(prev_filt, filt, activation, norm_layer, layer=i, - use_dropout=use_dropout, kernel_size=kernel_size, stride=2, - padding=padding, bias=False)) - prev_filt = filt - - decoder_layers = [] - for i, filt in enumerate(conv_filts[:-1][::-1]): - if i == 0: - decoder_layers.append(UpSampleBlock(prev_filt, filt, activation, norm_layer, layer=i, batch_norm=False, - kernel_size=kernel_size, stride=2, padding=padding, bias=False)) - else: - decoder_layers.append(UpSampleBlock(prev_filt * 2, filt, activation, norm_layer, layer=i, use_dropout=use_dropout, - batch_norm=True, kernel_size=kernel_size, stride=2, padding=padding, bias=False)) - + encoder_layers.append(DownSampleBlock(prev_filt, filt, gen_activation, nn.InstanceNorm2d, layer=i, + use_dropout=use_gen_dropout, kernel_size=kernel_size, stride=2, + padding=padding)) prev_filt = filt - decoder_layers.append(UpSampleBlock(nf * 2, output_nc, final_act, norm_layer, layer=i + 1, batch_norm=False, - kernel_size=kernel_size, stride=2, padding=padding, bias=False)) - - self.encoder = nn.ModuleList(encoder_layers) - self.decoder = nn.ModuleList(decoder_layers) - + super().__init__(encoder_layers) self.apply(weights_init) - def forward(self, x, return_hidden=False): - xencs = [] - for i, layer in enumerate(self.encoder): - x = layer(x) - xencs.append(x) - - hidden = xencs[-1] +class Decoder(nn.ModuleList): + def __init__(self, output_channels: int, gen_filts: int, gen_activation: str, final_activation: str, use_gen_dropout: bool): + kernel_size = 4 + padding = 1 - xencs = xencs[::-1] + conv_filts = [gen_filts, gen_filts * 2, gen_filts * 4, gen_filts * 8, gen_filts * 8, gen_filts * 8, gen_filts * 8] - for i, layer in enumerate(self.decoder): + prev_filt = conv_filts[-1] + decoder_layers = [] + for i, filt in enumerate(conv_filts[:-1][::-1]): if i == 0: - xinp = hidden + decoder_layers.append(UpSampleBlock(prev_filt, filt, gen_activation, nn.InstanceNorm2d, layer=i, batch_norm=False, + kernel_size=kernel_size, stride=2, padding=padding)) else: - xinp = torch.cat([x, xencs[i]], dim=1) - - x = layer(xinp) + decoder_layers.append(UpSampleBlock(prev_filt * 2, filt, gen_activation, nn.InstanceNorm2d, layer=i, use_dropout=use_gen_dropout, + batch_norm=True, kernel_size=kernel_size, stride=2, padding=padding)) - if return_hidden: - return x, hidden - else: - return x + prev_filt = filt + decoder_layers.append(UpSampleBlock(gen_filts * 2, output_channels, final_activation, nn.InstanceNorm2d, layer=i + 1, batch_norm=False, + kernel_size=kernel_size, stride=2, padding=padding)) -# custom weights initialization called on generator and discriminator -# scaling here means std -def weights_init(net, init_type='normal', scaling=0.02): - """Initialize network weights. - Parameters: - net (network) -- network to be initialized - init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal - init_gain (float) -- scaling factor for normal, xavier and orthogonal. - We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might - work better for some applications. Feel free to try yourself. - """ - def init_func(m): # define the initialization function - classname = m.__class__.__name__ - if hasattr(m, 'weight') and (classname.find('Conv')) != -1: - torch.nn.init.xavier_uniform_(m.weight.data) - # BatchNorm Layer's weight is not a matrix; only normal distribution applies. - elif classname.find('InstanceNorm') != -1: - torch.nn.init.xavier_uniform_(m.weight.data, 1.0) - torch.nn.init.constant_(m.bias.data, 0.0) + super().__init__(decoder_layers) + self.apply(weights_init) diff --git a/patchgan/disc.py b/patchgan/disc.py index c00b1b9..d8d15ef 100644 --- a/patchgan/disc.py +++ b/patchgan/disc.py @@ -1,6 +1,6 @@ from torch import nn from .transfer import Transferable -from .unet import weights_init +from .conv_layers import weights_init class Discriminator(nn.Module, Transferable): diff --git a/patchgan/patchgan.py b/patchgan/patchgan.py index 5e0c357..06944b7 100644 --- a/patchgan/patchgan.py +++ b/patchgan/patchgan.py @@ -1,9 +1,12 @@ import torch +from typing import Iterable +from torch import nn from torch.optim.lr_scheduler import ExponentialLR from .losses import fc_tversky, bce_loss, MAE_loss from torch.nn.functional import binary_cross_entropy -from .unet import UNet from .disc import Discriminator +from .conv_layers import Encoder, Decoder +from .point_encoder import PointEncoder from typing import Union, Optional import lightning as L @@ -21,13 +24,44 @@ def __init__(self, input_channels: int, output_channels: int, gen_filts: int, di self.save_hyperparameters() self.automatic_optimization = False - self.generator = UNet(input_channels, output_channels, gen_filts, use_dropout=use_gen_dropout, - activation=gen_activation, final_act=final_activation) + self.encoder = Encoder(input_channels, gen_filts, gen_activation, use_gen_dropout) + self.decoder = Decoder(output_channels, gen_filts, gen_activation, final_activation, use_gen_dropout) + self.discriminator = Discriminator(input_channels + output_channels, disc_filts, norm=disc_norm, n_layers=n_disc_layers) - def forward(self, img, return_hidden=False): - return self.generator(img, return_hidden) + @classmethod + def load_transfer_data(cls, checkpoint_path: str, input_channels: int, output_channels: int): + checkpoint = torch.load(checkpoint_path) + model_kwargs = checkpoint['hyperparameters'] + model_kwargs['input_channels'] = input_channels + model_kwargs['output_channels'] = output_channels + obj = cls(**model_kwargs) + + raise ValueError("weights loading not implemented!") + + return obj + + def forward(self, x): + xencs = [] + + for i, layer in enumerate(self.encoder): + x = layer(x) + xencs.append(x) + + hidden = xencs[-1] + + xencs = xencs[::-1] + + for i, layer in enumerate(self.decoder): + if i == 0: + xinp = hidden + else: + xinp = torch.cat([x, xencs[i]], dim=1) + + x = layer(xinp) + + return x def training_step(self, batch): ''' @@ -51,13 +85,15 @@ def validation_step(self, batch): for key, val in mean_loss.items(): self.log(key, val, prog_bar=True, on_epoch=True, reduce_fx=torch.mean) + def forward_batch(self, batch): + input_tensor, target_tensor = batch + return self(input_tensor), input_tensor, target_tensor + def batch_step(self, batch: Union[torch.Tensor, tuple[torch.Tensor]], train: bool, optimizer_g: Optional[torch.optim.Optimizer] = None, optimizer_d: Optional[torch.optim.Optimizer] = None): - input_tensor, target_tensor = batch - # train the generator - gen_img = self.generator(input_tensor) + gen_img, input_tensor, target_tensor = self.forward_batch(batch) disc_inp_fake = torch.cat((input_tensor, gen_img), 1) disc_fake = self.discriminator(disc_inp_fake) @@ -115,12 +151,17 @@ def batch_step(self, batch: Union[torch.Tensor, tuple[torch.Tensor]], train: boo return loss + def get_parameters(self) -> tuple[Iterable[nn.Parameter]]: + return list(self.encoder.parameters()) + list(self.decoder.parameters()), self.discriminator.parameters() + def configure_optimizers(self): gen_lr = self.hparams.gen_lr dsc_lr = self.hparams.dsc_lr - opt_g = torch.optim.Adam(self.generator.parameters(), lr=gen_lr, betas=(self.hparams.adam_b1, self.hparams.adam_b2)) - opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=dsc_lr, betas=(self.hparams.adam_b1, self.hparams.adam_b2)) + generator_params, discriminator_params = self.get_parameters() + + opt_g = torch.optim.Adam(generator_params, lr=gen_lr, betas=(self.hparams.adam_b1, self.hparams.adam_b2)) + opt_d = torch.optim.Adam(discriminator_params, lr=dsc_lr, betas=(self.hparams.adam_b1, self.hparams.adam_b2)) gen_lr_scheduler = ExponentialLR(opt_g, gamma=self.hparams.lr_decay) dsc_lr_scheduler = ExponentialLR(opt_d, gamma=self.hparams.lr_decay) @@ -135,3 +176,43 @@ def configure_optimizers(self): return [{"optimizer": opt_g, "lr_scheduler": gen_lr_scheduler_config}, {"optimizer": opt_d, "lr_scheduler": dsc_lr_scheduler_config}] + + +class PatchGANPoint(PatchGAN): + def __init__(self, *patchgan_args, **patchgan_kwargs): + super().__init__(*patchgan_args, **patchgan_kwargs) + + # create the point encoder to attach to the latent block + # by default use the final number of filters in the UNet (hard-coded to gen_filts * 8) + self.point_encoder = PointEncoder(self.hparams.gen_filts * 8) + + def forward(self, x, point): + xencs = [] + + for i, layer in enumerate(self.encoder): + x = layer(x) + xencs.append(x) + + hidden = xencs[-1] + + _, c, h, w = hidden.shape + + z_point = self.point_encoder(point).unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w) + + hidden = hidden + z_point + + xencs = xencs[::-1] + + for i, layer in enumerate(self.decoder): + if i == 0: + xinp = hidden + else: + xinp = torch.cat([x, xencs[i]], dim=1) + + x = layer(xinp) + + return x + + def forward_batch(self, batch): + input_tensor, point, target_tensor = batch + return self(input_tensor, point), input_tensor, target_tensor diff --git a/patchgan/point_encoder.py b/patchgan/point_encoder.py new file mode 100644 index 0000000..5e0da40 --- /dev/null +++ b/patchgan/point_encoder.py @@ -0,0 +1,12 @@ +from torch import nn + + +class PointEncoder(nn.Sequential): + def __init__(self, filt): + super().__init__( + nn.Linear(2, 16), + nn.LeakyReLU(0.2, True), + nn.LayerNorm(16), + nn.Linear(16, filt), + nn.LeakyReLU(0.2, True), + ) From 34ebeb11f41fb308032a55db9961350e284d6eb9 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Wed, 31 Jan 2024 22:51:28 -0800 Subject: [PATCH 11/17] removed default true for inference summary --- patchgan/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/patchgan/infer.py b/patchgan/infer.py index aecb8fe..13d1540 100644 --- a/patchgan/infer.py +++ b/patchgan/infer.py @@ -86,7 +86,7 @@ def patchgan_infer(): parser.add_argument('-c', '--config_file', required=True, type=str, help='Location of the config YAML file') parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)') parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)') - parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models") + parser.add_argument('--summary', action='store_true', help="Print summary of the models") args = parser.parse_args() From fb5d42ccfe01d7d2f5739f91eec54c6bd3298c33 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Wed, 31 Jan 2024 22:51:46 -0800 Subject: [PATCH 12/17] added point conditioning training to patchgan --- patchgan/io.py | 37 ++++++++++++++++++++++++++++++++++--- patchgan/train.py | 43 ++++++++++++++++++++++++++++++------------- 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/patchgan/io.py b/patchgan/io.py index c49fa4f..932ea10 100644 --- a/patchgan/io.py +++ b/patchgan/io.py @@ -10,7 +10,7 @@ class COCOStuffDataset(Dataset): augmentation = None - def __init__(self, imgfolder, maskfolder, labels=[1], size=256, augmentation='resize'): + def __init__(self, imgfolder, maskfolder, labels=[1], size=256, augmentation='randomcrop'): self.images = np.asarray(sorted(glob.glob(os.path.join(imgfolder, "*.jpg")))) self.masks = np.asarray(sorted(glob.glob(os.path.join(maskfolder, "*.png")))) self.size = size @@ -22,10 +22,10 @@ def __init__(self, imgfolder, maskfolder, labels=[1], size=256, augmentation='re assert np.all(self.image_ids == self.mask_ids), "Image IDs and Mask IDs do not match!" if augmentation == 'randomcrop': - self.augmentation = transforms.Resize(size=(size, size), antialias=None) + self.augmentation = transforms.RandomCrop(size=(size, size), pad_if_needed=True) elif augmentation == 'randomcrop+flip': self.augmentation = transforms.Compose([ - transforms.Resize(size=(size, size), antialias=None), + transforms.RandomCrop(size=(size, size), pad_if_needed=True), transforms.RandomHorizontalFlip(0.25), transforms.RandomVerticalFlip(0.25), ]) @@ -56,3 +56,34 @@ def __getitem__(self, index): mask[i, labels == label] = 1 return img, mask + + +class COCOStuffPointDataset(COCOStuffDataset): + augmentation = None + + def __getitem__(self, index): + image_file = self.images[index] + mask_file = self.masks[index] + + img = read_image(image_file, ImageReadMode.RGB) / 255. + labels = read_image(mask_file, ImageReadMode.GRAY) + 1 + + # add the mask so we can crop it + data_stacked = torch.cat((img, labels), dim=0) + + point = torch.rand(2) + + if self.augmentation is not None: + data_stacked = self.augmentation(data_stacked) + + point[0] = torch.floor(point[0] * data_stacked.shape[1]) + point[1] = torch.floor(point[1] * data_stacked.shape[2]) + + img = data_stacked[:3, :] + labels = data_stacked[3, :] + + mask = torch.zeros((1, labels.shape[0], labels.shape[1])) + label = labels[int(point[1]), int(point[0])] + mask[0, labels == label] = 1 + + return img, point, mask diff --git a/patchgan/train.py b/patchgan/train.py index 92482ea..4c5bf27 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -1,7 +1,7 @@ import torch from torchinfo import summary -from .io import COCOStuffDataset -from .patchgan import PatchGAN +from .io import COCOStuffDataset, COCOStuffPointDataset +from .patchgan import PatchGAN, PatchGANPoint import os from torch.utils.data import DataLoader, random_split from lightning.pytorch import Trainer @@ -22,7 +22,7 @@ def patchgan_train(): parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)') parser.add_argument('-n', '--n_epochs', required=True, type=int, help='Number of epochs to train the model') parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)') - parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models") + parser.add_argument('--summary', action='store_true', help="Print summary of the models") args = parser.parse_args() @@ -45,16 +45,33 @@ def patchgan_train(): else: raise AttributeError("Please provide either the training and validation data paths or a train/val split!") + model_type = config.get('model_type', 'patchgan') + + if model_type == 'patchgan': + patchgan_model = PatchGAN + elif model_type == 'patchgan_point': + patchgan_model = PatchGANPoint + else: + raise ValueError(f"{model_type} not supported!") + size = dataset_params.get('size', 256) augmentation = dataset_params.get('augmentation', 'randomcrop') dataset_kwargs = {} if dataset_params['type'] == 'COCOStuff': + assert model_type == 'patchgan', "model_type should be set to 'patchgan' to use the COCOStuff dataset. Did you mean COCOStuffPoint?" Dataset = COCOStuffDataset in_channels = 3 labels = dataset_params.get('labels', [1]) out_channels = len(labels) dataset_kwargs['labels'] = labels + elif dataset_params['type'] == 'COCOStuffPoint': + assert model_type == 'patchgan_point', "model_type should be set to 'patchgan_point' to use the COCOStuffPoint dataset. Did you mean COCOStuff?" + Dataset = COCOStuffPointDataset + in_channels = 3 + labels = dataset_params.get('labels', [1]) + out_channels = 1 + dataset_kwargs['labels'] = labels else: try: spec = importlib.machinery.SourceFileLoader('io', 'io.py') @@ -87,7 +104,9 @@ def patchgan_train(): model = None checkpoint_file = config.get('load_from_checkpoint', '') if os.path.isfile(checkpoint_file): - model = PatchGAN.load_from_checkpoint(checkpoint_file) + model = patchgan_model.load_from_checkpoint(checkpoint_file) + elif config.get('transfer_learn', {}).get('checkpoint', None) is not None: + model = patchgan_model.load_transfer_data(config['transfer_learn']['checkpoint'], in_channels, out_channels) if model is None: model_params = config['model_params'] @@ -112,17 +131,15 @@ def patchgan_train(): lr_decay = train_params.get('decay_rate', 0.98) decay_freq = train_params.get('decay_freq', 5) save_freq = train_params.get('save_freq', 10) - model = PatchGAN(in_channels, out_channels, gen_filts, disc_filts, final_activation, n_disc_layers, use_dropout, - activation, disc_norm, gen_learning_rate, dsc_learning_rate, lr_decay, decay_freq, - loss_type=loss_type, seg_alpha=seg_alpha) - - if config.get('transfer_learn', {}).get('checkpoint', None) is not None: - checkpoint = torch.load(config['transfer_learn']['checkpoint'], map_location=device) - model.generator.load_transfer_data({key.replace('PatchGAN.', ''): value for key, value in checkpoint['state_dict'].items() if 'generator' in key}) - model.discriminator.load_transfer_data({key.replace('PatchGAN.', ''): value for key, value in checkpoint['state_dict'].items() if 'discriminator' in key}) + model = patchgan_model(in_channels, out_channels, gen_filts, disc_filts, final_activation, n_disc_layers, use_dropout, + activation, disc_norm, gen_learning_rate, dsc_learning_rate, lr_decay, decay_freq, + loss_type=loss_type, seg_alpha=seg_alpha) if args.summary: - summary(model.generator, [1, in_channels, size, size], depth=4) + if model_type == 'patchgan': + summary(model, [1, in_channels, size, size], depth=4) + elif model_type == 'patchgan_point': + summary(model, [[1, in_channels, size, size], [1, 2]], depth=4) summary(model.discriminator, [1, in_channels + out_channels, size, size]) checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_path, From 40056ac903a28b923d89bc2171dec0fcd832c757 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Thu, 1 Feb 2024 11:08:55 -0800 Subject: [PATCH 13/17] ignoring pytorch lightning --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 5b65a82..634c882 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,5 @@ dmypy.json *.pth *.npz *.png +*.ckpt +lightning_logs/ From 853b48b247482a2f56ee77fe05a92e3cceec764a Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Thu, 1 Feb 2024 11:09:09 -0800 Subject: [PATCH 14/17] converted disc to sequential model --- patchgan/disc.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/patchgan/disc.py b/patchgan/disc.py index d8d15ef..fdfaec6 100644 --- a/patchgan/disc.py +++ b/patchgan/disc.py @@ -3,7 +3,7 @@ from .conv_layers import weights_init -class Discriminator(nn.Module, Transferable): +class Discriminator(nn.Sequential, Transferable): """Defines a PatchGAN discriminator""" def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.InstanceNorm2d): @@ -14,7 +14,6 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.Insta n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ - super(Discriminator, self).__init__() kw = 4 padw = 1 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, @@ -45,10 +44,6 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.Insta # output 1 channel prediction map sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw), nn.Sigmoid()] - self.model = nn.Sequential(*sequence) + super().__init__(*sequence) self.apply(weights_init) - - def forward(self, input): - """Standard forward.""" - return self.model(input) From 8831e37a5e3e2247f1e9f5c8c17b3114e5a4d186 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Thu, 1 Feb 2024 11:09:23 -0800 Subject: [PATCH 15/17] switching point encoder back to MLP --- patchgan/patchgan.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/patchgan/patchgan.py b/patchgan/patchgan.py index 06944b7..92a0a3a 100644 --- a/patchgan/patchgan.py +++ b/patchgan/patchgan.py @@ -3,10 +3,10 @@ from torch import nn from torch.optim.lr_scheduler import ExponentialLR from .losses import fc_tversky, bce_loss, MAE_loss -from torch.nn.functional import binary_cross_entropy +from torch.nn.functional import binary_cross_entropy, one_hot from .disc import Discriminator -from .conv_layers import Encoder, Decoder from .point_encoder import PointEncoder +from .conv_layers import Encoder, Decoder from typing import Union, Optional import lightning as L @@ -18,7 +18,7 @@ class PatchGAN(L.LightningModule): def __init__(self, input_channels: int, output_channels: int, gen_filts: int, disc_filts: int, final_activation: str, n_disc_layers: int = 5, use_gen_dropout: bool = True, gen_activation: str = 'leakyrelu', disc_norm: bool = False, gen_lr: float = 1.e-3, dsc_lr: float = 1.e-3, lr_decay: float = 0.98, - decay_freq: int = 5, adam_b1: float = 0.5, adam_b2: float = 0.999, seg_alpha: float = 200, + decay_freq: int = 5, adam_b1: float = 0.9, adam_b2: float = 0.999, seg_alpha: float = 200, loss_type: str = 'tversky', tversky_beta: float = 0.75, tversky_gamma: float = 0.75): super().__init__() self.save_hyperparameters() @@ -101,18 +101,20 @@ def batch_step(self, batch: Union[torch.Tensor, tuple[torch.Tensor]], train: boo labels_real = torch.full(disc_fake.shape, 1, dtype=torch.float, device=device) labels_fake = torch.full(disc_fake.shape, 0, dtype=torch.float, device=device) + target_tensor_full = one_hot(target_tensor, self.hparams.output_channels).permute(0, 3, 1, 2).to(torch.float) + if self.hparams.loss_type == 'tversky': - gen_loss = fc_tversky(target_tensor, gen_img, + gen_loss = fc_tversky(target_tensor_full, gen_img, beta=self.hparams.tversky_beta, gamma=self.hparams.tversky_gamma) * self.hparams.seg_alpha elif self.hparams.loss_type == 'weighted_bce': if gen_img.shape[1] > 1: - weight = 1 - torch.sum(target_tensor, dim=(2, 3), keepdim=True) / torch.sum(target_tensor) + weight = 1 - torch.sum(target_tensor_full, dim=(2, 3), keepdim=True) / (torch.sum(target_tensor_full) + 1.e-6) else: - weight = torch.ones_like(target_tensor) - gen_loss = binary_cross_entropy(gen_img, target_tensor, weight=weight) * self.hparams.seg_alpha + weight = torch.ones_like(target_tensor_full) + gen_loss = binary_cross_entropy(gen_img, target_tensor_full, weight=weight) * self.hparams.seg_alpha elif self.hparams.loss_type == 'MAE': - gen_loss = MAE_loss(gen_img, target_tensor) * self.hparams.seg_alpha + gen_loss = MAE_loss(gen_img, target_tensor_full) * self.hparams.seg_alpha gen_loss_disc = bce_loss(disc_fake, labels_real) gen_loss = gen_loss + gen_loss_disc @@ -128,7 +130,7 @@ def batch_step(self, batch: Union[torch.Tensor, tuple[torch.Tensor]], train: boo if train: self.toggle_optimizer(optimizer_d) - disc_inp_real = torch.cat((input_tensor, target_tensor), 1) + disc_inp_real = torch.cat((input_tensor, target_tensor_full), 1) disc_real = self.discriminator(disc_inp_real) disc_inp_fake = torch.cat((input_tensor, gen_img.detach()), 1) disc_fake = self.discriminator(disc_inp_fake) @@ -186,6 +188,9 @@ def __init__(self, *patchgan_args, **patchgan_kwargs): # by default use the final number of filters in the UNet (hard-coded to gen_filts * 8) self.point_encoder = PointEncoder(self.hparams.gen_filts * 8) + def get_parameters(self): + return list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.point_encoder.parameters()), self.discriminator.parameters() + def forward(self, x, point): xencs = [] @@ -195,11 +200,10 @@ def forward(self, x, point): hidden = xencs[-1] - _, c, h, w = hidden.shape - - z_point = self.point_encoder(point).unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w) + _, _, h, w = hidden.shape + point_mask = self.point_encoder(point).unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w) - hidden = hidden + z_point + hidden = hidden * point_mask xencs = xencs[::-1] From 97c2f7f4544ac44b06922faaca6eb6309f6dc8a8 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Thu, 1 Feb 2024 11:09:43 -0800 Subject: [PATCH 16/17] fixing setuptools to find data products --- pyproject.toml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f2501c6..4d148b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,15 @@ requires = ["setuptools >= 61.0", "wheel"] build-backend = "setuptools.build_meta" +[tool.setuptools.packages.find] +where = ["."] +include = ["patchgan"] +exclude = [] +namespaces = false + +[tool.setuptools.package-data] +patchgan = ["*.yaml"] + [project] name = "patchgan" dynamic = ["version"] From 8b85f5b576a083747dacf8168b1bc5eb003ee1bc Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Thu, 1 Feb 2024 11:10:15 -0800 Subject: [PATCH 17/17] fixing COCO dataset. converting mask input to be one-hot decoded --- patchgan/io.py | 17 ++-- patchgan/labels.yaml | 182 +++++++++++++++++++++++++++++++++++++++++++ patchgan/train.py | 16 +++- 3 files changed, 205 insertions(+), 10 deletions(-) create mode 100644 patchgan/labels.yaml diff --git a/patchgan/io.py b/patchgan/io.py index 932ea10..5b79eb3 100644 --- a/patchgan/io.py +++ b/patchgan/io.py @@ -10,7 +10,7 @@ class COCOStuffDataset(Dataset): augmentation = None - def __init__(self, imgfolder, maskfolder, labels=[1], size=256, augmentation='randomcrop'): + def __init__(self, imgfolder, maskfolder, labels='all', size=256, augmentation='randomcrop'): self.images = np.asarray(sorted(glob.glob(os.path.join(imgfolder, "*.jpg")))) self.masks = np.asarray(sorted(glob.glob(os.path.join(maskfolder, "*.png")))) self.size = size @@ -51,11 +51,11 @@ def __getitem__(self, index): img = data_stacked[:3, :] labels = data_stacked[3, :] - mask = torch.zeros((len(self.labels), labels.shape[0], labels.shape[1])) - for i, label in enumerate(self.labels): - mask[i, labels == label] = 1 + mask = torch.ones(labels.shape, dtype=torch.long) + for label in self.labels: + mask[labels == label] = self.labels.index(label) - return img, mask + return img, mask.to(torch.long) class COCOStuffPointDataset(COCOStuffDataset): @@ -82,8 +82,9 @@ def __getitem__(self, index): img = data_stacked[:3, :] labels = data_stacked[3, :] - mask = torch.zeros((1, labels.shape[0], labels.shape[1])) + mask = torch.zeros((labels.shape[0], labels.shape[1])) label = labels[int(point[1]), int(point[0])] - mask[0, labels == label] = 1 + if label in self.labels: + mask[labels == label] = self.labels.index(label) - return img, point, mask + return img, point, mask.to(torch.long) diff --git a/patchgan/labels.yaml b/patchgan/labels.yaml new file mode 100644 index 0000000..6a905f5 --- /dev/null +++ b/patchgan/labels.yaml @@ -0,0 +1,182 @@ +1: person +2: bicycle +3: car +4: motorcycle +5: airplane +6: bus +7: train +8: truck +9: boat +10: traffic light +11: fire hydrant +12: street sign +13: stop sign +14: parking meter +15: bench +16: bird +17: cat +18: dog +19: horse +20: sheep +21: cow +22: elephant +23: bear +24: zebra +25: giraffe +26: hat +27: backpack +28: umbrella +29: shoe +30: eye glasses +31: handbag +32: tie +33: suitcase +34: frisbee +35: skis +36: snowboard +37: sports ball +38: kite +39: baseball bat +40: baseball glove +41: skateboard +42: surfboard +43: tennis racket +44: bottle +45: plate +46: wine glass +47: cup +48: fork +49: knife +50: spoon +51: bowl +52: banana +53: apple +54: sandwich +55: orange +56: broccoli +57: carrot +58: hot dog +59: pizza +60: donut +61: cake +62: chair +63: couch +64: potted plant +65: bed +66: mirror +67: dining table +68: window +69: desk +70: toilet +71: door +72: tv +73: laptop +74: mouse +75: remote +76: keyboard +77: cell phone +78: microwave +79: oven +80: toaster +81: sink +82: refrigerator +83: blender +84: book +85: clock +86: vase +87: scissors +88: teddy bear +89: hair drier +90: toothbrush +91: hair brush +92: banner +93: blanket +94: branch +95: bridge +96: building-other +97: bush +98: cabinet +99: cage +100: cardboard +101: carpet +102: ceiling-other +103: ceiling-tile +104: cloth +105: clothes +106: clouds +107: counter +108: cupboard +109: curtain +110: desk-stuff +111: dirt +112: door-stuff +113: fence +114: floor-marble +115: floor-other +116: floor-stone +117: floor-tile +118: floor-wood +119: flower +120: fog +121: food-other +122: fruit +123: furniture-other +124: grass +125: gravel +126: ground-other +127: hill +128: house +129: leaves +130: light +131: mat +132: metal +133: mirror-stuff +134: moss +135: mountain +136: mud +137: napkin +138: net +139: paper +140: pavement +141: pillow +142: plant-other +143: plastic +144: platform +145: playingfield +146: railing +147: railroad +148: river +149: road +150: rock +151: roof +152: rug +153: salad +154: sand +155: sea +156: shelf +157: sky-other +158: skyscraper +159: snow +160: solid-other +161: stairs +162: stone +163: straw +164: structural-other +165: table +166: tent +167: textile-other +168: towel +169: tree +170: vegetable +171: wall-brick +172: wall-concrete +173: wall-other +174: wall-panel +175: wall-stone +176: wall-tile +177: wall-wood +178: water-other +179: waterdrops +180: window-blind +181: window-other +182: wood \ No newline at end of file diff --git a/patchgan/train.py b/patchgan/train.py index 4c5bf27..a299b49 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -11,6 +11,10 @@ import argparse +with open(os.path.join(os.path.split(__file__)[0], 'labels.yaml'), 'r') as infile: + coco_labels = yaml.safe_load(infile) + + def patchgan_train(): parser = argparse.ArgumentParser( prog='PatchGAN', @@ -63,14 +67,22 @@ def patchgan_train(): Dataset = COCOStuffDataset in_channels = 3 labels = dataset_params.get('labels', [1]) - out_channels = len(labels) + if isinstance(labels, list): + labels = sorted(labels) + elif labels == 'all': + labels = sorted(coco_labels.keys()) + out_channels = len(labels) + 1 # include a background channel dataset_kwargs['labels'] = labels elif dataset_params['type'] == 'COCOStuffPoint': assert model_type == 'patchgan_point', "model_type should be set to 'patchgan_point' to use the COCOStuffPoint dataset. Did you mean COCOStuff?" Dataset = COCOStuffPointDataset in_channels = 3 labels = dataset_params.get('labels', [1]) - out_channels = 1 + if isinstance(labels, list): + labels = sorted(labels) + elif labels == 'all': + labels = sorted(coco_labels.keys()) + out_channels = len(labels) + 1 # include a background channel dataset_kwargs['labels'] = labels else: try: