From 82a42526e0d16d5a572ea307ea9a9430ae56abbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lam=20Nguyen=20T=C3=B9ng=20Lam?= Date: Wed, 8 Mar 2023 18:15:23 +0100 Subject: [PATCH] fixed resnet --- skeleton/layers/residual_block.py | 4 ++-- skeleton/layers/resnet.py | 13 +++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/skeleton/layers/residual_block.py b/skeleton/layers/residual_block.py index 57ab480..a07cd62 100644 --- a/skeleton/layers/residual_block.py +++ b/skeleton/layers/residual_block.py @@ -3,11 +3,11 @@ from torch.nn import functional as F class ResidualBlock(nn.Module): - def __init__(self, out_channels, use_1x1conv=False, strides=1, kernel_size=3, padding=1): + def __init__(self,out_channels, use_1x1conv=False, strides=1, kernel_size=3, padding=1): super().__init__() self.conv1 = nn.LazyConv1d(out_channels, kernel_size=kernel_size, padding=padding, stride=strides) - self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=padding) + self.conv2 = nn.LazyConv1d(out_channels, kernel_size=kernel_size, padding=padding) if use_1x1conv: self.conv3 = nn.LazyConv1d(out_channels, kernel_size=1, stride=strides) else: diff --git a/skeleton/layers/resnet.py b/skeleton/layers/resnet.py index bbecdf1..8fd088e 100644 --- a/skeleton/layers/resnet.py +++ b/skeleton/layers/resnet.py @@ -9,9 +9,9 @@ def __init__(self, triples): super(ResNet, self).__init__() modules = [] modules.append(self.starting_block(128)) - for i, triple in enumerate(triples): + for _, triple in enumerate(triples): num_residuals, out_channels = triple[0], triple[1] - block = self.block(num_residuals, out_channels, first_block=(i==0)) + block = self.block(num_residuals,out_channels) modules.append(block) modules.append(nn.Sequential(nn.ReLU(), nn.AdaptiveAvgPool1d(3))) @@ -32,13 +32,10 @@ def starting_block(self, input_channels): nn.ReLU(), nn.MaxPool1d(kernel_size=3, stride=2, padding=1)) - def block(self, num_residuals, out_channels, first_block = False): + def block(self, num_residuals, out_channels): blk = [] - for i in range(num_residuals): - if i == 0 and not first_block: - blk.append(ResidualBlock(out_channels, use_1x1conv=True, strides=2)) - else: - blk.append(ResidualBlock(out_channels)) + for _ in range(num_residuals): + blk.append(ResidualBlock(out_channels, use_1x1conv=True)) return nn.Sequential(*blk) def forward(self, x):