Skip to content

Commit

Permalink
added point conditioning training to patchgan
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanakumars committed Feb 1, 2024
1 parent 34ebeb1 commit fb5d42c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 16 deletions.
37 changes: 34 additions & 3 deletions patchgan/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class COCOStuffDataset(Dataset):
augmentation = None

def __init__(self, imgfolder, maskfolder, labels=[1], size=256, augmentation='resize'):
def __init__(self, imgfolder, maskfolder, labels=[1], size=256, augmentation='randomcrop'):
self.images = np.asarray(sorted(glob.glob(os.path.join(imgfolder, "*.jpg"))))
self.masks = np.asarray(sorted(glob.glob(os.path.join(maskfolder, "*.png"))))
self.size = size
Expand All @@ -22,10 +22,10 @@ def __init__(self, imgfolder, maskfolder, labels=[1], size=256, augmentation='re
assert np.all(self.image_ids == self.mask_ids), "Image IDs and Mask IDs do not match!"

if augmentation == 'randomcrop':
self.augmentation = transforms.Resize(size=(size, size), antialias=None)
self.augmentation = transforms.RandomCrop(size=(size, size), pad_if_needed=True)
elif augmentation == 'randomcrop+flip':
self.augmentation = transforms.Compose([
transforms.Resize(size=(size, size), antialias=None),
transforms.RandomCrop(size=(size, size), pad_if_needed=True),
transforms.RandomHorizontalFlip(0.25),
transforms.RandomVerticalFlip(0.25),
])
Expand Down Expand Up @@ -56,3 +56,34 @@ def __getitem__(self, index):
mask[i, labels == label] = 1

return img, mask


class COCOStuffPointDataset(COCOStuffDataset):
augmentation = None

def __getitem__(self, index):
image_file = self.images[index]
mask_file = self.masks[index]

img = read_image(image_file, ImageReadMode.RGB) / 255.
labels = read_image(mask_file, ImageReadMode.GRAY) + 1

# add the mask so we can crop it
data_stacked = torch.cat((img, labels), dim=0)

point = torch.rand(2)

if self.augmentation is not None:
data_stacked = self.augmentation(data_stacked)

point[0] = torch.floor(point[0] * data_stacked.shape[1])
point[1] = torch.floor(point[1] * data_stacked.shape[2])

img = data_stacked[:3, :]
labels = data_stacked[3, :]

mask = torch.zeros((1, labels.shape[0], labels.shape[1]))
label = labels[int(point[1]), int(point[0])]
mask[0, labels == label] = 1

return img, point, mask
43 changes: 30 additions & 13 deletions patchgan/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torchinfo import summary
from .io import COCOStuffDataset
from .patchgan import PatchGAN
from .io import COCOStuffDataset, COCOStuffPointDataset
from .patchgan import PatchGAN, PatchGANPoint
import os
from torch.utils.data import DataLoader, random_split
from lightning.pytorch import Trainer
Expand All @@ -22,7 +22,7 @@ def patchgan_train():
parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)')
parser.add_argument('-n', '--n_epochs', required=True, type=int, help='Number of epochs to train the model')
parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)')
parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models")
parser.add_argument('--summary', action='store_true', help="Print summary of the models")

args = parser.parse_args()

Expand All @@ -45,16 +45,33 @@ def patchgan_train():
else:
raise AttributeError("Please provide either the training and validation data paths or a train/val split!")

model_type = config.get('model_type', 'patchgan')

if model_type == 'patchgan':
patchgan_model = PatchGAN
elif model_type == 'patchgan_point':
patchgan_model = PatchGANPoint
else:
raise ValueError(f"{model_type} not supported!")

size = dataset_params.get('size', 256)
augmentation = dataset_params.get('augmentation', 'randomcrop')

dataset_kwargs = {}
if dataset_params['type'] == 'COCOStuff':
assert model_type == 'patchgan', "model_type should be set to 'patchgan' to use the COCOStuff dataset. Did you mean COCOStuffPoint?"
Dataset = COCOStuffDataset
in_channels = 3
labels = dataset_params.get('labels', [1])
out_channels = len(labels)
dataset_kwargs['labels'] = labels
elif dataset_params['type'] == 'COCOStuffPoint':
assert model_type == 'patchgan_point', "model_type should be set to 'patchgan_point' to use the COCOStuffPoint dataset. Did you mean COCOStuff?"
Dataset = COCOStuffPointDataset
in_channels = 3
labels = dataset_params.get('labels', [1])
out_channels = 1
dataset_kwargs['labels'] = labels
else:
try:
spec = importlib.machinery.SourceFileLoader('io', 'io.py')
Expand Down Expand Up @@ -87,7 +104,9 @@ def patchgan_train():
model = None
checkpoint_file = config.get('load_from_checkpoint', '')
if os.path.isfile(checkpoint_file):
model = PatchGAN.load_from_checkpoint(checkpoint_file)
model = patchgan_model.load_from_checkpoint(checkpoint_file)
elif config.get('transfer_learn', {}).get('checkpoint', None) is not None:
model = patchgan_model.load_transfer_data(config['transfer_learn']['checkpoint'], in_channels, out_channels)

if model is None:
model_params = config['model_params']
Expand All @@ -112,17 +131,15 @@ def patchgan_train():
lr_decay = train_params.get('decay_rate', 0.98)
decay_freq = train_params.get('decay_freq', 5)
save_freq = train_params.get('save_freq', 10)
model = PatchGAN(in_channels, out_channels, gen_filts, disc_filts, final_activation, n_disc_layers, use_dropout,
activation, disc_norm, gen_learning_rate, dsc_learning_rate, lr_decay, decay_freq,
loss_type=loss_type, seg_alpha=seg_alpha)

if config.get('transfer_learn', {}).get('checkpoint', None) is not None:
checkpoint = torch.load(config['transfer_learn']['checkpoint'], map_location=device)
model.generator.load_transfer_data({key.replace('PatchGAN.', ''): value for key, value in checkpoint['state_dict'].items() if 'generator' in key})
model.discriminator.load_transfer_data({key.replace('PatchGAN.', ''): value for key, value in checkpoint['state_dict'].items() if 'discriminator' in key})
model = patchgan_model(in_channels, out_channels, gen_filts, disc_filts, final_activation, n_disc_layers, use_dropout,
activation, disc_norm, gen_learning_rate, dsc_learning_rate, lr_decay, decay_freq,
loss_type=loss_type, seg_alpha=seg_alpha)

if args.summary:
summary(model.generator, [1, in_channels, size, size], depth=4)
if model_type == 'patchgan':
summary(model, [1, in_channels, size, size], depth=4)
elif model_type == 'patchgan_point':
summary(model, [[1, in_channels, size, size], [1, 2]], depth=4)
summary(model.discriminator, [1, in_channels + out_channels, size, size])

checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_path,
Expand Down

0 comments on commit fb5d42c

Please sign in to comment.