Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can it work with a single image? #17

Open
aligoglos opened this issue Jan 2, 2021 · 4 comments
Open

Can it work with a single image? #17

aligoglos opened this issue Jan 2, 2021 · 4 comments

Comments

@aligoglos
Copy link

I wrote simple code to run model on a single image but result is gray still !!
minimal demo :

 import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
import argparse
import subprocess
import utils
import glob


def main():
	device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
	refimgs = None
	disable_colorization = False
	# Load remaster network
	modelR = __import__( 'model.remasternet', fromlist=['NetworkR'] ).NetworkR()
	state_dict = torch.load( 'remasternet.pth' )
	modelR.load_state_dict( state_dict['modelR'] )
	modelR = modelR.to(device)
	modelR.eval()
	if not disable_colorization:
		modelC = __import__( 'model.remasternet', fromlist=['NetworkC'] ).NetworkC()
		modelC.load_state_dict( state_dict['modelC'] )
		modelC = modelC.to(device)
		modelC.eval()
	paths = sorted(glob.glob('./inputs' + '/*'))
	for path in paths:
		image = cv2.imread(path)
		if ~(image is None):
			name = path.split('\\')[-1]
			print(name)
			refimgs = cv2.imread(F"./references/{name}")
			with torch.no_grad():
				gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
				frame_l = torch.from_numpy(gray).view( gray.shape[0], gray.shape[1], 1 )
				frame_l = frame_l.permute(2, 0, 1).float() # HWC to CHW
				frame_l /= 255.
				frame_l = frame_l.view(1, frame_l.size(0), 1, frame_l.size(1), frame_l.size(2))
				input = frame_l.to( device )
				output_l = modelR( input )
				if refimgs is None:
					output_ab = modelC( output_l )
				else:
					refimgs = torch.from_numpy(refimgs)
					refimgs = refimgs.permute(2, 0, 1).float().unsqueeze(axis = 0).unsqueeze(axis = 0)
					refimgs /= 255.
					refimgs = refimgs.to( device )
					output_ab = modelC( output_l, refimgs )

				output_l = output_l.detach().cpu()
				output_ab = output_ab.detach().cpu()
				out_l = output_l[0,:,0,:,:]
				out_c = output_ab[0,:,0,:,:]
				output = torch.cat((out_l, out_c), dim=0).numpy().transpose((1, 2, 0))
				output = Image.fromarray( np.uint8( utils.convertLAB2RGB( output )*255 ) )
				output.save( F"./results/{name}" )

if __name__ == "__main__":
	main()

input image :
1

out put :
1

** Note : reference image is equal to input

@aligoglos aligoglos changed the title Could it work with a single image? Can it work with a single image? Jan 2, 2021
@zhaoyuzhi
Copy link

I also encounter the same issue (output grayscale when single frame as input). Have you addressed this issue?

@hermosayhl
Copy link

I encounter this issue, too. Is there anyone make it ?

@hermosayhl
Copy link

I also encounter the same issue (output grayscale when single frame as input). Have you addressed this issue?

Dr zhao, have you overcome this problem?

@Dawars
Copy link

Dawars commented Mar 17, 2024

You have to emulate multiple frames by duplicating the image to make the temporal convolutions work:

input = torch.tile(input, (1, 1, 5, 1, 1))

The network still isn't able to use colors from the reference images if they are significantly different from the gray image.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants