Skip to content

Commit

Permalink
Generalized OHEM (#54)
Browse files Browse the repository at this point in the history
* Generalized OHEM

* remove config

* update docstring

* fixed sort prob

* fixed valid_mask
  • Loading branch information
xvjiarui authored Aug 9, 2020
1 parent 00f56eb commit 5e264c6
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 46 deletions.
23 changes: 5 additions & 18 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,36 +271,23 @@ Usually it is slow if you do not have high speed networking like InfiniBand.
### Launch multiple jobs on a single machine

If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs,
you need to specify different ports (29500 by default) for each job to avoid communication conflict.
you need to specify different ports (29500 by default) for each job to avoid communication conflict. Otherwise, there will be error message saying `RuntimeError: Address already in use`.
If you use `dist_train.sh` to launch training jobs, you can set the port in commands.
If you use `dist_train.sh` to launch training jobs, you can set the port in commands with environment variable `PORT`.
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
```
If you use launch training jobs with Slurm, you need to modify the config files (usually the 6th line from the bottom in config files) to set different communication ports.
If you use `slurm_train.sh` to launch training jobs, you can set the port in commands with environment variable `MASTER_PORT`.
In `config1.py`,
```python
dist_params = dict(backend='nccl', port=29500)
```

In `config2.py`,
```python
dist_params = dict(backend='nccl', port=29501)
```

Then you can launch two jobs with `config1.py` ang `config2.py`.
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR}
CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR}
MASTER_PORT=29500 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE}
MASTER_PORT=29501 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE}
```
Or you could specify port by `---options dist_params.port=29501`

## Useful tools
We provide lots of useful tools under `tools/` directory.
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ model=dict(
decode_head=dict(
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )
```
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training.
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training. If `thresh` is not specified, pixels of top ``min_kept`` loss will be selected.

## Class Balanced Loss
For dataset that is not balanced in classes distribution, you may change the loss weight of each class.
Expand Down
60 changes: 36 additions & 24 deletions mmseg/core/seg/sampler/ohem_pixel_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,55 +10,67 @@ class OHEMPixelSampler(BasePixelSampler):
"""Online Hard Example Mining Sampler for segmentation.
Args:
thresh (float): The threshold for hard example selection. Below
which, are prediction with low confidence. Default: 0.7.
min_kept (int): The minimum number of predictions to keep.
context (nn.Module): The context of sampler, subclass of
:obj:`BaseDecodeHead`.
thresh (float, optional): The threshold for hard example selection.
Below which, are prediction with low confidence. If not
specified, the hard examples will be pixels of top ``min_kept``
loss. Default: None.
min_kept (int, optional): The minimum number of predictions to keep.
Default: 100000.
ignore_index (int): The ignore index for training. Default: 255.
"""

def __init__(self, thresh=0.7, min_kept=100000, ignore_index=255):
def __init__(self, context, thresh=None, min_kept=100000):
super(OHEMPixelSampler, self).__init__()
self.context = context
assert min_kept > 1
self.thresh = thresh
self.min_kept = min_kept
self.ignore_index = ignore_index

def sample(self, seg_logit, seg_label):
"""
"""Sample pixels that have high loss or with low prediction confidence.
Args:
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
Returns:
torch.Tensor: segmentation weight, shape (N, H, W)
"""
with torch.no_grad():
assert seg_logit.shape[2:] == seg_label.shape[2:]
assert seg_label.shape[1] == 1
seg_label = seg_label.squeeze(1).long()
batch_kept = self.min_kept * seg_label.size(0)
seg_prob = F.softmax(seg_logit, dim=1)
mask = seg_label.contiguous().view(-1, ) != self.ignore_index
valid_mask = seg_label != self.context.ignore_index
seg_weight = seg_logit.new_zeros(size=seg_label.size())
valid_seg_weight = seg_weight[valid_mask]
if self.thresh is not None:
seg_prob = F.softmax(seg_logit, dim=1)

tmp_seg_label = seg_label.clone()
tmp_seg_label[tmp_seg_label == self.ignore_index] = 0
seg_prob = seg_prob.gather(1, tmp_seg_label.unsqueeze(1))
sort_prob, sort_indices = seg_prob.contiguous().view(
-1, )[mask].contiguous().sort()
tmp_seg_label = seg_label.clone().unsqueeze(1)
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
sort_prob, sort_indices = seg_prob[valid_mask].sort()

if sort_prob.numel() > 0:
min_threshold = sort_prob[min(batch_kept,
sort_prob.numel() - 1)]
if sort_prob.numel() > 0:
min_threshold = sort_prob[min(batch_kept,
sort_prob.numel() - 1)]
else:
min_threshold = 0.0
threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else:
min_threshold = 0.0
threshold = max(min_threshold, self.thresh)
losses = self.context.loss_decode(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1.

seg_weight = seg_logit.new_ones(size=seg_label.size())
seg_weight = seg_weight.view(-1)
seg_weight[mask][sort_prob < threshold] = 0.
seg_weight = seg_weight.view_as(seg_label)
seg_weight[valid_mask] = valid_seg_weight

return seg_weight
2 changes: 1 addition & 1 deletion mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self,
self.ignore_index = ignore_index
self.align_corners = align_corners
if sampler is not None:
self.sampler = build_pixel_sampler(sampler)
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None

Expand Down
21 changes: 19 additions & 2 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,37 @@
import torch

from mmseg.core import OHEMPixelSampler
from mmseg.models.decode_heads import FCNHead


def _context_for_ohem():
return FCNHead(in_channels=32, channels=16, num_classes=19)


def test_ohem_sampler():

with pytest.raises(AssertionError):
# seg_logit and seg_label must be of the same size
sampler = OHEMPixelSampler()
sampler = OHEMPixelSampler(context=_context_for_ohem())
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
sampler.sample(seg_logit, seg_label)

sampler = OHEMPixelSampler()
# test with thresh
sampler = OHEMPixelSampler(
context=_context_for_ohem(), thresh=0.7, min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() > 200

# test w.o thresh
sampler = OHEMPixelSampler(context=_context_for_ohem(), min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() == 200

0 comments on commit 5e264c6

Please sign in to comment.