-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
89 lines (76 loc) · 2.76 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
def gradient(img):
height = img.size(2)
width = img.size(3)
gradient_h = (img[:,:,2:,:]-img[:,:,:height-2,:]).abs()
gradient_w = (img[:, :, :, 2:] - img[:, :, :, :width-2]).abs()
return gradient_h, gradient_w
def tv_loss(illumination):
gradient_illu_h, gradient_illu_w = gradient(illumination)
loss_h = gradient_illu_h
loss_w = gradient_illu_w
loss = loss_h.mean() + loss_w.mean()
return loss
def C_loss(R1, R2):
loss = torch.nn.MSELoss()(R1, R2)
return loss
def R_loss(L1, R1, im1, X1):
max_rgb1, _ = torch.max(im1, 1)
max_rgb1 = max_rgb1.unsqueeze(1)
loss1 = torch.nn.MSELoss()(L1*R1, X1) + torch.nn.MSELoss()(R1, X1/L1.detach())
loss2 = torch.nn.MSELoss()(L1, max_rgb1) + tv_loss(L1)
return loss1 + loss2
def P_loss(im1, X1):
loss = torch.nn.MSELoss()(im1, X1)
return loss
def joint_RGB_horizontal(im1, im2):
if im1.size==im2.size:
w, h = im1.size
result = Image.new('RGB',(w*2, h))
result.paste(im1, box=(0,0))
result.paste(im2, box=(w,0))
return result
def joint_L_horizontal(im1, im2):
if im1.size==im2.size:
w, h = im1.size
result = Image.new('L',(w*2, h))
result.paste(im1, box=(0,0))
result.paste(im2, box=(w,0))
return result
class Fusionloss(nn.Module):
def __init__(self):
super(Fusionloss, self).__init__()
self.sobelconv=Sobelxy()
def forward(self,image_vis,image_ir,generate_img):
image_y=image_vis[:,:1,:,:]
x_in_max=torch.max(image_y,image_ir)
loss_in=F.l1_loss(x_in_max,generate_img)
y_grad=self.sobelconv(image_y)
ir_grad=self.sobelconv(image_ir)
generate_img_grad=self.sobelconv(generate_img)
x_grad_joint=torch.max(y_grad,ir_grad)
loss_grad=F.l1_loss(x_grad_joint,generate_img_grad)
loss_total=loss_in+10*loss_grad
return loss_total,loss_in,loss_grad
class Sobelxy(nn.Module):
def __init__(self):
super(Sobelxy, self).__init__()
kernelx = [[-1, 0, 1],
[-2,0 , 2],
[-1, 0, 1]]
kernely = [[1, 2, 1],
[0,0 , 0],
[-1, -2, -1]]
kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
def forward(self,x):
sobelx=F.conv2d(x, self.weightx, padding=1)
sobely=F.conv2d(x, self.weighty, padding=1)
return torch.abs(sobelx)+torch.abs(sobely)