From d8d7b4b6161374d00124ade0d468497aaf328a2b Mon Sep 17 00:00:00 2001 From: Emil Zakirov Date: Wed, 8 Apr 2020 19:01:22 +0300 Subject: [PATCH] add stochastic depth to resnet --- pytorch_tools/models/resnet.py | 14 +++++++++++++ pytorch_tools/models/tresnet.py | 6 ++++++ pytorch_tools/modules/residual.py | 16 +++++++++------ pytorch_tools/segmentation_models/unet.py | 2 +- tests/models/test_models.py | 5 +++++ tests/models/test_weights.py | 25 ++++++++++++++++++++++- 6 files changed, 60 insertions(+), 8 deletions(-) diff --git a/pytorch_tools/models/resnet.py b/pytorch_tools/models/resnet.py index 11d8ce8..9dba9f0 100644 --- a/pytorch_tools/models/resnet.py +++ b/pytorch_tools/models/resnet.py @@ -70,6 +70,9 @@ class ResNet(nn.Module): Flag to overwrite forward pass to return 5 tensors with different resolutions. Defaults to False. drop_rate (float): Dropout probability before classifier, for training. Defaults to 0.0. + drop_connect_rate (float): + Drop rate for StochasticDepth. Randomly removes samples each block. Used as regularization during training. + keep prob will be linearly decreased from 1 to 1 - drop_connect_rate each block. Ref: https://arxiv.org/abs/1603.09382 global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'. Defaults to 'avg'. init_bn0 (bool): @@ -95,6 +98,7 @@ def __init__( antialias=False, encoder=False, drop_rate=0.0, + drop_connect_rate=0.0, global_pool="avg", init_bn0=True, ): @@ -108,6 +112,9 @@ def __init__( self.block = block self.expansion = block.expansion self.norm_act = norm_act + self.block_idx = 0 + self.num_blocks = sum(layers) + self.drop_connect_rate = drop_connect_rate super(ResNet, self).__init__() if deep_stem: @@ -185,6 +192,7 @@ def _make_layer( norm_layer=norm_layer, norm_act=norm_act, antialias=antialias, + keep_prob=self.keep_prob, ) ] @@ -201,6 +209,7 @@ def _make_layer( norm_layer=norm_layer, norm_act=norm_act, antialias=antialias, + keep_prob=self.keep_prob, ) ) return nn.Sequential(*layers) @@ -266,6 +275,11 @@ def load_state_dict(self, state_dict, **kwargs): state_dict[k.replace("layer0.", "")] = state_dict.pop(k) super().load_state_dict(state_dict, **kwargs) + @property + def keep_prob(self): + keep_prob = 1 - self.drop_connect_rate * self.block_idx / self.num_blocks + self.block_idx += 1 + return keep_prob # fmt: off CFGS = { diff --git a/pytorch_tools/models/tresnet.py b/pytorch_tools/models/tresnet.py index 05a6e26..5d2dfb4 100644 --- a/pytorch_tools/models/tresnet.py +++ b/pytorch_tools/models/tresnet.py @@ -51,6 +51,8 @@ class TResNet(ResNet): Flag to overwrite forward pass to return 5 tensors with different resolutions. Defaults to False. drop_rate (float): Dropout probability before classifier, for training. Defaults to 0.0. to 'avg'. + drop_connect_rate (float): + Drop rate for StochasticDepth. Randomly removes samples each block. Used as regularization during training. Ref: https://arxiv.org/abs/1603.09382 """ def __init__( @@ -65,6 +67,7 @@ def __init__( norm_act="leaky_relu", encoder=False, drop_rate=0.0, + drop_connect_rate=0.0, ): nn.Module.__init__(self) stem_width = int(64 * width_factor) @@ -74,6 +77,9 @@ def __init__( self.groups = 1 # not really used but needed inside _make_layer self.base_width = 64 # used inside _make_layer self.norm_act = norm_act + self.block_idx = 0 + self.num_blocks = sum(layers) + self.drop_connect_rate = drop_connect_rate # in the paper they use conv1x1 but in code conv3x3 (which seems better) self.conv1 = nn.Sequential(SpaceToDepth(), conv3x3(in_channels * 16, stem_width)) diff --git a/pytorch_tools/modules/residual.py b/pytorch_tools/modules/residual.py index eb85394..c8da1f4 100644 --- a/pytorch_tools/modules/residual.py +++ b/pytorch_tools/modules/residual.py @@ -151,6 +151,7 @@ def __init__( norm_layer=ABN, norm_act="relu", antialias=False, + keep_prob=1, ): super(BasicBlock, self).__init__() antialias = antialias and stride == 2 @@ -167,6 +168,7 @@ def __init__( self.downsample = downsample self.blurpool = BlurPool(channels=planes) if antialias else nn.Identity() self.antialias = antialias + self.drop_connect = DropConnect(keep_prob) if keep_prob < 1 else nn.Identity() def forward(self, x): residual = x @@ -180,11 +182,11 @@ def forward(self, x): if self.antialias: out = self.blurpool(out) out = self.conv2(out) - # avoid 2 inplace ops by chaining into one long op. Neede for inplaceabn + # avoid 2 inplace ops by chaining into one long op. Needed for inplaceabn if self.se_module is not None: - out = self.se_module(self.bn2(out)) + residual + out = self.drop_connect(self.se_module(self.bn2(out))) + residual else: - out = self.bn2(out) + residual + out = self.drop_connect(self.bn2(out)) + residual return self.final_act(out) @@ -204,6 +206,7 @@ def __init__( norm_layer=ABN, norm_act="relu", antialias=False, + keep_prob=1, # for drop connect ): super(Bottleneck, self).__init__() antialias = antialias and stride == 2 @@ -222,6 +225,7 @@ def __init__( self.downsample = downsample self.blurpool = BlurPool(channels=width) if antialias else nn.Identity() self.antialias = antialias + self.drop_connect = DropConnect(keep_prob) if keep_prob < 1 else nn.Identity() def forward(self, x): residual = x @@ -241,9 +245,9 @@ def forward(self, x): out = self.conv3(out) # avoid 2 inplace ops by chaining into one long op if self.se_module is not None: - out = self.se_module(self.bn3(out)) + residual + out = self.drop_connect(self.se_module(self.bn3(out))) + residual else: - out = self.bn3(out) + residual + out = self.drop_connect(self.bn3(out)) + residual return self.final_act(out) # TResnet models use slightly modified versions of BasicBlock and Bottleneck @@ -292,5 +296,5 @@ def forward(self, x): out = self.conv3(out) # avoid 2 inplace ops by chaining into one long op - out = self.bn3(out) + residual + out = self.drop_connect(self.bn3(out)) + residual return self.final_act(out) \ No newline at end of file diff --git a/pytorch_tools/segmentation_models/unet.py b/pytorch_tools/segmentation_models/unet.py index 3778428..d784e37 100644 --- a/pytorch_tools/segmentation_models/unet.py +++ b/pytorch_tools/segmentation_models/unet.py @@ -38,7 +38,7 @@ def __init__( self.layer3 = UnetDecoderBlock(in_channels[2], out_channels[2], **bn_params) self.layer4 = UnetDecoderBlock(in_channels[3], out_channels[3], **bn_params) self.layer5 = UnetDecoderBlock(in_channels[4], out_channels[4], **bn_params) - self.dropout = nn.Dropout2d(drop_rate, inplace=True) + self.dropout = nn.Dropout2d(drop_rate, inplace=False) # inplace=True raises a backprop error self.final_conv = conv1x1(out_channels[4], final_channels) initialize(self) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 3dd58ff..0cce9db 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -73,6 +73,11 @@ def test_dilation(arch, output_stride): W, H = INP.shape[-2:] assert res.shape[-2:] == (W // output_stride, H // output_stride) +@pytest.mark.parametrize("arch", TEST_MODEL_NAMES) +def test_drop_connect(arch): + m = models.__dict__[arch](drop_connect_rate=0.2) + _test_forward(m) + NUM_PARAMS = { "tresnetm": 31389032, "tresnetl": 55989256, diff --git a/tests/models/test_weights.py b/tests/models/test_weights.py index 321bc4b..7687dac 100644 --- a/tests/models/test_weights.py +++ b/tests/models/test_weights.py @@ -1,8 +1,9 @@ ## test that imagenet pretrained weights are valid and able to classify correctly the cat and dog +import torch +import pytest import numpy as np from PIL import Image -import pytest from pytorch_tools.utils.preprocessing import get_preprocessing_fn from pytorch_tools.utils.visualization import tensor_from_rgb_image @@ -53,3 +54,25 @@ def test_imagenet_pretrain(arch): im = im.view(1, *im.shape).float() pred_cls = m(im).argmax() assert pred_cls == im_cls + +# test that output mean for fixed input is the same +MODEL_NAMES2 = [ + "resnet34", + "se_resnet50", + "efficientnet_b0", +] + +MODEL_MEAN = { + "resnet34": 7.6799e-06, + "se_resnet50": -2.6095e-06, + "efficientnet_b0": 0.0070, +} + +@pytest.mark.parametrize("arch", MODEL_NAMES2) +def test_output_mean(arch): + m = models.__dict__[arch](pretrained="imagenet") + m.eval() + inp = torch.ones(1, 3, 256, 256) + with torch.no_grad(): + out = m(inp).mean().numpy() + assert np.allclose(out, MODEL_MEAN[arch], rtol=1e-4, atol=1e-4) \ No newline at end of file