From c5b97562a9ab0b33a007e70389dbd8793161a4b0 Mon Sep 17 00:00:00 2001 From: LikeLy-Journey <56780417+LikeLy-Journey@users.noreply.github.com> Date: Sat, 28 Mar 2020 12:20:09 +0800 Subject: [PATCH] Revert "Add PointRend" --- .../cityscapes_pointrend_deeplabv3_plus.yaml | 23 --- segmentron/config/config.py | 6 +- segmentron/config/settings.py | 3 - segmentron/data/dataloader/pascal_aug.py | 2 - segmentron/data/dataloader/pascal_voc.py | 2 - segmentron/models/__init__.py | 1 - segmentron/models/backbones/build.py | 10 +- segmentron/models/pointrend.py | 166 ------------------ segmentron/solver/loss.py | 31 +--- segmentron/utils/score.py | 5 +- tools/eval.py | 7 +- 11 files changed, 10 insertions(+), 246 deletions(-) delete mode 100644 configs/cityscapes_pointrend_deeplabv3_plus.yaml delete mode 100644 segmentron/models/pointrend.py diff --git a/configs/cityscapes_pointrend_deeplabv3_plus.yaml b/configs/cityscapes_pointrend_deeplabv3_plus.yaml deleted file mode 100644 index a818a6b..0000000 --- a/configs/cityscapes_pointrend_deeplabv3_plus.yaml +++ /dev/null @@ -1,23 +0,0 @@ -DATASET: - NAME: "cityscape" - MEAN: [0.5, 0.5, 0.5] - STD: [0.5, 0.5, 0.5] -TRAIN: - EPOCHS: 400 - BATCH_SIZE: 2 - CROP_SIZE: 768 -TEST: - BATCH_SIZE: 2 - CROP_SIZE: (1024, 2048) -# TEST_MODEL_PATH: trained_models/deeplabv3_plus_xception_segmentron.pth - -SOLVER: - LR: 0.01 - -MODEL: - MODEL_NAME: "PointRend" - BACKBONE: "xception65" - BN_EPS_FOR_ENCODER: 1e-3 - DEEPLABV3_PLUS: - ENABLE_DECODER: False - diff --git a/segmentron/config/config.py b/segmentron/config/config.py index a78e3b4..0fa905c 100644 --- a/segmentron/config/config.py +++ b/segmentron/config/config.py @@ -80,15 +80,11 @@ def remove_irrelevant_cfg(self): from ..models.model_zoo import MODEL_REGISTRY model_list = MODEL_REGISTRY.get_list() model_list_lower = [x.lower() for x in model_list] - + # print('model_list:', model_list) assert model_name.lower() in model_list_lower, "Expected model name in {}, but received {}"\ .format(model_list, model_name) pop_keys = [] for key in self.MODEL.keys(): - if key.lower() in model_list_lower: - if model_name.lower() == 'pointrend' and \ - key.lower() == self.MODEL.POINTREND.BASEMODEL.lower(): - continue if key.lower() in model_list_lower and key.lower() != model_name.lower(): pop_keys.append(key) for key in pop_keys: diff --git a/segmentron/config/settings.py b/segmentron/config/settings.py index 46769a9..4a9030f 100644 --- a/segmentron/config/settings.py +++ b/segmentron/config/settings.py @@ -174,9 +174,6 @@ cfg.MODEL.CGNET.STAGE2_BLOCK_NUM = 3 cfg.MODEL.CGNET.STAGE3_BLOCK_NUM = 21 -########################## PointRend config ################################## -cfg.MODEL.POINTREND.BASEMODEL = 'DeepLabV3_Plus' - ########################## hrnet config ###################################### cfg.MODEL.HRNET.PRETRAINED_LAYERS = ['*'] cfg.MODEL.HRNET.STEM_INPLANES = 64 diff --git a/segmentron/data/dataloader/pascal_aug.py b/segmentron/data/dataloader/pascal_aug.py index 71bc2e6..69509fe 100644 --- a/segmentron/data/dataloader/pascal_aug.py +++ b/segmentron/data/dataloader/pascal_aug.py @@ -73,8 +73,6 @@ def __getitem__(self, index): img, target = self._sync_transform(img, target) elif self.mode == 'val': img, target = self._val_sync_transform(img, target) - elif self.mode == 'testval': - img, target = self._val_sync_transform(img, target) else: raise RuntimeError('unknown mode for dataloader: {}'.format(self.mode)) # general resize, normalize and toTensor diff --git a/segmentron/data/dataloader/pascal_voc.py b/segmentron/data/dataloader/pascal_voc.py index c26b2e9..c184e4d 100644 --- a/segmentron/data/dataloader/pascal_voc.py +++ b/segmentron/data/dataloader/pascal_voc.py @@ -1,7 +1,6 @@ """Pascal VOC Semantic Segmentation Dataset.""" import os import torch -import logging import numpy as np from PIL import Image @@ -85,7 +84,6 @@ def __getitem__(self, index): img, mask = self._val_sync_transform(img, mask) else: assert self.mode == 'testval' - logging.warn("Use mode of testval, you should set batch size=1") img, mask = self._img_transform(img), self._mask_transform(mask) # general resize, normalize and toTensor if self.transform is not None: diff --git a/segmentron/models/__init__.py b/segmentron/models/__init__.py index 1361794..7c5166d 100644 --- a/segmentron/models/__init__.py +++ b/segmentron/models/__init__.py @@ -25,4 +25,3 @@ from .espnetv2 import ESPNetV2 from .enet import ENet from .edanet import EDANet -from .pointrend import PointRend diff --git a/segmentron/models/backbones/build.py b/segmentron/models/backbones/build.py index 4723572..0d60a40 100644 --- a/segmentron/models/backbones/build.py +++ b/segmentron/models/backbones/build.py @@ -3,7 +3,6 @@ import logging import torch.utils.model_zoo as model_zoo -from ...utils.download import download from ...utils.registry import Registry from ...config import cfg @@ -43,14 +42,7 @@ def load_backbone_pretrained(model, backbone): return else: logging.info('load backbone pretrained model from url..') - try: - msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False) - except Exception as e: - logging.warning(e) - logging.info('Use torch download failed, try custom method!') - - msg = model.load_state_dict(torch.load(download(model_urls[backbone], - path=os.path.join(torch.hub._get_torch_home(), 'checkpoints'))), strict=False) + msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False) logging.info(msg) diff --git a/segmentron/models/pointrend.py b/segmentron/models/pointrend.py deleted file mode 100644 index f57b5a3..0000000 --- a/segmentron/models/pointrend.py +++ /dev/null @@ -1,166 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from torchvision.models._utils import IntermediateLayerGetter -from .model_zoo import MODEL_REGISTRY -from .segbase import SegBaseModel -from ..config import cfg - - -@MODEL_REGISTRY.register(name='PointRend') -class PointRend(SegBaseModel): - def __init__(self): - super(PointRend, self).__init__(need_backbone=False) - model_name = cfg.MODEL.POINTREND.BASEMODEL - self.backbone = MODEL_REGISTRY.get(model_name)() - - self.head = PointHead(num_classes=self.nclass) - - def forward(self, x): - c1, _, _, c4 = self.backbone.encoder(x) - - out = self.backbone.head(c4, c1) - - result = {'res2': c1, 'coarse': out} - result.update(self.head(x, result["res2"], result["coarse"])) - if not self.training: - return (result['fine'],) - return result - - -class PointHead(nn.Module): - def __init__(self, in_c=275, num_classes=19, k=3, beta=0.75): - super().__init__() - self.mlp = nn.Conv1d(in_c, num_classes, 1) - self.k = k - self.beta = beta - - def forward(self, x, res2, out): - """ - 1. Fine-grained features are interpolated from res2 for DeeplabV3 - 2. During training we sample as many points as there are on a stride 16 feature map of the input - 3. To measure prediction uncertainty - we use the same strategy during training and inference: the difference between the most - confident and second most confident class probabilities. - """ - if not self.training: - return self.inference(x, res2, out) - - points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta) - - coarse = point_sample(out, points, align_corners=False) - fine = point_sample(res2, points, align_corners=False) - - feature_representation = torch.cat([coarse, fine], dim=1) - - rend = self.mlp(feature_representation) - - return {"rend": rend, "points": points} - - @torch.no_grad() - def inference(self, x, res2, out): - """ - During inference, subdivision uses N=8096 - (i.e., the number of points in the stride 16 map of a 1024×2048 image) - """ - num_points = 8096 - - while out.shape[-1] != x.shape[-1]: - out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True) - - points_idx, points = sampling_points(out, num_points, training=self.training) - - coarse = point_sample(out, points, align_corners=False) - fine = point_sample(res2, points, align_corners=False) - - feature_representation = torch.cat([coarse, fine], dim=1) - - rend = self.mlp(feature_representation) - - B, C, H, W = out.shape - points_idx = points_idx.unsqueeze(1).expand(-1, C, -1) - out = (out.reshape(B, C, -1) - .scatter_(2, points_idx, rend) - .view(B, C, H, W)) - - return {"fine": out} - - -def point_sample(input, point_coords, **kwargs): - """ - From Detectron2, point_features.py#19 - A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. - Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside - [0, 1] x [0, 1] square. - Args: - input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. - point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains - [0, 1] x [0, 1] normalized point coordinates. - Returns: - output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains - features for points in `point_coords`. The features are obtained via bilinear - interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. - """ - add_dim = False - if point_coords.dim() == 3: - add_dim = True - point_coords = point_coords.unsqueeze(2) - output = F.grid_sample(input, 2.0 * point_coords - 1.0)#, **kwargs) - if add_dim: - output = output.squeeze(3) - return output - - -@torch.no_grad() -def sampling_points(mask, N, k=3, beta=0.75, training=True): - """ - Follows 3.1. Point Selection for Inference and Training - In Train:, `The sampling strategy selects N points on a feature map to train on.` - In Inference, `then selects the N most uncertain points` - Args: - mask(Tensor): [B, C, H, W] - N(int): `During training we sample as many points as there are on a stride 16 feature map of the input` - k(int): Over generation multiplier - beta(float): ratio of importance points - training(bool): flag - Return: - selected_point(Tensor) : flattened indexing points [B, num_points, 2] - """ - assert mask.dim() == 4, "Dim must be N(Batch)CHW" - device = mask.device - B, _, H, W = mask.shape - mask, _ = mask.sort(1, descending=True) - - if not training: - H_step, W_step = 1 / H, 1 / W - N = min(H * W, N) - uncertainty_map = -1 * (mask[:, 0] - mask[:, 1]) - _, idx = uncertainty_map.view(B, -1).topk(N, dim=1) - - points = torch.zeros(B, N, 2, dtype=torch.float, device=device) - points[:, :, 0] = W_step / 2.0 + (idx % W).to(torch.float) * W_step - points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step - return idx, points - - # Official Comment : point_features.py#92 - # It is crucial to calculate uncertanty based on the sampled prediction value for the points. - # Calculating uncertainties of the coarse predictions first and sampling them for points leads - # to worse results. To illustrate the difference: a sampled point between two coarse predictions - # with -1 and 1 logits has 0 logit prediction and therefore 0 uncertainty value, however, if one - # calculates uncertainties for the coarse predictions first (-1 and -1) and sampe it for the - # center point, they will get -1 unceratinty. - - over_generation = torch.rand(B, k * N, 2, device=device) - over_generation_map = point_sample(mask, over_generation, align_corners=False) - - uncertainty_map = -1 * (over_generation_map[:, 0] - over_generation_map[:, 1]) - _, idx = uncertainty_map.topk(int(beta * N), -1) - - shift = (k * N) * torch.arange(B, dtype=torch.long, device=device) - - idx += shift[:, None] - - importance = over_generation.view(-1, 2)[idx.view(-1), :].view(B, int(beta * N), 2) - coverage = torch.rand(B, N - int(beta * N), 2, device=device) - return torch.cat([importance, coverage], 1).to(device) \ No newline at end of file diff --git a/segmentron/solver/loss.py b/segmentron/solver/loss.py index c1c5029..9da2ee4 100644 --- a/segmentron/solver/loss.py +++ b/segmentron/solver/loss.py @@ -6,7 +6,6 @@ from torch.autograd import Variable from .lovasz_losses import lovasz_softmax -from ..models.pointrend import point_sample from ..data.dataloader import datasets from ..config import cfg @@ -361,32 +360,6 @@ def forward(self, *inputs): return dict(loss=self._aux_forward(*inputs)) -class PointRendLoss(nn.CrossEntropyLoss): - def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs): - super(PointRendLoss, self).__init__(ignore_index=ignore_index) - self.aux = aux - self.aux_weight = aux_weight - self.ignore_index = ignore_index - - def forward(self, *inputs, **kwargs): - result, gt = tuple(inputs) - - pred = F.interpolate(result["coarse"], gt.shape[-2:], mode="bilinear", align_corners=True) - seg_loss = F.cross_entropy(pred, gt, ignore_index=self.ignore_index) - - gt_points = point_sample( - gt.float().unsqueeze(1), - result["points"], - mode="nearest", - align_corners=False - ).squeeze_(1).long() - points_loss = F.cross_entropy(result["rend"], gt_points, ignore_index=self.ignore_index) - - loss = seg_loss + points_loss - - return dict(loss=loss) - - def get_segmentation_loss(model, use_ohem=False, **kwargs): if use_ohem: return MixSoftmaxCrossEntropyOHEMLoss(**kwargs) @@ -400,13 +373,11 @@ def get_segmentation_loss(model, use_ohem=False, **kwargs): logging.info('Use dice loss!') return DiceLoss(**kwargs) + model = model.lower() if model == 'icnet': return ICNetLoss(**kwargs) elif model == 'encnet': return EncNetLoss(**kwargs) - elif model == 'pointrend': - logging.info('Use pointrend loss!') - return PointRendLoss(**kwargs) else: return MixSoftmaxCrossEntropyLoss(**kwargs) diff --git a/segmentron/utils/score.py b/segmentron/utils/score.py index a1f145b..09cf09d 100644 --- a/segmentron/utils/score.py +++ b/segmentron/utils/score.py @@ -30,7 +30,10 @@ def update(self, preds, labels): """ def reduce_tensor(tensor): - rt = tensor.clone() + if isinstance(tensor, torch.Tensor): + rt = tensor.clone() + else: + rt = copy.deepcopy(tensor) dist.all_reduce(rt, op=dist.ReduceOp.SUM) return rt diff --git a/tools/eval.py b/tools/eval.py index 85d1b14..bff9b50 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -47,10 +47,9 @@ def __init__(self, args): # create network self.model = get_segmentation_model().to(self.device) - if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \ - cfg.MODEL.BN_EPS_FOR_ENCODER: - logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) - self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) + if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER: + logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) + self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) if args.distributed: self.model = nn.parallel.DistributedDataParallel(self.model,