-
Notifications
You must be signed in to change notification settings - Fork 1
/
dbpns.py
54 lines (46 loc) · 1.79 KB
/
dbpns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import torch.nn as nn
import torch.optim as optim
from base_networks import *
from torchvision.transforms import *
class Net(nn.Module):
def __init__(self, num_channels, base_filter, feat, num_stages, scale_factor):
super(Net, self).__init__()
if scale_factor == 2:
kernel = 6
stride = 2
padding = 2
elif scale_factor == 4:
kernel = 8
stride = 4
padding = 2
elif scale_factor == 8:
kernel = 12
stride = 8
padding = 2
#Initial Feature Extraction
self.feat0 = ConvBlock(num_channels, feat, 3, 1, 1, activation='prelu', norm=None)
self.feat1 = ConvBlock(feat, base_filter, 1, 1, 0, activation='prelu', norm=None)
#Back-projection stages
self.up1 = UpBlock(base_filter, kernel, stride, padding)
self.down1 = DownBlock(base_filter, kernel, stride, padding)
self.up2 = UpBlock(base_filter, kernel, stride, padding)
#Reconstruction
self.output_conv = ConvBlock(num_stages*base_filter, num_channels, 3, 1, 1, activation=None, norm=None)
for m in self.modules():
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('ConvTranspose2d') != -1:
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.feat0(x)
x = self.feat1(x)
h1 = self.up1(x)
h2 = self.up2(self.down1(h1))
x = self.output_conv(torch.cat((h2, h1),1))
return x