-
Notifications
You must be signed in to change notification settings - Fork 37
/
resnet_blocks.py
38 lines (25 loc) · 998 Bytes
/
resnet_blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch.nn as nn
import torch.nn.functional as F
class IdentityLayer(nn.Module):
def forward(self, x):
return x
class ResNetBlock(nn.Module):
def __init__(self, in_size=16, out_size=16, downsample = False):
super(ResNetBlock,self).__init__()
self.out_size = out_size
self.in_size = in_size
if downsample:
self.stride1 = 2
self.reslayer = nn.Conv2d(in_channels=self.in_size, out_channels=self.out_size, stride=2, kernel_size=1)
else:
self.stride1 = 1
self.reslayer = IdentityLayer()
self.bn1 = nn.BatchNorm2d(out_size)
self.bn2 = nn.BatchNorm2d(out_size)
def forward(self, x, conv1_w, conv2_w):
residual = self.reslayer(x)
out = F.relu(self.bn1(F.conv2d(x, conv1_w, stride=self.stride1, padding=1)), inplace=True)
out = self.bn2(F.conv2d(out, conv2_w, padding=1))
out += residual
out = F.relu(out)
return out