Skip to content

Commit

Permalink
fixed resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
Lam Nguyen Tùng Lam committed Mar 8, 2023
1 parent 31ef7a2 commit 82a4252
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
4 changes: 2 additions & 2 deletions skeleton/layers/residual_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from torch.nn import functional as F

class ResidualBlock(nn.Module):
def __init__(self, out_channels, use_1x1conv=False, strides=1, kernel_size=3, padding=1):
def __init__(self,out_channels, use_1x1conv=False, strides=1, kernel_size=3, padding=1):
super().__init__()
self.conv1 = nn.LazyConv1d(out_channels, kernel_size=kernel_size, padding=padding,
stride=strides)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=padding)
self.conv2 = nn.LazyConv1d(out_channels, kernel_size=kernel_size, padding=padding)
if use_1x1conv:
self.conv3 = nn.LazyConv1d(out_channels, kernel_size=1, stride=strides)
else:
Expand Down
13 changes: 5 additions & 8 deletions skeleton/layers/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ def __init__(self, triples):
super(ResNet, self).__init__()
modules = []
modules.append(self.starting_block(128))
for i, triple in enumerate(triples):
for _, triple in enumerate(triples):
num_residuals, out_channels = triple[0], triple[1]
block = self.block(num_residuals, out_channels, first_block=(i==0))
block = self.block(num_residuals,out_channels)
modules.append(block)
modules.append(nn.Sequential(nn.ReLU(), nn.AdaptiveAvgPool1d(3)))

Expand All @@ -32,13 +32,10 @@ def starting_block(self, input_channels):
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=2, padding=1))

def block(self, num_residuals, out_channels, first_block = False):
def block(self, num_residuals, out_channels):
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(ResidualBlock(out_channels, use_1x1conv=True, strides=2))
else:
blk.append(ResidualBlock(out_channels))
for _ in range(num_residuals):
blk.append(ResidualBlock(out_channels, use_1x1conv=True))
return nn.Sequential(*blk)

def forward(self, x):
Expand Down

0 comments on commit 82a4252

Please sign in to comment.