Skip to content
This repository has been archived by the owner on Nov 21, 2023. It is now read-only.

Commit

Permalink
Separating the refinement network's definition from the coarse network's
Browse files Browse the repository at this point in the history
  • Loading branch information
mkaic committed May 20, 2021
1 parent d3e1988 commit 95e23bf
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 54 deletions.
File renamed without changes.
51 changes: 28 additions & 23 deletions dali_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
print('loading libraries...')

import torch
from nvidia.dali.pipeline import Pipeline
Expand All @@ -14,7 +13,6 @@
from PIL import Image
import numpy as np
import time
import cupy as cp
import imageio

#print('Initializing Dataset...')
Expand Down Expand Up @@ -136,10 +134,10 @@ def define_graph(self):
bprime = fn.copy(bg)
alpha = fn.decoders.image(alpha_read, device = 'mixed')

bg = fn.resize(bg, size = [1920, 1080])
fg = fn.resize(fg, size = [1920, 1080])
bprime = fn.resize(bprime, size = [1920, 1080])
alpha = fn.resize(alpha, size = [1920, 1080])
bg = fn.resize(bg, resize_x = 1920, resize_y = 1080)
fg = fn.resize(fg, resize_x = 1920, resize_y = 1080)
bprime = fn.resize(bprime, resize_x = 1920, resize_y = 1080)
alpha = fn.resize(alpha, resize_x = 1920, resize_y = 1080)


bg_rot = fn.random.uniform(range = [-8.0, 9.0])
Expand Down Expand Up @@ -172,26 +170,26 @@ def define_graph(self):

bg = fn.color_twist(bg, brightness = bg_brightness, saturation = bg_saturation, hue = bg_hue)
bg = fn.contrast(bg, contrast = bg_contrast)
bg = fn.rotate(bg, angle = bg_rot, size = [1920, 1080])
bg = fn.rotate(bg, angle = bg_rot, keep_size = True)
bg = fn.gaussian_blur(bg, sigma = bg_blur, window_size = 3)
bg = fn.flip(bg, horizontal = bg_flip_chance)
bg = DALImath.clamp(bg + fn.random.normal(bg) * 5, 0, 256)
bg = DALImath.clamp(bg + fn.random.normal(bg) * 10, 0, 256)

fg = fn.color_twist(fg, brightness = fg_brightness, saturation = fg_saturation, hue = fg_hue)
fg = fn.contrast(fg, contrast = fg_contrast)
fg = fn.rotate(fg, angle = fg_rot, size = [1920, 1080])
fg = fn.rotate(fg, angle = fg_rot, keep_size = True)
fg = fn.gaussian_blur(fg, sigma = fg_blur, window_size = 3)
fg = fn.flip(fg, horizontal = fg_flip_chance)
bg = DALImath.clamp(fg + fn.random.normal(fg) * 5, 0, 256)
fg = DALImath.clamp(fg + fn.random.normal(fg) * 10, 0, 256)

bprime = fn.color_twist(bprime, brightness = bprime_brightness, saturation = bprime_saturation, hue = bprime_hue)
bprime = fn.contrast(bprime, contrast = bprime_contrast)
bprime = fn.rotate(bprime, angle = bprime_rot, size = [1920, 1080])
bprime = fn.rotate(bprime, angle = bprime_rot, keep_size = True)
bprime = fn.gaussian_blur(bprime, sigma = bprime_blur, window_size = 3)
bprime = fn.flip(bprime, horizontal = bg_flip_chance)
bg = DALImath.clamp(bprime + fn.random.normal(bprime) * 5, 0, 256)
bprime = DALImath.clamp(bprime + fn.random.normal(bprime) * 10, 0, 256)

alpha = fn.rotate(alpha, angle = alpha_rot, size = [1920, 1080])
alpha = fn.rotate(alpha, angle = alpha_rot, keep_size = True)
alpha = fn.gaussian_blur(alpha, sigma = alpha_blur, window_size = 3)
alpha = fn.flip(alpha, horizontal = fg_flip_chance)

Expand All @@ -212,19 +210,24 @@ def __next__(self):

tensor_dict = next(self.loader)[0]

fg = tensor_dict['fg'].permute(0,3,2,1)
bg = tensor_dict['bg'].permute(0,3,2,1)
bprime = tensor_dict['bprime'].permute(0,3,2,1)
alpha = tensor_dict['alpha'].permute(0,3,2,1)
fg = tensor_dict['fg'].permute(0,3,1,2)
bg = tensor_dict['bg'].permute(0,3,1,2)
bprime = tensor_dict['bprime'].permute(0,3,1,2)
alpha = tensor_dict['alpha'].permute(0,3,1,2)
alpha = alpha[:, :1]

png = torch.cat([fg, alpha], 1)
bg = bg.float()/256
fg = fg.float()/256
bprime = bprime.float()/256
alpha = alpha.float()/256

png = torch.cat([fg, alpha], dim = 1)

bg_trans_x = np.random.randint(-100, 100)
bg_trans_y = np.random.randint(-100, 100)
bg_shear_x = np.random.randint(-5, 6)
bg_shear_y = np.random.randint(-5, 6)
bg_scale = np.random.randint(8, 13) / 10
bg_scale = np.random.randint(10, 13) / 10

aug_bg_params = {

Expand Down Expand Up @@ -277,22 +280,24 @@ def __next__(self):
}
aug_png_tensor = TF.affine(**aug_png_params)

aug_fg_tensor = aug_png_tensor[:, :3, :, :]
aug_alpha_tensor = aug_png_tensor[:, 3:4, :, :]
aug_fg_tensor = aug_png_tensor[:, :3]
aug_alpha_tensor = aug_png_tensor[:, 3:4]


if(np.random.randint(0, 10) > 6):
shadow_x = np.random.randint(0, 200)
shadow_y = np.random.randint(0, 200)
shadow_shear = np.random.randint(-30, 30)
shadow_rotation = np.random.randint(-30, 30)
shadow_strength = np.random.randint(10, 90) / 100
shadow_strength = np.random.randint(20, 80) / 100
shadow_blur = np.random.randint(2, 16) * 2 + 1

shadow_stamp = TF.affine(aug_alpha_tensor, translate = [shadow_x, shadow_y], shear = shadow_shear, angle = shadow_rotation, scale = 1)
shadow_stamp = TF.gaussian_blur(shadow_stamp, shadow_blur)
shadow_stamp = shadow_stamp * shadow_strength

aug_bg_tensor = aug_bg_tensor - aug_bg_tensor * shadow_stamp
aug_bg_tensor = aug_bg_tensor - (aug_bg_tensor * shadow_stamp)


return (aug_bg_tensor, aug_fg_tensor, aug_bprime_tensor, aug_alpha_tensor)

Expand Down
29 changes: 20 additions & 9 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import os

from model_definition import *
from coarse_definition import CoarseMatteGenerator
from refine_definition import RefinementNetwork
from train_utils import get_image_patches, replace_image_patches, color_ramp

device = "cuda"
Expand Down Expand Up @@ -100,9 +101,12 @@

class UserInputDataset(Dataset):

def __init__(self):
def __init__(self, depth = 5, stride = 3):
super().__init__()

self.depth = depth
self.stride = stride

return

def __len__(self):
Expand All @@ -120,12 +124,15 @@ def __getitem__(self, source_idx):

#always temporally center our search
background_start_idx = int(source_idx / source_len * background_len)
search_start_idx = max(0, background_start_idx - search_width)
search_end_idx = min(background_len, background_start_idx + search_width)
search_start_idx = max(0, background_start_idx - self.depth*self.stride)
search_end_idx = min(background_len, background_start_idx + self.depth*self.stride)
best_background = np.zeros_like(source_img)

for background_idx, background_name in enumerate(background_list\
[search_start_idx: search_end_idx], start = search_start_idx):
#getting strided files is a hassle, because I can't use enumerate to get the indices.

for background_idx, background_name in zip(range(search_start_idx, search_end_idx, self.stride),\
(background_list\
[search_start_idx : search_end_idx : self.stride], start = search_start_idx)):

matches = matcher.match(background_des_list[background_idx], source_des_list[source_idx], None)

Expand Down Expand Up @@ -209,10 +216,14 @@ def __getitem__(self, source_idx):
#Now, feed the outputs of the coarse generator into the refinement network, which will refine patches.
fake_refined_patches = Refine(start_patches, middle_patches)

mega_upscaled_fake_coarse_alpha = F.interpolate(fake_coarse_alpha.detach(), size = [input_tensor.shape[-2], input_tensor.shape[-1]])
fake_refined_alpha = replace_image_patches(images = mega_upscaled_fake_coarse_alpha, patches = fake_refined_patches, indices = indices)
fake_refined_patches = refine(start_patches, middle_patches)

mega_upscaled_fake_coarse = F.interpolate(fake_coarse[:, :4].detach(), size = input_tensor.shape[-2:])
fake_refined = replace_image_patches(images = mega_upscaled_fake_coarse, patches = fake_refined_patches, indices = indices)
fake_refined_alpha = color_ramp(0.05, 0.95, torch.clamp(fake_refined[:, 0:1], 0, 1))
fake_refined_foreground = torch.clamp(fake_refined[:, 1:4] + composite_tensor, 0, 1)

RGBA = torch.cat([input_tensor[:, :3], fake_refined_alpha], 1)
RGBA = torch.cat([fake_refined_foreground, fake_refined_alpha], 1)

for j in range(input_tensor.shape[0]):
image = transforms.ToPILImage()(RGBA[j])
Expand Down
47 changes: 47 additions & 0 deletions refine_definition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from tqdm import tqdm

class RefinementNetwork(nn.Module):

def __init__(self, coarse_channels = 37, input_channels = 6, patch_size = 8):
super().__init__()

self.input_channels = input_channels
self.coarse_channels = coarse_channels
self.concat_channels = self.input_channels + self.coarse_channels

#subtracting one from the input channels here because we aren't using the error map as an input
self.conv1 = nn.Conv2d(self.concat_channels, 24, kernel_size = 3)
self.conv2 = nn.Conv2d(24, 16, kernel_size = 3)
self.conv3 = nn.Conv2d(16 + self.input_channels, 12, kernel_size = 3)
self.conv4 = nn.Conv2d(12, 4, kernel_size = 3)

self.bn1 = nn.BatchNorm2d(24)
self.bn2 = nn.BatchNorm2d(16)
self.bn3 = nn.BatchNorm2d(12)

self.activation = nn.ReLU()


def forward(self, start_patches, middle_patches):

z1 = self.conv1(start_patches)
z1 = self.bn1(z1)
x1 = self.activation(z1)

z2 = self.conv2(x1)
z2 = self.bn2(z2)
x2 = self.activation(z2)
x2 = F.interpolate(x2, size = middle_patches.shape[-2:])

z3 = torch.cat([x2, middle_patches], 1)
z3 = self.conv3(z3)
z3 = self.bn3(z3)
x3 = self.activation(z3)

z4 = self.conv4(x3)

return z4
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ numpy
kornia
more-itertools
Pillow
cupy-cuda101
random
24 changes: 12 additions & 12 deletions train_both.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

}

batch_size = 4
batch_size = 2

#Initialize the dataset and the loader to feed the data to the network.

Expand All @@ -62,7 +62,7 @@

Pipeline = AugmentationPipeline(\
dataset = ImageFeeder, \
num_threads = 4, \
num_threads = 8, \
device_id = 0,
batch_size = batch_size)

Expand All @@ -78,7 +78,7 @@
num_hidden_channels = 32

#Initialize the network which will produce a coarse alpha (1 chan), foreground (3 chan) confidence map (1 chan), and a number of hidden channels (num_hidden_channels chan)...
coarse = torch.load("model_saves/coarse_generator_network_epoch_7.zip").train().to(device)
coarse = torch.load("model_saves/proper_coarse_epoch195000.zip").train().to(device)
refine = RefinementNetwork(coarse_channels = 5 + num_hidden_channels).train().to(device)

use_amp = False
Expand All @@ -100,10 +100,6 @@
with torch.cuda.amp.autocast(enabled = use_amp):

real_background, real_foreground, real_bprime, real_alpha = next(DALIDataloader)
real_background = real_background.float()/256
real_foreground = real_foreground.float()/256
real_bprime = real_bprime.float()/256
real_alpha = real_alpha.float()/256

"""
real_foreground = real_foreground.to(device)
Expand Down Expand Up @@ -191,7 +187,7 @@
iteration += 1


if(iteration % 500 == 0):
if(iteration % 1000 == 0):
image = fake_coarse_alpha[0]
image = transforms.ToPILImage()(image)
image.save(f'outputs7/{iteration}C_fake_coarse_alpha.jpg')
Expand All @@ -215,18 +211,22 @@
image = fake_refined_foreground[0]
image = transforms.ToPILImage()(image)
image.save(f'outputs7/{iteration}F_refined_foreground.jpg')

image = real_coarse_composite[0]
image = transforms.ToPILImage()(image)
image.save(f'outputs7/{iteration}E_coarse_composite.jpg')


if(iteration % 500 == 0):
if(iteration % 1000 == 0):

print(coarse_loss)
print(refine_loss)


if(iteration % 10000 == 0):
if(iteration % 15000 == 0):

torch.save(coarse, f"./model_saves/coarse_generator_network_epoch_{epoch}.zip")
torch.save(refine, f"./model_saves/refinement_network_epoch_{epoch}.zip")
torch.save(coarse, f"./model_saves/coarse_generator_network_epoch_{iteration}.zip")
torch.save(refine, f"./model_saves/refinement_network_epoch_{iteration}.zip")


print('\nTraining completed successfully.')
Expand Down
19 changes: 10 additions & 9 deletions train_coarse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

Pipeline = AugmentationPipeline(\
dataset = ImageFeeder, \
num_threads = 4, \
num_threads = 8, \
device_id = 0,
batch_size = batch_size)

Expand Down Expand Up @@ -94,10 +94,6 @@
with torch.cuda.amp.autocast(enabled = use_amp):

real_background, real_foreground, real_bprime, real_alpha = next(DALIDataloader)
real_background = real_background.float()/256
real_foreground = real_foreground.float()/256
real_bprime = real_bprime.float()/256
real_alpha = real_alpha.float()/256

"""
real_foreground = real_foreground.to(device)
Expand All @@ -108,6 +104,7 @@

#Composite the augmented foreground onto the augmented background according to the augmented alpha.
composite_tensor = composite(real_background, real_foreground, real_alpha)


#return the input tensor (composite plus b-prime) and the alpha_tensor. The input tensor is just a bunch of channels, the real_alpha is the central (singular) alpha
#corresponding to the target frame.
Expand All @@ -131,7 +128,7 @@
fake_coarse_hidden_channels = torch.relu(fake_coarse[:,5:])

#The real error map is calculated as the squared difference between the real alpha and the fake alpha.
real_coarse_error = torch.abs(real_coarse_alpha.detach()-fake_coarse_alpha.detach())
real_coarse_error = torch.clamp(torch.abs(real_coarse_alpha.detach()-fake_coarse_alpha.detach()), 0, 1)

#construct the fake foreground
#fake_coarse_foreground = torch.clamp(real_coarse_composite[:, dataset_params["comp_context_depth"]*3:dataset_params["comp_context_depth"]*3 + 3] + fake_coarse_foreground_residual, 0, 1)
Expand All @@ -157,7 +154,7 @@
iteration += 1


if(iteration % 500 == 0):
if(iteration % 1000 == 0):
image = fake_coarse_alpha[0]
image = transforms.ToPILImage()(image)
image.save(f'outputs7/{iteration}C_fake_coarse_alpha.jpg')
Expand All @@ -173,13 +170,17 @@
image = fake_coarse_error[0]
image = transforms.ToPILImage()(image)
image.save(f'outputs7/{iteration}D_fake_error.jpg')

image = real_coarse_composite[0]
image = transforms.ToPILImage()(image)
image.save(f'outputs7/{iteration}E_coarse_composite.jpg')


if(iteration % 500 == 0):
if(iteration % 1000 == 0):

print(coarse_loss)

if(iteration % 8000 == 0):
if(iteration % 15000 == 0):

torch.save(coarse, f"./model_saves/proper_coarse_epoch{iteration}.zip")

Expand Down

0 comments on commit 95e23bf

Please sign in to comment.