Skip to content

Commit

Permalink
Merge pull request #34 from LikeLy-Journey/PointRend
Browse files Browse the repository at this point in the history
Add PointRend
  • Loading branch information
LikeLy-Journey authored Mar 28, 2020
2 parents 7e6fbd2 + a2108d3 commit ad44e77
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 10 deletions.
23 changes: 23 additions & 0 deletions configs/cityscapes_pointrend_deeplabv3_plus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
DATASET:
NAME: "cityscape"
MEAN: [0.5, 0.5, 0.5]
STD: [0.5, 0.5, 0.5]
TRAIN:
EPOCHS: 400
BATCH_SIZE: 2
CROP_SIZE: 768
TEST:
BATCH_SIZE: 2
CROP_SIZE: (1024, 2048)
# TEST_MODEL_PATH: trained_models/deeplabv3_plus_xception_segmentron.pth

SOLVER:
LR: 0.01

MODEL:
MODEL_NAME: "PointRend"
BACKBONE: "xception65"
BN_EPS_FOR_ENCODER: 1e-3
DEEPLABV3_PLUS:
ENABLE_DECODER: False

6 changes: 5 additions & 1 deletion segmentron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,15 @@ def remove_irrelevant_cfg(self):
from ..models.model_zoo import MODEL_REGISTRY
model_list = MODEL_REGISTRY.get_list()
model_list_lower = [x.lower() for x in model_list]
# print('model_list:', model_list)

assert model_name.lower() in model_list_lower, "Expected model name in {}, but received {}"\
.format(model_list, model_name)
pop_keys = []
for key in self.MODEL.keys():
if key.lower() in model_list_lower:
if model_name.lower() == 'pointrend' and \
key.lower() == self.MODEL.POINTREND.BASEMODEL.lower():
continue
if key.lower() in model_list_lower and key.lower() != model_name.lower():
pop_keys.append(key)
for key in pop_keys:
Expand Down
3 changes: 3 additions & 0 deletions segmentron/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@
cfg.MODEL.CGNET.STAGE2_BLOCK_NUM = 3
cfg.MODEL.CGNET.STAGE3_BLOCK_NUM = 21

########################## PointRend config ##################################
cfg.MODEL.POINTREND.BASEMODEL = 'DeepLabV3_Plus'

########################## hrnet config ######################################
cfg.MODEL.HRNET.PRETRAINED_LAYERS = ['*']
cfg.MODEL.HRNET.STEM_INPLANES = 64
Expand Down
3 changes: 3 additions & 0 deletions segmentron/data/dataloader/pascal_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def __getitem__(self, index):
img, target = self._sync_transform(img, target)
elif self.mode == 'val':
img, target = self._val_sync_transform(img, target)
elif self.mode == 'testval':
logging.warn("Use mode of testval, you should set batch size=1")
img, target = self._img_transform(img), self._mask_transform(target)
else:
raise RuntimeError('unknown mode for dataloader: {}'.format(self.mode))
# general resize, normalize and toTensor
Expand Down
1 change: 1 addition & 0 deletions segmentron/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .espnetv2 import ESPNetV2
from .enet import ENet
from .edanet import EDANet
from .pointrend import PointRend
10 changes: 9 additions & 1 deletion segmentron/models/backbones/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import torch.utils.model_zoo as model_zoo

from ...utils.download import download
from ...utils.registry import Registry
from ...config import cfg

Expand Down Expand Up @@ -42,7 +43,14 @@ def load_backbone_pretrained(model, backbone):
return
else:
logging.info('load backbone pretrained model from url..')
msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False)
try:
msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False)
except Exception as e:
logging.warning(e)
logging.info('Use torch download failed, try custom method!')

msg = model.load_state_dict(torch.load(download(model_urls[backbone],
path=os.path.join(torch.hub._get_torch_home(), 'checkpoints'))), strict=False)
logging.info(msg)


Expand Down
166 changes: 166 additions & 0 deletions segmentron/models/pointrend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models._utils import IntermediateLayerGetter
from .model_zoo import MODEL_REGISTRY
from .segbase import SegBaseModel
from ..config import cfg


@MODEL_REGISTRY.register(name='PointRend')
class PointRend(SegBaseModel):
def __init__(self):
super(PointRend, self).__init__(need_backbone=False)
model_name = cfg.MODEL.POINTREND.BASEMODEL
self.backbone = MODEL_REGISTRY.get(model_name)()

self.head = PointHead(num_classes=self.nclass)

def forward(self, x):
c1, _, _, c4 = self.backbone.encoder(x)

out = self.backbone.head(c4, c1)

result = {'res2': c1, 'coarse': out}
result.update(self.head(x, result["res2"], result["coarse"]))
if not self.training:
return (result['fine'],)
return result


class PointHead(nn.Module):
def __init__(self, in_c=275, num_classes=19, k=3, beta=0.75):
super().__init__()
self.mlp = nn.Conv1d(in_c, num_classes, 1)
self.k = k
self.beta = beta

def forward(self, x, res2, out):
"""
1. Fine-grained features are interpolated from res2 for DeeplabV3
2. During training we sample as many points as there are on a stride 16 feature map of the input
3. To measure prediction uncertainty
we use the same strategy during training and inference: the difference between the most
confident and second most confident class probabilities.
"""
if not self.training:
return self.inference(x, res2, out)

points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)

coarse = point_sample(out, points, align_corners=False)
fine = point_sample(res2, points, align_corners=False)

feature_representation = torch.cat([coarse, fine], dim=1)

rend = self.mlp(feature_representation)

return {"rend": rend, "points": points}

@torch.no_grad()
def inference(self, x, res2, out):
"""
During inference, subdivision uses N=8096
(i.e., the number of points in the stride 16 map of a 1024×2048 image)
"""
num_points = 8096

while out.shape[-1] != x.shape[-1]:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)

points_idx, points = sampling_points(out, num_points, training=self.training)

coarse = point_sample(out, points, align_corners=False)
fine = point_sample(res2, points, align_corners=False)

feature_representation = torch.cat([coarse, fine], dim=1)

rend = self.mlp(feature_representation)

B, C, H, W = out.shape
points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
out = (out.reshape(B, C, -1)
.scatter_(2, points_idx, rend)
.view(B, C, H, W))

return {"fine": out}


def point_sample(input, point_coords, **kwargs):
"""
From Detectron2, point_features.py#19
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
[0, 1] x [0, 1] square.
Args:
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
[0, 1] x [0, 1] normalized point coordinates.
Returns:
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
features for points in `point_coords`. The features are obtained via bilinear
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
"""
add_dim = False
if point_coords.dim() == 3:
add_dim = True
point_coords = point_coords.unsqueeze(2)
output = F.grid_sample(input, 2.0 * point_coords - 1.0)#, **kwargs)
if add_dim:
output = output.squeeze(3)
return output


@torch.no_grad()
def sampling_points(mask, N, k=3, beta=0.75, training=True):
"""
Follows 3.1. Point Selection for Inference and Training
In Train:, `The sampling strategy selects N points on a feature map to train on.`
In Inference, `then selects the N most uncertain points`
Args:
mask(Tensor): [B, C, H, W]
N(int): `During training we sample as many points as there are on a stride 16 feature map of the input`
k(int): Over generation multiplier
beta(float): ratio of importance points
training(bool): flag
Return:
selected_point(Tensor) : flattened indexing points [B, num_points, 2]
"""
assert mask.dim() == 4, "Dim must be N(Batch)CHW"
device = mask.device
B, _, H, W = mask.shape
mask, _ = mask.sort(1, descending=True)

if not training:
H_step, W_step = 1 / H, 1 / W
N = min(H * W, N)
uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])
_, idx = uncertainty_map.view(B, -1).topk(N, dim=1)

points = torch.zeros(B, N, 2, dtype=torch.float, device=device)
points[:, :, 0] = W_step / 2.0 + (idx % W).to(torch.float) * W_step
points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step
return idx, points

# Official Comment : point_features.py#92
# It is crucial to calculate uncertanty based on the sampled prediction value for the points.
# Calculating uncertainties of the coarse predictions first and sampling them for points leads
# to worse results. To illustrate the difference: a sampled point between two coarse predictions
# with -1 and 1 logits has 0 logit prediction and therefore 0 uncertainty value, however, if one
# calculates uncertainties for the coarse predictions first (-1 and -1) and sampe it for the
# center point, they will get -1 unceratinty.

over_generation = torch.rand(B, k * N, 2, device=device)
over_generation_map = point_sample(mask, over_generation, align_corners=False)

uncertainty_map = -1 * (over_generation_map[:, 0] - over_generation_map[:, 1])
_, idx = uncertainty_map.topk(int(beta * N), -1)

shift = (k * N) * torch.arange(B, dtype=torch.long, device=device)

idx += shift[:, None]

importance = over_generation.view(-1, 2)[idx.view(-1), :].view(B, int(beta * N), 2)
coverage = torch.rand(B, N - int(beta * N), 2, device=device)
return torch.cat([importance, coverage], 1).to(device)
31 changes: 30 additions & 1 deletion segmentron/solver/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from torch.autograd import Variable
from .lovasz_losses import lovasz_softmax
from ..models.pointrend import point_sample
from ..data.dataloader import datasets
from ..config import cfg

Expand Down Expand Up @@ -360,6 +361,32 @@ def forward(self, *inputs):
return dict(loss=self._aux_forward(*inputs))


class PointRendLoss(nn.CrossEntropyLoss):
def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
super(PointRendLoss, self).__init__(ignore_index=ignore_index)
self.aux = aux
self.aux_weight = aux_weight
self.ignore_index = ignore_index

def forward(self, *inputs, **kwargs):
result, gt = tuple(inputs)

pred = F.interpolate(result["coarse"], gt.shape[-2:], mode="bilinear", align_corners=True)
seg_loss = F.cross_entropy(pred, gt, ignore_index=self.ignore_index)

gt_points = point_sample(
gt.float().unsqueeze(1),
result["points"],
mode="nearest",
align_corners=False
).squeeze_(1).long()
points_loss = F.cross_entropy(result["rend"], gt_points, ignore_index=self.ignore_index)

loss = seg_loss + points_loss

return dict(loss=loss)


def get_segmentation_loss(model, use_ohem=False, **kwargs):
if use_ohem:
return MixSoftmaxCrossEntropyOHEMLoss(**kwargs)
Expand All @@ -373,11 +400,13 @@ def get_segmentation_loss(model, use_ohem=False, **kwargs):
logging.info('Use dice loss!')
return DiceLoss(**kwargs)


model = model.lower()
if model == 'icnet':
return ICNetLoss(**kwargs)
elif model == 'encnet':
return EncNetLoss(**kwargs)
elif model == 'pointrend':
logging.info('Use pointrend loss!')
return PointRendLoss(**kwargs)
else:
return MixSoftmaxCrossEntropyLoss(**kwargs)
5 changes: 1 addition & 4 deletions segmentron/utils/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ def update(self, preds, labels):
"""

def reduce_tensor(tensor):
if isinstance(tensor, torch.Tensor):
rt = tensor.clone()
else:
rt = copy.deepcopy(tensor)
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
return rt

Expand Down
7 changes: 4 additions & 3 deletions tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def __init__(self, args):
# create network
self.model = get_segmentation_model().to(self.device)

if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER:
logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER))
self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER)
if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
cfg.MODEL.BN_EPS_FOR_ENCODER:
logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER))
self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER)

if args.distributed:
self.model = nn.parallel.DistributedDataParallel(self.model,
Expand Down

0 comments on commit ad44e77

Please sign in to comment.