From 5e264c608cfa7f9826ba247ed6aa661f2cec1831 Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU Date: Sun, 9 Aug 2020 23:49:23 +0800 Subject: [PATCH] Generalized OHEM (#54) * Generalized OHEM * remove config * update docstring * fixed sort prob * fixed valid_mask --- docs/getting_started.md | 23 ++------ docs/tutorials/training_tricks.md | 2 +- mmseg/core/seg/sampler/ohem_pixel_sampler.py | 60 ++++++++++++-------- mmseg/models/decode_heads/decode_head.py | 2 +- tests/test_sampler.py | 21 ++++++- 5 files changed, 62 insertions(+), 46 deletions(-) diff --git a/docs/getting_started.md b/docs/getting_started.md index 3098ea1..9140435 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -271,36 +271,23 @@ Usually it is slow if you do not have high speed networking like InfiniBand. ### Launch multiple jobs on a single machine If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs, -you need to specify different ports (29500 by default) for each job to avoid communication conflict. +you need to specify different ports (29500 by default) for each job to avoid communication conflict. Otherwise, there will be error message saying `RuntimeError: Address already in use`. -If you use `dist_train.sh` to launch training jobs, you can set the port in commands. +If you use `dist_train.sh` to launch training jobs, you can set the port in commands with environment variable `PORT`. ```shell CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4 CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4 ``` -If you use launch training jobs with Slurm, you need to modify the config files (usually the 6th line from the bottom in config files) to set different communication ports. +If you use `slurm_train.sh` to launch training jobs, you can set the port in commands with environment variable `MASTER_PORT`. -In `config1.py`, -```python -dist_params = dict(backend='nccl', port=29500) -``` - -In `config2.py`, -```python -dist_params = dict(backend='nccl', port=29501) -``` - -Then you can launch two jobs with `config1.py` ang `config2.py`. ```shell -CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR} -CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR} +MASTER_PORT=29500 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} +MASTER_PORT=29501 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ``` -Or you could specify port by `---options dist_params.port=29501` - ## Useful tools We provide lots of useful tools under `tools/` directory. diff --git a/docs/tutorials/training_tricks.md b/docs/tutorials/training_tricks.md index 2a56daf..11b3480 100644 --- a/docs/tutorials/training_tricks.md +++ b/docs/tutorials/training_tricks.md @@ -25,7 +25,7 @@ model=dict( decode_head=dict( sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) ) ``` -In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training. +In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training. If `thresh` is not specified, pixels of top ``min_kept`` loss will be selected. ## Class Balanced Loss For dataset that is not balanced in classes distribution, you may change the loss weight of each class. diff --git a/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/mmseg/core/seg/sampler/ohem_pixel_sampler.py index 28c14ab..88bb10d 100644 --- a/mmseg/core/seg/sampler/ohem_pixel_sampler.py +++ b/mmseg/core/seg/sampler/ohem_pixel_sampler.py @@ -10,22 +10,25 @@ class OHEMPixelSampler(BasePixelSampler): """Online Hard Example Mining Sampler for segmentation. Args: - thresh (float): The threshold for hard example selection. Below - which, are prediction with low confidence. Default: 0.7. - min_kept (int): The minimum number of predictions to keep. + context (nn.Module): The context of sampler, subclass of + :obj:`BaseDecodeHead`. + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: None. + min_kept (int, optional): The minimum number of predictions to keep. Default: 100000. - ignore_index (int): The ignore index for training. Default: 255. """ - def __init__(self, thresh=0.7, min_kept=100000, ignore_index=255): + def __init__(self, context, thresh=None, min_kept=100000): super(OHEMPixelSampler, self).__init__() + self.context = context assert min_kept > 1 self.thresh = thresh self.min_kept = min_kept - self.ignore_index = ignore_index def sample(self, seg_logit, seg_label): - """ + """Sample pixels that have high loss or with low prediction confidence. Args: seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) @@ -33,32 +36,41 @@ def sample(self, seg_logit, seg_label): Returns: torch.Tensor: segmentation weight, shape (N, H, W) - """ with torch.no_grad(): assert seg_logit.shape[2:] == seg_label.shape[2:] assert seg_label.shape[1] == 1 seg_label = seg_label.squeeze(1).long() batch_kept = self.min_kept * seg_label.size(0) - seg_prob = F.softmax(seg_logit, dim=1) - mask = seg_label.contiguous().view(-1, ) != self.ignore_index + valid_mask = seg_label != self.context.ignore_index + seg_weight = seg_logit.new_zeros(size=seg_label.size()) + valid_seg_weight = seg_weight[valid_mask] + if self.thresh is not None: + seg_prob = F.softmax(seg_logit, dim=1) - tmp_seg_label = seg_label.clone() - tmp_seg_label[tmp_seg_label == self.ignore_index] = 0 - seg_prob = seg_prob.gather(1, tmp_seg_label.unsqueeze(1)) - sort_prob, sort_indices = seg_prob.contiguous().view( - -1, )[mask].contiguous().sort() + tmp_seg_label = seg_label.clone().unsqueeze(1) + tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 + seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) + sort_prob, sort_indices = seg_prob[valid_mask].sort() - if sort_prob.numel() > 0: - min_threshold = sort_prob[min(batch_kept, - sort_prob.numel() - 1)] + if sort_prob.numel() > 0: + min_threshold = sort_prob[min(batch_kept, + sort_prob.numel() - 1)] + else: + min_threshold = 0.0 + threshold = max(min_threshold, self.thresh) + valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. else: - min_threshold = 0.0 - threshold = max(min_threshold, self.thresh) + losses = self.context.loss_decode( + seg_logit, + seg_label, + weight=None, + ignore_index=self.context.ignore_index, + reduction_override='none') + # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa + _, sort_indices = losses[valid_mask].sort(descending=True) + valid_seg_weight[sort_indices[:batch_kept]] = 1. - seg_weight = seg_logit.new_ones(size=seg_label.size()) - seg_weight = seg_weight.view(-1) - seg_weight[mask][sort_prob < threshold] = 0. - seg_weight = seg_weight.view_as(seg_label) + seg_weight[valid_mask] = valid_seg_weight return seg_weight diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 9f55fee..0f58c80 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -73,7 +73,7 @@ def __init__(self, self.ignore_index = ignore_index self.align_corners = align_corners if sampler is not None: - self.sampler = build_pixel_sampler(sampler) + self.sampler = build_pixel_sampler(sampler, context=self) else: self.sampler = None diff --git a/tests/test_sampler.py b/tests/test_sampler.py index af26b8d..3c79c16 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -2,20 +2,37 @@ import torch from mmseg.core import OHEMPixelSampler +from mmseg.models.decode_heads import FCNHead + + +def _context_for_ohem(): + return FCNHead(in_channels=32, channels=16, num_classes=19) def test_ohem_sampler(): with pytest.raises(AssertionError): # seg_logit and seg_label must be of the same size - sampler = OHEMPixelSampler() + sampler = OHEMPixelSampler(context=_context_for_ohem()) seg_logit = torch.randn(1, 19, 45, 45) seg_label = torch.randint(0, 19, size=(1, 1, 89, 89)) sampler.sample(seg_logit, seg_label) - sampler = OHEMPixelSampler() + # test with thresh + sampler = OHEMPixelSampler( + context=_context_for_ohem(), thresh=0.7, min_kept=200) + seg_logit = torch.randn(1, 19, 45, 45) + seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) + seg_weight = sampler.sample(seg_logit, seg_label) + assert seg_weight.shape[0] == seg_logit.shape[0] + assert seg_weight.shape[1:] == seg_logit.shape[2:] + assert seg_weight.sum() > 200 + + # test w.o thresh + sampler = OHEMPixelSampler(context=_context_for_ohem(), min_kept=200) seg_logit = torch.randn(1, 19, 45, 45) seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) seg_weight = sampler.sample(seg_logit, seg_label) assert seg_weight.shape[0] == seg_logit.shape[0] assert seg_weight.shape[1:] == seg_logit.shape[2:] + assert seg_weight.sum() == 200