From 95e23bf8be81d0f3f9cbaf8e1fb02ed7ad3a9f93 Mon Sep 17 00:00:00 2001 From: Kai Date: Thu, 20 May 2021 14:04:21 -0700 Subject: [PATCH] Separating the refinement network's definition from the coarse network's --- model_definition.py => coarse_definition.py | 0 dali_dataloader.py | 51 +++++++++++---------- inference.py | 29 ++++++++---- refine_definition.py | 47 +++++++++++++++++++ requirements.txt | 1 - train_both.py | 24 +++++----- train_coarse.py | 19 ++++---- 7 files changed, 117 insertions(+), 54 deletions(-) rename model_definition.py => coarse_definition.py (100%) create mode 100644 refine_definition.py diff --git a/model_definition.py b/coarse_definition.py similarity index 100% rename from model_definition.py rename to coarse_definition.py diff --git a/dali_dataloader.py b/dali_dataloader.py index 7390f5f..54eb783 100644 --- a/dali_dataloader.py +++ b/dali_dataloader.py @@ -1,4 +1,3 @@ -print('loading libraries...') import torch from nvidia.dali.pipeline import Pipeline @@ -14,7 +13,6 @@ from PIL import Image import numpy as np import time -import cupy as cp import imageio #print('Initializing Dataset...') @@ -136,10 +134,10 @@ def define_graph(self): bprime = fn.copy(bg) alpha = fn.decoders.image(alpha_read, device = 'mixed') - bg = fn.resize(bg, size = [1920, 1080]) - fg = fn.resize(fg, size = [1920, 1080]) - bprime = fn.resize(bprime, size = [1920, 1080]) - alpha = fn.resize(alpha, size = [1920, 1080]) + bg = fn.resize(bg, resize_x = 1920, resize_y = 1080) + fg = fn.resize(fg, resize_x = 1920, resize_y = 1080) + bprime = fn.resize(bprime, resize_x = 1920, resize_y = 1080) + alpha = fn.resize(alpha, resize_x = 1920, resize_y = 1080) bg_rot = fn.random.uniform(range = [-8.0, 9.0]) @@ -172,26 +170,26 @@ def define_graph(self): bg = fn.color_twist(bg, brightness = bg_brightness, saturation = bg_saturation, hue = bg_hue) bg = fn.contrast(bg, contrast = bg_contrast) - bg = fn.rotate(bg, angle = bg_rot, size = [1920, 1080]) + bg = fn.rotate(bg, angle = bg_rot, keep_size = True) bg = fn.gaussian_blur(bg, sigma = bg_blur, window_size = 3) bg = fn.flip(bg, horizontal = bg_flip_chance) - bg = DALImath.clamp(bg + fn.random.normal(bg) * 5, 0, 256) + bg = DALImath.clamp(bg + fn.random.normal(bg) * 10, 0, 256) fg = fn.color_twist(fg, brightness = fg_brightness, saturation = fg_saturation, hue = fg_hue) fg = fn.contrast(fg, contrast = fg_contrast) - fg = fn.rotate(fg, angle = fg_rot, size = [1920, 1080]) + fg = fn.rotate(fg, angle = fg_rot, keep_size = True) fg = fn.gaussian_blur(fg, sigma = fg_blur, window_size = 3) fg = fn.flip(fg, horizontal = fg_flip_chance) - bg = DALImath.clamp(fg + fn.random.normal(fg) * 5, 0, 256) + fg = DALImath.clamp(fg + fn.random.normal(fg) * 10, 0, 256) bprime = fn.color_twist(bprime, brightness = bprime_brightness, saturation = bprime_saturation, hue = bprime_hue) bprime = fn.contrast(bprime, contrast = bprime_contrast) - bprime = fn.rotate(bprime, angle = bprime_rot, size = [1920, 1080]) + bprime = fn.rotate(bprime, angle = bprime_rot, keep_size = True) bprime = fn.gaussian_blur(bprime, sigma = bprime_blur, window_size = 3) bprime = fn.flip(bprime, horizontal = bg_flip_chance) - bg = DALImath.clamp(bprime + fn.random.normal(bprime) * 5, 0, 256) + bprime = DALImath.clamp(bprime + fn.random.normal(bprime) * 10, 0, 256) - alpha = fn.rotate(alpha, angle = alpha_rot, size = [1920, 1080]) + alpha = fn.rotate(alpha, angle = alpha_rot, keep_size = True) alpha = fn.gaussian_blur(alpha, sigma = alpha_blur, window_size = 3) alpha = fn.flip(alpha, horizontal = fg_flip_chance) @@ -212,19 +210,24 @@ def __next__(self): tensor_dict = next(self.loader)[0] - fg = tensor_dict['fg'].permute(0,3,2,1) - bg = tensor_dict['bg'].permute(0,3,2,1) - bprime = tensor_dict['bprime'].permute(0,3,2,1) - alpha = tensor_dict['alpha'].permute(0,3,2,1) + fg = tensor_dict['fg'].permute(0,3,1,2) + bg = tensor_dict['bg'].permute(0,3,1,2) + bprime = tensor_dict['bprime'].permute(0,3,1,2) + alpha = tensor_dict['alpha'].permute(0,3,1,2) alpha = alpha[:, :1] - png = torch.cat([fg, alpha], 1) + bg = bg.float()/256 + fg = fg.float()/256 + bprime = bprime.float()/256 + alpha = alpha.float()/256 + + png = torch.cat([fg, alpha], dim = 1) bg_trans_x = np.random.randint(-100, 100) bg_trans_y = np.random.randint(-100, 100) bg_shear_x = np.random.randint(-5, 6) bg_shear_y = np.random.randint(-5, 6) - bg_scale = np.random.randint(8, 13) / 10 + bg_scale = np.random.randint(10, 13) / 10 aug_bg_params = { @@ -277,22 +280,24 @@ def __next__(self): } aug_png_tensor = TF.affine(**aug_png_params) - aug_fg_tensor = aug_png_tensor[:, :3, :, :] - aug_alpha_tensor = aug_png_tensor[:, 3:4, :, :] + aug_fg_tensor = aug_png_tensor[:, :3] + aug_alpha_tensor = aug_png_tensor[:, 3:4] + if(np.random.randint(0, 10) > 6): shadow_x = np.random.randint(0, 200) shadow_y = np.random.randint(0, 200) shadow_shear = np.random.randint(-30, 30) shadow_rotation = np.random.randint(-30, 30) - shadow_strength = np.random.randint(10, 90) / 100 + shadow_strength = np.random.randint(20, 80) / 100 shadow_blur = np.random.randint(2, 16) * 2 + 1 shadow_stamp = TF.affine(aug_alpha_tensor, translate = [shadow_x, shadow_y], shear = shadow_shear, angle = shadow_rotation, scale = 1) shadow_stamp = TF.gaussian_blur(shadow_stamp, shadow_blur) shadow_stamp = shadow_stamp * shadow_strength - aug_bg_tensor = aug_bg_tensor - aug_bg_tensor * shadow_stamp + aug_bg_tensor = aug_bg_tensor - (aug_bg_tensor * shadow_stamp) + return (aug_bg_tensor, aug_fg_tensor, aug_bprime_tensor, aug_alpha_tensor) diff --git a/inference.py b/inference.py index 8dc115a..1131b7a 100644 --- a/inference.py +++ b/inference.py @@ -19,7 +19,8 @@ import os -from model_definition import * +from coarse_definition import CoarseMatteGenerator +from refine_definition import RefinementNetwork from train_utils import get_image_patches, replace_image_patches, color_ramp device = "cuda" @@ -100,9 +101,12 @@ class UserInputDataset(Dataset): - def __init__(self): + def __init__(self, depth = 5, stride = 3): super().__init__() + self.depth = depth + self.stride = stride + return def __len__(self): @@ -120,12 +124,15 @@ def __getitem__(self, source_idx): #always temporally center our search background_start_idx = int(source_idx / source_len * background_len) - search_start_idx = max(0, background_start_idx - search_width) - search_end_idx = min(background_len, background_start_idx + search_width) + search_start_idx = max(0, background_start_idx - self.depth*self.stride) + search_end_idx = min(background_len, background_start_idx + self.depth*self.stride) best_background = np.zeros_like(source_img) - for background_idx, background_name in enumerate(background_list\ - [search_start_idx: search_end_idx], start = search_start_idx): + #getting strided files is a hassle, because I can't use enumerate to get the indices. + + for background_idx, background_name in zip(range(search_start_idx, search_end_idx, self.stride),\ + (background_list\ + [search_start_idx : search_end_idx : self.stride], start = search_start_idx)): matches = matcher.match(background_des_list[background_idx], source_des_list[source_idx], None) @@ -209,10 +216,14 @@ def __getitem__(self, source_idx): #Now, feed the outputs of the coarse generator into the refinement network, which will refine patches. fake_refined_patches = Refine(start_patches, middle_patches) - mega_upscaled_fake_coarse_alpha = F.interpolate(fake_coarse_alpha.detach(), size = [input_tensor.shape[-2], input_tensor.shape[-1]]) - fake_refined_alpha = replace_image_patches(images = mega_upscaled_fake_coarse_alpha, patches = fake_refined_patches, indices = indices) + fake_refined_patches = refine(start_patches, middle_patches) + + mega_upscaled_fake_coarse = F.interpolate(fake_coarse[:, :4].detach(), size = input_tensor.shape[-2:]) + fake_refined = replace_image_patches(images = mega_upscaled_fake_coarse, patches = fake_refined_patches, indices = indices) + fake_refined_alpha = color_ramp(0.05, 0.95, torch.clamp(fake_refined[:, 0:1], 0, 1)) + fake_refined_foreground = torch.clamp(fake_refined[:, 1:4] + composite_tensor, 0, 1) - RGBA = torch.cat([input_tensor[:, :3], fake_refined_alpha], 1) + RGBA = torch.cat([fake_refined_foreground, fake_refined_alpha], 1) for j in range(input_tensor.shape[0]): image = transforms.ToPILImage()(RGBA[j]) diff --git a/refine_definition.py b/refine_definition.py new file mode 100644 index 0000000..f15660a --- /dev/null +++ b/refine_definition.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import time +from tqdm import tqdm + +class RefinementNetwork(nn.Module): + + def __init__(self, coarse_channels = 37, input_channels = 6, patch_size = 8): + super().__init__() + + self.input_channels = input_channels + self.coarse_channels = coarse_channels + self.concat_channels = self.input_channels + self.coarse_channels + + #subtracting one from the input channels here because we aren't using the error map as an input + self.conv1 = nn.Conv2d(self.concat_channels, 24, kernel_size = 3) + self.conv2 = nn.Conv2d(24, 16, kernel_size = 3) + self.conv3 = nn.Conv2d(16 + self.input_channels, 12, kernel_size = 3) + self.conv4 = nn.Conv2d(12, 4, kernel_size = 3) + + self.bn1 = nn.BatchNorm2d(24) + self.bn2 = nn.BatchNorm2d(16) + self.bn3 = nn.BatchNorm2d(12) + + self.activation = nn.ReLU() + + + def forward(self, start_patches, middle_patches): + + z1 = self.conv1(start_patches) + z1 = self.bn1(z1) + x1 = self.activation(z1) + + z2 = self.conv2(x1) + z2 = self.bn2(z2) + x2 = self.activation(z2) + x2 = F.interpolate(x2, size = middle_patches.shape[-2:]) + + z3 = torch.cat([x2, middle_patches], 1) + z3 = self.conv3(z3) + z3 = self.bn3(z3) + x3 = self.activation(z3) + + z4 = self.conv4(x3) + + return z4 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b4c22bb..23b224f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,4 @@ numpy kornia more-itertools Pillow -cupy-cuda101 random \ No newline at end of file diff --git a/train_both.py b/train_both.py index 5c1a612..147068f 100644 --- a/train_both.py +++ b/train_both.py @@ -50,7 +50,7 @@ } -batch_size = 4 +batch_size = 2 #Initialize the dataset and the loader to feed the data to the network. @@ -62,7 +62,7 @@ Pipeline = AugmentationPipeline(\ dataset = ImageFeeder, \ - num_threads = 4, \ + num_threads = 8, \ device_id = 0, batch_size = batch_size) @@ -78,7 +78,7 @@ num_hidden_channels = 32 #Initialize the network which will produce a coarse alpha (1 chan), foreground (3 chan) confidence map (1 chan), and a number of hidden channels (num_hidden_channels chan)... -coarse = torch.load("model_saves/coarse_generator_network_epoch_7.zip").train().to(device) +coarse = torch.load("model_saves/proper_coarse_epoch195000.zip").train().to(device) refine = RefinementNetwork(coarse_channels = 5 + num_hidden_channels).train().to(device) use_amp = False @@ -100,10 +100,6 @@ with torch.cuda.amp.autocast(enabled = use_amp): real_background, real_foreground, real_bprime, real_alpha = next(DALIDataloader) - real_background = real_background.float()/256 - real_foreground = real_foreground.float()/256 - real_bprime = real_bprime.float()/256 - real_alpha = real_alpha.float()/256 """ real_foreground = real_foreground.to(device) @@ -191,7 +187,7 @@ iteration += 1 - if(iteration % 500 == 0): + if(iteration % 1000 == 0): image = fake_coarse_alpha[0] image = transforms.ToPILImage()(image) image.save(f'outputs7/{iteration}C_fake_coarse_alpha.jpg') @@ -215,18 +211,22 @@ image = fake_refined_foreground[0] image = transforms.ToPILImage()(image) image.save(f'outputs7/{iteration}F_refined_foreground.jpg') + + image = real_coarse_composite[0] + image = transforms.ToPILImage()(image) + image.save(f'outputs7/{iteration}E_coarse_composite.jpg') - if(iteration % 500 == 0): + if(iteration % 1000 == 0): print(coarse_loss) print(refine_loss) - if(iteration % 10000 == 0): + if(iteration % 15000 == 0): - torch.save(coarse, f"./model_saves/coarse_generator_network_epoch_{epoch}.zip") - torch.save(refine, f"./model_saves/refinement_network_epoch_{epoch}.zip") + torch.save(coarse, f"./model_saves/coarse_generator_network_epoch_{iteration}.zip") + torch.save(refine, f"./model_saves/refinement_network_epoch_{iteration}.zip") print('\nTraining completed successfully.') diff --git a/train_coarse.py b/train_coarse.py index 16c73d0..5e747f1 100644 --- a/train_coarse.py +++ b/train_coarse.py @@ -61,7 +61,7 @@ Pipeline = AugmentationPipeline(\ dataset = ImageFeeder, \ - num_threads = 4, \ + num_threads = 8, \ device_id = 0, batch_size = batch_size) @@ -94,10 +94,6 @@ with torch.cuda.amp.autocast(enabled = use_amp): real_background, real_foreground, real_bprime, real_alpha = next(DALIDataloader) - real_background = real_background.float()/256 - real_foreground = real_foreground.float()/256 - real_bprime = real_bprime.float()/256 - real_alpha = real_alpha.float()/256 """ real_foreground = real_foreground.to(device) @@ -108,6 +104,7 @@ #Composite the augmented foreground onto the augmented background according to the augmented alpha. composite_tensor = composite(real_background, real_foreground, real_alpha) + #return the input tensor (composite plus b-prime) and the alpha_tensor. The input tensor is just a bunch of channels, the real_alpha is the central (singular) alpha #corresponding to the target frame. @@ -131,7 +128,7 @@ fake_coarse_hidden_channels = torch.relu(fake_coarse[:,5:]) #The real error map is calculated as the squared difference between the real alpha and the fake alpha. - real_coarse_error = torch.abs(real_coarse_alpha.detach()-fake_coarse_alpha.detach()) + real_coarse_error = torch.clamp(torch.abs(real_coarse_alpha.detach()-fake_coarse_alpha.detach()), 0, 1) #construct the fake foreground #fake_coarse_foreground = torch.clamp(real_coarse_composite[:, dataset_params["comp_context_depth"]*3:dataset_params["comp_context_depth"]*3 + 3] + fake_coarse_foreground_residual, 0, 1) @@ -157,7 +154,7 @@ iteration += 1 - if(iteration % 500 == 0): + if(iteration % 1000 == 0): image = fake_coarse_alpha[0] image = transforms.ToPILImage()(image) image.save(f'outputs7/{iteration}C_fake_coarse_alpha.jpg') @@ -173,13 +170,17 @@ image = fake_coarse_error[0] image = transforms.ToPILImage()(image) image.save(f'outputs7/{iteration}D_fake_error.jpg') + + image = real_coarse_composite[0] + image = transforms.ToPILImage()(image) + image.save(f'outputs7/{iteration}E_coarse_composite.jpg') - if(iteration % 500 == 0): + if(iteration % 1000 == 0): print(coarse_loss) - if(iteration % 8000 == 0): + if(iteration % 15000 == 0): torch.save(coarse, f"./model_saves/proper_coarse_epoch{iteration}.zip")