diff --git a/README.md b/README.md index c29c172..cdfa53b 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Tool box for PyTorch for fast prototyping. * [TTA wrapper](./pytorch_tools/tta_wrapper/) - wrapper for easy test-time augmentation # Installation -Requeres GPU drivers and CUDA already installed. +Requires GPU drivers and CUDA already installed. `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" git+https://github.com/NVIDIA/apex.git` `pip install git+https://github.com/bonlime/pytorch-tools.git@master` diff --git a/pytorch_tools/__init__.py b/pytorch_tools/__init__.py index 85f6dc8..651659b 100644 --- a/pytorch_tools/__init__.py +++ b/pytorch_tools/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.4" +__version__ = "0.1.5" from . import fit_wrapper from . import losses @@ -9,3 +9,4 @@ from . import segmentation_models from . import tta_wrapper from . import utils +from . import detection_models diff --git a/pytorch_tools/detection_models/__init__.py b/pytorch_tools/detection_models/__init__.py index e69de29..04079d7 100644 --- a/pytorch_tools/detection_models/__init__.py +++ b/pytorch_tools/detection_models/__init__.py @@ -0,0 +1,12 @@ +from .retinanet import RetinaNet +from .retinanet import retinanet_r50_fpn +from .retinanet import retinanet_r101_fpn + +from .efficientdet import EfficientDet +from .efficientdet import efficientdet_d0 +from .efficientdet import efficientdet_d1 +from .efficientdet import efficientdet_d2 +from .efficientdet import efficientdet_d3 +from .efficientdet import efficientdet_d4 +from .efficientdet import efficientdet_d5 +from .efficientdet import efficientdet_d6 diff --git a/pytorch_tools/detection_models/efficientdet.py b/pytorch_tools/detection_models/efficientdet.py new file mode 100644 index 0000000..0446b56 --- /dev/null +++ b/pytorch_tools/detection_models/efficientdet.py @@ -0,0 +1,329 @@ +import logging +from copy import deepcopy +from functools import wraps + +import torch +import torch.nn as nn +from torch.hub import load_state_dict_from_url + +from pytorch_tools.modules import ABN +from pytorch_tools.modules.bifpn import BiFPN +from pytorch_tools.modules import bn_from_name +from pytorch_tools.modules.residual import conv1x1 +from pytorch_tools.modules.residual import conv3x3 +from pytorch_tools.modules.residual import DepthwiseSeparableConv +from pytorch_tools.modules.tf_same_ops import conv_to_same_conv +from pytorch_tools.modules.tf_same_ops import maxpool_to_same_maxpool + +from pytorch_tools.segmentation_models.encoders import get_encoder + +import pytorch_tools.utils.box as box_utils +from pytorch_tools.utils.misc import DEFAULT_IMAGENET_SETTINGS +from pytorch_tools.utils.misc import initialize_iterator + + +def patch_bn(module): + """TF ported weights use slightly different eps in BN. Need to adjust for better performance""" + if isinstance(module, ABN): + module.eps = 1e-3 + module.momentum = 1e-2 + for m in module.children(): + patch_bn(m) + + +class EfficientDet(nn.Module): + """TODO: add docstring""" + + def __init__( + self, + pretrained="coco", # Not used. here for proper signature + encoder_name="efficientnet_d0", + encoder_weights="imagenet", + pyramid_channels=64, + num_fpn_layers=3, + num_head_repeats=3, + num_classes=90, + drop_connect_rate=0, + encoder_norm_layer="abn", # TODO: set to frozenabn when ready + encoder_norm_act="swish", + decoder_norm_layer="abn", + decoder_norm_act="swish", + match_tf_same_padding=False, + ): + super().__init__() + self.encoder = get_encoder( + encoder_name, + norm_layer=encoder_norm_layer, + norm_act=encoder_norm_act, + encoder_weights=encoder_weights, + ) + norm_layer = bn_from_name(decoder_norm_layer) + bn_args = dict(norm_layer=norm_layer, norm_act=decoder_norm_act) + self.pyramid6 = nn.Sequential( + conv1x1(self.encoder.out_shapes[0], pyramid_channels, bias=True), + norm_layer(pyramid_channels, activation="identity"), + nn.MaxPool2d(3, stride=2, padding=1), + ) + self.pyramid7 = nn.MaxPool2d(3, stride=2, padding=1) # in EffDet it's a simple maxpool + + self.bifpn = BiFPN( + self.encoder.out_shapes[:-2], + pyramid_channels=pyramid_channels, + num_layers=num_fpn_layers, + **bn_args, + ) + + def make_head(out_size): + layers = [] + for _ in range(num_head_repeats): + # TODO: add drop connect + layers += [DepthwiseSeparableConv(pyramid_channels, pyramid_channels, use_norm=False)] + layers += [DepthwiseSeparableConv(pyramid_channels, out_size, use_norm=False)] + return nn.ModuleList(layers) + + # The convolution layers in the head are shared among all levels, but + # each level has its batch normalization to capture the statistical + # difference among different levels. + def make_head_norm(): + return nn.ModuleList( + [ + nn.ModuleList( + [ + norm_layer(pyramid_channels, activation=decoder_norm_act) + for _ in range(num_head_repeats) + ] + + [nn.Identity()] # no bn after last depthwise conv + ) + for _ in range(5) + ] + ) + + anchors_per_location = 9 # TODO: maybe allow to pass this arg? + self.cls_head_convs = make_head(num_classes * anchors_per_location) + self.cls_head_norms = make_head_norm() + self.box_head_convs = make_head(4 * anchors_per_location) + self.box_head_norms = make_head_norm() + self.num_classes = num_classes + self.num_head_repeats = num_head_repeats + + patch_bn(self) + self._initialize_weights() + if match_tf_same_padding: + conv_to_same_conv(self) + maxpool_to_same_maxpool(self) + + # Name from mmdetectin for convenience + def extract_features(self, x): + """Extract features from backbone + enchance with BiFPN""" + # don't use p2 and p1 + p5, p4, p3, _, _ = self.encoder(x) + # coarser FPN levels + p6 = self.pyramid6(p5) + p7 = self.pyramid7(p6) + features = [p7, p6, p5, p4, p3] + # enhance features + features = self.bifpn(features) + # want features from lowest OS to highest to align with `generate_anchors_boxes` function + features = list(reversed(features)) + return features + + def forward(self, x): + features = self.extract_features(x) + class_outputs = [] + box_outputs = [] + for feat, (cls_bns, box_bns) in zip(features, zip(self.cls_head_norms, self.box_head_norms)): + cls_feat, box_feat = feat, feat + # it looks like that with drop_connect there is an additional residual here + # TODO: need to investigate using pretrained weights + for cls_conv, cls_bn in zip(self.cls_head_convs, cls_bns): + cls_feat = cls_bn(cls_conv(cls_feat)) + for box_conv, box_bn in zip(self.box_head_convs, box_bns): + box_feat = box_bn(box_conv(box_feat)) + + box_feat = box_feat.permute(0, 2, 3, 1) + box_outputs.append(box_feat.contiguous().view(box_feat.shape[0], -1, 4)) + + cls_feat = cls_feat.permute(0, 2, 3, 1) + class_outputs.append(cls_feat.contiguous().view(cls_feat.shape[0], -1, self.num_classes)) + + class_outputs = torch.cat(class_outputs, 1) + box_outputs = torch.cat(box_outputs, 1) + # my anchors are in [x1, y1, x2,y2] format while pretrained weights are in [y1, x1, y2, x2] format + # it may be confusing to reorder x and y every time later so I do it once here. it gives + # compatability with pretrained weigths from Google and doesn't affect training from scratch + # box_outputs = box_outputs[..., [1, 0, 3, 2]] # TODO: return back + return class_outputs, box_outputs + + @torch.no_grad() + def predict(self, x): + """Run forward on given images and decode raw prediction into bboxes + Returns: bboxes, scores, classes + """ + class_outputs, box_outputs = self.forward(x) + anchors = box_utils.generate_anchors_boxes(x.shape[-2:])[0] + return box_utils.decode(class_outputs, box_outputs, anchors) + + def _initialize_weights(self): + # init everything except encoder + no_encoder_m = [m for n, m in self.named_modules() if not "encoder" in n] + initialize_iterator(no_encoder_m) + # need to init last bias so that after sigmoid it's 0.01 + cls_bias_init = -torch.log(torch.tensor((1 - 0.01) / 0.01)) # -4.59 + nn.init.constant_(self.cls_head_convs[-1][1].bias, cls_bias_init) + + +PRETRAIN_SETTINGS = {**DEFAULT_IMAGENET_SETTINGS, "input_size": (512, 512), "crop_pct": 1, "num_classes": 90} + +# fmt: off +CFGS = { + "efficientdet_d0": { + "default": { + "params": { + "encoder_name":"efficientnet_b0", + "pyramid_channels":64, + "num_fpn_layers":3, + "num_head_repeats":3, + }, + **PRETRAIN_SETTINGS, + }, + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d0.pth",}, + }, + "efficientdet_d1": { + "default": { + "params": { + "encoder_name":"efficientnet_b1", + "pyramid_channels":88, + "num_fpn_layers":4, + "num_head_repeats":3, + }, + **PRETRAIN_SETTINGS, + "input_size": (640, 640), + }, + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d1.pth",}, + }, + "efficientdet_d2": { + "default": { + "params": { + "encoder_name":"efficientnet_b2", + "pyramid_channels":112, + "num_fpn_layers":5, + "num_head_repeats":3, + }, + **PRETRAIN_SETTINGS, + "input_size": (768, 768), + }, + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d2.pth",}, + }, + "efficientdet_d3": { + "default": { + "params": { + "encoder_name":"efficientnet_b3", + "pyramid_channels":160, + "num_fpn_layers":6, + "num_head_repeats":4, + }, + **PRETRAIN_SETTINGS, + "input_size": (896, 896), + }, + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d3.pth",}, + }, + "efficientdet_d4": { + "default": { + "params": { + "encoder_name":"efficientnet_b4", + "pyramid_channels":224, + "num_fpn_layers":7, + "num_head_repeats":4, + }, + **PRETRAIN_SETTINGS, + "input_size": (1024, 1024), + }, + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d4.pth",}, + }, + "efficientdet_d5": { + "default": { + "params": { + "encoder_name":"efficientnet_b5", + "pyramid_channels":288, + "num_fpn_layers":7, + "num_head_repeats":4, + }, + **PRETRAIN_SETTINGS, + "input_size": (1280, 1280), + }, + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d5.pth",}, + }, + "efficientdet_d6": { + "default": { + "params": { + "encoder_name":"efficientnet_b6", + "pyramid_channels":384, + "num_fpn_layers":8, + "num_head_repeats":5, + }, + **PRETRAIN_SETTINGS, + "input_size": (1280, 1280), + }, + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d6.pth",}, + }, +} +# fmt: on + + +def _efficientdet(arch, pretrained=None, **kwargs): + cfgs = deepcopy(CFGS) + cfg_settings = cfgs[arch]["default"] + cfg_params = cfg_settings.pop("params") + kwargs.update(cfg_params) + model = EfficientDet(**kwargs) + if pretrained: + state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"]) + kwargs_cls = kwargs.get("num_classes", None) + if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: + logging.warning( + f"Using model pretrained for {cfg_settings['num_classes']} classes with {kwargs_cls} classes. Last layer is initialized randomly" + ) + last_conv_name = f"cls_head_convs.{kwargs['num_head_repeats']}.1" + state_dict[f"{last_conv_name}.weight"] = model.state_dict()[f"{last_conv_name}.weight"] + state_dict[f"{last_conv_name}.bias"] = model.state_dict()[f"{last_conv_name}.bias"] + model.load_state_dict(state_dict) + setattr(model, "pretrained_settings", cfg_settings) + return model + + +@wraps(EfficientDet) +def efficientdet_d0(pretrained="coco", **kwargs): + return _efficientdet("efficientdet_d0", pretrained, **kwargs) + + +@wraps(EfficientDet) +def efficientdet_d1(pretrained="coco", **kwargs): + return _efficientdet("efficientdet_d1", pretrained, **kwargs) + + +@wraps(EfficientDet) +def efficientdet_d2(pretrained="coco", **kwargs): + return _efficientdet("efficientdet_d2", pretrained, **kwargs) + + +@wraps(EfficientDet) +def efficientdet_d3(pretrained="coco", **kwargs): + return _efficientdet("efficientdet_d3", pretrained, **kwargs) + + +@wraps(EfficientDet) +def efficientdet_d4(pretrained="coco", **kwargs): + return _efficientdet("efficientdet_d4", pretrained, **kwargs) + + +@wraps(EfficientDet) +def efficientdet_d5(pretrained="coco", **kwargs): + return _efficientdet("efficientdet_d5", pretrained, **kwargs) + + +@wraps(EfficientDet) +def efficientdet_d6(pretrained="coco", **kwargs): + return _efficientdet("efficientdet_d6", pretrained, **kwargs) + + +# No B7 because it's the same model as B6 but with larger input diff --git a/pytorch_tools/detection_models/retinanet.py b/pytorch_tools/detection_models/retinanet.py index 5ebc360..2fbdf74 100644 --- a/pytorch_tools/detection_models/retinanet.py +++ b/pytorch_tools/detection_models/retinanet.py @@ -1,74 +1,188 @@ -# import torch +import logging +from copy import deepcopy +from functools import wraps + +import torch import torch.nn as nn -import torch.nn.functional as F + from pytorch_tools.modules.fpn import FPN -# from pytorch_tools.modules.bifpn import BiFPN from pytorch_tools.modules import bn_from_name -# from pytorch_tools.modules.residual import conv1x1 from pytorch_tools.modules.residual import conv3x3 -# from pytorch_tools.modules.decoder import SegmentationUpsample -# from pytorch_tools.utils.misc import initialize from pytorch_tools.segmentation_models.encoders import get_encoder +import pytorch_tools.utils.box as box_utils +from pytorch_tools.utils.misc import DEFAULT_IMAGENET_SETTINGS +from pytorch_tools.utils.misc import initialize_iterator class RetinaNet(nn.Module): + """RetinaNet + Main difference from other implementations are: + * support of any custom encoder from this repo + * optional normalization layer in box classification head + * ability to freeze batch norm in encoder with one line + + Args: + pretrained (str): one of `coco` or None. if `coco` - load pretrained weights + encoder_name (str): name of classification model (without last dense layers) used as feature + extractor to build detection model + encoder_weights (str): one of ``None`` (random initialization), ``imagenet`` (pre-trained on ImageNet) + pyramid_channels (int): size of features after FPN. Default 256 + num_classes (int): a number of classes to predict + class_outputs shape is (BS, *, NUM_CLASSES) where each row in * corresponds to one bbox + encoder_norm_layer (str): Normalization layer to use in encoder. If using pretrained + it should be the same as in pretrained weights + encoder_norm_act (str): Activation for normalization layer in encoder + decoder_norm_layer (str): Normalization to use in head convolutions. Default (none) is not to use normalization. + Current implementation is optimized for `GroupNorm`, not `BatchNorm` check code for details + decoder_norm_act (str): Activation for normalization layer in head convolutions + + Ref: + Focal Loss for Dense Object Detection - https://arxiv.org/abs/1708.02002 + Mmdetection - https://github.com/open-mmlab/mmdetection/ (at commit b9daf23) + TF TPU version - https://github.com/tensorflow/tpu/tree/master/models/official/retinanet + """ + def __init__( - self, - encoder_name="resnet34", - encoder_weights="imagenet", - pyramid_channels=256, + self, + pretrained="coco", # not used here for proper signature + encoder_name="resnet50", + encoder_weights="imagenet", + pyramid_channels=256, num_classes=80, - norm_layer="abn", - norm_act="relu", + # drop_connect_rate=0, # TODO: add + encoder_norm_layer="abn", + encoder_norm_act="relu", + decoder_norm_layer="none", # None by default to match detectron & mmdet versions + decoder_norm_act="relu", **encoder_params, - ): + ): super().__init__() self.encoder = get_encoder( encoder_name, - norm_layer=norm_layer, - norm_act=norm_act, + norm_layer=encoder_norm_layer, + norm_act=encoder_norm_act, encoder_weights=encoder_weights, **encoder_params, ) - norm_layer = bn_from_name(norm_layer) - self.pyramid6 = conv3x3(256, 256, 2, bias=True) - self.pyramid7 = conv3x3(256, 256, 2, bias=True) - self.fpn = FPN( - self.encoder.out_shapes[:-2], - pyramid_channels=pyramid_channels, + norm_layer = bn_from_name(decoder_norm_layer) + self.pyramid6 = nn.Sequential( + conv3x3(self.encoder.out_shapes[0], pyramid_channels, 2, bias=True), + norm_layer(pyramid_channels, activation="identity"), ) + self.pyramid7 = nn.Sequential( + conv3x3(pyramid_channels, pyramid_channels, 2, bias=True), + norm_layer(pyramid_channels, activation="identity"), + ) + self.fpn = FPN(self.encoder.out_shapes[:-2], pyramid_channels=pyramid_channels) - def make_head(out_size): + def make_final_convs(): layers = [] for _ in range(4): - # some implementations don't use BN here but I think it's needed - # TODO: test how it affects results - layers += [nn.Conv2d(256, 256, 3, padding=1), norm_layer(256, activation=norm_act)] - # layers += [nn.Conv2d(256, 256, 3, padding=1), nn.ReLU()] - - layers += [nn.Conv2d(256, out_size, 3, padding=1)] + layers += [conv3x3(pyramid_channels, pyramid_channels, bias=True)] + # Norm here is fine for GroupNorm but for BN it should be implemented the other way + # see EffDet for example. Maybe need to change this implementation to align with EffDet + layers += [norm_layer(pyramid_channels, activation=decoder_norm_act)] return nn.Sequential(*layers) - self.ratios = [1.0, 2.0, 0.5] - self.scales = [4 * 2 ** (i / 3) for i in range(3)] - anchors = len(self.ratios) * len(self.scales) # 9 - - self.cls_head = make_head(num_classes * anchors) - self.box_head = make_head(4 * anchors) + anchors_per_location = 9 + self.cls_convs = make_final_convs() + self.cls_head_conv = conv3x3(pyramid_channels, num_classes * anchors_per_location, bias=True) + self.box_convs = make_final_convs() + self.box_head_conv = conv3x3(pyramid_channels, 4 * anchors_per_location, bias=True) + self.num_classes = num_classes + self._initialize_weights() - def forward(self, x): + # Name from mmdetectin for convenience + def extract_features(self, x): + """Extract features from backbone + enchance with FPN""" # don't use p2 and p1 p5, p4, p3, _, _ = self.encoder(x) + # coarser FPN levels + p6 = self.pyramid6(p5) + p7 = self.pyramid7(p6.relu()) # in mmdet there is no relu here. but i think it's needed # enhance features p5, p4, p3 = self.fpn([p5, p4, p3]) - # coarsers FPN levels - p6 = self.pyramid6(p5) - p7 = self.pyramid7(F.relu(p6)) - features = [p7, p6, p5, p4, p3] - # TODO: (18.03.20) TF implementation has additional BN here before class/box outputs - class_outputs = [self.cls_head(f) for f in features] - box_outputs = [self.box_head(f) for f in features] + # want features from lowest OS to highest to align with `generate_anchors_boxes` function + features = [p3, p4, p5, p6, p7] + return features + + def forward(self, x): + features = self.extract_features(x) + class_outputs = [] + box_outputs = [] + for feat in features: + cls_feat = self.cls_head_conv(self.cls_convs(feat)) + box_feat = self.box_head_conv(self.box_convs(feat)) + cls_feat = cls_feat.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.num_classes) + box_feat = box_feat.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4) + class_outputs.append(cls_feat) + box_outputs.append(box_feat) + class_outputs = torch.cat(class_outputs, 1) + box_outputs = torch.cat(box_outputs, 1) return class_outputs, box_outputs - - \ No newline at end of file + @torch.no_grad() + def predict(self, x): + """Run forward on given images and decode raw prediction into bboxes""" + class_outputs, box_outputs = self.forward(x) + anchors = box_utils.generate_anchors_boxes(x.shape[-2:])[0] + return box_utils.decode(class_outputs, box_outputs, anchors) + + def _initialize_weights(self): + # init everything except encoder + no_encoder_m = [m for n, m in self.named_modules() if not "encoder" in n] + initialize_iterator(no_encoder_m) + # need to init last bias so that after sigmoid it's 0.01 + cls_bias_init = -torch.log(torch.tensor((1 - 0.01) / 0.01)) # -4.59 + nn.init.constant_(self.cls_head_conv.bias, cls_bias_init) + + +# Don't really know input size for the models. 512 is just a guess +PRETRAIN_SETTINGS = {**DEFAULT_IMAGENET_SETTINGS, "input_size": (512, 512), "crop_pct": 1, "num_classes": 80} + +# weights below were ported from caffe to mmdetection and them ported again by @bonlime +# mmdetection resnet is slightly different (stride 2 in conv1x1 instead of conv3x3) +# and order of anchors is also different so it's impossible to do inference using this weights but they work +# much better for transfer learning than starting from imagenet pretrain +# fmt: off +CFGS = { + "retinanet_r50_fpn": { + "default": {"params": {"encoder_name":"resnet50"}, **PRETRAIN_SETTINGS}, + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/retinanet_r50_fpn_3x_coco.pth",}, + }, + "retinanet_r101_fpn": { + "default": {"params": {"encoder_name":"resnet101"}, **PRETRAIN_SETTINGS}, + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/retinanet_r101_fpn_2x_coco.pth",}, + }, +} +# fmt: on + + +def _retinanet(arch, pretrained=None, **kwargs): + cfgs = deepcopy(CFGS) + cfg_settings = cfgs[arch]["default"] + cfg_params = cfg_settings.pop("params") + kwargs.update(cfg_params) + model = RetinaNet(**kwargs) + if pretrained: + state_dict = torch.hub.load_state_dict_from_url(cfgs[arch][pretrained]["url"]) + kwargs_cls = kwargs.get("num_classes", None) + if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: + logging.warning( + f"Using model pretrained for {cfg_settings['num_classes']} classes with {kwargs_cls} classes. Last layer is initialized randomly" + ) + state_dict["cls_head_conv.weight"] = model.state_dict()["cls_head_conv.weight"] + state_dict["cls_head_conv.bias"] = model.state_dict()["cls_head_conv.bias"] + model.load_state_dict(state_dict) + setattr(model, "pretrained_settings", cfg_settings) + return model + + +@wraps(RetinaNet) +def retinanet_r50_fpn(pretrained="coco", **kwargs): + return _retinanet("retinanet_r50_fpn", pretrained, **kwargs) + + +@wraps(RetinaNet) +def retinanet_r101_fpn(pretrained="coco", **kwargs): + return _retinanet("retinanet_r101_fpn", pretrained, **kwargs) diff --git a/pytorch_tools/fit_wrapper/README.md b/pytorch_tools/fit_wrapper/README.md index fe2e97b..96e00da 100644 --- a/pytorch_tools/fit_wrapper/README.md +++ b/pytorch_tools/fit_wrapper/README.md @@ -3,6 +3,8 @@ This module contains model runner (very close to `model.fit` in Keras) for **sup `Runner` is used to actually run the train loop calling `Callbacks` at appropriate times. Mixed precision (powered by apex) is supported implicitly. Users are expected to initialize their models before creating runner using `apex.amp.initialize`. +Main idea of this runner is to be as simple as possible. All core functionality is ~100 lines of code. + ## Minimal example This code will run training for 5 epochs. ```python @@ -40,4 +42,11 @@ runner = pt.fit_wrapper.Runner( ] ) runner.fit(train_loader, epochs=5, val_loader=val_loader) -``` \ No newline at end of file +``` + +## How to +### Add custom step logic +Monkey patch `Runner._make_step` function with yours + +### Process multiple inputs/outputs +Instead of modifying Runner move this logic inside Loss function. \ No newline at end of file diff --git a/pytorch_tools/fit_wrapper/callbacks.py b/pytorch_tools/fit_wrapper/callbacks.py index 1c782cd..70bb146 100644 --- a/pytorch_tools/fit_wrapper/callbacks.py +++ b/pytorch_tools/fit_wrapper/callbacks.py @@ -3,13 +3,14 @@ import logging from tqdm import tqdm from enum import Enum +from copy import deepcopy from collections import OrderedDict from collections import defaultdict import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from .state import RunnerState -from pytorch_tools.utils.misc import listify +import pytorch_tools.utils.misc as utils from pytorch_tools.utils.visualization import plot_confusion_matrix from pytorch_tools.utils.visualization import render_figure_to_tensor @@ -65,7 +66,7 @@ class Callbacks(Callback): def __init__(self, callbacks): super().__init__() - self.callbacks = listify(callbacks) + self.callbacks = utils.listify(callbacks) def set_state(self, state): for callback in self.callbacks: @@ -113,24 +114,29 @@ class Timer(Callback): def __init__(self): super().__init__() self.has_printed = False + self.timer = utils.TimeMeter() def on_batch_begin(self): - self.state.timer.batch_start() + self.timer.batch_start() def on_batch_end(self): - self.state.timer.batch_end() + self.timer.batch_end() + + def on_loader_begin(self): + self.timer.reset() def on_loader_end(self): if not self.has_printed: self.has_printed = True - d_time = self.state.timer.data_time.avg_smooth - b_time = self.state.timer.batch_time.avg_smooth + d_time = self.timer.data_time.avg_smooth + b_time = self.timer.batch_time.avg_smooth print(f"\nTimeMeter profiling. Data time: {d_time:.2E}s. Model time: {b_time:.2E}s \n") class PhasesScheduler(Callback): """ Scheduler that uses `phases` to process updates. + Supported `mode`'s are {`linear`, `cos`, `poly`} Args: phases (List[Dict]): phases @@ -155,9 +161,9 @@ def __init__(self, phases, change_every=50): super(PhasesScheduler, self).__init__() def _format_phase(self, phase): - phase["ep"] = listify(phase["ep"]) - phase["lr"] = listify(phase["lr"]) - phase["mom"] = listify(phase.get("mom", None)) # optional + phase["ep"] = utils.listify(phase["ep"]) + phase["lr"] = utils.listify(phase["lr"]) + phase["mom"] = utils.listify(phase.get("mom", None)) # optional if len(phase["lr"]) == 2 or len(phase["mom"]) == 2: phase["mode"] = phase.get("mode", "linear") assert len(phase["ep"]) == 2, "Linear learning rates must contain end epoch" @@ -170,6 +176,11 @@ def _schedule(start, end, pct, mode): return start + (end - start) * pct elif mode == "cos": return end + (start - end) / 2 * (math.cos(math.pi * pct) + 1) + elif mode == "poly": + gamma = (end / start) ** (1 / 100) + return start * gamma ** (pct * 100) + else: + raise ValueError(f"Mode: `{mode}` is not supported in PhasesScheduler") def _get_lr_mom(self, batch_curr): phase = self.phase @@ -473,7 +484,8 @@ def on_batch_end(self): class FileLogger(Callback): - """Logs loss and metrics every epoch into file + """Logs loss and metrics every epoch into file. + If launched in distributed mode - reduces metrics before logging Args: log_dir (str): path where to store the logs logger (logging.Logger): external logger. Default None @@ -488,6 +500,8 @@ def on_epoch_begin(self): self.logger.info(f"Epoch {self.state.epoch_log} | lr {self.current_lr:.3f}") def on_epoch_end(self): + if utils.env_world_size() > 1: + self.reduce_metrics() loss, metrics = self.state.train_loss, self.state.train_metrics self.logger.info("Train " + self._format_meters(loss, metrics)) if self.state.val_loss is not None: @@ -513,6 +527,17 @@ def current_lr(self): def _format_meters(loss, metrics): return f"loss: {loss.avg:.4f} | " + " | ".join(f"{m.name}: {m.avg:.4f}" for m in metrics) + def reduce_metrics(self): + # can't reduce AverageMeter so need to reduce every attribute separately + meters = self.state.train_metrics + [self.state.train_loss,] + meters = meters + self.state.metric_meters + [self.state.loss_meter,] + if self.state.val_loss is not None: + meters = meters + self.state.val_metrics + [self.state.val_loss,] + reduce_attributes = ["val", "avg", "avg_smooth", "sum", "count"] + for meter in meters: + for attr in reduce_attributes: + old_value = utils.to_tensor([getattr(meter, attr)]).float().cuda() + setattr(meter, attr, utils.reduce_tensor(old_value).cpu().numpy()[0]) class Mixup(Callback): """Performs mixup on input. Only for classification. @@ -684,4 +709,68 @@ def on_epoch_end(self): self.state.optimizer.state = defaultdict(dict) if self.verbose: - print("Reseting optimizer") \ No newline at end of file + print("Reseting optimizer") + +# docstring from https://github.com/rwightman/pytorch-image-models +class ModelEma(Callback): + """ Model Exponential Moving Average + Keeps a moving average of everything in the model state_dict (parameters and buffers). + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. + + Current implementation follows TensorFlow and uses the following formula: + ema -= (1 - decay) * (ema - model) + This is mathematically equivalent to the classic formula below but inplace is faster + ema = decay * ema + (1 - decay) * model + + NOTE: Pay attention to the decay constant you are using relative to your update count per epoch. + + NOTE: put this Callback AFTER Checkpoint saver! Otherwise you would validate EMA weights but save + model weights + + NOTE: Need to be used in all process (not only master)! otherwise you would save not the best model bur some random + + NOTE: Pass model to ModelEma after cuda() and AMP but before SyncBN and DDP wrapper + + Args: + model (nn.Module): model after cuda and AMP + decay (float): decay for EMA for every step + decay_every (int): how oftern to really decay weights. Decaying every step produced a + visible training slowdown. Real decay factor is adjusted to match every step update. + """ + def __init__(self, model, decay=0.9999, decay_every=10): + super().__init__() + self.ema = deepcopy(model).eval() + for p in self.ema.parameters(): + p.requires_grad_(False) + self.model_copy = None + self.decay_factor = 1 - decay ** decay_every # simulate every step decay + self.decay_every = decay_every + + def on_batch_end(self): + if not self.state.is_train or (self.state.step % self.decay_every != 0): + return + + with torch.no_grad(): + for (ema_v, m_v) in zip(self.ema.state_dict().values(), self.state.model.state_dict().values()): + if m_v.numel() == 1: # to prevent errors on `num_batches_tracked` in BN + continue + ema_v.sub_(ema_v.sub(m_v), alpha=self.decay_factor) + + def on_loader_begin(self): + if self.state.is_train: + return + # validate on ema model + self.model_copy = self.state.model + self.state.model = self.ema + + def on_epoch_end(self): + if self.state.is_train: + return + # return model back + self.state.model = self.model_copy \ No newline at end of file diff --git a/pytorch_tools/fit_wrapper/state.py b/pytorch_tools/fit_wrapper/state.py index 95ca30c..336ffcd 100644 --- a/pytorch_tools/fit_wrapper/state.py +++ b/pytorch_tools/fit_wrapper/state.py @@ -1,6 +1,5 @@ from ..utils.misc import listify from ..utils.misc import AverageMeter -from ..utils.misc import TimeMeter class RunnerState: @@ -39,7 +38,6 @@ def __init__( self.loss_meter = AverageMeter("loss") # for timer callback - self.timer = TimeMeter() self.__is_frozen = True def __setattr__(self, key, value): diff --git a/pytorch_tools/fit_wrapper/wrapper.py b/pytorch_tools/fit_wrapper/wrapper.py index 97506e1..ddf3764 100644 --- a/pytorch_tools/fit_wrapper/wrapper.py +++ b/pytorch_tools/fit_wrapper/wrapper.py @@ -18,10 +18,11 @@ class Runner: must have `name` attribute. Defaults to None. callbacks (List): List of Callbacks to use. Defaults to ConsoleLogger(). gradient_clip_val (float): Gradient clipping value. 0 means no clip. Causes ~5% training slowdown + accumulate_steps (int): if > 1 uses gradient accumulation across iterations to simulate larger batch size """ def __init__( - self, model, optimizer, criterion, metrics=None, callbacks=ConsoleLogger(), gradient_clip_val=0 + self, model, optimizer, criterion, metrics=None, callbacks=ConsoleLogger(), gradient_clip_val=0, accumulate_steps=1, ): super().__init__() @@ -33,6 +34,7 @@ def __init__( self.callbacks = Callbacks(callbacks) self.callbacks.set_state(self.state) self.gradient_clip_val = gradient_clip_val + self.accumulate_steps = accumulate_steps def fit( self, train_loader, steps_per_epoch=None, val_loader=None, val_steps=None, epochs=1, start_epoch=0, @@ -78,22 +80,23 @@ def _make_step(self): self.state.output = output loss = self.state.criterion(output, target) if self.state.is_train: - self.state.optimizer.zero_grad() - with amp.scale_loss(loss, self.state.optimizer) as scaled_loss: + with amp.scale_loss(loss / self.accumulate_steps, self.state.optimizer) as scaled_loss: scaled_loss.backward() if self.gradient_clip_val > 0: - torch.nn.utils.clip_grad_norm_(self.state.model.parameters(), self.gradient_clip_val) - self.state.optimizer.step() + torch.nn.utils.clip_grad_norm_(amp.master_params(self.state.optimizer), self.gradient_clip_val) + if self.state.step % self.accumulate_steps == 0: + self.state.optimizer.step() + self.state.optimizer.zero_grad() torch.cuda.synchronize() # update metrics self.state.loss_meter.update(to_numpy(loss)) - for metric, meter in zip(self.state.metrics, self.state.metric_meters): - meter.update(to_numpy(metric(output, target).squeeze())) + with torch.no_grad(): + for metric, meter in zip(self.state.metrics, self.state.metric_meters): + meter.update(to_numpy(metric(output, target).squeeze())) def _run_loader(self, loader, steps=None): self.state.loss_meter.reset() - self.state.timer.reset() for metric in self.state.metric_meters: metric.reset() self.state.epoch_size = steps or len(loader) # steps overwrites len diff --git a/pytorch_tools/losses/__init__.py b/pytorch_tools/losses/__init__.py index d2c302f..962fa1c 100644 --- a/pytorch_tools/losses/__init__.py +++ b/pytorch_tools/losses/__init__.py @@ -10,6 +10,7 @@ from .vgg_loss import ContentLoss, StyleLoss from .smooth import CrossEntropyLoss from .hinge import BinaryHinge +from .huber import SmoothL1Loss from .functional import focal_loss_with_logits from .functional import soft_dice_score diff --git a/pytorch_tools/losses/huber.py b/pytorch_tools/losses/huber.py new file mode 100644 index 0000000..99d4337 --- /dev/null +++ b/pytorch_tools/losses/huber.py @@ -0,0 +1,33 @@ +from .base import Loss +from .base import Reduction + +class SmoothL1Loss(Loss): + """Huber loss aka Smooth L1 Loss + + loss = 0.5 * x^2 if |x| <= d + loss = 0.5 * d^2 + d * (|x| - d) if |x| > d + + Args: + delta (float): point where the Huber loss function changes from a quadratic to linear + reduction (str): The reduction type to apply to the output. {'none', 'mean', 'sum'}. + 'none' - no reduction will be applied + 'sum' - the output will be summed + 'mean' - the sum of the output will be divided by the number of elements in the output + """ + + def __init__(self, delta=0.1, reduction="none"): + super().__init__() + self.delta = delta + self.reduction = Reduction(reduction) + + def forward(self, pred, target): + x = (pred - target).abs() + l1 = self.delta * (x - 0.5 * self.delta) + l2 = 0.5 * x.pow(2) + + loss = l1.where(x >= self.delta, l2) + if self.reduction == Reduction.MEAN: + loss = loss.mean() + elif self.reduction == Reduction.SUM: + loss = loss.sum() + return loss \ No newline at end of file diff --git a/pytorch_tools/models/README.md b/pytorch_tools/models/README.md index 5d8f823..44994c1 100644 --- a/pytorch_tools/models/README.md +++ b/pytorch_tools/models/README.md @@ -9,9 +9,19 @@ All models have `pretrained_settings` attribute with training size, mean, std an ## Encoders All models from this repo could be used as feature extractors for both object detection and semantic segmentation. Passing `encoder=True` arg will overwrite `forward` method of the model to return features at 5 different resolutions starting from 1/32 to 1/2. +## Features +* Unified API. Create `resnet, efficientnet, hrnet` models using the same code +* Low memory footprint dy to heavy use of inplace operations. Could be reduced even more by using `norm_layer='inplaceabn'` +* Fast models. As of `04.20` Efficient net's in this repo are the fastest available on GitHub (afaik) +* Support for custom number of input channels in pretrained models. Try with `resnet34(pretrained='imagenet', in_channels=7)` +* All core functionality covered with tests + + ## Repositories used * [Torch Vision Main Repo](https://github.com/pytorch/vision) * [Cadene pretrained models](https://github.com/Cadene/pretrained-models.pytorch/) * [Ross Wightman models](https://github.com/rwightman/pytorch-image-models/) * [Inplace ABN](https://github.com/mapillary/inplace_abn) -* [Efficient Densenet](https://github.com/gpleiss/efficient_densenet_pytorch) \ No newline at end of file +* [Efficient Densenet](https://github.com/gpleiss/efficient_densenet_pytorch) +* [Official Efficient Net](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) +* [Original HRNet for Classification](https://github.com/HRNet/HRNet-Image-Classification) \ No newline at end of file diff --git a/pytorch_tools/models/__init__.py b/pytorch_tools/models/__init__.py index 3f98c2d..4fb855e 100644 --- a/pytorch_tools/models/__init__.py +++ b/pytorch_tools/models/__init__.py @@ -45,3 +45,10 @@ from .hrnet import hrnet_w44 from .hrnet import hrnet_w48 from .hrnet import hrnet_w64 + +from .bit_resnet import bit_m_50x1 +from .bit_resnet import bit_m_50x3 +from .bit_resnet import bit_m_101x1 +from .bit_resnet import bit_m_101x3 +from .bit_resnet import bit_m_152x2 +from .bit_resnet import bit_m_152x4 diff --git a/pytorch_tools/models/bit_resnet.py b/pytorch_tools/models/bit_resnet.py new file mode 100644 index 0000000..c5acbd5 --- /dev/null +++ b/pytorch_tools/models/bit_resnet.py @@ -0,0 +1,315 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Bottleneck ResNet v2 with GroupNorm and Weight Standardization.""" +import os +import numpy as np +from copy import deepcopy +from functools import wraps +from urllib.parse import urlparse +from collections import OrderedDict # pylint: disable=g-importing-member + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch_tools.modules.weight_standartization import WS_Conv2d as StdConv2d + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, padding=0, bias=bias) + + +def tf2th(conv_weights): + """Possibly convert HWIO to OIHW.""" + if conv_weights.ndim == 4: + conv_weights = conv_weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(conv_weights) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + + Follows the implementation of "Identity Mappings in Deep Residual Networks": + https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua + + Except it puts the stride on 3x3 conv when available. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout // 4 + + self.gn1 = nn.GroupNorm(32, cin) + self.conv1 = conv1x1(cin, cmid) + self.gn2 = nn.GroupNorm(32, cmid) + self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cmid) + self.conv3 = conv1x1(cmid, cout) + self.relu = nn.ReLU(inplace=True) + + if stride != 1 or cin != cout: + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride) + + def forward(self, x): + out = self.relu(self.gn1(x)) + + # Residual branch + residual = x + if hasattr(self, "downsample"): + residual = self.downsample(out) + + # Unit's branch + out = self.conv1(out) + out = self.conv2(self.relu(self.gn2(out))) + out = self.conv3(self.relu(self.gn3(out))) + + return out + residual + + def load_from(self, weights, prefix=""): + convname = "standardized_conv2d" + with torch.no_grad(): + self.conv1.weight.copy_(tf2th(weights[f"{prefix}a/{convname}/kernel"])) + self.conv2.weight.copy_(tf2th(weights[f"{prefix}b/{convname}/kernel"])) + self.conv3.weight.copy_(tf2th(weights[f"{prefix}c/{convname}/kernel"])) + self.gn1.weight.copy_(tf2th(weights[f"{prefix}a/group_norm/gamma"])) + self.gn2.weight.copy_(tf2th(weights[f"{prefix}b/group_norm/gamma"])) + self.gn3.weight.copy_(tf2th(weights[f"{prefix}c/group_norm/gamma"])) + self.gn1.bias.copy_(tf2th(weights[f"{prefix}a/group_norm/beta"])) + self.gn2.bias.copy_(tf2th(weights[f"{prefix}b/group_norm/beta"])) + self.gn3.bias.copy_(tf2th(weights[f"{prefix}c/group_norm/beta"])) + if hasattr(self, "downsample"): + w = weights[f"{prefix}a/proj/{convname}/kernel"] + self.downsample.weight.copy_(tf2th(w)) + + +# this models are designed for trasfer learning only! not for training from scratch +class ResNetV2(nn.Module): + """ + Implementation of Pre-activation (v2) ResNet mode. + Used to create Bit-M-50/101/152x1/2/3/4 models + + Args: + num_classes (int): Number of classification classes. Defaults to 5 + """ + + def __init__( + self, + block_units, + width_factor, + # in_channels=3, # TODO: add later + num_classes=5, # just a random number + # encoder=False, # TODO: add later + ): + super().__init__() + wf = width_factor # shortcut 'cause we'll use it a lot. + + # The following will be unreadable if we split lines. + # pylint: disable=line-too-long + # fmt: off + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)), + ('pad', nn.ConstantPad2d(1, 0)), + ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), + # The following is subtly not the same! + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=64*wf, cout=256*wf, cmid=64*wf))] + + [(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] + + [(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=512*wf, cout=1024*wf, cmid=256*wf, stride=2))] + + [(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)], + ))), + ('block4', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] + + [(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)], + ))), + ])) + # pylint: enable=line-too-long + + self.head = nn.Sequential(OrderedDict([ + ('gn', nn.GroupNorm(32, 2048*wf)), + ('relu', nn.ReLU(inplace=True)), + ('avg', nn.AdaptiveAvgPool2d(output_size=1)), + ('conv', nn.Conv2d(2048*wf, num_classes, kernel_size=1, bias=True)), + ])) + # fmt: on + + def features(self, x): + return self.body(self.root(x)) + + def logits(self, x): + return self.head(x) + + def forward(self, x): + x = self.logits(self.features(x)) + assert x.shape[-2:] == (1, 1) # We should have no spatial shape left. + return x[..., 0, 0] + + def load_from(self, weights, prefix="resnet/"): + with torch.no_grad(): + self.root.conv.weight.copy_( + tf2th(weights[f"{prefix}root_block/standardized_conv2d/kernel"]) + ) # pylint: disable=line-too-long + self.head.gn.weight.copy_(tf2th(weights[f"{prefix}group_norm/gamma"])) + self.head.gn.bias.copy_(tf2th(weights[f"{prefix}group_norm/beta"])) + # always zero_head + nn.init.zeros_(self.head.conv.weight) + nn.init.zeros_(self.head.conv.bias) + + for bname, block in self.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, prefix=f"{prefix}{bname}/{uname}/") + + +KNOWN_MODELS = OrderedDict( + [ + ("BiT-M-R50x1", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), + ("BiT-M-R50x3", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), + ("BiT-M-R101x1", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), + ("BiT-M-R101x3", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), + ("BiT-M-R152x2", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), + ("BiT-M-R152x4", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), + ("BiT-S-R50x1", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), + ("BiT-S-R50x3", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), + ("BiT-S-R101x1", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), + ("BiT-S-R101x3", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), + ("BiT-S-R152x2", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), + ("BiT-S-R152x4", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), + ] +) + + +PRETRAIN_SETTINGS = { + "input_space": "RGB", + "input_size": [3, 448, 448], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": None, +} + +# fmt: off +CFGS = { + # weights are loaded by default + "bit_m_50x1": { + "default": { + "params": {"block_units": [3, 4, 6, 3], "width_factor": 1}, + "url": "https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz", + **PRETRAIN_SETTINGS + }, + }, + "bit_m_50x3": { + "default": { + "params": {"block_units": [3, 4, 6, 3], "width_factor": 3}, + "url": "https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz", + **PRETRAIN_SETTINGS, + }, + }, + "bit_m_101x1": { + "default": { + "params": {"block_units": [3, 4, 23, 3], "width_factor": 1}, + "url": "https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz", + **PRETRAIN_SETTINGS, + }, + }, + "bit_m_101x3": { + "default": { + "params": {"block_units": [3, 4, 23, 3], "width_factor": 3}, + "url": "https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz", + **PRETRAIN_SETTINGS, + }, + }, + "bit_m_152x2": { + "default": { + "params": {"block_units": [3, 8, 36, 3], "width_factor": 2}, + "url": "https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz", + **PRETRAIN_SETTINGS, + }, + }, + "bit_m_152x4": { + "default": { + "params": {"block_units": [3, 8, 36, 3], "width_factor": 4}, + "url": "https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz", + **PRETRAIN_SETTINGS + }, + }, +} + +# fmt: on +def _bit_resnet(arch, pretrained=None, **kwargs): + cfgs = deepcopy(CFGS) + cfg_settings = cfgs[arch]["default"] + cfg_params = cfg_settings.pop("params") + cfg_url = cfg_settings.pop("url") + kwargs.pop("pretrained", None) + kwargs.update(cfg_params) + model = ResNetV2(**kwargs) + # load weights to torch checkpoints folder + try: + torch.hub.load_state_dict_from_url(cfg_url) + except RuntimeError: + pass # to avoid RuntimeError: Only one file(not dir) is allowed in the zipfile + filename = os.path.basename(urlparse(cfg_url).path) + torch_home = torch.hub._get_torch_home() + cached_file = os.path.join(torch_home, "checkpoints", filename) + weights = np.load(cached_file) + model.load_from(weights) + return model + + +# only want M versions of models for fine-tuning +@wraps(ResNetV2) +def bit_m_50x1(**kwargs): + return _bit_resnet("bit_m_50x1", **kwargs) + + +@wraps(ResNetV2) +def bit_m_50x3(**kwargs): + return _bit_resnet("bit_m_50x3", **kwargs) + + +@wraps(ResNetV2) +def bit_m_101x1(**kwargs): + return _bit_resnet("bit_m_101x1", **kwargs) + + +@wraps(ResNetV2) +def bit_m_101x3(**kwargs): + return _bit_resnet("bit_m_101x3", **kwargs) + + +@wraps(ResNetV2) +def bit_m_152x2(**kwargs): + return _bit_resnet("bit_m_152x2", **kwargs) + + +@wraps(ResNetV2) +def bit_m_152x4(**kwargs): + return _bit_resnet("bit_m_152x4", **kwargs) diff --git a/pytorch_tools/models/densenet.py b/pytorch_tools/models/densenet.py index 05993b5..bc27ab3 100644 --- a/pytorch_tools/models/densenet.py +++ b/pytorch_tools/models/densenet.py @@ -327,11 +327,10 @@ def _densenet(arch, pretrained=None, **kwargs): cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) - assert ( - common_args == set() - ), "Args {} are going to be overwritten by default params for {} weights".format( - common_args, pretrained - ) + if common_args: + logging.warning( + f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" + ) kwargs.update(cfg_params) model = DenseNet(**kwargs) diff --git a/pytorch_tools/models/efficientnet.py b/pytorch_tools/models/efficientnet.py index 8f52ccc..17e67c0 100644 --- a/pytorch_tools/models/efficientnet.py +++ b/pytorch_tools/models/efficientnet.py @@ -24,6 +24,8 @@ from pytorch_tools.modules import bn_from_name from pytorch_tools.modules.residual import InvertedResidual from pytorch_tools.modules.residual import conv1x1, conv3x3 +from pytorch_tools.modules.tf_same_ops import conv_to_same_conv +from pytorch_tools.modules.tf_same_ops import maxpool_to_same_maxpool from pytorch_tools.utils.misc import initialize from pytorch_tools.utils.misc import add_docs_for from pytorch_tools.utils.misc import make_divisible @@ -47,12 +49,12 @@ class EfficientNet(nn.Module): width_multiplier (float): Multiplyer for number of channels in each block. Don't need to be passed manually depth_multiplier (float): - Multiplyer for number of InvertedResiduals in each block + Multiplyer for number of InvertedResiduals in each block. Don't need to be passed manually pretrained (str, optional): - If not, returns a model pre-trained on 'str' dataset. `imagenet` is available for every model. + If not None, returns a model pre-trained on 'str' dataset. `imagenet` is available for every model. NOTE: weights which are loaded into this model were ported from TF. There is a drop in accuracy for Imagenet (~1-2% top1) but they work well for finetuning. - NOTE 2: models were pretrained on very different resolution. take it into account during finetuning + NOTE 2: models were pretrained on very different resolutions. take it into account during finetuning num_classes (int): Number of classification classes. Defaults to 1000. in_channels (int): @@ -70,6 +72,8 @@ class EfficientNet(nn.Module): But increases backward time and doesn't support `swish` activation. Defaults to 'abn'. norm_act (str): Activation for normalizion layer. It's reccomended to use `leacky_relu` with `inplaceabn`. Defaults to `swish` + match_tf_same_padding (bool): If True patches Conv and MaxPool to implements tf-like asymmetric padding + Should only be used to validate pretrained weights. Not needed for training. Gives ~10% slowdown """ def __init__( @@ -87,6 +91,7 @@ def __init__( stem_size=32, norm_layer="abn", norm_act="swish", + match_tf_same_padding=False, ): super().__init__() norm_layer = bn_from_name(norm_layer) @@ -127,19 +132,23 @@ def __init__( self.blocks.append(nn.Sequential(*block)) # Head - out_channels = block_arg["out_channels"] - num_features = make_divisible(1280 * width_multiplier) - self.conv_head = conv1x1(out_channels, num_features) - self.bn2 = norm_layer(num_features, activation=norm_act) if encoder: self.forward = self.encoder_features else: + out_channels = block_arg["out_channels"] + num_features = make_divisible(1280 * width_multiplier) + self.conv_head = conv1x1(out_channels, num_features) + self.bn2 = norm_layer(num_features, activation=norm_act) self.global_pool = nn.AdaptiveAvgPool2d(1) self.dropout = nn.Dropout(drop_rate, inplace=True) self.classifier = nn.Linear(num_features, num_classes) + patch_bn(self) # adjust epsilon initialize(self) + if match_tf_same_padding: + conv_to_same_conv(self) + maxpool_to_same_maxpool(self) def encoder_features(self, x): x0 = self.conv_stem(x) @@ -148,11 +157,9 @@ def encoder_features(self, x): x1 = self.blocks[1](x0) x2 = self.blocks[2](x1) x3 = self.blocks[3](x2) - x4 = self.blocks[4](x3) - x4 = self.blocks[5](x4) + x3 = self.blocks[4](x3) + x4 = self.blocks[5](x3) x4 = self.blocks[6](x4) - x4 = self.conv_head(x4) - x4 = self.bn2(x4) return [x4, x3, x2, x1, x0] def features(self, x): @@ -234,7 +241,7 @@ def _decode_block_string(block_string): out_channels=int(options["o"]), dw_kernel_size=int(options["k"]), stride=tuple([options["s"], options["s"]]), - use_se=float(options["se"]) > 0 if "se" in options else False, + attn_type="se" if "se" in options else None, expand_ratio=int(options["e"]), noskip="noskip" in block_string, num_repeat=int(options["r"]), @@ -391,6 +398,7 @@ def patch_bn(module): for m in module.children(): patch_bn(m) + def _efficientnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] @@ -402,12 +410,10 @@ def _efficientnet(arch, pretrained=None, **kwargs): cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) - - assert ( - common_args == set() - ), "Args {} are going to be overwritten by default params for {} weights".format( - common_args, pretrained - ) + if common_args: + logging.warning( + f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" + ) kwargs.update(cfg_params) model = EfficientNet(**kwargs) if pretrained: @@ -421,10 +427,11 @@ def _efficientnet(arch, pretrained=None, **kwargs): ) state_dict["classifier.weight"] = model.state_dict()["classifier.weight"] state_dict["classifier.bias"] = model.state_dict()["classifier.bias"] - if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels - state_dict["conv_stem.weight"] = repeat_channels(state_dict["conv_stem.weight"], kwargs["in_channels"]) + if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels + state_dict["conv_stem.weight"] = repeat_channels( + state_dict["conv_stem.weight"], kwargs["in_channels"] + ) model.load_state_dict(state_dict) - patch_bn(model) # adjust epsilon setattr(model, "pretrained_settings", cfg_settings) return model diff --git a/pytorch_tools/models/hrnet.py b/pytorch_tools/models/hrnet.py index 1f7bde6..cbf5817 100644 --- a/pytorch_tools/models/hrnet.py +++ b/pytorch_tools/models/hrnet.py @@ -43,22 +43,23 @@ def make_layer(inplanes, planes, blocks, norm_layer=ABN, norm_act="relu"): layers = [] layers.append(block(inplanes, planes, downsample=downsample, **bn_args)) inplanes = planes * block.expansion - for i in range(1, blocks): + for _ in range(1, blocks): layers.append(block(inplanes, planes, **bn_args)) return nn.Sequential(*layers) + class HighResolutionModule(nn.Module): def __init__( - self, - num_branches, # number of parallel branches - num_blocks, # number of blocks + self, + num_branches, # number of parallel branches + num_blocks, # number of blocks num_channels, norm_layer=ABN, norm_act="relu", ): super(HighResolutionModule, self).__init__() self.block = BasicBlock - self.num_branches = num_branches # used in forward + self.num_branches = num_branches # used in forward self.num_inchannels = num_channels self.bn_args = {"norm_layer": norm_layer, "norm_act": norm_act} branches = [self._make_branch(n_bl, n_ch) for n_bl, n_ch in zip(num_blocks, num_channels)] @@ -69,6 +70,7 @@ def __init__( def _make_branch(self, b_blocks, b_channels): return nn.Sequential(*[self.block(b_channels, b_channels, **self.bn_args) for _ in range(b_blocks)]) + # fmt: off # don't want to rewrite this piece it's too fragile def _make_fuse_layers(self, norm_layer, norm_act): if self.num_branches == 1: @@ -104,23 +106,24 @@ def _make_fuse_layers(self, norm_layer, norm_act): fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers) - + # fmt: on def forward(self, x): if self.num_branches == 1: return [self.branches[0](x[0])] - + x = [branch(x_i) for branch, x_i in zip(self.branches, x)] x_fuse = [] for i in range(len(self.fuse_layers)): y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) for j in range(1, self.num_branches): - y = y + self.fuse_layers[i][j](x[j]) + y = y + self.fuse_layers[i][j](x[j]) x_fuse.append(self.relu(y)) return x_fuse + class TransitionBlock(nn.Module): """Transition is where new branches for smaller resolution are born -- ==> -- @@ -129,7 +132,7 @@ class TransitionBlock(nn.Module): \ \=> -- """ - + def __init__(self, prev_channels, current_channels, norm_layer=ABN, norm_act="relu"): super().__init__() transition_layers = [] @@ -140,40 +143,40 @@ def __init__(self, prev_channels, current_channels, norm_layer=ABN, norm_act="re transition_layers.append(nn.Sequential(*layers)) else: transition_layers.append(nn.Identity()) - - if len(current_channels) > len(prev_channels): # only works for ONE extra branch + + if len(current_channels) > len(prev_channels): # only works for ONE extra branch layers = [ - conv3x3(prev_channels[-1], current_channels[-1], 2), - norm_layer(current_channels[-1], activation=norm_act) + conv3x3(prev_channels[-1], current_channels[-1], 2), + norm_layer(current_channels[-1], activation=norm_act), ] transition_layers.append(nn.Sequential(*layers)) self.trans_layers = nn.ModuleList(transition_layers) - - def forward(self, x): # x is actually an array + + def forward(self, x): # x is actually an array out_x = [trans_l(x_i) for x_i, trans_l in zip(x, self.trans_layers)] out_x.append(self.trans_layers[-1](x[-1])) return out_x + class HRClassificationHead(nn.Module): def __init__(self, pre_channels, norm_layer=ABN, norm_act="relu"): super().__init__() head_block = Bottleneck head_channels = [32, 64, 128, 256] - # Increasing the #channels on each resolution + # Increasing the #channels on each resolution # from C, 2C, 4C, 8C to 128, 256, 512, 1024 incre_modules = [] for (pre_c, head_c) in zip(pre_channels, head_channels): incre_modules.append(make_layer(pre_c, head_c, 1, norm_layer, norm_act)) self.incre_modules = nn.ModuleList(incre_modules) - + # downsampling modules downsamp_modules = [] - for i in range(len(pre_channels)-1): + for i in range(len(pre_channels) - 1): in_ch = head_channels[i] * head_block.expansion - out_ch = head_channels[i+1] * head_block.expansion + out_ch = head_channels[i + 1] * head_block.expansion downsamp_module = nn.Sequential( - conv3x3(in_ch, out_ch, 2, bias=True), - norm_layer(out_ch, activation=norm_act) + conv3x3(in_ch, out_ch, 2, bias=True), norm_layer(out_ch, activation=norm_act) ) downsamp_modules.append(downsamp_module) self.downsamp_modules = nn.ModuleList(downsamp_modules) @@ -182,13 +185,13 @@ def __init__(self, pre_channels, norm_layer=ABN, norm_act="relu"): conv1x1(head_channels[3] * head_block.expansion, 2048, bias=True), norm_layer(2048, activation=norm_act), ) - + def forward(self, x): - x = [self.incre_modules[i](x[i]) for i in range(4)] + x = [self.incre_modules[i](x[i]) for i in range(4)] for i in range(1, 4): - x[i] = x[i] + self.downsamp_modules[i-1](x[i-1]) + x[i] = x[i] + self.downsamp_modules[i - 1](x[i - 1]) return self.final_layer(x[3]) - + class HighResolutionNet(nn.Module): """HighResolution Nets constructor @@ -219,13 +222,14 @@ class HighResolutionNet(nn.Module): NOTE: HRNet first features have resolution 4x times smaller than input, not 2x as all other models. So it CAN'T be used as encoder in Unet and Linknet models """ - # drop_rate (float): - # Dropout probability before classifier, for training. Defaults to 0. + + # drop_rate (float): + # Dropout probability before classifier, for training. Defaults to 0. def __init__( - self, + self, width=18, small=False, - pretrained=None, # not used. here for proper signature + pretrained=None, # not used. here for proper signature num_classes=1000, in_channels=3, norm_layer="abn", @@ -241,27 +245,25 @@ def __init__( self.conv2 = conv3x3(stem_width, stem_width, stride=2) self.bn2 = norm_layer(stem_width, activation=norm_act) - + channels = [width, width * 2, width * 4, width * 8] n_blocks = [2 if small else 4] * 4 - + self.layer1 = make_layer(stem_width, stem_width, n_blocks[0], **bn_args) - + self.transition1 = TransitionBlock([stem_width * Bottleneck.expansion], channels[:2], **bn_args) - self.stage2 = self._make_stage( - n_modules=1, n_branches=2, n_blocks=n_blocks[:2], n_chnls=channels[:2] - ) - + self.stage2 = self._make_stage(n_modules=1, n_branches=2, n_blocks=n_blocks[:2], n_chnls=channels[:2]) + self.transition2 = TransitionBlock(channels[:2], channels[:3], **bn_args) - self.stage3 = self._make_stage( # 3 if small else 4 - n_modules=(4,3)[small], n_branches=3, n_blocks=n_blocks[:3], n_chnls=channels[:3] + self.stage3 = self._make_stage( # 3 if small else 4 + n_modules=(4, 3)[small], n_branches=3, n_blocks=n_blocks[:3], n_chnls=channels[:3] ) - + self.transition3 = TransitionBlock(channels[:3], channels, **bn_args) - self.stage4 = self._make_stage( # 2 if small else 3 - n_modules=(3,2)[small], n_branches=4, n_blocks=n_blocks, n_chnls=channels, + self.stage4 = self._make_stage( # 2 if small else 3 + n_modules=(3, 2)[small], n_branches=4, n_blocks=n_blocks, n_chnls=channels, ) - + self.encoder = encoder if encoder: self.forward = self.encoder_features @@ -276,16 +278,9 @@ def __init__( def _make_stage(self, n_modules, n_branches, n_blocks, n_chnls): modules = [] for i in range(n_modules): - modules.append( - HighResolutionModule( - n_branches, - n_blocks, - n_chnls, - **self.bn_args, - ) - ) + modules.append(HighResolutionModule(n_branches, n_blocks, n_chnls, **self.bn_args,)) return nn.Sequential(*modules) - + def encoder_features(self, x): # stem x = self.conv1(x) @@ -293,46 +288,46 @@ def encoder_features(self, x): x = self.conv2(x) x = self.bn2(x) x = self.layer1(x) - - x = self.transition1([x]) # x is actually a list now + + x = self.transition1([x]) # x is actually a list now x = self.stage2(x) - + x = self.transition2(x) x = self.stage3(x) - + x = self.transition3(x) x = self.stage4(x) - if self.encoder: # want to return from lowest resolution to highest + if self.encoder: # want to return from lowest resolution to highest x = [x[3], x[2], x[1], x[0], x[0]] return x - + def features(self, x): x = self.encoder_features(x) x = self.cls_head(x) return x - + def logits(self, x): x = self.global_pool(x) x = torch.flatten(x, 1) -# x = self.dropout(x) + # x = self.dropout(x) x = self.last_linear(x) return x - + def forward(self, x): x = self.features(x) x = self.logits(x) return x - + def load_state_dict(self, state_dict, **kwargs): self_keys = list(self.state_dict().keys()) sd_keys = list(state_dict.keys()) - sd_keys = [k for k in sd_keys if "num_batches_tracked" not in k] # filter + sd_keys = [k for k in sd_keys if "num_batches_tracked" not in k] # filter new_state_dict = {} for new_key, old_key in zip(self_keys, sd_keys): new_state_dict[new_key] = state_dict[old_key] super().load_state_dict(new_state_dict, **kwargs) - - + + # fmt: off CFGS = { "hrnet_w18_small": { @@ -368,9 +363,10 @@ def load_state_dict(self, state_dict, **kwargs): "imagenet": {"url": None}, }, } - + # fmt:on - + + def _hrnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] @@ -381,11 +377,10 @@ def _hrnet(arch, pretrained=None, **kwargs): cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) - assert ( - common_args == set() - ), "Args {} are going to be overwritten by default params for {} weights".format( - common_args, pretrained - ) + if common_args: + logging.warning( + f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" + ) kwargs.update(cfg_params) model = HighResolutionNet(**kwargs) if pretrained: @@ -421,7 +416,7 @@ def hrnet_w18_small(**kwargs): def hrnet_w18(**kwargs): r"""Constructs a HRNetv2-18 model.""" return _hrnet("hrnet_w18", **kwargs) - + @wraps(HighResolutionNet) @add_docs_for(HighResolutionNet) @@ -429,33 +424,37 @@ def hrnet_w30(**kwargs): r"""Constructs a HRNetv2-30 model.""" return _hrnet("hrnet_w30", **kwargs) + @wraps(HighResolutionNet) @add_docs_for(HighResolutionNet) def hrnet_w32(**kwargs): r"""Constructs a HRNetv2-32 model.""" return _hrnet("hrnet_w32", **kwargs) + @wraps(HighResolutionNet) @add_docs_for(HighResolutionNet) def hrnet_w40(**kwargs): r"""Constructs a HRNetv2-40 model.""" return _hrnet("hrnet_w40", **kwargs) + @wraps(HighResolutionNet) @add_docs_for(HighResolutionNet) def hrnet_w44(**kwargs): r"""Constructs a HRNetv2-44 model.""" return _hrnet("hrnet_w44", **kwargs) + @wraps(HighResolutionNet) @add_docs_for(HighResolutionNet) def hrnet_w48(**kwargs): r"""Constructs a HRNetv2-48 model.""" return _hrnet("hrnet_w48", **kwargs) + @wraps(HighResolutionNet) @add_docs_for(HighResolutionNet) def hrnet_w64(**kwargs): r"""Constructs a HRNetv2-64 model.""" return _hrnet("hrnet_w64", **kwargs) - diff --git a/pytorch_tools/models/resnet.py b/pytorch_tools/models/resnet.py index 7fbb4f8..bdc9e57 100644 --- a/pytorch_tools/models/resnet.py +++ b/pytorch_tools/models/resnet.py @@ -12,12 +12,14 @@ import torch import torch.nn as nn -from torchvision.models.utils import load_state_dict_from_url +from torch.hub import load_state_dict_from_url from pytorch_tools.modules import BasicBlock, Bottleneck from pytorch_tools.modules import GlobalPool2d, BlurPool from pytorch_tools.modules.residual import conv1x1, conv3x3 +from pytorch_tools.modules.pooling import FastGlobalAvgPool2d from pytorch_tools.modules import bn_from_name +from pytorch_tools.modules import SpaceToDepth from pytorch_tools.utils.misc import add_docs_for from pytorch_tools.utils.misc import DEFAULT_IMAGENET_SETTINGS from pytorch_tools.utils.misc import repeat_channels @@ -49,14 +51,19 @@ class ResNet(nn.Module): Number of classification classes. Defaults to 1000. in_channels (int): Number of input (color) channels. Defaults to 3. - use_se (bool): - Enable Squeeze-Excitation module in blocks. + attn_type (Union[str, None]): + If given, selects attention type to use in blocks. One of + `se` - Squeeze-Excitation + `eca` - Efficient Channel Attention groups (int): Number of convolution groups for 3x3 conv in Bottleneck. Defaults to 1. base_width (int): Factor determining bottleneck channels. `planes * base_width / 64 * groups`. Defaults to 64. - deep_stem (bool): - Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. Defaults to False. + stem_type (str): + Type on input stem. Supported options are: + '' - default. One 7x7 conv with 64 channels + 'deep' - three 3x3 conv with 32, 32, 64, channels + 'space2depth' - Reshape followed by one convolution. Idea from TResNet paper output_stride (List[8, 16, 32]): Applying dilation strategy to pretrained ResNet. Typically used in Semantic Segmentation. Defaults to 32. NOTE: Don't use this arg with `antialias` and `pretrained` together. it may produce weird results @@ -74,8 +81,6 @@ class ResNet(nn.Module): drop_connect_rate (float): Drop rate for StochasticDepth. Randomly removes samples each block. Used as regularization during training. keep prob will be linearly decreased from 1 to 1 - drop_connect_rate each block. Ref: https://arxiv.org/abs/1603.09382 - global_pool (str): - Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'. Defaults to 'avg'. init_bn0 (bool): Zero-initialize the last BN in each residual branch, so that the residual branch starts with zeros, and each residual block behaves like an identity. @@ -89,10 +94,10 @@ def __init__( pretrained=None, # not used. here for proper signature num_classes=1000, in_channels=3, - use_se=False, + attn_type=None, groups=1, base_width=64, - deep_stem=False, + stem_type="", output_stride=32, norm_layer="abn", norm_act="relu", @@ -100,7 +105,6 @@ def __init__( encoder=False, drop_rate=0.0, drop_connect_rate=0.0, - global_pool="avg", init_bn0=True, ): @@ -118,20 +122,9 @@ def __init__( self.drop_connect_rate = drop_connect_rate super(ResNet, self).__init__() - if deep_stem: - self.conv1 = nn.Sequential( - conv3x3(in_channels, stem_width // 2, 2), - norm_layer(stem_width // 2, activation=norm_act), - conv3x3(stem_width // 2, stem_width // 2), - norm_layer(stem_width // 2, activation=norm_act), - conv3x3(stem_width // 2, stem_width), - ) - else: - self.conv1 = nn.Conv2d(in_channels, stem_width, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = norm_layer(stem_width, activation=norm_act) - self.maxpool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=0 if use_se else 1, ceil_mode=True if use_se else False, - ) + # move stem creation in separate function for simplicity + self._make_stem(stem_type, stem_width, in_channels, norm_layer, norm_act) + if output_stride not in [8, 16, 32]: raise ValueError("Output stride should be in [8, 16, 32]") if output_stride == 8: @@ -140,17 +133,17 @@ def __init__( stride_3, stride_4, dilation_3, dilation_4 = 2, 1, 1, 2 elif output_stride == 32: stride_3, stride_4, dilation_3, dilation_4 = 2, 2, 1, 1 - largs = dict(use_se=use_se, norm_layer=norm_layer, norm_act=norm_act, antialias=antialias) + largs = dict(attn_type=attn_type, norm_layer=norm_layer, norm_act=norm_act, antialias=antialias) self.layer1 = self._make_layer(64, layers[0], stride=1, **largs) self.layer2 = self._make_layer(128, layers[1], stride=2, **largs) self.layer3 = self._make_layer(256, layers[2], stride=stride_3, dilation=dilation_3, **largs) self.layer4 = self._make_layer(512, layers[3], stride=stride_4, dilation=dilation_4, **largs) - self.global_pool = GlobalPool2d(global_pool) + self.global_pool = FastGlobalAvgPool2d() self.num_features = 512 * self.expansion self.encoder = encoder if not encoder: self.dropout = nn.Dropout(p=drop_rate, inplace=True) - self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.last_linear = nn.Linear(self.num_features, num_classes) else: self.forward = self.encoder_features @@ -162,7 +155,7 @@ def _make_layer( blocks, stride=1, dilation=1, - use_se=None, + attn_type=None, norm_layer=None, norm_act=None, antialias=None, @@ -188,7 +181,7 @@ def _make_layer( downsample=downsample, groups=self.groups, base_width=self.base_width, - use_se=use_se, + attn_type=attn_type, dilation=first_dilation, norm_layer=norm_layer, norm_act=norm_act, @@ -205,7 +198,7 @@ def _make_layer( planes=planes, groups=self.groups, base_width=self.base_width, - use_se=use_se, + attn_type=attn_type, dilation=first_dilation, norm_layer=norm_layer, norm_act=norm_act, @@ -215,6 +208,29 @@ def _make_layer( ) return nn.Sequential(*layers) + def _make_stem(self, stem_type, stem_width, in_channels, norm_layer, norm_act): + assert stem_type in {"", "deep", "space2depth"}, f"Stem type {stem_type} is not supported" + if stem_type == "space2depth": + # in the paper they use conv1x1 but in code conv3x3 (which seems better) + self.conv1 = nn.Sequential(SpaceToDepth(), conv3x3(in_channels * 16, stem_width)) + self.bn1 = norm_layer(stem_width, activation=norm_act) + self.maxpool = nn.Identity() # not used but needed for code compatability + else: + if stem_type == "deep": + self.conv1 = nn.Sequential( + conv3x3(in_channels, stem_width // 2, 2), + norm_layer(stem_width // 2, activation=norm_act), + conv3x3(stem_width // 2, stem_width // 2), + norm_layer(stem_width // 2, activation=norm_act), + conv3x3(stem_width // 2, stem_width), + ) + else: + self.conv1 = nn.Conv2d( + in_channels, stem_width, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm_layer(stem_width, activation=norm_act) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + def _initialize_weights(self, init_bn0=False): for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -282,6 +298,7 @@ def keep_prob(self): self.block_idx += 1 return keep_prob + # fmt: off CFGS = { # RESNET MODELS @@ -307,10 +324,32 @@ def keep_prob(self): "resnet50": { "default": {"params": {"block": Bottleneck, "layers": [3, 4, 6, 3]}, **DEFAULT_IMAGENET_SETTINGS,}, "imagenet": {"url": "https://download.pytorch.org/models/resnet50-19c8e357.pth"}, + # I couldn't validate this weights because they give Acc@1 0.1 maybe a bug somewhere. Still leaving them just in case + # it works better that starting from scratch + "imagenet_gn": { + "url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.2/R-101-GN-abf6008e.pth", + "params": {"norm_layer": "agn"} + }, + # Acc@1: 76.33. Acc@5: 93.34. This weights only work with weight standardization! + "imagenet_gn_ws": { + "url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.2/R-50-GN-WS-fd84efb6.pth", + "params": {"norm_layer": "agn"} + }, }, "resnet101": { "default": {"params": {"block": Bottleneck, "layers": [3, 4, 23, 3]}, **DEFAULT_IMAGENET_SETTINGS,}, "imagenet": {"url": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"}, + # I couldn't validate this weights because they give Acc@1 0.1 maybe a bug somewhere. Still leaving them just in case + # it works better that starting from scratch + "imagenet_gn": { + "url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.2/R-101-GN-abf6008e.pth", + "params": {"norm_layer": "agn"} + }, + # Acc@1: 77.85. Acc@5: 93.90. This weights only work with weight standardization! + "imagenet_gn_ws": { + "url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.2/R-101-GN-WS-c067a7de.pth", + "params": {"norm_layer": "agn"} + }, }, "resnet152": { "default": {"params": {"block": Bottleneck, "layers": [3, 8, 36, 3]}, **DEFAULT_IMAGENET_SETTINGS,}, @@ -337,25 +376,36 @@ def keep_prob(self): "params": {"block": Bottleneck, "layers": [3, 4, 6, 3], "base_width": 4, "groups": 32,}, **DEFAULT_IMAGENET_SETTINGS, }, - "imagenet": { # Acc@1: 75.80. Acc@5: 92.71. - "url": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" - }, + # Acc@1: 75.80. Acc@5: 92.71. + "imagenet": {"url": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth"}, # weights from rwightman "imagenet2": { "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth" }, + # Acc@1: 77.28. Acc@5: 93.61. This weights only work with weight standardization! + "imagenet_gn_ws": { + "url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.2/X-50-GN-WS-2dea43a8.pth", + "params": {"norm_layer": "agn"} + }, }, "resnext101_32x4d": { "default": { "params": {"block": Bottleneck, "layers": [3, 4, 23, 3], "base_width": 4, "groups": 32,}, **DEFAULT_IMAGENET_SETTINGS, - }, # No pretrained + }, # No imagenet pretrained + # 78.19. Acc@5: 93.98 This weights only work with weight standardization! + "imagenet_gn_ws": { + "url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.2/X-101-GN-WS-eb1224cd.pth", + "params": {"norm_layer": "agn"}, + } + }, "resnext101_32x8d": { "default": { "params": {"block": Bottleneck, "layers": [3, 4, 23, 3], "base_width": 8, "groups": 32,}, **DEFAULT_IMAGENET_SETTINGS, }, + # on 8.05.20 this link was broken. maybe need to fix in the future "imagenet": {"url": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth"}, # pretrained on weakly labeled instagram and then tuned on Imagenet "imagenet_ig": {"url": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth"}, @@ -386,28 +436,28 @@ def keep_prob(self): # SE RESNET MODELS "se_resnet34": { "default": { - "params": {"block": BasicBlock, "layers": [3, 4, 6, 3], "use_se": True}, + "params": {"block": BasicBlock, "layers": [3, 4, 6, 3], "attn_type": "se"}, **DEFAULT_IMAGENET_SETTINGS, }, # NO WEIGHTS }, "se_resnet50": { "default": { - "params": {"block": Bottleneck, "layers": [3, 4, 6, 3], "use_se": True}, + "params": {"block": Bottleneck, "layers": [3, 4, 6, 3], "attn_type": "se"}, **DEFAULT_IMAGENET_SETTINGS, }, "imagenet": {"url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth"}, }, "se_resnet101": { "default": { - "params": {"block": Bottleneck, "layers": [3, 4, 23, 3], "use_se": True}, + "params": {"block": Bottleneck, "layers": [3, 4, 23, 3], "attn_type": "se"}, **DEFAULT_IMAGENET_SETTINGS, }, "imagenet": {"url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth"}, }, "se_resnet152": { "default": { - "params": {"block": Bottleneck, "layers": [3, 4, 36, 3], "use_se": True}, + "params": {"block": Bottleneck, "layers": [3, 4, 36, 3], "attn_type": "se"}, **DEFAULT_IMAGENET_SETTINGS, }, "imagenet": {"url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth"}, @@ -420,7 +470,7 @@ def keep_prob(self): "layers": [3, 4, 6, 3], "base_width": 4, "groups": 32, - "use_se": True, + "attn_type": "se", }, **DEFAULT_IMAGENET_SETTINGS, }, @@ -433,7 +483,7 @@ def keep_prob(self): "layers": [3, 4, 23, 3], "base_width": 4, "groups": 32, - "use_se": True, + "attn_type": "se", }, **DEFAULT_IMAGENET_SETTINGS, }, @@ -453,11 +503,10 @@ def _resnet(arch, pretrained=None, **kwargs): cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) - assert ( - common_args == set() - ), "Args {} are going to be overwritten by default params for {} weights".format( - common_args, pretrained - ) + if common_args: + logging.warning( + f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" + ) kwargs.update(cfg_params) model = ResNet(**kwargs) if pretrained: diff --git a/pytorch_tools/models/tresnet.py b/pytorch_tools/models/tresnet.py index 35063c9..9be0c36 100644 --- a/pytorch_tools/models/tresnet.py +++ b/pytorch_tools/models/tresnet.py @@ -19,6 +19,7 @@ # avoid overwriting doc string wraps = partial(wraps, assigned=("__module__", "__name__", "__qualname__", "__annotations__")) + class TResNet(ResNet): """TResNet M / TResNet L / XL @@ -71,44 +72,45 @@ def __init__( drop_rate=0.0, drop_connect_rate=0.0, ): - nn.Module.__init__(self) + nn.Module.__init__(self) stem_width = int(64 * width_factor) norm_layer = bn_from_name(norm_layer) self.inplanes = stem_width self.num_classes = num_classes - self.groups = 1 # not really used but needed inside _make_layer - self.base_width = 64 # used inside _make_layer + self.groups = 1 # not really used but needed inside _make_layer + self.base_width = 64 # used inside _make_layer self.norm_act = norm_act self.block_idx = 0 self.num_blocks = sum(layers) self.drop_connect_rate = drop_connect_rate - # in the paper they use conv1x1 but in code conv3x3 (which seems better) - self.conv1 = nn.Sequential(SpaceToDepth(), conv3x3(in_channels * 16, stem_width)) - self.bn1 = norm_layer(stem_width, activation=norm_act) - self.maxpool = nn.Identity() # not used but needed for code compatability + self._make_stem("space2depth", stem_width, in_channels, norm_layer, norm_act) if output_stride not in [8, 16, 32]: raise ValueError("Output stride should be in [8, 16, 32]") # TODO add OS later # if output_stride == 8: - # stride_3, stride_4, dilation_3, dilation_4 = 1, 1, 2, 4 + # stride_3, stride_4, dilation_3, dilation_4 = 1, 1, 2, 4 # elif output_stride == 16: - # stride_3, stride_4, dilation_3, dilation_4 = 2, 1, 1, 2 + # stride_3, stride_4, dilation_3, dilation_4 = 2, 1, 1, 2 # elif output_stride == 32: stride_3, stride_4, dilation_3, dilation_4 = 2, 2, 1, 1 - largs = dict(use_se=True, norm_layer=norm_layer, norm_act=norm_act, antialias=True) + largs = dict(attn_type="se", norm_layer=norm_layer, norm_act=norm_act, antialias=True) self.block = TBasicBlock self.expansion = TBasicBlock.expansion self.layer1 = self._make_layer(stem_width, layers[0], stride=1, **largs) self.layer2 = self._make_layer(stem_width * 2, layers[1], stride=2, **largs) - self.block = TBottleneck # first 2 - Basic, last 2 - Bottleneck + self.block = TBottleneck # first 2 - Basic, last 2 - Bottleneck self.expansion = TBottleneck.expansion - self.layer3 = self._make_layer(stem_width * 4, layers[2], stride=stride_3, dilation=dilation_3, **largs) - largs.update(use_se=False) # no se in last layer - self.layer4 = self._make_layer(stem_width * 8, layers[3], stride=stride_4, dilation=dilation_4, **largs) + self.layer3 = self._make_layer( + stem_width * 4, layers[2], stride=stride_3, dilation=dilation_3, **largs + ) + largs.update(attn_type=None) # no se in last layer + self.layer4 = self._make_layer( + stem_width * 8, layers[3], stride=stride_4, dilation=dilation_4, **largs + ) self.global_pool = FastGlobalAvgPool2d(flatten=True) self.num_features = stem_width * 8 * self.expansion self.encoder = encoder @@ -126,6 +128,7 @@ def load_state_dict(self, state_dict, **kwargs): state_dict.pop("last_linear.bias") nn.Module.load_state_dict(self, state_dict, **kwargs) + # fmt: off # images should be normalized to [0, 1] PRETRAIN_SETTINGS = { @@ -171,12 +174,7 @@ def load_state_dict(self, state_dict, **kwargs): }, } # fmt: on -def patch_blur_pool(module): - """changes `gauss` attribute in blur pool to True""" - if isinstance(module, BlurPool): - module.gauss = True - for m in module.children(): - patch_blur_pool(m) + def patch_bn(module): """changes weight from InplaceABN to be compatible with usual ABN""" @@ -185,6 +183,7 @@ def patch_bn(module): for m in module.children(): patch_bn(m) + def _resnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] @@ -195,11 +194,10 @@ def _resnet(arch, pretrained=None, **kwargs): cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) - assert ( - common_args == set() - ), "Args {} are going to be overwritten by default params for {} weights".format( - common_args, pretrained - ) + if common_args: + logging.warning( + f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" + ) kwargs.update(cfg_params) model = TResNet(**kwargs) if pretrained: @@ -214,29 +212,32 @@ def _resnet(arch, pretrained=None, **kwargs): # if there is last_linear in state_dict, it's going to be overwritten state_dict["last_linear.weight"] = model.state_dict()["last_linear.weight"] state_dict["last_linear.bias"] = model.state_dict()["last_linear.bias"] - if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels - state_dict["conv1.1.weight"] = repeat_channels(state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16) + if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels + state_dict["conv1.1.weight"] = repeat_channels( + state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16 + ) model.load_state_dict(state_dict) - # need to adjust some parameters to be align with original model - patch_blur_pool(model) patch_bn(model) setattr(model, "pretrained_settings", cfg_settings) return model + @wraps(TResNet) @add_docs_for(TResNet) def tresnetm(**kwargs): r"""Constructs a TResnetM model.""" return _resnet("tresnetm", **kwargs) + @wraps(TResNet) @add_docs_for(TResNet) def tresnetl(**kwargs): r"""Constructs a TResnetL model.""" return _resnet("tresnetl", **kwargs) + @wraps(TResNet) @add_docs_for(TResNet) def tresnetxl(**kwargs): r"""Constructs a TResnetXL model.""" - return _resnet("tresnetxl", **kwargs) \ No newline at end of file + return _resnet("tresnetxl", **kwargs) diff --git a/pytorch_tools/models/vgg.py b/pytorch_tools/models/vgg.py index 41004cf..0f17b85 100644 --- a/pytorch_tools/models/vgg.py +++ b/pytorch_tools/models/vgg.py @@ -90,7 +90,6 @@ def forward(self, x): x = self.logits(x) return x - def _make_layers(self, cfg): layers = [] in_channels = self.in_channels @@ -170,11 +169,10 @@ def _vgg(arch, pretrained=None, **kwargs): cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) - assert ( - common_args == set() - ), "Args {} are going to be overwritten by default params for {} weights".format( - common_args, pretrained or "default" - ) + if common_args: + logging.warning( + f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" + ) kwargs.update(cfg_params) model = VGG(**kwargs) if pretrained: diff --git a/pytorch_tools/modules/__init__.py b/pytorch_tools/modules/__init__.py index 86a776f..3a436df 100644 --- a/pytorch_tools/modules/__init__.py +++ b/pytorch_tools/modules/__init__.py @@ -13,6 +13,7 @@ from .residual import SEModule # from .residual import Transition, DenseLayer +from .weight_standartization import conv_to_ws_conv from .activations import ACT_DICT from .activations import ACT_FUNC_DICT @@ -21,8 +22,10 @@ from .activated_batch_norm import ABN from .activated_group_norm import AGN +from .activated_no_norm import NoNormAct from inplace_abn import InPlaceABN, InPlaceABNSync + def bn_from_name(norm_name): norm_name = norm_name.lower() if norm_name == "abn": @@ -35,5 +38,7 @@ def bn_from_name(norm_name): return partial(ABN, frozen=True) elif norm_name in ("agn", "groupnorm", "group_norm"): return AGN + elif norm_name in ("none",): + return NoNormAct else: raise ValueError(f"Normalization {norm_name} not supported") diff --git a/pytorch_tools/modules/activated_batch_norm.py b/pytorch_tools/modules/activated_batch_norm.py index 8a32930..586ae0d 100644 --- a/pytorch_tools/modules/activated_batch_norm.py +++ b/pytorch_tools/modules/activated_batch_norm.py @@ -45,6 +45,8 @@ def __init__( self.momentum = momentum self.activation = ACT(activation) self.activation_param = activation_param + self.frozen = frozen + if frozen: self.register_buffer("weight", torch.ones(num_features)) self.register_buffer("bias", torch.zeros(num_features)) @@ -101,4 +103,6 @@ def extra_repr(self): rep = "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, activation={activation}" if self.activation in ["leaky_relu", "elu"]: rep += "[{activation_param}]" + if self.frozen: + rep += ", frozen=True" return rep.format(**self.__dict__) diff --git a/pytorch_tools/modules/activated_group_norm.py b/pytorch_tools/modules/activated_group_norm.py index 516b8b8..6b06743 100644 --- a/pytorch_tools/modules/activated_group_norm.py +++ b/pytorch_tools/modules/activated_group_norm.py @@ -7,6 +7,7 @@ from .activations import ACT from .activations import ACT_FUNC_DICT + class AGN(nn.Module): """Activated Group Normalization This gathers a GroupNorm and an activation function in a single module @@ -27,13 +28,7 @@ class AGN(nn.Module): """ def __init__( - self, - num_features, - num_groups=32, - eps=1e-5, - affine=True, - activation="relu", - activation_param=0.01, + self, num_features, num_groups=32, eps=1e-5, affine=True, activation="relu", activation_param=0.01, ): super(AGN, self).__init__() self.num_features = num_features diff --git a/pytorch_tools/modules/activated_no_norm.py b/pytorch_tools/modules/activated_no_norm.py new file mode 100644 index 0000000..7271287 --- /dev/null +++ b/pytorch_tools/modules/activated_no_norm.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from torch.nn.parameter import Parameter + +from .activations import ACT +from .activations import ACT_FUNC_DICT + + +class NoNormAct(nn.Module): + """Activated No Normalization + This is just an activation wrapped in class to allow easy swaping with BN and GN + Args: + num_features (int): Not used. here for compatability + activation (str): Name of the activation functions + activation_param (float): Negative slope for the `leaky_relu` activation. + """ + + def __init__(self, num_features, activation="relu", activation_param=0.01): + super().__init__() + self.num_features = num_features + self.activation = ACT(activation) + self.activation_param = activation_param + + def forward(self, x): + func = ACT_FUNC_DICT[self.activation] + if self.activation == ACT.LEAKY_RELU: + return func(x, inplace=True, negative_slope=self.activation_param) + elif self.activation == ACT.ELU: + return func(x, inplace=True, alpha=self.activation_param) + else: + return func(x, inplace=True) + + def extra_repr(self): + rep = "activation={activation}" + return rep.format(**self.__dict__) diff --git a/pytorch_tools/modules/bifpn.py b/pytorch_tools/modules/bifpn.py index 2f603c2..9b73309 100644 --- a/pytorch_tools/modules/bifpn.py +++ b/pytorch_tools/modules/bifpn.py @@ -2,17 +2,23 @@ import torch.nn as nn import torch.nn.functional as F +from .activations import activation_from_name from .residual import DepthwiseSeparableConv +from .residual import conv1x1 +from . import ABN + class FastNormalizedFusion(nn.Module): """Combines 2 or 3 feature maps into one with weights. Args: input_num (int): 2 for intermediate features, 3 for output features """ - def __init__(self, in_nodes): + + def __init__(self, in_nodes, activation="relu"): super().__init__() self.weights = nn.Parameter(torch.ones(in_nodes, dtype=torch.float32)) - self.register_buffer("eps", torch.tensor(0.0001)) + self.eps = 1e-4 + self.act = activation_from_name(activation) def forward(self, *features): # Assure that weights are positive (see paper) @@ -20,103 +26,135 @@ def forward(self, *features): # Normalize weights weights /= weights.sum() + self.eps fused_features = sum([p * w for p, w in zip(features, weights)]) - return fused_features + return self.act(fused_features) +# need to create weights to allow loading anyway. So inherit from FastNormalizedFusion for simplicity +class SumFusion(FastNormalizedFusion): + def forward(self, *features): + return self.act(sum(features)) -# close to one in the paper class BiFPNLayer(nn.Module): r"""Builds one layer of Bi-directional Feature Pyramid Network Args: channels (int): Number of channels in each feature map after BiFPN. Defaults to 64. - downsample_by_stride (bool): If True, use convolution layer with stride=2 instead of F.interpolate - upsample_mode (str): how to upsample low resolution features during top_down pathway. - See F.interpolate mode for details. - + Input: - features (List): 5 feature maps from encoder with resolution from 1/32 to 1/2 + features (List): 5 feature maps from encoder with resolution from 1/128 to 1/8 Returns: p_out: features processed by 1 layer of BiFPN """ - def __init__(self, channels=64, output_stride=32, upsample_mode="nearest", **bn_args): - super(BiFPNLayer, self).__init__() + def __init__(self, channels=64, norm_layer=ABN, norm_act="relu"): + super().__init__() + + self.up = nn.Upsample(scale_factor=2, mode="nearest") + self.down = nn.MaxPool2d(3, stride=2, padding=1) # padding=1 TODO: change back - self.up = nn.Upsample(scale_factor=2, mode=upsample_mode) - self.first_up = self.up if output_stride == 32 else nn.Identity() - last_stride = 2 if output_stride == 32 else 1 - self.down_p2 = DepthwiseSeparableConv(channels, channels, stride=2, **bn_args) - self.down_p3 = DepthwiseSeparableConv(channels, channels, stride=2, **bn_args) - self.down_p4 = DepthwiseSeparableConv(channels, channels, stride=last_stride, **bn_args) + # disable attention for large models. This is very dirty way to check that it's B6 & B7. But i don't care + Fusion = SumFusion if channels > 288 else FastNormalizedFusion - ## TODO (jamil) 11.02.2020 Rewrite this using list comprehensions - self.fuse_p4_td = FastNormalizedFusion(in_nodes=2) - self.fuse_p3_td = FastNormalizedFusion(in_nodes=2) - self.fuse_p2_td = FastNormalizedFusion(in_nodes=2) - self.fuse_p1_td = FastNormalizedFusion(in_nodes=2) + # There is no activation in SeparableConvs, instead activation is in fusion layer + self.fuse_p6_up = Fusion(in_nodes=2, activation=norm_act) + self.fuse_p5_up = Fusion(in_nodes=2, activation=norm_act) + self.fuse_p4_up = Fusion(in_nodes=2, activation=norm_act) - # Top-down pathway, no block for P1 layer - self.p4_td = DepthwiseSeparableConv(channels, channels, **bn_args) - self.p3_td = DepthwiseSeparableConv(channels, channels, **bn_args) - self.p2_td = DepthwiseSeparableConv(channels, channels, **bn_args) + self.fuse_p3_out = Fusion(in_nodes=2, activation=norm_act) + self.fuse_p4_out = Fusion(in_nodes=3, activation=norm_act) + self.fuse_p5_out = Fusion(in_nodes=3, activation=norm_act) + self.fuse_p6_out = Fusion(in_nodes=3, activation=norm_act) + self.fuse_p7_out = Fusion(in_nodes=2, activation=norm_act) - # Bottom-up pathway - self.fuse_p2_out = FastNormalizedFusion(in_nodes=3) - self.fuse_p3_out = FastNormalizedFusion(in_nodes=3) - self.fuse_p4_out = FastNormalizedFusion(in_nodes=3) - self.fuse_p5_out = FastNormalizedFusion(in_nodes=2) + bn_args = dict(norm_layer=norm_layer, norm_act="identity") + # Top-down pathway, no block for P7 layer + self.p6_up = DepthwiseSeparableConv(channels, channels, **bn_args) + self.p5_up = DepthwiseSeparableConv(channels, channels, **bn_args) + self.p4_up = DepthwiseSeparableConv(channels, channels, **bn_args) - self.p5_out = DepthwiseSeparableConv(channels, channels, **bn_args) - self.p4_out = DepthwiseSeparableConv(channels, channels, **bn_args) + # Bottom-up pathway self.p3_out = DepthwiseSeparableConv(channels, channels, **bn_args) - - + self.p4_out = DepthwiseSeparableConv(channels, channels, **bn_args) + self.p5_out = DepthwiseSeparableConv(channels, channels, **bn_args) + self.p6_out = DepthwiseSeparableConv(channels, channels, **bn_args) + self.p7_out = DepthwiseSeparableConv(channels, channels, **bn_args) + def forward(self, features): - p5_inp, p4_inp, p3_inp, p2_inp = features - - # Top-down pathway - p4_td = self.p4_td(self.fuse_p4_td(p4_inp, self.first_up(p5_inp))) - p3_td = self.p3_td(self.fuse_p3_td(p3_inp, self.up(p4_td))) - p2_out = self.p2_td(self.fuse_p2_td(p2_inp, self.up(p3_td))) - # Calculate Bottom-Up Pathway - p3_out = self.p3_out(self.fuse_p3_out(p3_inp, p3_td, self.down_p2(p2_out))) - p4_out = self.p4_out(self.fuse_p4_out(p4_inp, p4_td, self.down_p3(p3_out))) - p5_out = self.p5_out(self.fuse_p5_out(p5_inp, self.down_p4(p4_out))) + # p7, p6, p5, p4, p3 + p7_in, p6_in, p5_in, p4_in, p3_in = features + + # Top-down pathway (from low res to high res) + p6_up = self.p6_up(self.fuse_p6_up(p6_in, self.up(p7_in))) + p5_up = self.p5_up(self.fuse_p5_up(p5_in, self.up(p6_up))) + p4_up = self.p4_up(self.fuse_p4_up(p4_in, self.up(p5_up))) + p3_out = self.p3_out(self.fuse_p3_out(p3_in, self.up(p4_up))) + + # Bottom-Up Pathway (from high res to low res) + p4_out = self.p4_out(self.fuse_p4_out(p4_in, p4_up, self.down(p3_out))) + p5_out = self.p5_out(self.fuse_p5_out(p5_in, p5_up, self.down(p4_out))) + p6_out = self.p6_out(self.fuse_p6_out(p6_in, p6_up, self.down(p5_out))) + p7_out = self.p7_out(self.fuse_p7_out(p7_in, self.down(p6_out))) + + return p7_out, p6_out, p5_out, p4_out, p3_out + + +# additionally downsamples the input +class FirstBiFPNLayer(BiFPNLayer): + def __init__(self, encoder_channels, channels=64, norm_layer=ABN, norm_act="relu"): + super().__init__(channels=channels, norm_layer=norm_layer, norm_act=norm_act) + + # TODO: later remove bias from downsample + self.p5_downsample_1 = nn.Sequential( + conv1x1(encoder_channels[0], channels, bias=True), norm_layer(channels, activation="identity") + ) + self.p4_downsample_1 = nn.Sequential( + conv1x1(encoder_channels[1], channels, bias=True), norm_layer(channels, activation="identity") + ) + self.p3_downsample_1 = nn.Sequential( + conv1x1(encoder_channels[2], channels, bias=True), norm_layer(channels, activation="identity") + ) + + # Devil is in the details. In original repo they use 2 different downsamples from encoder channels + # it makes sense to preseve more information, but most of implementations in the internet + # use output of the first downsample + self.p4_downsample_2 = nn.Sequential( + conv1x1(encoder_channels[1], channels, bias=True), norm_layer(channels, activation="identity") + ) + self.p5_downsample_2 = nn.Sequential( + conv1x1(encoder_channels[0], channels, bias=True), norm_layer(channels, activation="identity") + ) + # only one downsample for p3 - return p5_out, p4_out, p3_out, p2_out + def forward(self, features): -# very simplified -class SimpleBiFPNLayer(nn.Module): - def __init__(self, channels=64, **bn_args): - super(SimpleBiFPNLayer, self).__init__() + # p7, p6, p5, p4, p3 + p7_in, p6_in, p5_in, p4_in, p3_in = features - self.up = nn.Upsample(scale_factor=2, mode="nearest") - self.down_p2 = DepthwiseSeparableConv(channels, channels, stride=2) - self.down_p3 = DepthwiseSeparableConv(channels, channels, stride=2) - self.down_p4 = DepthwiseSeparableConv(channels, channels, stride=2) + # downsample input's convs + p5_in_down1 = self.p5_downsample_1(p5_in) + p5_in_down2 = self.p5_downsample_2(p5_in) + p4_in_down1 = self.p4_downsample_1(p4_in) + p4_in_down2 = self.p4_downsample_2(p4_in) + p3_in_down1 = self.p3_downsample_1(p3_in) - self.fuse = sum + # Top-down pathway (from low res to high res) + p6_up = self.p6_up(self.fuse_p6_up(p6_in, self.up(p7_in))) + p5_up = self.p5_up(self.fuse_p5_up(p5_in_down1, self.up(p6_up))) + p4_up = self.p4_up(self.fuse_p4_up(p4_in_down1, self.up(p5_up))) + p3_out = self.p3_out(self.fuse_p3_out(p3_in_down1, self.up(p4_up))) - def forward(self, features): - p5_inp, p4_inp, p3_inp, p2_inp = features - - # Top-down pathway - p4_td = self.fuse(p4_inp, self.up(p5_inp)) - p3_td = self.fuse(p3_inp, self.up(p4_td)) - p2_out = self.fuse(p2_inp, self.up(p3_td)) - - # Calculate Bottom-Up Pathway - p3_out = self.fuse(p3_inp, p3_td, self.down_p2(p2_out)) - p4_out = self.fuse(p4_inp, p4_td, self.down_p3(p3_out)) - p5_out = self.fuse(p5_inp, self.down_p4(p4_out)) + # Bottom-Up Pathway (from high res to low res) + p4_out = self.p4_out(self.fuse_p4_out(p4_in_down2, p4_up, self.down(p3_out))) + p5_out = self.p5_out(self.fuse_p5_out(p5_in_down2, p5_up, self.down(p4_out))) + p6_out = self.p6_out(self.fuse_p6_out(p6_in, p6_up, self.down(p5_out))) + p7_out = self.p7_out(self.fuse_p7_out(p7_in, self.down(p6_out))) - return p5_out, p4_out, p3_out, p2_out + return p7_out, p6_out, p5_out, p4_out, p3_out -class BiFPN(nn.Module): +class BiFPN(nn.Sequential): """ Implementation of Bi-directional Feature Pyramid Network @@ -131,28 +169,10 @@ class BiFPN(nn.Module): https://arxiv.org/pdf/1911.09070.pdf """ - def __init__( - self, - encoder_channels, - pyramid_channels=64, - num_layers=1, - output_stride=32, - **bn_args, - ): - super(BiFPN, self).__init__() - - self.input_convs = nn.ModuleList([nn.Conv2d(in_ch, pyramid_channels, 1) for in_ch in encoder_channels]) - - bifpns = [] - for _ in range(num_layers): - bifpns.append(BiFPNLayer(pyramid_channels, output_stride, **bn_args)) - self.bifpn = nn.Sequential(*bifpns) - - def forward(self, features): - - # Preprocces raw encoder features - p5, p4, p3, p2 = [inp_conv(feature) for inp_conv, feature in zip(self.input_convs, features)] - + def __init__(self, encoder_channels, pyramid_channels=64, num_layers=1, **bn_args): + # First layer preprocesses raw encoder features + bifpns = [FirstBiFPNLayer(encoder_channels, pyramid_channels, **bn_args)] # Apply BiFPN block `num_layers` times - p5_out, p4_out, p3_out, p2_out = self.bifpn([p5, p4, p3, p2]) - return p5_out, p4_out, p3_out, p2_out + for _ in range(num_layers - 1): + bifpns.append(BiFPNLayer(pyramid_channels, **bn_args)) + super().__init__(*bifpns) diff --git a/pytorch_tools/modules/decoder.py b/pytorch_tools/modules/decoder.py index 731cdb5..af9412b 100644 --- a/pytorch_tools/modules/decoder.py +++ b/pytorch_tools/modules/decoder.py @@ -7,7 +7,7 @@ class UnetDecoderBlock(nn.Module): - def __init__(self, in_channels, out_channels, norm_layer=ABN, norm_act="relu"): + def __init__(self, in_channels, out_channels, norm_layer=ABN, norm_act="relu", upsample=True): super(UnetDecoderBlock, self).__init__() conv1 = conv3x3(in_channels, out_channels) @@ -15,10 +15,11 @@ def __init__(self, in_channels, out_channels, norm_layer=ABN, norm_act="relu"): abn1 = norm_layer(out_channels, activation=norm_act) abn2 = norm_layer(out_channels, activation=norm_act) self.block = nn.Sequential(conv1, abn1, conv2, abn2) + self.upsample = nn.Upsample(scale_factor=2, mode="bilinear") if upsample else nn.Identity() def forward(self, x): x, skip = x - x = F.interpolate(x, scale_factor=2, mode="nearest") + x = self.upsample(x) if skip is not None: x = torch.cat([x, skip], dim=1) x = self.block(x) diff --git a/pytorch_tools/modules/fpn.py b/pytorch_tools/modules/fpn.py index 2b6ef39..726f7ff 100644 --- a/pytorch_tools/modules/fpn.py +++ b/pytorch_tools/modules/fpn.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from .residual import conv1x1, conv3x3 + class MergeBlock(nn.Module): def forward(self, x): x, skip = x @@ -11,6 +12,7 @@ def forward(self, x): x += skip return x + class FPN(nn.Module): """Feature Pyramid Network for enhancing high-resolution feature maps with semantic meaning from low resolution maps @@ -25,11 +27,11 @@ class FPN(nn.Module): """ def __init__( - self, - encoder_channels, - pyramid_channels=256, - num_layers=1, - **bn_args, # for compatability only. Not used + self, + encoder_channels, + pyramid_channels=256, + num_layers=1, + **bn_args, # for compatability only. Not used ): super().__init__() assert num_layers == 1, "More that 1 layer is not supported in FPN" @@ -52,4 +54,4 @@ def forward(self, features): pyramid_features[idx] = self.merge_block([pyramid_features[idx - 1], pyramid_features[idx]]) # smooth them after merging pyramid_features = [s_conv(feature) for s_conv, feature in zip(self.smooth_convs, pyramid_features)] - return pyramid_features \ No newline at end of file + return pyramid_features diff --git a/pytorch_tools/modules/pooling.py b/pytorch_tools/modules/pooling.py index 34ee24d..6a359e6 100644 --- a/pytorch_tools/modules/pooling.py +++ b/pytorch_tools/modules/pooling.py @@ -64,6 +64,7 @@ def feat_mult(self): def __repr__(self): return self.__class__.__name__ + " (" + ", pool_type=" + self.pool_type + ")" + # https://github.com/mrT23/TResNet/ class FastGlobalAvgPool2d(nn.Module): def __init__(self, flatten=False): @@ -81,30 +82,24 @@ def forward(self, x): class BlurPool(nn.Module): """ Idea from https://arxiv.org/abs/1904.11486 - Efficient implementation of Rect-3 using AvgPool + Efficient implementation of Rect-3 Args: - channels (int): numbers of channels. needed for gaussian blur - gauss (bool): flag to use Gaussian Blur instead of Average Blur. Uses more memory + channels (int): numbers of input channels. needed to construct gauss kernel """ - def __init__(self, channels=0, gauss=False): + def __init__(self, channels=0): super(BlurPool, self).__init__() - self.gauss = gauss self.channels = channels - # init both options to be able to switch - a = torch.tensor([1., 2., 1.]) - filt = (a[:, None] * a[None, :]).clone().detach() + filt = torch.tensor([1.0, 2.0, 1.0]) + filt = filt[:, None] * filt[None, :] filt = filt / torch.sum(filt) filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) self.register_buffer("filt", filt) - self.pool = nn.AvgPool2d(3, stride=2, padding=1) def forward(self, inp): - if self.gauss: - inp_pad = F.pad(inp, (1, 1, 1, 1), 'reflect') - return F.conv2d(inp_pad, self.filt, stride=2, padding=0, groups=inp.shape[1]) - else: - return self.pool(inp) + inp_pad = F.pad(inp, (1, 1, 1, 1), "reflect") + return F.conv2d(inp_pad, self.filt, stride=2, padding=0, groups=inp.shape[1]) + # from https://github.com/mrT23/TResNet/ class SpaceToDepth(nn.Module): @@ -114,4 +109,4 @@ def forward(self, x): x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) - return x \ No newline at end of file + return x diff --git a/pytorch_tools/modules/residual.py b/pytorch_tools/modules/residual.py index c8da1f4..6b865d9 100644 --- a/pytorch_tools/modules/residual.py +++ b/pytorch_tools/modules/residual.py @@ -30,14 +30,69 @@ def conv1x1(in_planes, out_planes, stride=1, bias=False): return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) +class SEModule(nn.Module): + def __init__(self, channels, reduction_channels, norm_act="relu"): + super(SEModule, self).__init__() + + self.pool = FastGlobalAvgPool2d() + # authors of original paper DO use bias + self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, stride=1, bias=True) + self.act1 = activation_from_name(norm_act) + self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, stride=1, bias=True) + + def forward(self, x): + x_se = self.pool(x) + x_se = self.fc1(x_se) + x_se = self.act1(x_se) + x_se = self.fc2(x_se) + return x * x_se.sigmoid() + + +class ECAModule(nn.Module): + """Efficient Channel Attention + This implementation is different from the paper. I've removed all hyperparameters and + use fixed kernel size of 3. If you think it may be better to use different k_size - feel free to open an issue. + + Ref: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks + https://arxiv.org/abs/1910.03151 + + """ + + def __init__(self, *args, **kwargs): + super().__init__() + self.pool = FastGlobalAvgPool2d() + self.conv = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False) + + def forward(self, x): + x_s = self.pool(x) + x_s = self.conv(x_s.view(x.size(0), 1, -1)) + x_s = x_s.view(x.size(0), -1, 1, 1).sigmoid() + return x * x_s.expand_as(x) + + +def get_attn(attn_type): + """Args: attn_type (Uniont[str, None]): Attention type. Supported: + `se` - Squeeze and Excitation + `eca` - Efficient Channel Attention + None - not attention + """ + ATT_TO_MODULE = {"se": SEModule, "eca": ECAModule} + if attn_type is None: + return nn.Identity + else: + return ATT_TO_MODULE[attn_type.lower()] + + class DepthwiseSeparableConv(nn.Sequential): """Depthwise separable conv with BN after depthwise & pointwise.""" - def __init__(self, in_channels, out_channels, stride=1, dilation=1, norm_layer=ABN, norm_act="relu"): + def __init__( + self, in_channels, out_channels, stride=1, dilation=1, norm_layer=ABN, norm_act="relu", use_norm=True + ): modules = [ conv3x3(in_channels, in_channels, stride=stride, groups=in_channels, dilation=dilation), - conv1x1(in_channels, out_channels), - norm_layer(out_channels, activation=norm_act), + conv1x1(in_channels, out_channels, bias=True), # in efficient det they for some reason add bias + norm_layer(out_channels, activation=norm_act) if use_norm else nn.Identity(), ] super().__init__(*modules) @@ -50,7 +105,7 @@ def __init__( dw_kernel_size=3, stride=1, dilation=1, - use_se=False, + attn_type=None, expand_ratio=1.0, # expansion keep_prob=1, # drop connect param noskip=False, @@ -77,7 +132,7 @@ def __init__( ) self.bn2 = norm_layer(mid_chs, activation=norm_act) # some models like MobileNet use mid_chs here instead of in_channels. But I don't care for now - self.se = SEModule(mid_chs, in_channels // 4, norm_act) if use_se else nn.Identity() + self.se = get_attn(attn_type)(mid_chs, in_channels // 4, norm_act) self.conv_pw1 = conv1x1(mid_chs, out_channels) self.bn3 = norm_layer(out_channels, activation="identity") self.drop_connect = DropConnect(keep_prob) if keep_prob < 1 else nn.Identity() @@ -117,24 +172,6 @@ def forward(self, x): return output -class SEModule(nn.Module): - def __init__(self, channels, reduction_channels, norm_act="relu"): - super(SEModule, self).__init__() - - self.pool = FastGlobalAvgPool2d() - # authors of original paper DO use bias - self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, stride=1, bias=True) - self.act1 = activation_from_name(norm_act) - self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, stride=1, bias=True) - - def forward(self, x): - x_se = self.pool(x) - x_se = self.fc1(x_se) - x_se = self.act1(x_se) - x_se = self.fc2(x_se) - return x * x_se.sigmoid() - - class BasicBlock(nn.Module): expansion = 1 @@ -146,7 +183,7 @@ def __init__( downsample=None, groups=1, base_width=64, - use_se=False, + attn_type=None, dilation=1, norm_layer=ABN, norm_act="relu", @@ -163,7 +200,7 @@ def __init__( self.bn1 = norm_layer(planes, activation=norm_act) self.conv2 = conv3x3(planes, outplanes) self.bn2 = norm_layer(outplanes, activation="identity") - self.se_module = SEModule(outplanes, planes // 4) if use_se else None + self.se_module = get_attn(attn_type)(outplanes, planes // 4) self.final_act = activation_from_name(norm_act) self.downsample = downsample self.blurpool = BlurPool(channels=planes) if antialias else nn.Identity() @@ -183,10 +220,7 @@ def forward(self, x): out = self.blurpool(out) out = self.conv2(out) # avoid 2 inplace ops by chaining into one long op. Needed for inplaceabn - if self.se_module is not None: - out = self.drop_connect(self.se_module(self.bn2(out))) + residual - else: - out = self.drop_connect(self.bn2(out)) + residual + out = self.drop_connect(self.se_module(self.bn2(out))) + residual return self.final_act(out) @@ -201,7 +235,7 @@ def __init__( downsample=None, groups=1, base_width=64, - use_se=False, + attn_type=None, dilation=1, norm_layer=ABN, norm_act="relu", @@ -220,7 +254,7 @@ def __init__( self.bn2 = norm_layer(width, activation=norm_act) self.conv3 = conv1x1(width, outplanes) self.bn3 = norm_layer(outplanes, activation="identity") - self.se_module = SEModule(outplanes, planes // 4) if use_se else None + self.se_module = get_attn(attn_type)(outplanes, planes // 4) self.final_act = activation_from_name(norm_act) self.downsample = downsample self.blurpool = BlurPool(channels=width) if antialias else nn.Identity() @@ -244,32 +278,32 @@ def forward(self, x): out = self.conv3(out) # avoid 2 inplace ops by chaining into one long op - if self.se_module is not None: - out = self.drop_connect(self.se_module(self.bn3(out))) + residual - else: - out = self.drop_connect(self.bn3(out)) + residual + out = self.drop_connect(self.se_module(self.bn3(out))) + residual return self.final_act(out) + # TResnet models use slightly modified versions of BasicBlock and Bottleneck # need to adjust for it + class TBasicBlock(BasicBlock): def __init__(self, **kwargs): super().__init__(**kwargs) self.final_act = nn.ReLU(inplace=True) self.bn1.activation_param = 1e-3 # needed for loading weights - if not kwargs.get("use_se"): + if not kwargs.get("attn_type") == "se": return planes = kwargs["planes"] self.se_module = SEModule(planes, max(planes // 4, 64)) + class TBottleneck(Bottleneck): def __init__(self, **kwargs): super().__init__(**kwargs) self.final_act = nn.ReLU(inplace=True) - self.bn1.activation_param = 1e-3 # needed for loading weights + self.bn1.activation_param = 1e-3 # needed for loading weights self.bn2.activation_param = 1e-3 - if not kwargs["use_se"]: + if not kwargs.get("attn_type") == "se": return planes = kwargs["planes"] reduce_planes = max(planes * self.expansion // 8, 64) @@ -291,10 +325,9 @@ def forward(self, x): if self.antialias: out = self.blurpool(out) - if self.se_module is not None: - out = self.se_module(out) + out = self.se_module(out) out = self.conv3(out) # avoid 2 inplace ops by chaining into one long op out = self.drop_connect(self.bn3(out)) + residual - return self.final_act(out) \ No newline at end of file + return self.final_act(out) diff --git a/pytorch_tools/modules/spatial_ocr_block.py b/pytorch_tools/modules/spatial_ocr_block.py index cfe5be4..d419f28 100644 --- a/pytorch_tools/modules/spatial_ocr_block.py +++ b/pytorch_tools/modules/spatial_ocr_block.py @@ -5,7 +5,7 @@ ## Copyright (c) 2019 ## ## This source code is licensed under the MIT-style license found in the -## LICENSE file in the root directory of this source tree +## LICENSE file in the root directory of this source tree ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ## modified and simplified by @bonlime @@ -24,18 +24,20 @@ class SpatialOCR_Gather(nn.Module): Returns: torch.Tensor (B x C_2 x C_1 x 1) """ + def forward(self, feats, probs): # C_1 is number of final classes. C_2 in number of features in `feats` - probs = probs.view(probs.size(0), probs.size(1), -1) # B x C_1 x H x W => B x C_1 x HW - feats = feats.view(feats.size(0), feats.size(1), -1) # B x C_2 x H x W => B x C_2 x HW - feats = feats.permute(0, 2, 1) # B x HW x C_2 - probs = probs.softmax(dim=2) # B x C_1 x HW + probs = probs.view(probs.size(0), probs.size(1), -1) # B x C_1 x H x W => B x C_1 x HW + feats = feats.view(feats.size(0), feats.size(1), -1) # B x C_2 x H x W => B x C_2 x HW + feats = feats.permute(0, 2, 1) # B x HW x C_2 + probs = probs.softmax(dim=2) # B x C_1 x HW # B x C_1 x HW @ B x HW x C_2 => B x C_1 x C_2 => B x C_2 x C_1 => B x C_2 x C_1 x 1 ocr_context = torch.matmul(probs, feats).permute(0, 2, 1).unsqueeze(3) return ocr_context + # class ObjectAttentionBlock2D(nn.Module): -''' +""" The basic implementation for object context block Input: N X C X H X W @@ -44,7 +46,8 @@ def forward(self, feats, probs): key_channels : the dimension after the key/query transform Return: N X C X H X W -''' +""" + class SpatialOCR(nn.Module): """ @@ -57,14 +60,8 @@ class SpatialOCR(nn.Module): norm_layer (): Normalization layer to use norm_act (str): activation to use in `norm_layer` """ - def __init__( - self, - in_channels, - key_channels, - out_channels, - norm_layer=ABN, - norm_act="relu" - ): + + def __init__(self, in_channels, key_channels, out_channels, norm_layer=ABN, norm_act="relu"): super().__init__() self.in_channels = in_channels @@ -72,28 +69,25 @@ def __init__( self.f_pixel = nn.Sequential( conv1x1(in_channels, key_channels, bias=True), - norm_layer(key_channels, activation=norm_act), + norm_layer(key_channels, activation=norm_act), conv1x1(key_channels, key_channels, bias=True), norm_layer(key_channels, activation=norm_act), ) self.f_object = nn.Sequential( conv1x1(in_channels, key_channels, bias=True), - norm_layer(key_channels, activation=norm_act), + norm_layer(key_channels, activation=norm_act), conv1x1(key_channels, key_channels, bias=True), norm_layer(key_channels, activation=norm_act), ) self.f_down = nn.Sequential( - conv1x1(in_channels, key_channels, bias=True), - norm_layer(key_channels, activation=norm_act), + conv1x1(in_channels, key_channels, bias=True), norm_layer(key_channels, activation=norm_act), ) self.f_up = nn.Sequential( - conv1x1(key_channels, in_channels, bias=True), - norm_layer(in_channels, activation=norm_act), + conv1x1(key_channels, in_channels, bias=True), norm_layer(in_channels, activation=norm_act), ) self.conv_bn = nn.Sequential( - conv1x1(2 * in_channels, out_channels, bias=True), - norm_layer(out_channels, activation=norm_act), + conv1x1(2 * in_channels, out_channels, bias=True), norm_layer(out_channels, activation=norm_act), ) def forward(self, feats, proxy_feats): @@ -112,8 +106,8 @@ def forward(self, feats, proxy_feats): # sim_map.shape = B x H*W//16 x 256 @ B x 256 x C => B x H*W//16 x C sim_map = torch.matmul(query, key) - sim_map = (self.key_channels**-.5) * sim_map - sim_map = sim_map.softmax(dim=-1) + sim_map = (self.key_channels ** -0.5) * sim_map + sim_map = sim_map.softmax(dim=-1) # add bg context ... # context.shape = B x H*W//16 x C @ B x C x 256 => B x H*W//16 x 256 @@ -124,4 +118,4 @@ def forward(self, feats, proxy_feats): context = self.f_up(context) # concat and project output = self.conv_bn(torch.cat([context, feats], 1)) - return output \ No newline at end of file + return output diff --git a/pytorch_tools/modules/tf_same_ops.py b/pytorch_tools/modules/tf_same_ops.py new file mode 100644 index 0000000..cfb8026 --- /dev/null +++ b/pytorch_tools/modules/tf_same_ops.py @@ -0,0 +1,91 @@ +"""Implementations of Conv2d and MaxPool which match Tensorflow `same` padding""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair + + +def pad_same(x, k, s, d, value=0): + # type: (Tensor, int, int, int, float)->Tensor + # x - input tensor, s - stride, k - kernel_size, d - dilation + ih, iw = x.size()[-2:] + pad_h = max((math.ceil(ih / s) - 1) * s + (k - 1) * d + 1 - ih, 0) + pad_w = max((math.ceil(iw / s) - 1) * s + (k - 1) * d + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) + return x + + +# current implementation is only for symmetric case. But there are no non symmetric cases +def conv2d_same(x, weight, bias=None, stride=(1, 1), dilation=(1, 1), groups=1): + # type: (Tensor, Tensor, Optional[torch.Tensor], Tuple[int, int], Tuple[int, int], int)->Tensor + x = pad_same(x, weight.shape[-1], stride[0], dilation[0]) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +def maxpool2d_same(x, kernel_size, stride): + # type: (Tensor, Tuple[int, int], Tuple[int, int])->Tensor + x = pad_same(x, kernel_size[0], stride[0], 1, value=-float("inf")) + return F.max_pool2d(x, kernel_size, stride, (0, 0)) + + +class Conv2dSamePadding(nn.Conv2d): + """Assymetric padding matching TensorFlow `same`""" + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.dilation, self.groups) + + +# as of 1.5 there is no _pair in MaxPool. Remove when this is fixed +class MaxPool2dSamePadding(nn.MaxPool2d): + """Assymetric padding matching TensorFlow `same`""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.kernel_size = _pair(self.kernel_size) + self.stride = _pair(self.stride) + + def forward(self, x): + return maxpool2d_same(x, self.kernel_size, self.stride) + + +def conv_to_same_conv(module): + """Turn All Conv2d into SameConv2d to match TF padding""" + module_output = module + # skip 1x1 convs + if isinstance(module, nn.Conv2d) and module.kernel_size[0] != 1: + module_output = Conv2dSamePadding( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=0, # explicitly set to 0 + dilation=module.dilation, + groups=module.groups, + bias=module.bias is not None, + ) + with torch.no_grad(): + module_output.weight.copy_(module.weight) + module_output.weight.requires_grad = module.weight.requires_grad + if module.bias is not None: + module_output.bias.copy_(module.bias) + module_output.bias.requires_grad = module.bias.requires_grad + + for name, child in module.named_children(): + module_output.add_module(name, conv_to_same_conv(child)) + del module + return module_output + + +def maxpool_to_same_maxpool(module): + """Turn All MaxPool2d into SameMaxPool2d to match TF padding""" + module_output = module + if isinstance(module, nn.MaxPool2d): + module_output = MaxPool2dSamePadding( + kernel_size=module.kernel_size, stride=module.stride, padding=0, # explicitly set to 0 + ) + for name, child in module.named_children(): + module_output.add_module(name, maxpool_to_same_maxpool(child)) + del module + return module_output diff --git a/pytorch_tools/modules/weight_standartization.py b/pytorch_tools/modules/weight_standartization.py index 50e8631..3e0edd7 100644 --- a/pytorch_tools/modules/weight_standartization.py +++ b/pytorch_tools/modules/weight_standartization.py @@ -1,37 +1,41 @@ +import torch from torch import nn import torch.nn.functional as F # implements idea from `Weight Standardization` paper https://arxiv.org/abs/1903.10520 -# eps is inside sqrt to avoid overflow Idea from https://arxiv.org/abs/1911.05920 +# eps is inside sqrt to avoid overflow Idea from https://arxiv.org/abs/1911.05920 class WS_Conv2d(nn.Conv2d): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True): - super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) - def forward(self, x): weight = self.weight - weight = weight.sub(weight.mean(dim=(1, 2, 3), keepdim=True)) - std = weight.var(dim=(1, 2, 3), keepdim=True).add_(1e-7).sqrt_() - weight = weight.div(std.expand_as(weight)) + var, mean = torch.var_mean(weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = (weight - mean) / torch.sqrt(var + 1e-7) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) -# code from random issue on github. -def convertConv2WeightStand(module, nextChild=None): - mod = module - norm_list = [torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.BatchNorm3d, torch.nn.GroupNorm, torch.nn.LayerNorm] - conv_list = [torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d] - for norm in norm_list: - for conv in conv_list: - if isinstance(mod, conv) and isinstance(nextChild, norm): - mod = Conv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, - mod.padding, mod.dilation, mod.groups, mod.bias!=None) - moduleChildList = list(module.named_children()) - for index, [name, child] in enumerate(moduleChildList): - nextChild = None - if index < len(moduleChildList) -1: - nextChild = moduleChildList[index+1][1] - mod.add_module(name, convertConv2WeightStand(child, nextChild)) - - return mod +# code from SyncBatchNorm in pytorch +def conv_to_ws_conv(module): + module_output = module + if isinstance(module, torch.nn.Conv2d): + module_output = WS_Conv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + # groups are also present in DepthWiseConvs which we don't want to patch + # TODO: fix this + groups=module.groups, + bias=module.bias is not None, + ) + with torch.no_grad(): # not sure if torch.no_grad is needed. but just in case + module_output.weight.copy_(module.weight) + module_output.weight.requires_grad = module.weight.requires_grad + if module.bias is not None: + module_output.bias.copy_(module.bias) + module_output.bias.requires_grad = module.bias.requires_grad + for name, child in module.named_children(): + module_output.add_module(name, conv_to_ws_conv(child)) + del module + return module_output diff --git a/pytorch_tools/optim/README.md b/pytorch_tools/optim/README.md index ca95eee..c1bd7ca 100644 --- a/pytorch_tools/optim/README.md +++ b/pytorch_tools/optim/README.md @@ -1,2 +1,2 @@ -# Custom optimizers and utils +# PyTorch Optimizers and Utils Self-explanatory \ No newline at end of file diff --git a/pytorch_tools/optim/__init__.py b/pytorch_tools/optim/__init__.py index fc06bbe..798ccd9 100644 --- a/pytorch_tools/optim/__init__.py +++ b/pytorch_tools/optim/__init__.py @@ -6,6 +6,7 @@ from .radam import RAdam, PlainRAdam from .sgdw import SGDW from .schedulers import LinearLR, ExponentialLR +from .rmsprop import RMSprop from .lookahead import Lookahead from torch import optim @@ -25,7 +26,8 @@ def optimizer_from_name(optim_name): # in this implementation eps in inside sqrt so it can be smaller return partial(AdamW_my, center=True, eps=1e-7) elif optim_name == "rmsprop": - return partial(optim.RMSprop, 2e-5) + # in this implementation eps in inside sqrt so it can be smaller + return partial(RMSprop, eps=1e-7) elif optim_name == "radam": return partial(RAdam, eps=2e-5) elif optim_name in ["fused_sgd", "fusedsgd"]: diff --git a/pytorch_tools/optim/rmsprop.py b/pytorch_tools/optim/rmsprop.py new file mode 100644 index 0000000..b1fb033 --- /dev/null +++ b/pytorch_tools/optim/rmsprop.py @@ -0,0 +1,61 @@ +"""Implementation of TF-like RMSprop with epsilong inside sqrt. The only difference is at line 52""" +import torch +from torch.optim import RMSprop as _RMSprop + +class RMSprop(_RMSprop): + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + square_avg = state['square_avg'] + alpha = group['alpha'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + grad = grad.add(p, alpha=group['weight_decay']) + + square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) + + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) + avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps']) + else: + # avg = square_avg.sqrt().add_(group['eps']) + avg = square_avg.add_(group['eps']).sqrt() + + if group['momentum'] > 0: + buf = state['momentum_buffer'] + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.add_(buf, alpha=-group['lr']) + else: + p.addcdiv_(grad, avg, value=-group['lr']) + + return loss \ No newline at end of file diff --git a/pytorch_tools/segmentation_models/README.md b/pytorch_tools/segmentation_models/README.md index 30404ce..b9ca9f7 100644 --- a/pytorch_tools/segmentation_models/README.md +++ b/pytorch_tools/segmentation_models/README.md @@ -1 +1,22 @@ -TODO \ No newline at end of file +# PyTorch Segmentation Models Zoo +All models here were either written from scratch or refactored from open-source implementations. +All models here use `Activated Normalization` layers instead of traditional `Normalization` followed by `Activation`. It makes changing activation function and normalization layer easy and convenient. It also allows using [Inplace Activated Batch Norm](https://github.com/mapillary/inplace_abn) from the box, which is essential for reducing memory footprint in segmentation tasks. + + +## Encoders +All [models](../models/) could be used as feature extractors (aka backbones) for segmentation architectures. Almost all combinations of backbones and segm.model are supported. + + +## Features +* Unified API. Create `Unet, SegmentaionFPN, HRnet` models using the same code +* Support for custom number of input channels in pretrained encoders +* All core functionality covered with tests + + +## Repositories used +* [Torch Vision Main Repo](https://github.com/pytorch/vision) +* [Cadene pretrained models](https://github.com/Cadene/pretrained-models.pytorch/) +* [Ross Wightman models](https://github.com/rwightman/pytorch-image-models/) +* [Pytorch Toolbelt by @BloodAxe](https://github.com/BloodAxe/pytorch-toolbelt) +* [Segmentation Models py @qubvel](https://github.com/qubvel/segmentation_models.pytorch) +* [Original HRNet for Segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation) \ No newline at end of file diff --git a/pytorch_tools/segmentation_models/encoders.py b/pytorch_tools/segmentation_models/encoders.py index 4f11acc..088a48e 100644 --- a/pytorch_tools/segmentation_models/encoders.py +++ b/pytorch_tools/segmentation_models/encoders.py @@ -23,14 +23,14 @@ "densenet201": (1920, 1792, 512, 256, 64), #'densenet161': (2208, 2112, 768, 384, 96), - "efficientnet_b0": (1280, 80, 40, 24, 16), - "efficientnet_b1": (1280, 80, 40, 24, 16), - "efficientnet_b2": (1408, 88, 48, 24, 16), - "efficientnet_b3": (1536, 96, 48, 32, 24), - "efficientnet_b4": (1792, 112, 56, 32, 24), - "efficientnet_b5": (2048, 128, 64, 40, 24), - "efficientnet_b6": (2304, 144, 72, 40, 32), - "efficientnet_b7": (2560, 160, 80, 48, 32), + "efficientnet_b0": (320, 112, 40, 24, 16), + "efficientnet_b1": (320, 112, 40, 24, 16), + "efficientnet_b2": (352, 120, 48, 24, 16), + "efficientnet_b3": (384, 136, 48, 32, 24), + "efficientnet_b4": (448, 160, 56, 32, 24), + "efficientnet_b5": (512, 176, 64, 40, 24), + "efficientnet_b6": (576, 200, 72, 40, 32), + "efficientnet_b7": (640, 224, 80, 48, 32), # this models return feature maps at OS= 32, 16, 8, 4, 4 # they CAN'T be used as encoders in Unet and Linknet diff --git a/pytorch_tools/segmentation_models/hrnet.py b/pytorch_tools/segmentation_models/hrnet.py index a763849..8e7c3dd 100644 --- a/pytorch_tools/segmentation_models/hrnet.py +++ b/pytorch_tools/segmentation_models/hrnet.py @@ -138,12 +138,15 @@ def forward(self, x): return x def _init_weights(self): - # init all weights except encoder (to allow pretrain) - initialize(self.head) + # it works better if we only init last bias not whole decoder part + # set last layer bias for better convergence with sigmoid loss + # -4.59 = -np.log((1 - 0.01) / 0.01) if self.OCR: - initialize(self.aux_head) - initialize(self.ocr_distri_head) - initialize(self.conv3x3) + nn.init.constant_(self.head.bias, -4.59) + nn.init.constant_(self.aux_head[2].bias, -4.59) + else: + nn.init.constant_(self.head[2].bias, -4.59) + # fmt: off SETTINGS = { diff --git a/pytorch_tools/segmentation_models/linknet.py b/pytorch_tools/segmentation_models/linknet.py index 53a8e99..e51636d 100644 --- a/pytorch_tools/segmentation_models/linknet.py +++ b/pytorch_tools/segmentation_models/linknet.py @@ -22,7 +22,8 @@ def __init__( self.dropout = nn.Dropout2d(drop_rate, inplace=True) self.final_conv = conv1x1(prefinal_channels, final_channels) - initialize(self) + # it works much better without initializing decoder. maybe need to investigate into this issue + # initialize(self) def forward(self, x): encoder_head = x[0] @@ -51,7 +52,7 @@ class Linknet(EncoderDecoder): drop_rate (float): Probability of spatial dropout on last feature map norm_layer (str): Normalization layer to use. One of 'abn', 'inplaceabn'. The inplace version lowers memory footprint. But increases backward time. Defaults to 'abn'. - norm_act (str): Activation for normalizion layer. 'inplaceabn' doesn't support `ReLU` activation. + norm_act (str): Activation for normalization layer. 'inplaceabn' doesn't support `ReLU` activation. Returns: ``torch.nn.Module``: **Linknet** .. _Linknet: diff --git a/pytorch_tools/segmentation_models/segm_fpn.py b/pytorch_tools/segmentation_models/segm_fpn.py index 46cd152..bc73518 100644 --- a/pytorch_tools/segmentation_models/segm_fpn.py +++ b/pytorch_tools/segmentation_models/segm_fpn.py @@ -6,7 +6,6 @@ from pytorch_tools.modules.residual import conv1x1 from pytorch_tools.modules.residual import conv3x3 from pytorch_tools.modules.decoder import SegmentationUpsample -from pytorch_tools.utils.misc import initialize from .encoders import get_encoder @@ -114,9 +113,6 @@ def __init__( self.segm_head = conv1x1(segmentation_channels, num_classes) self.upsample = nn.Upsample(scale_factor=4, mode="bilinear") if last_upsample else nn.Identity() self.name = f"segm-fpn-{encoder_name}" - initialize(self.fpn) - initialize(self.decoder) - initialize(self.segm_head) def forward(self, x): x = self.encoder(x) # returns 5 features maps diff --git a/pytorch_tools/segmentation_models/unet.py b/pytorch_tools/segmentation_models/unet.py index d784e37..5490a6e 100644 --- a/pytorch_tools/segmentation_models/unet.py +++ b/pytorch_tools/segmentation_models/unet.py @@ -1,6 +1,7 @@ import torch.nn as nn from pytorch_tools.modules import bn_from_name from pytorch_tools.modules.residual import conv1x1 +from pytorch_tools.modules.residual import conv3x3 from pytorch_tools.modules.decoder import UnetDecoderBlock from pytorch_tools.utils.misc import initialize from .base import EncoderDecoder @@ -20,6 +21,7 @@ def __init__( final_channels=1, center=False, drop_rate=0, + output_stride=32, **bn_params, # norm layer, norm_act ): @@ -32,16 +34,13 @@ def __init__( in_channels = self.compute_channels(encoder_channels, decoder_channels) out_channels = decoder_channels - - self.layer1 = UnetDecoderBlock(in_channels[0], out_channels[0], **bn_params) - self.layer2 = UnetDecoderBlock(in_channels[1], out_channels[1], **bn_params) + self.layer1 = UnetDecoderBlock(in_channels[0], out_channels[0], upsample=output_stride == 32, **bn_params) + self.layer2 = UnetDecoderBlock(in_channels[1], out_channels[1], upsample=not output_stride == 8, **bn_params) self.layer3 = UnetDecoderBlock(in_channels[2], out_channels[2], **bn_params) self.layer4 = UnetDecoderBlock(in_channels[3], out_channels[3], **bn_params) self.layer5 = UnetDecoderBlock(in_channels[4], out_channels[4], **bn_params) self.dropout = nn.Dropout2d(drop_rate, inplace=False) # inplace=True raises a backprop error - self.final_conv = conv1x1(out_channels[4], final_channels) - - initialize(self) + self.final_conv = conv1x1(out_channels[4], final_channels, bias=True) def compute_channels(self, encoder_channels, decoder_channels): channels = [ @@ -97,11 +96,14 @@ def __init__( decoder_channels=(256, 128, 64, 32, 16), num_classes=1, center=False, # usefull for VGG models + output_stride=32, drop_rate=0, norm_layer="abn", norm_act="relu", **encoder_params, - ): + ): + if output_stride != 32: + encoder_params["output_stride"] = output_stride encoder = get_encoder( encoder_name, norm_layer=norm_layer, @@ -115,9 +117,13 @@ def __init__( final_channels=num_classes, center=center, drop_rate=drop_rate, + output_stride=output_stride, norm_layer=bn_from_name(norm_layer), norm_act=norm_act, ) super().__init__(encoder, decoder) self.name = f"u-{encoder_name}" + # set last layer bias for better convergence with sigmoid loss + # -4.59 = -np.log((1 - 0.01) / 0.01) + nn.init.constant_(self.decoder.final_conv.bias, -4.59) \ No newline at end of file diff --git a/pytorch_tools/utils/box.py b/pytorch_tools/utils/box.py new file mode 100644 index 0000000..16b349e --- /dev/null +++ b/pytorch_tools/utils/box.py @@ -0,0 +1,373 @@ +"""Various functions to help with bboxes for object detection""" +import torch +import numpy as np +from functools import wraps + + +def box2delta(boxes, anchors): + # type: (Tensor, Tensor)->Tensor + """Convert boxes to deltas from anchors. Boxes are expected in 'ltrb' format + Args: + boxes (torch.Tensor): shape [N, 4] or [BS, N, 4] + anchors (torch.Tensor): shape [N, 4] or [BS, N, 4] + Returns: + deltas (torch.Tensor): shape [N, 4] or [BS, N, 4] + offset_x, offset_y, scale_x, scale_y + """ + + anchors_wh = anchors[..., 2:] - anchors[..., :2] + anchors_ctr = anchors[..., :2] + 0.5 * anchors_wh + boxes_wh = boxes[..., 2:] - boxes[..., :2] + boxes_ctr = boxes[..., :2] + 0.5 * boxes_wh + offset_delta = (boxes_ctr - anchors_ctr) / anchors_wh + scale_delta = torch.log(boxes_wh / anchors_wh) + return torch.cat([offset_delta, scale_delta], -1) + + +def delta2box(deltas, anchors): + # type: (Tensor, Tensor)->Tensor + """Convert anchors to boxes using deltas. Boxes are expected in 'ltrb' format + Args: + deltas (torch.Tensor): shape [N, 4] or [BS, N, 4] + anchors (torch.Tensor): shape [N, 4] or [BS, N, 4] + Returns: + bboxes (torch.Tensor): bboxes obtained from anchors by regression + Output shape is [N, 4] or [BS, N, 4] depending on input + """ + anchors_wh = anchors[..., 2:] - anchors[..., :2] + ctr = anchors[..., :2] + 0.5 * anchors_wh + pred_ctr = deltas[..., :2] * anchors_wh + ctr + + # Value for clamping large dw and dh predictions. The heuristic is that we clamp + # such that dw and dh are no larger than what would transform a 16px box into a + # 1000px box (based on a small anchor, 16px, and a typical image size, 1000px). + SCALE_CLAMP = 4.135 # ~= np.log(1000. / 16.) + deltas[..., 2:] = deltas[..., 2:].clamp(min=-SCALE_CLAMP, max=SCALE_CLAMP) + + pred_wh = deltas[..., 2:].exp() * anchors_wh + return torch.cat([pred_ctr - 0.5 * pred_wh, pred_ctr + 0.5 * pred_wh], -1) + + +def box_area(box): + """Args: + box (torch.Tensor): shape [N, 4] or [BS, N, 4] in 'ltrb' format + """ + return (box[..., 2] - box[..., 0]) * (box[..., 3] - box[..., 1]) + + +def clip_bboxes(bboxes, size): + """Args: + bboxes (torch.Tensor): in `ltrb` format. Shape [N, 4] + size (Union[torch.Size, tuple]): (H, W). Shape [2,]""" + bboxes[:, 0::2] = bboxes[:, 0::2].clamp(0, size[1]) + bboxes[:, 1::2] = bboxes[:, 1::2].clamp(0, size[0]) + return bboxes + + +def clip_bboxes_batch(bboxes, size): + # type: (Tensor, Tensor)->Tensor + """Args: + bboxes (torch.Tensor): in `ltrb` format. Shape [BS, N, 4] + size (torch.Tensor): (H, W). Shape [BS, 2] """ + size = size.to(bboxes) + h_size = size[..., 0].view(-1, 1, 1) # .float() + w_size = size[..., 1].view(-1, 1, 1) # .float() + h_bboxes = bboxes[..., 1::2] + w_bboxes = bboxes[..., 0::2] + zeros = torch.zeros_like(h_bboxes) + bboxes[..., 1::2] = h_bboxes.where(h_bboxes > 0, zeros).where(h_bboxes < h_size, h_size) + bboxes[..., 0::2] = w_bboxes.where(w_bboxes > 0, zeros).where(w_bboxes < w_size, w_size) + # FIXME: I'm using where to support passing tensor. change to `clamp` when PR #32587 is resolved + # bboxes[:, 0::2] = bboxes[:, 0::2].clamp(0, size[1].item()) + # bboxes[:, 1::2] = bboxes[:, 1::2].clamp(0, size[0].item()) + return bboxes + + +# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py +# with slight modifications +def box_iou(boxes1, boxes2): + # type: (Tensor, Tensor)->Tensor + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in `ltrb`: (x1, y1, x2, y2) format. + Arguments: + boxes1 (Tensor[N, 4]) + boxes2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + iou = inter / (area1[:, None] + area2 - inter) + return iou + + +# based on https://github.com/NVIDIA/retinanet-examples/ +# and on https://github.com/google/automl/ +def generate_anchors_boxes( + image_size, num_scales=3, aspect_ratios=(1.0, 2.0, 0.5), pyramid_levels=[3, 4, 5, 6, 7], anchor_scale=4, +): + """Generates multiscale anchor boxes + Minimum object size which could be detected is anchor_scale * 2**pyramid_levels[0]. By default it's 32px + Maximum object size which could be detected is anchor_scale * 2**pyramid_levels[-1]. By default it's 512px + + Args: + image_size (int or (int, int)): shape of the image + num_scales (int): integer number representing intermediate scales added on each level. For instances, + num_scales=3 adds three additional anchor scales [2^0, 2^0.33, 2^0.66] on each level. + aspect_ratios (List[int]): Aspect ratios of anchor boxes + pyramid_levels (List[int]): Levels from which features are taken. Needed to calculate stride + anchor_scale (float): scale of size of the base anchor. Lower values allows detection of smaller objects. + + Returns: + anchor_boxes (torch.Tensor): stacked anchor boxes on all feature levels. shape [N, 4]. + boxes are in 'ltrb' format + num_anchors (int): number of anchors per location + """ + + if isinstance(image_size, int): + image_size = (image_size, image_size) + scale_vals = [anchor_scale * 2 ** (i / num_scales) for i in range(num_scales)] + # from lowest stride to largest. Anchors from models should be in the same order! + strides = [2 ** i for i in pyramid_levels] + + # get offsets for anchor boxes for one pixel + # can rewrite in pure Torch but using np is more convenient. This function usually should only be called once + num_anchors = len(scale_vals) * len(aspect_ratios) + ratio_vals_sq = np.sqrt(np.tile(aspect_ratios, len(scale_vals))) + scale_vals = np.repeat(scale_vals, len(aspect_ratios))[:, np.newaxis] + wh = np.stack([np.ones(num_anchors) * ratio_vals_sq, np.ones(num_anchors) / ratio_vals_sq], axis=1) + lt = -0.5 * wh * scale_vals + rb = 0.5 * wh * scale_vals + base_offsets = torch.from_numpy(np.hstack([lt, rb])).float() # [num_anchors, 4] + base_offsets = base_offsets.view(-1, 1, 1, 4) # [num_anchors, 1, 1, 4] + # generate anchor boxes for all given strides + all_anchors = [] + for stride in strides: + y, x = torch.meshgrid([torch.arange(stride / 2, image_size[i], stride) for i in range(2)]) + xyxy = torch.stack((x, y, x, y), 2).unsqueeze(0) + # permute to match TF EffDet anchors order after reshape + anchors = (xyxy + base_offsets * stride).permute(1, 2, 0, 3).reshape(-1, 4) + all_anchors.append(anchors) + all_anchors = torch.cat(all_anchors) + # clip boxes to image. Not sure if we really need to clip them + # clip_bboxes(all_anchors, image_size) + return all_anchors, num_anchors + + +def generate_targets(anchors, batch_gt_boxes, num_classes, matched_iou=0.5, unmatched_iou=0.4): + """Generate targets for regression and classification + + Based on IoU between anchor and true bounding box there are three types of anchor boxes + 1) IoU >= matched_iou: Highest similarity. Matched/Positive. Mask value is 1 + 2) matched_iou > IoU >= unmatched_iou: Medium similarity. Ignored. Mask value is -1 + 3) unmatched_iou > IoU: Lowest similarity. Unmatched/Negative. Mask value is 0 + + Args: + anchors (torch.Tensor): all anchors on a single image. shape [N, 4] + batch_gt_boxes (torch.Tensor): all ground truth bounding boxes and classes for the batch. shape [BS, N, 5] + classes are expected to be in the last column. + bboxes are in `ltrb` format! + num_classes (int): number of classes. needed for one-hot encoding labels + matched_iou (float): + unmatched_iou (float): + + Returns: + box_target, cls_target, matches_mask + + """ + + def _generate_single_targets(gt_boxes): + gt_boxes, gt_classes = gt_boxes.split(4, dim=1) + overlap = box_iou(anchors, gt_boxes) + + # Keep best box per anchor + overlap, indices = overlap.max(1) + box_target = box2delta(gt_boxes[indices], anchors) + + # There are three types of anchors. + # matched (with objects), unmatched (with background), and in between (which should be ignored) + IGNORED_VALUE = -1 + UNMATCHED_VALUE = 0 + matches_mask = torch.ones_like(overlap) * IGNORED_VALUE + matches_mask[overlap < unmatched_iou] = UNMATCHED_VALUE # background + matches_mask[overlap >= matched_iou] = 1 + + # Generate one-hot-encoded target classes + cls_target = torch.zeros( + (anchors.size(0), num_classes + 1), device=gt_classes.device, dtype=gt_classes.dtype + ) + gt_classes = gt_classes[indices].long() + gt_classes[overlap < unmatched_iou] = num_classes # background has no class + cls_target.scatter_(1, gt_classes, 1) + cls_target = cls_target[:, :num_classes] # remove background class from one-hot + + return cls_target, box_target, matches_mask + + anchors = anchors.to(batch_gt_boxes) # change device & type if needed + batch_results = ([], [], []) + for single_gt_boxes in batch_gt_boxes: + single_target_results = _generate_single_targets(single_gt_boxes) + for batch_res, single_res in zip(batch_results, single_target_results): + batch_res.append(single_res) + b_cls_target, b_box_target, b_matches_mask = [torch.stack(targets) for targets in batch_results] + return b_cls_target, b_box_target, b_matches_mask + + +# copied from torchvision +def batched_nms(boxes, scores, idxs, iou_threshold): + # type: (Tensor, Tensor, Tensor, float)->Tensor + """ + Performs non-maximum suppression in a batched fashion. + Each index value correspond to a category, and NMS + will not be applied between elements of different categories. + Parameters + ---------- + boxes : Tensor[N, 4] + boxes where NMS will be performed. They + are expected to be in (x1, y1, x2, y2) format + scores : Tensor[N] + scores for each one of the boxes + idxs : Tensor[N] + indices of the categories for each one of the boxes. + iou_threshold : float + discards all overlapping boxes + with IoU > iou_threshold + Returns + ------- + keep : Tensor + int64 tensor with the indices of + the elements that have been kept by NMS, sorted + in decreasing order of scores + """ + if boxes.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + # strategy: in order to perform NMS independently per class. + # we add an offset to all the boxes. The offset is dependent + # only on the class idx, and is large enough so that boxes + # from different classes do not overlap + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None] + keep = torch.ops.torchvision.nms(boxes_for_nms, scores, iou_threshold) + return keep + + +# jit actually makes it slower for fp16 and results are different! +# FIXME: check it after 1.6 release. maybe they will fix JIT by that time +# @torch.jit.script +def decode( + batch_cls_head, + batch_box_head, + anchors, + img_shapes=None, + img_scales=None, + threshold=0.05, + max_detection_points=5000, + max_detection_per_image=100, + iou_threshold=0.5, +): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, float, int, int, float)->Tensor + """ + Decodes raw outputs of a model for easy visualization of bboxes + + Args: + batch_cls_head (torch.Tensor): shape [BS, *, NUM_CLASSES] + batch_box_head (torch.Tensor): shape [BS, *, 4] + anchors (torch.Tensor): shape [*, 4] + img_shapes (torch.Tensor): if given clips predicted bboxes to img height and width. Shape [BS, 2] or [2,] + img_scales (torch.Tensor): if given used to rescale img_shapes. Shape [BS,] + threshold (float): minimum score threshold to consider object detected + max_detection_points (int): Maximum number of bboxes to consider for NMS for one image + max_detection_per_image (int): Maximum number of bboxes to return per image + iou_threshold (float): iou_threshold for Non Maximum Supression + + Returns: + torch.Tensor with bboxes, scores and classes + shape [BS, MAX_DETECTION_PER_IMAGE, 6]. + bboxes in 'ltrb' format. If img_shape is not given they are NOT CLIPPED (!) + """ + + batch_size = batch_cls_head.size(0) + num_classes = batch_cls_head.size(-1) + + anchors = anchors.to(batch_cls_head).unsqueeze(0).expand(batch_size, -1, -1) # [N, 4] -> [BS, N, 4] + # it has to be raw logits but check anyway to avoid applying sigmoid twice + if batch_cls_head.min() < 0 or batch_cls_head.max() > 1: + batch_cls_head = batch_cls_head.sigmoid() + + # It's much faster to calculate topk once for full batch here rather than doing it inside loop + # In TF The same bbox may belong to two different objects + # select `max_detection_points` scores and corresponding bboxes + scores_topk_all, cls_topk_indices_all = torch.topk( + batch_cls_head.view(batch_size, -1), k=max_detection_points + ) + indices_all = cls_topk_indices_all / num_classes + classes_all = cls_topk_indices_all % num_classes + + # Gather corresponding bounding boxes & anchors, then regress and clip + box_topk_all = torch.gather(batch_box_head, 1, indices_all.unsqueeze(2).expand(-1, -1, 4)) + anchors_topk_all = torch.gather(anchors, 1, indices_all.unsqueeze(2).expand(-1, -1, 4)) + regressed_boxes_all = delta2box(box_topk_all, anchors_topk_all) + if img_shapes is not None: + if img_scales is not None: + img_shapes = img_shapes / img_scales.unsqueeze(1) + regressed_boxes_all = clip_bboxes_batch(regressed_boxes_all, img_shapes) + + # prepare output tensors + out_scores = torch.zeros((batch_size, max_detection_per_image)).to(batch_cls_head) + out_boxes = torch.zeros((batch_size, max_detection_per_image, 4)).to(batch_cls_head) + out_classes = torch.zeros((batch_size, max_detection_per_image)).to(batch_cls_head) + + for batch in range(batch_size): + scores_topk = scores_topk_all[batch] # , cls_topk_indices_all[batch] + classes = classes_all[batch] # cls_topk_indices % num_classes + regressed_boxes = regressed_boxes_all[batch] # delta2box(box_topk, anchor_topk) + + # apply NMS + nms_idx = batched_nms(regressed_boxes, scores_topk, classes, iou_threshold) + nms_idx = nms_idx[: min(len(nms_idx), max_detection_per_image)] + # select suppressed bboxes + im_scores = scores_topk[nms_idx] + im_classes = classes[nms_idx] + im_bboxes = regressed_boxes[nms_idx] + im_classes += 1 # back to class idx with background class = 0 + + out_scores[batch, : im_scores.size(0)] = im_scores + out_classes[batch, : im_classes.size(0)] = im_classes + out_boxes[batch, : im_bboxes.size(0)] = im_bboxes + # no need to pad because it's already padded with 0's + + ## old way ## + # get regressed bboxes + # all_img_bboxes = delta2box(batch_box_head[batch], anchors) + # if img_shape: # maybe clip + # all_img_bboxes = clip_bboxes(all_img_bboxes, img_shape) + # select at most `top_n` bboxes and from them select with score > threshold + # max_cls_score, max_cls_idx = batch_cls_head[batch].max(1) + # top_cls_score, top_cls_idx = max_cls_score.topk(top_n) + # top_cls_idx = top_cls_idx[top_cls_score > threshold] + + # im_scores = max_cls_score[top_cls_idx] + # im_classes = max_cls_idx[top_cls_idx] + # im_bboxes = all_img_bboxes[top_cls_idx] + + # apply NMS + # nms_idx = batched_nms(im_bboxes, im_scores, im_classes, iou_threshold) + # im_scores = im_scores[nms_idx] + # im_classes = im_classes[nms_idx] + # im_bboxes = im_bboxes[nms_idx] + + # out_scores[batch, :im_scores.size(0)] = im_scores + # out_classes[batch, :im_classes.size(0)] = im_classes + # out_boxes[batch, :im_bboxes.size(0)] = im_bboxes + + return torch.cat([out_boxes, out_scores.unsqueeze(-1), out_classes.unsqueeze(-1)], dim=2) diff --git a/pytorch_tools/utils/misc.py b/pytorch_tools/utils/misc.py index 72b5fca..733a22e 100644 --- a/pytorch_tools/utils/misc.py +++ b/pytorch_tools/utils/misc.py @@ -5,21 +5,37 @@ import random import collections import numpy as np +from functools import partial import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist -from functools import partial -def initialize(model): - for m in model.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) +def initialize_fn(m): + """m (nn.Module): module""" + if isinstance(m, nn.Conv2d): + # nn.init.kaiming_uniform_ doesn't take into account groups + # remove when https://github.com/pytorch/pytorch/issues/23854 is resolved + # this is needed for proper init of EffNet models + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="linear") + # No check for BN because in PyTorch it is initialized with 1 & 0 by default + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="linear") + nn.init.constant_(m.bias, 0) + + +def initialize(module): + for m in module.modules(): + initialize_fn(m) + + +def initialize_iterator(module_iterator): + for m in module_iterator: + initialize_fn(m) def set_random_seed(seed): @@ -205,6 +221,7 @@ def make_divisible(v, divisor=8): new_v += divisor return new_v + def repeat_channels(conv_weights, new_channels, old_channels=3): """Repeat channels to match new number of input channels Args: @@ -214,5 +231,5 @@ def repeat_channels(conv_weights, new_channels, old_channels=3): """ rep_times = math.ceil(new_channels / old_channels) new_weights = conv_weights.repeat(1, rep_times, 1, 1)[:, :new_channels, :, :] - new_weights *= old_channels / new_channels # to keep the same output amplitude - return new_weights \ No newline at end of file + new_weights *= old_channels / new_channels # to keep the same output amplitude + return new_weights diff --git a/pytorch_tools/utils/visualization.py b/pytorch_tools/utils/visualization.py index 761554b..592bec3 100644 --- a/pytorch_tools/utils/visualization.py +++ b/pytorch_tools/utils/visualization.py @@ -4,7 +4,8 @@ import numpy as np -def tensor_from_rgb_image(image: np.ndarray) -> torch.Tensor: +def tensor_from_rgb_image(image): + """Args: image (np.array): Input image in HxWxC format""" image = np.moveaxis(image, -1, 0) image = np.ascontiguousarray(image) image = torch.from_numpy(image) diff --git a/requirements.txt b/requirements.txt index 3a84f92..39dc412 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch>=1.1 +torch>=1.4 inplace-abn \ No newline at end of file diff --git a/tests/benchmarking/README.md b/tests/benchmarking/README.md index b5fe0be..37fd0a0 100644 --- a/tests/benchmarking/README.md +++ b/tests/benchmarking/README.md @@ -54,6 +54,8 @@ Mean of 10 runs 10 iters each BS=64: 54.97+-0.01 msecs Forward. 185.83+-5.13 mse **Resnet50 Baseline: 25.56M params** Mean of 10 runs 10 iters each BS=64: 59.65+-0.07 msecs Forward. 164.39+-2.58 msecs Backward. Max memory: 5935.15Mb +**Resnet50 AMP** 25.56M params +Mean of 10 runs 10 iters each BS=256: 111.44+-0.03 msecs Forward. 357.80+-1.61 msecs Backward. Max memory: 11226.14Mb. 545.56 imgs/sec **Resnet34 Leaky ReLU 21.82M params** Mean of 10 runs 10 iters each BS=64: 30.81+-0.18 msecs Forward. 103.95+-1.05 msecs Backward. Max memory: 2766.59Mb diff --git a/tests/benchmarking/memory_test.py b/tests/benchmarking/memory_test.py index d2c3f52..c355f3f 100644 --- a/tests/benchmarking/memory_test.py +++ b/tests/benchmarking/memory_test.py @@ -40,11 +40,13 @@ def test_model(model, forward_only=False): def run_once(): start.record() output = model(INP) + if isinstance(output, tuple) and len(output) > 1: + output = output[0] f_end.record() if forward_only: torch.cuda.synchronize() return start.elapsed_time(f_end), start.elapsed_time(f_end) - loss = criterion(output, TARGET) + loss = output.mean() optimizer.zero_grad() loss.backward() optimizer.step() @@ -70,10 +72,11 @@ def run_once(): f_times = np.array(f_times) fb_times = np.array(fb_times) print( - "Mean of {} runs {} iters each BS={}:\n\t {:.2f}+-{:.2f} msecs Forward. {:.2f}+-{:.2f} msecs Backward. Max memory: {:.2f}Mb. {:.2f} imgs/sec".format( + "Mean of {} runs {} iters each BS={}, SZ={}:\n\t {:.2f}+-{:.2f} msecs Forward. {:.2f}+-{:.2f} msecs Backward. Max memory: {:.2f}Mb. {:.2f} imgs/sec".format( N_RUNS, RUN_ITERS, BS, + SZ, f_times.mean(), f_times.std(), (fb_times - f_times).mean(), @@ -95,7 +98,10 @@ def run_once(): "--amp", action="store_true", help="Measure speed using apex mixed precision" ) parser.add_argument( - "--bs", default=64, type=int, help="BS for classification", + "--bs", default=64, type=int, help="BS for benchmarking", + ) + parser.add_argument( + "--sz", default=224, type=int, help="Size of images for benchmarking", ) args = parser.parse_args() # all models are first init to cpu memory to find errors earlier @@ -134,11 +140,10 @@ def run_once(): print("Initialized models") BS = args.bs + SZ = args.sz N_RUNS = 10 RUN_ITERS = 10 - INP = torch.ones((BS, 3, 224, 224), requires_grad=not args.forward).cuda(0) - TARGET = torch.ones(BS).long().cuda(0) - criterion = torch.nn.CrossEntropyLoss().cuda(0) + INP = torch.ones((BS, 3, SZ, SZ), requires_grad=not args.forward).cuda(0) for name, model in models_dict.items(): print(f"{name} {count_parameters(model) / 1e6:.2f}M params") model = model.cuda(0) @@ -150,13 +155,16 @@ def run_once(): test_model(model, forward_only=args.forward) # now test segmentation models - BS = 16 - INP = torch.ones((BS, 3, 224, 224), requires_grad=True).cuda(0) - TARGET = torch.ones((BS, 1, 224, 224)).cuda(0) - criterion = pt.losses.JaccardLoss().cuda(0) + INP = torch.ones((BS, 3, SZ, SZ), requires_grad=True).cuda(0) for name, model in segm_models_dict.items(): enc_params = count_parameters(model.encoder) / 1e6 total_params = count_parameters(model) / 1e6 print(f"{name}. Encoder {enc_params:.2f}M. Decoder {total_params - enc_params:.2f}M. Total {total_params:.2f}M params") + model = model.cuda(0) + if args.amp: + model = amp.initialize(model, verbosity=0, opt_level="O1") + INP = INP.half() + if args.forward: + model.eval() test_model(model, forward_only=args.forward) diff --git a/tests/detection_models/test_det_models.py b/tests/detection_models/test_det_models.py new file mode 100644 index 0000000..0ab1ff2 --- /dev/null +++ b/tests/detection_models/test_det_models.py @@ -0,0 +1,68 @@ +import torch +import pytest +import numpy as np +from PIL import Image +from pytorch_tools.utils.preprocessing import get_preprocessing_fn +from pytorch_tools.utils.visualization import tensor_from_rgb_image + +import pytorch_tools as pt +import pytorch_tools.detection_models as pt_det + + +# all weights were tested on 05.2020. for now only leave one model for faster tests +MODEL_NAMES = [ + "efficientdet_d0", + "retinanet_r50_fpn", + # "efficientdet_d1", + # "efficientdet_d2", + # "efficientdet_d3", + # "efficientdet_d4", + # "efficientdet_d5", + # "efficientdet_d6", + # "retinanet_r101_fpn", +] + +# format "coco image class: PIL Image" +IMGS = { + 17: Image.open("tests/imgs/dog.jpg"), +} + +INP = torch.ones(1, 3, 512, 512).cuda() + + +@torch.no_grad() +def _test_forward(model): + return model(INP) + + +@pytest.mark.parametrize("arch", MODEL_NAMES) +def test_coco_pretrain(arch): + # want TF same padding for better results + kwargs = {} + if "eff" in arch: + kwargs["match_tf_same_padding"] = True + m = pt_det.__dict__[arch](pretrained="coco", **kwargs).cuda() + m.eval() + # get size of the images used for pretraining + inp_size = m.pretrained_settings["input_size"][-1] + # get preprocessing fn according to pretrained settings + preprocess_fn = get_preprocessing_fn(m.pretrained_settings) + for im_cls, im in IMGS.items(): + im = np.array(im.resize((inp_size, inp_size))) + im_t = tensor_from_rgb_image(preprocess_fn(im)).unsqueeze(0).float().cuda() + boxes_scores_classes = m.predict(im_t) + # check that most confident bbox is close to correct class. The reason for such strange test is + # because in different models class mappings are shifted by +- 1 + assert (boxes_scores_classes[0, 0, 5] - im_cls) < 2 + + +@pytest.mark.parametrize("arch", MODEL_NAMES[:2]) +def test_pretrain_custom_num_classes(arch): + m = pt_det.__dict__[arch](pretrained="coco", num_classes=80).eval().cuda() + _test_forward(m) + + +@pytest.mark.parametrize("arch", MODEL_NAMES[:2]) +def test_encoder_frozenabn(arch): + m = pt_det.__dict__[arch](encoder_norm_layer="frozenabn").eval().cuda() + _test_forward(m) diff --git a/tests/fit_wrapper/test_runner.py b/tests/fit_wrapper/test_runner.py index 6f16881..29232a4 100644 --- a/tests/fit_wrapper/test_runner.py +++ b/tests/fit_wrapper/test_runner.py @@ -100,14 +100,35 @@ def test_val_loader(): runner.fit(TEST_LOADER, epochs=2, steps_per_epoch=100, val_loader=TEST_LOADER, val_steps=200) - def test_grad_clip_loader(): runner = Runner( model=TEST_MODEL, optimizer=TEST_OPTIMIZER, criterion=TEST_CRITERION, metrics=TEST_METRIC, - gradient_clip_val=1.0 + gradient_clip_val=1.0, + ) + runner.fit(TEST_LOADER, epochs=2) + + +def test_accumulate_steps(): + runner = Runner( + model=TEST_MODEL, + optimizer=TEST_OPTIMIZER, + criterion=TEST_CRITERION, + metrics=TEST_METRIC, + accumulate_steps=10, + ) + runner.fit(TEST_LOADER, epochs=2) + + +def test_ModelEma_callback(): + runner = Runner( + model=TEST_MODEL, + optimizer=TEST_OPTIMIZER, + criterion=TEST_CRITERION, + metrics=TEST_METRIC, + callbacks=pt_clb.ModelEma(TEST_MODEL), ) runner.fit(TEST_LOADER, epochs=2) @@ -123,9 +144,7 @@ def test_grad_clip_loader(): pt_clb.Timer(), pt_clb.ReduceLROnPlateau(), pt_clb.CheckpointSaver(TMP_PATH, save_name="model.chpn"), - pt_clb.CheckpointSaver( - TMP_PATH, save_name="model.chpn", monitor=TEST_METRIC.name, mode="max" - ), + pt_clb.CheckpointSaver(TMP_PATH, save_name="model.chpn", monitor=TEST_METRIC.name, mode="max"), pt_clb.TensorBoard(log_dir=TMP_PATH), pt_clb.TensorBoardWithCM(log_dir=TMP_PATH), pt_clb.ConsoleLogger(), @@ -151,9 +170,17 @@ def test_callback(callback): ) def test_segm_callback(callback): runner = Runner( - model=TEST_SEGM_MODEL, - optimizer=TEST_SEGM_OPTIMZER, + model=TEST_SEGM_MODEL, optimizer=TEST_SEGM_OPTIMZER, criterion=TEST_CRITERION, callbacks=callback, + ) + runner.fit(TEST_SEGM_LOADER, epochs=2) + + +def test_invalid_phases_scheduler_mode(): + runner = Runner( + model=TEST_MODEL, + optimizer=TEST_OPTIMIZER, criterion=TEST_CRITERION, - callbacks=callback, + callbacks=pt_clb.PhasesScheduler([{"ep": [0, 1], "lr": [0, 1], "mode": "new_mode"},]), ) - runner.fit(TEST_SEGM_LOADER, epochs=2) \ No newline at end of file + with pytest.raises(ValueError): + runner.fit(TEST_LOADER, epochs=2) diff --git a/tests/imgs/cityscapes_sample.jpg b/tests/imgs/cityscapes_sample.jpg new file mode 100644 index 0000000..9e401d8 Binary files /dev/null and b/tests/imgs/cityscapes_sample.jpg differ diff --git a/tests/imgs/cityscapes_sample2.jpg b/tests/imgs/cityscapes_sample2.jpg new file mode 100644 index 0000000..c6942a4 Binary files /dev/null and b/tests/imgs/cityscapes_sample2.jpg differ diff --git a/tests/models/imgs/dog.jpg b/tests/imgs/dog.jpg similarity index 100% rename from tests/models/imgs/dog.jpg rename to tests/imgs/dog.jpg diff --git a/tests/models/imgs/helmet.jpeg b/tests/imgs/helmet.jpeg similarity index 100% rename from tests/models/imgs/helmet.jpeg rename to tests/imgs/helmet.jpeg diff --git a/tests/losses/test_losses.py b/tests/losses/test_losses.py index 74df11b..2f7a741 100644 --- a/tests/losses/test_losses.py +++ b/tests/losses/test_losses.py @@ -52,9 +52,7 @@ def test_focal_loss_fn_basic(): @pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) def test_focal_loss_fn_reduction(reduction): - torch_ce = F.binary_cross_entropy_with_logits( - INP_BINARY, TARGET_BINARY.float(), reduction=reduction - ) + torch_ce = F.binary_cross_entropy_with_logits(INP_BINARY, TARGET_BINARY.float(), reduction=reduction) my_ce = pt_F.focal_loss_with_logits(INP_BINARY, TARGET_BINARY, alpha=0.5, gamma=0, reduction=reduction) assert torch.allclose(torch_ce, my_ce * 2) @@ -108,6 +106,7 @@ def test_focal_loss(): fl_i = losses.FocalLoss(mode="binary", reduction="sum", ignore_label=-100)(INP_IMG_BINARY, y_true) assert torch.allclose(fl.sum() - loss_diff, fl_i) + @pytest.mark.parametrize( ["y_true", "y_pred", "expected"], [ @@ -333,9 +332,7 @@ def test_binary_cross_entropy(reduction): assert torch.allclose(torch_ce, my_ce) # test for images - torch_ce = F.binary_cross_entropy_with_logits( - INP_IMG_BINARY, TARGET_IMG_BINARY, reduction=reduction - ) + torch_ce = F.binary_cross_entropy_with_logits(INP_IMG_BINARY, TARGET_IMG_BINARY, reduction=reduction) my_ce = my_ce_loss(INP_IMG_BINARY, TARGET_IMG_BINARY) assert torch.allclose(torch_ce, my_ce) @@ -386,3 +383,9 @@ def test_multiclass_multilabel_lovasz(): def test_binary_hinge(): assert losses.BinaryHinge()(INP_IMG_BINARY, TARGET_IMG_BINARY) + + +@pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) +def test_smoothl1(reduction): + loss_my = losses.SmoothL1Loss(delta=1, reduction=reduction)(INP, TARGET_MULTILABEL) + loss_torch = F.smooth_l1_loss(INP, TARGET_MULTILABEL, reduction=reduction) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index b251ba9..0042a85 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -23,7 +23,14 @@ HRNET_NAMES = [name for name in ALL_MODEL_NAMES if "hrnet" in name] # test only part of the models -TEST_MODEL_NAMES = DENSENET_NAMES[:1] + EFFNET_NAMES[:1] + VGG_NAMES[:1] + RESNET_NAMES[:1] + TRESNET_NAMES[:1] + HRNET_NAMES[:1] +TEST_MODEL_NAMES = ( + DENSENET_NAMES[:1] + + EFFNET_NAMES[:1] + + VGG_NAMES[:1] + + RESNET_NAMES[:1] + + TRESNET_NAMES[:1] + + HRNET_NAMES[:1] +) # TEST_MODEL_NAMES = HRNET_NAMES[:1] INP = torch.ones(2, 3, 128, 128) @@ -52,6 +59,7 @@ def test_custom_in_channels(arch): with torch.no_grad(): m(torch.ones(2, 5, 128, 128)) + @pytest.mark.parametrize("arch", EFFNET_NAMES[:2] + RESNET_NAMES[:2]) def test_pretrained_custom_in_channels(arch): m = models.__dict__[arch](in_channels=5, pretrained="imagenet") @@ -82,11 +90,13 @@ def test_dilation(arch, output_stride): W, H = INP.shape[-2:] assert res.shape[-2:] == (W // output_stride, H // output_stride) + @pytest.mark.parametrize("arch", EFFNET_NAMES[:2] + RESNET_NAMES[:2]) def test_drop_connect(arch): m = models.__dict__[arch](drop_connect_rate=0.2) _test_forward(m) + NUM_PARAMS = { "tresnetm": 31389032, "tresnetl": 55989256, @@ -96,8 +106,16 @@ def test_drop_connect(arch): "efficientnet_b2": 9109994, "efficientnet_b3": 12233232, } -@pytest.mark.parametrize('name_num_params', zip(NUM_PARAMS.items())) + + +@pytest.mark.parametrize("name_num_params", zip(NUM_PARAMS.items())) def test_num_parameters(name_num_params): name, num_params = name_num_params[0] m = models.__dict__[name]() - assert pt.utils.misc.count_parameters(m)[0] == num_params \ No newline at end of file + assert pt.utils.misc.count_parameters(m)[0] == num_params + + +@pytest.mark.parametrize("stem_type", ["", "deep", "space2depth"]) +def test_resnet_stem_type(stem_type): + m = models.resnet50(stem_type=stem_type) + _test_forward(m) diff --git a/tests/models/test_weights.py b/tests/models/test_weights.py index 7687dac..c68cd72 100644 --- a/tests/models/test_weights.py +++ b/tests/models/test_weights.py @@ -18,8 +18,8 @@ # tests are made to be run from root project directory # format "imagenet_image_class: PIL Image" IMGS = { - 560: Image.open("tests/models/imgs/helmet.jpeg"), - 207: Image.open("tests/models/imgs/dog.jpg"), + 560: Image.open("tests/imgs/helmet.jpeg"), + 207: Image.open("tests/imgs/dog.jpg"), } # временная заглушка. TODO: убрать @@ -55,6 +55,7 @@ def test_imagenet_pretrain(arch): pred_cls = m(im).argmax() assert pred_cls == im_cls + # test that output mean for fixed input is the same MODEL_NAMES2 = [ "resnet34", @@ -68,6 +69,7 @@ def test_imagenet_pretrain(arch): "efficientnet_b0": 0.0070, } + @pytest.mark.parametrize("arch", MODEL_NAMES2) def test_output_mean(arch): m = models.__dict__[arch](pretrained="imagenet") @@ -75,4 +77,4 @@ def test_output_mean(arch): inp = torch.ones(1, 3, 256, 256) with torch.no_grad(): out = m(inp).mean().numpy() - assert np.allclose(out, MODEL_MEAN[arch], rtol=1e-4, atol=1e-4) \ No newline at end of file + assert np.allclose(out, MODEL_MEAN[arch], rtol=1e-4, atol=1e-4) diff --git a/tests/modules/test_activations.py b/tests/modules/test_activations.py deleted file mode 100644 index 1f2a9a9..0000000 --- a/tests/modules/test_activations.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import pytest -import pytorch_tools.modules as modules - -activations_name = ["Swish", "Swish_Naive", "Mish", "Mish_naive"] - - -@pytest.mark.parametrize("activation", activations_name) -def test_activations_init(activation): - inp = torch.ones(10) - act = modules.activation_from_name(activation) - res = act(inp) - assert res.mean() - -def test_frozen_abn(): - l = modules.bn_from_name("frozen_abn")(10) - assert list(l.parameters()) == [] - l = modules.ABN(10, frozen=True) - assert list(l.parameters()) == [] \ No newline at end of file diff --git a/tests/modules/test_modules.py b/tests/modules/test_modules.py new file mode 100644 index 0000000..1be5837 --- /dev/null +++ b/tests/modules/test_modules.py @@ -0,0 +1,31 @@ +import torch +import pytest +import pytorch_tools as pt +import pytorch_tools.modules as modules + +activations_name = ["Swish", "Swish_Naive", "Mish", "Mish_naive"] + + +@pytest.mark.parametrize("activation", activations_name) +def test_activations_init(activation): + inp = torch.ones(10) + act = modules.activation_from_name(activation) + res = act(inp) + assert res.mean() + + +def test_frozen_abn(): + l = modules.bn_from_name("frozen_abn")(10) + assert list(l.parameters()) == [] + l = modules.ABN(10, frozen=True) + assert list(l.parameters()) == [] + + +# need to test and resnet and vgg because in resnet there are no Convs with bias +# and in VGG there are no Convs without bias +@pytest.mark.parametrize("norm_layer", ["abn", "agn"]) +@pytest.mark.parametrize("arch", ["resnet18", "vgg11_bn"]) +def test_weight_standardization(norm_layer, arch): + m = pt.models.__dict__[arch](norm_layer=norm_layer) + ws_m = modules.weight_standartization.conv_to_ws_conv(m) + out = ws_m(torch.ones(2, 3, 224, 224)) diff --git a/tests/segmentation_models/test_segm_models.py b/tests/segmentation_models/test_segm_models.py index d454815..cb6378d 100644 --- a/tests/segmentation_models/test_segm_models.py +++ b/tests/segmentation_models/test_segm_models.py @@ -7,12 +7,13 @@ INP = torch.ones(2, 3, 64, 64) ENCODERS = ["resnet34", "se_resnet50", "efficientnet_b1", "densenet121"] -SEGM_ARCHS = [pt_sm.Unet, pt_sm.Linknet, pt_sm.DeepLabV3, pt_sm.SegmentationFPN, pt_sm.SegmentationBiFPN] +SEGM_ARCHS = [pt_sm.Unet, pt_sm.Linknet, pt_sm.DeepLabV3, pt_sm.SegmentationFPN] # pt_sm.SegmentationBiFPN # this lines are usefull for quick tests # ENCODERS = ["se_resnet50"] # SEGM_ARCHS = [pt_sm.SegmentationFPN, pt_sm.SegmentationFPN] + def _test_forward(model): with torch.no_grad(): return model(INP) @@ -47,21 +48,24 @@ def test_num_classes(encoder_name, model_class): out = _test_forward(m) assert out.size(1) == 5 + @pytest.mark.parametrize("encoder_name", ENCODERS) @pytest.mark.parametrize("model_class", SEGM_ARCHS) def test_drop_rate(encoder_name, model_class): m = model_class(encoder_name=encoder_name, drop_rate=0.2) _test_forward(m) + @pytest.mark.parametrize("encoder_name", ENCODERS) @pytest.mark.parametrize("model_class", [pt_sm.DeepLabV3]) # pt_sm.Unet, pt_sm.Linknet @pytest.mark.parametrize("output_stride", [32, 16, 8]) def test_dilation(encoder_name, model_class, output_stride): if output_stride == 8 and model_class != pt_sm.DeepLabV3: - return None # OS=8 only supported for Deeplab + return None # OS=8 only supported for Deeplab m = model_class(encoder_name=encoder_name, output_stride=output_stride) _test_forward(m) + @pytest.mark.parametrize("model_class", [pt_sm.DeepLabV3, pt_sm.SegmentationFPN]) def test_deeplab_last_upsample(model_class): m = model_class(last_upsample=True) @@ -74,7 +78,8 @@ def test_deeplab_last_upsample(model_class): # should be 4 times smaller assert tuple(out.shape[-2:]) == (W // 4, H // 4) + @pytest.mark.parametrize("merge_policy", ["add", "cat"]) def test_merge_policy(merge_policy): m = pt_sm.SegmentationFPN(merge_policy=merge_policy) - _test_forward(m) \ No newline at end of file + _test_forward(m) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py new file mode 100644 index 0000000..e3257f4 --- /dev/null +++ b/tests/utils/test_utils.py @@ -0,0 +1,135 @@ +import torch +import pytest +import pytorch_tools as pt + + +def random_boxes(mean_box, stdev, N): + return torch.rand(N, 4) * stdev + torch.tensor(mean_box, dtype=torch.float) + + +# fmt: off +DEVICE_DTYPE = [ + ("cpu", torch.float), + ("cuda", torch.float), + ("cuda", torch.half) +] +# fmt: on +# check that it works for all combinations of dtype and device +@pytest.mark.parametrize("device_dtype", DEVICE_DTYPE) +def test_clip_bboxes(device_dtype): + device, dtype = device_dtype + # fmt: off + bboxes = torch.tensor( + [ + [-5, -10, 50, 100], + [10, 15, 20, 25], + ], + device=device, + dtype=dtype, + ) + expected_bboxes = torch.tensor( + [ + [0, 0, 40, 60], + [10, 15, 20, 25], + ], + device=device, + dtype=dtype, + ) + # fmt: on + size = (60, 40) + # test single bbox clip + res1 = pt.utils.box.clip_bboxes(bboxes, size) + assert torch.allclose(res1, expected_bboxes) + # test single bbox clip passing torch.Size + res2 = pt.utils.box.clip_bboxes(bboxes, torch.Size(size)) + assert torch.allclose(res2, expected_bboxes) + + BS = 4 + batch_bboxes = bboxes.unsqueeze(0).expand(BS, -1, -1) + batch_expected = expected_bboxes.unsqueeze(0).expand(BS, -1, -1) + batch_sizes = torch.tensor(size).repeat(BS, 1) + # test batch clipping + res3 = pt.utils.box.clip_bboxes_batch(batch_bboxes.clone(), batch_sizes) + assert torch.allclose(res3, batch_expected) + + # check that even in batch mode we can pass single size + res4 = pt.utils.box.clip_bboxes_batch(batch_bboxes.clone(), torch.tensor(size)) + assert torch.allclose(res4, batch_expected) + + jit_clip = torch.jit.script(pt.utils.box.clip_bboxes_batch) + # check that function is JIT script friendly + res5 = jit_clip(batch_bboxes.clone(), batch_sizes) + assert torch.allclose(res5, batch_expected) + + +@pytest.mark.parametrize("device_dtype", DEVICE_DTYPE) +def test_delta2box(device_dtype): + device, dtype = device_dtype + # fmt: off + anchors = torch.tensor( + [ + [ 0., 0., 1., 1.], + [ 0., 0., 1., 1.], + [ 0., 0., 1., 1.], + [ 5., 5., 5., 5.] + ], + device=device, + dtype=dtype, + ) + deltas = torch.tensor( + [ + [ 0., 0., 0., 0.], + [ 1., 1., 1., 1.], + [ 0., 0., 2., -1.], + [ 0.7, -1.9, -0.5, 0.3] + ], + device=device, + dtype=dtype, + ) + # by default we don't expect results to be clipped + expected_res = torch.tensor( + [ + [0.0000, 0.0000, 1.0000, 1.0000], + [0.1409, 0.1409, 2.8591, 2.8591], + [-3.1945, 0.3161, 4.1945, 0.6839], + [5.0000, 5.0000, 5.0000, 5.0000], + ], + device=device, + dtype=dtype, + ) + # fmt: on + res1 = pt.utils.box.delta2box(deltas, anchors) + assert torch.allclose(res1, expected_res, atol=3e-4) + + BS = 4 + batch_anchors = anchors.unsqueeze(0).expand(BS, -1, -1) + batch_deltas = deltas.unsqueeze(0).expand(BS, -1, -1) + batch_expected = expected_res.unsqueeze(0).expand(BS, -1, -1) + + # test applying to batch + res2 = pt.utils.box.delta2box(batch_deltas.clone(), batch_anchors) + assert torch.allclose(res2, batch_expected, atol=3e-4) + + # check that function is JIT script friendly + jit_func = torch.jit.script(pt.utils.box.delta2box) + res3 = jit_func(batch_deltas.clone(), batch_anchors) + assert torch.allclose(res3, batch_expected, atol=3e-4) + + +@pytest.mark.parametrize("device_dtype", DEVICE_DTYPE) +def test_box2delta(device_dtype): + ## this test only checks that encoding and decoding gives the same result + device, dtype = device_dtype + boxes = random_boxes([10, 10, 20, 20], 10, 10).to(device).to(dtype) + anchors = random_boxes([10, 10, 20, 20], 10, 10).to(device).to(dtype) + deltas = pt.utils.box.box2delta(boxes, anchors) + boxes_reconstructed = pt.utils.box.delta2box(deltas, anchors) + atol = 2e-2 if dtype == torch.half else 1e-6 # for fp16 sometimes error is large + assert torch.allclose(boxes, boxes_reconstructed, atol=atol) + + # check that it's jit friendly + jit_box2delta = torch.jit.script(pt.utils.box.box2delta) + jit_delta2box = torch.jit.script(pt.utils.box.delta2box) + deltas2 = jit_box2delta(boxes, anchors) + boxes_reconstructed2 = jit_delta2box(deltas2, anchors) + assert torch.allclose(boxes, boxes_reconstructed2, atol=atol)