Skip to content

Commit

Permalink
add support for pretrained + custom in channels. add tresnet encoders
Browse files Browse the repository at this point in the history
  • Loading branch information
bonlime committed Apr 9, 2020
1 parent d8d7b4b commit 1e86b17
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 4 deletions.
3 changes: 3 additions & 0 deletions pytorch_tools/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytorch_tools.utils.misc import add_docs_for
from pytorch_tools.utils.misc import make_divisible
from pytorch_tools.utils.misc import DEFAULT_IMAGENET_SETTINGS
from pytorch_tools.utils.misc import repeat_channels

# avoid overwriting doc string
wraps = partial(wraps, assigned=("__module__", "__name__", "__qualname__", "__annotations__"))
Expand Down Expand Up @@ -420,6 +421,8 @@ def _efficientnet(arch, pretrained=None, **kwargs):
)
state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels
state_dict["conv_stem.weight"] = repeat_channels(state_dict["conv_stem.weight"], kwargs["in_channels"])
model.load_state_dict(state_dict)
patch_bn(model) # adjust epsilon
setattr(model, "pretrained_settings", cfg_settings)
Expand Down
7 changes: 7 additions & 0 deletions pytorch_tools/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_tools.modules import bn_from_name
from pytorch_tools.utils.misc import add_docs_for
from pytorch_tools.utils.misc import DEFAULT_IMAGENET_SETTINGS
from pytorch_tools.utils.misc import repeat_channels

# avoid overwriting doc string
wraps = partial(wraps, assigned=("__module__", "__name__", "__qualname__", "__annotations__"))
Expand Down Expand Up @@ -471,6 +472,12 @@ def _resnet(arch, pretrained=None, **kwargs):
# if there is last_linear in state_dict, it's going to be overwritten
state_dict["fc.weight"] = model.state_dict()["last_linear.weight"]
state_dict["fc.bias"] = model.state_dict()["last_linear.bias"]
# support pretrained for custom input channels
# layer0. is needed to support se_resne(x)t weights
if kwargs.get("in_channels", 3) != 3:
old_weights = state_dict.get("conv1.weight")
old_weights = state_dict.get("layer0.conv1.weight") if old_weights is None else old_weights
state_dict["layer0.conv1.weight"] = repeat_channels(old_weights, kwargs["in_channels"])
model.load_state_dict(state_dict)
setattr(model, "pretrained_settings", cfg_settings)
return model
Expand Down
8 changes: 8 additions & 0 deletions pytorch_tools/models/tresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch_tools.modules import bn_from_name
from pytorch_tools.modules import ABN
from pytorch_tools.utils.misc import add_docs_for
from pytorch_tools.utils.misc import repeat_channels

# avoid overwriting doc string
wraps = partial(wraps, assigned=("__module__", "__name__", "__qualname__", "__annotations__"))
Expand Down Expand Up @@ -49,6 +50,8 @@ class TResNet(ResNet):
Activation for normalizion layer. It's reccomended to use `leacky_relu` with `inplaceabn`.
encoder (bool):
Flag to overwrite forward pass to return 5 tensors with different resolutions. Defaults to False.
NOTE: TResNet first features have resolution 4x times smaller than input, not 2x as all other models.
So it CAN'T be used as encoder in Unet and Linknet models
drop_rate (float):
Dropout probability before classifier, for training. Defaults to 0.0. to 'avg'.
drop_connect_rate (float):
Expand Down Expand Up @@ -119,6 +122,9 @@ def __init__(
self._initialize_weights(init_bn0=True)

def load_state_dict(self, state_dict, **kwargs):
if self.encoder:
state_dict.pop("last_linear.weight")
state_dict.pop("last_linear.bias")
nn.Module.load_state_dict(self, state_dict, **kwargs)

# fmt: off
Expand Down Expand Up @@ -209,6 +215,8 @@ def _resnet(arch, pretrained=None, **kwargs):
# if there is last_linear in state_dict, it's going to be overwritten
state_dict["last_linear.weight"] = model.state_dict()["last_linear.weight"]
state_dict["last_linear.bias"] = model.state_dict()["last_linear.bias"]
if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels
state_dict["conv1.1.weight"] = repeat_channels(state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16)
model.load_state_dict(state_dict)
# need to adjust some parameters to be align with original model
patch_blur_pool(model)
Expand Down
5 changes: 3 additions & 2 deletions pytorch_tools/segmentation_models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@
"efficientnet_b5": (2048, 128, 64, 40, 24),
"efficientnet_b6": (2304, 144, 72, 40, 32),
"efficientnet_b7": (2560, 160, 80, 48, 32),
"tresnetm": (2048, 1024, 128, 64, 64),
"tresnetl": (2432, 1216, 152, 76, 76),
"tresnetxl": (2656, 1328, 166, 83, 83),
}


def get_encoder(name, **kwargs):
if name not in models.__dict__:
raise ValueError(f"No such encoder: {name}")
kwargs["encoder"] = True
# if 'resne' in name:
# kwargs['dilated'] = True # dilate resnets for better performance
kwargs["pretrained"] = kwargs.pop("encoder_weights")
m = models.__dict__[name](**kwargs)
m.out_shapes = ENCODER_SHAPES[name]
Expand Down
13 changes: 13 additions & 0 deletions pytorch_tools/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import math
import time
import torch
import random
Expand Down Expand Up @@ -203,3 +204,15 @@ def make_divisible(v, divisor=8):
if new_v < 0.9 * v: # ensure round down does not go down by more than 10%.
new_v += divisor
return new_v

def repeat_channels(conv_weights, new_channels, old_channels=3):
"""Repeat channels to match new number of input channels
Args:
conv_weights (torch.Tensor): shape [*, old_channels, *, *]
new_channels (int): desired number of channels
old_channels (int): original number of channels
"""
rep_times = math.ceil(new_channels / old_channels)
new_weights = conv_weights.repeat(1, rep_times, 1, 1)[:, :new_channels, :, :]
new_weights *= old_channels / new_channels # to keep the same output amplitude
return new_weights
9 changes: 7 additions & 2 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def test_custom_in_channels(arch):
with torch.no_grad():
m(torch.ones(2, 5, 128, 128))

@pytest.mark.parametrize("arch", EFFNET_NAMES[:2] + RESNET_NAMES[:2])
def test_pretrained_custom_in_channels(arch):
m = models.__dict__[arch](in_channels=5, pretrained="imagenet")
with torch.no_grad():
m(torch.ones(2, 5, 128, 128))


@pytest.mark.parametrize("arch", TEST_MODEL_NAMES)
def test_inplace_abn(arch):
Expand All @@ -73,7 +79,7 @@ def test_dilation(arch, output_stride):
W, H = INP.shape[-2:]
assert res.shape[-2:] == (W // output_stride, H // output_stride)

@pytest.mark.parametrize("arch", TEST_MODEL_NAMES)
@pytest.mark.parametrize("arch", EFFNET_NAMES[:2] + RESNET_NAMES[:2])
def test_drop_connect(arch):
m = models.__dict__[arch](drop_connect_rate=0.2)
_test_forward(m)
Expand All @@ -87,7 +93,6 @@ def test_drop_connect(arch):
"efficientnet_b2": 9109994,
"efficientnet_b3": 12233232,
}
# @pytest.mark.parametrize('name, num_params', NUM_PARAMS.values(), ids=list(NUM_PARAMS.keys()))
@pytest.mark.parametrize('name_num_params', zip(NUM_PARAMS.items()))
def test_num_parameters(name_num_params):
name, num_params = name_num_params[0]
Expand Down

0 comments on commit 1e86b17

Please sign in to comment.