Skip to content

Commit

Permalink
Revert "Add PointRend"
Browse files Browse the repository at this point in the history
  • Loading branch information
LikeLy-Journey authored Mar 28, 2020
1 parent 9a0b148 commit c5b9756
Show file tree
Hide file tree
Showing 11 changed files with 10 additions and 246 deletions.
23 changes: 0 additions & 23 deletions configs/cityscapes_pointrend_deeplabv3_plus.yaml

This file was deleted.

6 changes: 1 addition & 5 deletions segmentron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,11 @@ def remove_irrelevant_cfg(self):
from ..models.model_zoo import MODEL_REGISTRY
model_list = MODEL_REGISTRY.get_list()
model_list_lower = [x.lower() for x in model_list]

# print('model_list:', model_list)
assert model_name.lower() in model_list_lower, "Expected model name in {}, but received {}"\
.format(model_list, model_name)
pop_keys = []
for key in self.MODEL.keys():
if key.lower() in model_list_lower:
if model_name.lower() == 'pointrend' and \
key.lower() == self.MODEL.POINTREND.BASEMODEL.lower():
continue
if key.lower() in model_list_lower and key.lower() != model_name.lower():
pop_keys.append(key)
for key in pop_keys:
Expand Down
3 changes: 0 additions & 3 deletions segmentron/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@
cfg.MODEL.CGNET.STAGE2_BLOCK_NUM = 3
cfg.MODEL.CGNET.STAGE3_BLOCK_NUM = 21

########################## PointRend config ##################################
cfg.MODEL.POINTREND.BASEMODEL = 'DeepLabV3_Plus'

########################## hrnet config ######################################
cfg.MODEL.HRNET.PRETRAINED_LAYERS = ['*']
cfg.MODEL.HRNET.STEM_INPLANES = 64
Expand Down
2 changes: 0 additions & 2 deletions segmentron/data/dataloader/pascal_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def __getitem__(self, index):
img, target = self._sync_transform(img, target)
elif self.mode == 'val':
img, target = self._val_sync_transform(img, target)
elif self.mode == 'testval':
img, target = self._val_sync_transform(img, target)
else:
raise RuntimeError('unknown mode for dataloader: {}'.format(self.mode))
# general resize, normalize and toTensor
Expand Down
2 changes: 0 additions & 2 deletions segmentron/data/dataloader/pascal_voc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Pascal VOC Semantic Segmentation Dataset."""
import os
import torch
import logging
import numpy as np

from PIL import Image
Expand Down Expand Up @@ -85,7 +84,6 @@ def __getitem__(self, index):
img, mask = self._val_sync_transform(img, mask)
else:
assert self.mode == 'testval'
logging.warn("Use mode of testval, you should set batch size=1")
img, mask = self._img_transform(img), self._mask_transform(mask)
# general resize, normalize and toTensor
if self.transform is not None:
Expand Down
1 change: 0 additions & 1 deletion segmentron/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,3 @@
from .espnetv2 import ESPNetV2
from .enet import ENet
from .edanet import EDANet
from .pointrend import PointRend
10 changes: 1 addition & 9 deletions segmentron/models/backbones/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import torch.utils.model_zoo as model_zoo

from ...utils.download import download
from ...utils.registry import Registry
from ...config import cfg

Expand Down Expand Up @@ -43,14 +42,7 @@ def load_backbone_pretrained(model, backbone):
return
else:
logging.info('load backbone pretrained model from url..')
try:
msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False)
except Exception as e:
logging.warning(e)
logging.info('Use torch download failed, try custom method!')

msg = model.load_state_dict(torch.load(download(model_urls[backbone],
path=os.path.join(torch.hub._get_torch_home(), 'checkpoints'))), strict=False)
msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False)
logging.info(msg)


Expand Down
166 changes: 0 additions & 166 deletions segmentron/models/pointrend.py

This file was deleted.

31 changes: 1 addition & 30 deletions segmentron/solver/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from torch.autograd import Variable
from .lovasz_losses import lovasz_softmax
from ..models.pointrend import point_sample
from ..data.dataloader import datasets
from ..config import cfg

Expand Down Expand Up @@ -361,32 +360,6 @@ def forward(self, *inputs):
return dict(loss=self._aux_forward(*inputs))


class PointRendLoss(nn.CrossEntropyLoss):
def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
super(PointRendLoss, self).__init__(ignore_index=ignore_index)
self.aux = aux
self.aux_weight = aux_weight
self.ignore_index = ignore_index

def forward(self, *inputs, **kwargs):
result, gt = tuple(inputs)

pred = F.interpolate(result["coarse"], gt.shape[-2:], mode="bilinear", align_corners=True)
seg_loss = F.cross_entropy(pred, gt, ignore_index=self.ignore_index)

gt_points = point_sample(
gt.float().unsqueeze(1),
result["points"],
mode="nearest",
align_corners=False
).squeeze_(1).long()
points_loss = F.cross_entropy(result["rend"], gt_points, ignore_index=self.ignore_index)

loss = seg_loss + points_loss

return dict(loss=loss)


def get_segmentation_loss(model, use_ohem=False, **kwargs):
if use_ohem:
return MixSoftmaxCrossEntropyOHEMLoss(**kwargs)
Expand All @@ -400,13 +373,11 @@ def get_segmentation_loss(model, use_ohem=False, **kwargs):
logging.info('Use dice loss!')
return DiceLoss(**kwargs)


model = model.lower()
if model == 'icnet':
return ICNetLoss(**kwargs)
elif model == 'encnet':
return EncNetLoss(**kwargs)
elif model == 'pointrend':
logging.info('Use pointrend loss!')
return PointRendLoss(**kwargs)
else:
return MixSoftmaxCrossEntropyLoss(**kwargs)
5 changes: 4 additions & 1 deletion segmentron/utils/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def update(self, preds, labels):
"""

def reduce_tensor(tensor):
rt = tensor.clone()
if isinstance(tensor, torch.Tensor):
rt = tensor.clone()
else:
rt = copy.deepcopy(tensor)
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
return rt

Expand Down
7 changes: 3 additions & 4 deletions tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,9 @@ def __init__(self, args):
# create network
self.model = get_segmentation_model().to(self.device)

if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
cfg.MODEL.BN_EPS_FOR_ENCODER:
logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER))
self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER)
if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER:
logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER))
self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER)

if args.distributed:
self.model = nn.parallel.DistributedDataParallel(self.model,
Expand Down

0 comments on commit c5b9756

Please sign in to comment.