diff --git a/configs/cityscapes_pointrend_deeplabv3_plus.yaml b/configs/cityscapes_pointrend_deeplabv3_plus.yaml new file mode 100644 index 0000000..a818a6b --- /dev/null +++ b/configs/cityscapes_pointrend_deeplabv3_plus.yaml @@ -0,0 +1,23 @@ +DATASET: + NAME: "cityscape" + MEAN: [0.5, 0.5, 0.5] + STD: [0.5, 0.5, 0.5] +TRAIN: + EPOCHS: 400 + BATCH_SIZE: 2 + CROP_SIZE: 768 +TEST: + BATCH_SIZE: 2 + CROP_SIZE: (1024, 2048) +# TEST_MODEL_PATH: trained_models/deeplabv3_plus_xception_segmentron.pth + +SOLVER: + LR: 0.01 + +MODEL: + MODEL_NAME: "PointRend" + BACKBONE: "xception65" + BN_EPS_FOR_ENCODER: 1e-3 + DEEPLABV3_PLUS: + ENABLE_DECODER: False + diff --git a/segmentron/config/config.py b/segmentron/config/config.py index 0fa905c..a78e3b4 100644 --- a/segmentron/config/config.py +++ b/segmentron/config/config.py @@ -80,11 +80,15 @@ def remove_irrelevant_cfg(self): from ..models.model_zoo import MODEL_REGISTRY model_list = MODEL_REGISTRY.get_list() model_list_lower = [x.lower() for x in model_list] - # print('model_list:', model_list) + assert model_name.lower() in model_list_lower, "Expected model name in {}, but received {}"\ .format(model_list, model_name) pop_keys = [] for key in self.MODEL.keys(): + if key.lower() in model_list_lower: + if model_name.lower() == 'pointrend' and \ + key.lower() == self.MODEL.POINTREND.BASEMODEL.lower(): + continue if key.lower() in model_list_lower and key.lower() != model_name.lower(): pop_keys.append(key) for key in pop_keys: diff --git a/segmentron/config/settings.py b/segmentron/config/settings.py index 4a9030f..46769a9 100644 --- a/segmentron/config/settings.py +++ b/segmentron/config/settings.py @@ -174,6 +174,9 @@ cfg.MODEL.CGNET.STAGE2_BLOCK_NUM = 3 cfg.MODEL.CGNET.STAGE3_BLOCK_NUM = 21 +########################## PointRend config ################################## +cfg.MODEL.POINTREND.BASEMODEL = 'DeepLabV3_Plus' + ########################## hrnet config ###################################### cfg.MODEL.HRNET.PRETRAINED_LAYERS = ['*'] cfg.MODEL.HRNET.STEM_INPLANES = 64 diff --git a/segmentron/data/dataloader/pascal_aug.py b/segmentron/data/dataloader/pascal_aug.py index 69509fe..3e6b68b 100644 --- a/segmentron/data/dataloader/pascal_aug.py +++ b/segmentron/data/dataloader/pascal_aug.py @@ -73,6 +73,9 @@ def __getitem__(self, index): img, target = self._sync_transform(img, target) elif self.mode == 'val': img, target = self._val_sync_transform(img, target) + elif self.mode == 'testval': + logging.warn("Use mode of testval, you should set batch size=1") + img, target = self._img_transform(img), self._mask_transform(target) else: raise RuntimeError('unknown mode for dataloader: {}'.format(self.mode)) # general resize, normalize and toTensor diff --git a/segmentron/models/__init__.py b/segmentron/models/__init__.py index 7c5166d..1361794 100644 --- a/segmentron/models/__init__.py +++ b/segmentron/models/__init__.py @@ -25,3 +25,4 @@ from .espnetv2 import ESPNetV2 from .enet import ENet from .edanet import EDANet +from .pointrend import PointRend diff --git a/segmentron/models/backbones/build.py b/segmentron/models/backbones/build.py index 0d60a40..4723572 100644 --- a/segmentron/models/backbones/build.py +++ b/segmentron/models/backbones/build.py @@ -3,6 +3,7 @@ import logging import torch.utils.model_zoo as model_zoo +from ...utils.download import download from ...utils.registry import Registry from ...config import cfg @@ -42,7 +43,14 @@ def load_backbone_pretrained(model, backbone): return else: logging.info('load backbone pretrained model from url..') - msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False) + try: + msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False) + except Exception as e: + logging.warning(e) + logging.info('Use torch download failed, try custom method!') + + msg = model.load_state_dict(torch.load(download(model_urls[backbone], + path=os.path.join(torch.hub._get_torch_home(), 'checkpoints'))), strict=False) logging.info(msg) diff --git a/segmentron/models/pointrend.py b/segmentron/models/pointrend.py new file mode 100644 index 0000000..f57b5a3 --- /dev/null +++ b/segmentron/models/pointrend.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchvision.models._utils import IntermediateLayerGetter +from .model_zoo import MODEL_REGISTRY +from .segbase import SegBaseModel +from ..config import cfg + + +@MODEL_REGISTRY.register(name='PointRend') +class PointRend(SegBaseModel): + def __init__(self): + super(PointRend, self).__init__(need_backbone=False) + model_name = cfg.MODEL.POINTREND.BASEMODEL + self.backbone = MODEL_REGISTRY.get(model_name)() + + self.head = PointHead(num_classes=self.nclass) + + def forward(self, x): + c1, _, _, c4 = self.backbone.encoder(x) + + out = self.backbone.head(c4, c1) + + result = {'res2': c1, 'coarse': out} + result.update(self.head(x, result["res2"], result["coarse"])) + if not self.training: + return (result['fine'],) + return result + + +class PointHead(nn.Module): + def __init__(self, in_c=275, num_classes=19, k=3, beta=0.75): + super().__init__() + self.mlp = nn.Conv1d(in_c, num_classes, 1) + self.k = k + self.beta = beta + + def forward(self, x, res2, out): + """ + 1. Fine-grained features are interpolated from res2 for DeeplabV3 + 2. During training we sample as many points as there are on a stride 16 feature map of the input + 3. To measure prediction uncertainty + we use the same strategy during training and inference: the difference between the most + confident and second most confident class probabilities. + """ + if not self.training: + return self.inference(x, res2, out) + + points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta) + + coarse = point_sample(out, points, align_corners=False) + fine = point_sample(res2, points, align_corners=False) + + feature_representation = torch.cat([coarse, fine], dim=1) + + rend = self.mlp(feature_representation) + + return {"rend": rend, "points": points} + + @torch.no_grad() + def inference(self, x, res2, out): + """ + During inference, subdivision uses N=8096 + (i.e., the number of points in the stride 16 map of a 1024×2048 image) + """ + num_points = 8096 + + while out.shape[-1] != x.shape[-1]: + out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True) + + points_idx, points = sampling_points(out, num_points, training=self.training) + + coarse = point_sample(out, points, align_corners=False) + fine = point_sample(res2, points, align_corners=False) + + feature_representation = torch.cat([coarse, fine], dim=1) + + rend = self.mlp(feature_representation) + + B, C, H, W = out.shape + points_idx = points_idx.unsqueeze(1).expand(-1, C, -1) + out = (out.reshape(B, C, -1) + .scatter_(2, points_idx, rend) + .view(B, C, H, W)) + + return {"fine": out} + + +def point_sample(input, point_coords, **kwargs): + """ + From Detectron2, point_features.py#19 + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + Args: + input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains + [0, 1] x [0, 1] normalized point coordinates. + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2) + output = F.grid_sample(input, 2.0 * point_coords - 1.0)#, **kwargs) + if add_dim: + output = output.squeeze(3) + return output + + +@torch.no_grad() +def sampling_points(mask, N, k=3, beta=0.75, training=True): + """ + Follows 3.1. Point Selection for Inference and Training + In Train:, `The sampling strategy selects N points on a feature map to train on.` + In Inference, `then selects the N most uncertain points` + Args: + mask(Tensor): [B, C, H, W] + N(int): `During training we sample as many points as there are on a stride 16 feature map of the input` + k(int): Over generation multiplier + beta(float): ratio of importance points + training(bool): flag + Return: + selected_point(Tensor) : flattened indexing points [B, num_points, 2] + """ + assert mask.dim() == 4, "Dim must be N(Batch)CHW" + device = mask.device + B, _, H, W = mask.shape + mask, _ = mask.sort(1, descending=True) + + if not training: + H_step, W_step = 1 / H, 1 / W + N = min(H * W, N) + uncertainty_map = -1 * (mask[:, 0] - mask[:, 1]) + _, idx = uncertainty_map.view(B, -1).topk(N, dim=1) + + points = torch.zeros(B, N, 2, dtype=torch.float, device=device) + points[:, :, 0] = W_step / 2.0 + (idx % W).to(torch.float) * W_step + points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step + return idx, points + + # Official Comment : point_features.py#92 + # It is crucial to calculate uncertanty based on the sampled prediction value for the points. + # Calculating uncertainties of the coarse predictions first and sampling them for points leads + # to worse results. To illustrate the difference: a sampled point between two coarse predictions + # with -1 and 1 logits has 0 logit prediction and therefore 0 uncertainty value, however, if one + # calculates uncertainties for the coarse predictions first (-1 and -1) and sampe it for the + # center point, they will get -1 unceratinty. + + over_generation = torch.rand(B, k * N, 2, device=device) + over_generation_map = point_sample(mask, over_generation, align_corners=False) + + uncertainty_map = -1 * (over_generation_map[:, 0] - over_generation_map[:, 1]) + _, idx = uncertainty_map.topk(int(beta * N), -1) + + shift = (k * N) * torch.arange(B, dtype=torch.long, device=device) + + idx += shift[:, None] + + importance = over_generation.view(-1, 2)[idx.view(-1), :].view(B, int(beta * N), 2) + coverage = torch.rand(B, N - int(beta * N), 2, device=device) + return torch.cat([importance, coverage], 1).to(device) \ No newline at end of file diff --git a/segmentron/solver/loss.py b/segmentron/solver/loss.py index 9da2ee4..c1c5029 100644 --- a/segmentron/solver/loss.py +++ b/segmentron/solver/loss.py @@ -6,6 +6,7 @@ from torch.autograd import Variable from .lovasz_losses import lovasz_softmax +from ..models.pointrend import point_sample from ..data.dataloader import datasets from ..config import cfg @@ -360,6 +361,32 @@ def forward(self, *inputs): return dict(loss=self._aux_forward(*inputs)) +class PointRendLoss(nn.CrossEntropyLoss): + def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs): + super(PointRendLoss, self).__init__(ignore_index=ignore_index) + self.aux = aux + self.aux_weight = aux_weight + self.ignore_index = ignore_index + + def forward(self, *inputs, **kwargs): + result, gt = tuple(inputs) + + pred = F.interpolate(result["coarse"], gt.shape[-2:], mode="bilinear", align_corners=True) + seg_loss = F.cross_entropy(pred, gt, ignore_index=self.ignore_index) + + gt_points = point_sample( + gt.float().unsqueeze(1), + result["points"], + mode="nearest", + align_corners=False + ).squeeze_(1).long() + points_loss = F.cross_entropy(result["rend"], gt_points, ignore_index=self.ignore_index) + + loss = seg_loss + points_loss + + return dict(loss=loss) + + def get_segmentation_loss(model, use_ohem=False, **kwargs): if use_ohem: return MixSoftmaxCrossEntropyOHEMLoss(**kwargs) @@ -373,11 +400,13 @@ def get_segmentation_loss(model, use_ohem=False, **kwargs): logging.info('Use dice loss!') return DiceLoss(**kwargs) - model = model.lower() if model == 'icnet': return ICNetLoss(**kwargs) elif model == 'encnet': return EncNetLoss(**kwargs) + elif model == 'pointrend': + logging.info('Use pointrend loss!') + return PointRendLoss(**kwargs) else: return MixSoftmaxCrossEntropyLoss(**kwargs) diff --git a/segmentron/utils/score.py b/segmentron/utils/score.py index 09cf09d..a1f145b 100644 --- a/segmentron/utils/score.py +++ b/segmentron/utils/score.py @@ -30,10 +30,7 @@ def update(self, preds, labels): """ def reduce_tensor(tensor): - if isinstance(tensor, torch.Tensor): - rt = tensor.clone() - else: - rt = copy.deepcopy(tensor) + rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) return rt diff --git a/tools/eval.py b/tools/eval.py index bff9b50..85d1b14 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -47,9 +47,10 @@ def __init__(self, args): # create network self.model = get_segmentation_model().to(self.device) - if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER: - logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) - self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) + if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \ + cfg.MODEL.BN_EPS_FOR_ENCODER: + logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) + self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) if args.distributed: self.model = nn.parallel.DistributedDataParallel(self.model,