From af212e38d0ca7edd9cf758de7a3937c0c63a9c03 Mon Sep 17 00:00:00 2001 From: LikeLy-Journey Date: Sun, 3 May 2020 11:06:54 +0000 Subject: [PATCH 1/2] refine pointrend --- .../cityscapes_pointrend_deeplabv3_plus.yaml | 9 ++-- segmentron/models/pointrend.py | 41 ++++++++++++++++--- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/configs/cityscapes_pointrend_deeplabv3_plus.yaml b/configs/cityscapes_pointrend_deeplabv3_plus.yaml index a818a6b..102ee8b 100644 --- a/configs/cityscapes_pointrend_deeplabv3_plus.yaml +++ b/configs/cityscapes_pointrend_deeplabv3_plus.yaml @@ -3,16 +3,16 @@ DATASET: MEAN: [0.5, 0.5, 0.5] STD: [0.5, 0.5, 0.5] TRAIN: - EPOCHS: 400 + EPOCHS: 200 BATCH_SIZE: 2 - CROP_SIZE: 768 + CROP_SIZE: 769 TEST: BATCH_SIZE: 2 - CROP_SIZE: (1024, 2048) + CROP_SIZE: (1025, 2049) # TEST_MODEL_PATH: trained_models/deeplabv3_plus_xception_segmentron.pth SOLVER: - LR: 0.01 + LR: 0.02 MODEL: MODEL_NAME: "PointRend" @@ -20,4 +20,3 @@ MODEL: BN_EPS_FOR_ENCODER: 1e-3 DEEPLABV3_PLUS: ENABLE_DECODER: False - diff --git a/segmentron/models/pointrend.py b/segmentron/models/pointrend.py index f57b5a3..195279a 100644 --- a/segmentron/models/pointrend.py +++ b/segmentron/models/pointrend.py @@ -32,7 +32,15 @@ def forward(self, x): class PointHead(nn.Module): def __init__(self, in_c=275, num_classes=19, k=3, beta=0.75): super().__init__() - self.mlp = nn.Conv1d(in_c, num_classes, 1) + self.mlp = nn.Sequential( + nn.Conv1d(in_c, 256, kernel_size=1, stride=1, padding=0, bias=True), + nn.ReLU(True), + nn.Conv1d(256, 256, kernel_size=1, stride=1, padding=0, bias=True), + nn.ReLU(True), + nn.Conv1d(256, 256, kernel_size=1, stride=1, padding=0, bias=True), + nn.ReLU(True), + nn.Conv1d(256, num_classes, 1) + ) self.k = k self.beta = beta @@ -47,7 +55,10 @@ def forward(self, x, res2, out): if not self.training: return self.inference(x, res2, out) - points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta) + # out = F.interpolate(out, size=x.shape[-2:], mode="bilinear", align_corners=True) + # res2 = F.interpolate(res2, size=out.shape[-2:], mode="bilinear", align_corners=True) + N = x.shape[-1] // 16 + points = sampling_points(out, N * N, self.k, self.beta) coarse = point_sample(out, points, align_corners=False) fine = point_sample(res2, points, align_corners=False) @@ -66,8 +77,9 @@ def inference(self, x, res2, out): """ num_points = 8096 - while out.shape[-1] != x.shape[-1]: - out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True) + while out.shape[-1] * 2 < x.shape[-1]: + out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) + # res2 = F.interpolate(res2, size=out.shape[-2:], mode="bilinear", align_corners=True) points_idx, points = sampling_points(out, num_points, training=self.training) @@ -83,7 +95,24 @@ def inference(self, x, res2, out): out = (out.reshape(B, C, -1) .scatter_(2, points_idx, rend) .view(B, C, H, W)) - + + out = F.interpolate(out, size=x.shape[-2:], mode="bilinear", align_corners=False) + # res2 = F.interpolate(res2, size=out.shape[-2:], mode="bilinear", align_corners=True) + + points_idx, points = sampling_points(out, num_points, training=self.training) + + coarse = point_sample(out, points, align_corners=False) + fine = point_sample(res2, points, align_corners=False) + + feature_representation = torch.cat([coarse, fine], dim=1) + + rend = self.mlp(feature_representation) + + B, C, H, W = out.shape + points_idx = points_idx.unsqueeze(1).expand(-1, C, -1) + out = (out.reshape(B, C, -1) + .scatter_(2, points_idx, rend) + .view(B, C, H, W)) return {"fine": out} @@ -106,7 +135,7 @@ def point_sample(input, point_coords, **kwargs): if point_coords.dim() == 3: add_dim = True point_coords = point_coords.unsqueeze(2) - output = F.grid_sample(input, 2.0 * point_coords - 1.0)#, **kwargs) + output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) if add_dim: output = output.squeeze(3) return output From fe779a2d9cabe43d731aa7f5bc5d8726ead98f9c Mon Sep 17 00:00:00 2001 From: LikeLy-Journey Date: Sun, 3 May 2020 11:07:24 +0000 Subject: [PATCH 2/2] fix params count --- segmentron/utils/visualize.py | 1 + tools/train.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/segmentron/utils/visualize.py b/segmentron/utils/visualize.py index 6314368..0f923ac 100644 --- a/segmentron/utils/visualize.py +++ b/segmentron/utils/visualize.py @@ -33,6 +33,7 @@ def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False): print(line) +@torch.no_grad() def show_flops_params(model, device, input_shape=[1, 3, 1024, 2048]): #summary(model, tuple(input_shape[1:]), device=device) input = torch.randn(*input_shape).to(torch.device(device)) diff --git a/tools/train.py b/tools/train.py index ba3e2ff..425b290 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,4 +1,5 @@ import time +import copy import datetime import os import sys @@ -61,10 +62,11 @@ def __init__(self, args): # create network self.model = get_segmentation_model().to(self.device) + # print params and flops if get_rank() == 0: try: - show_flops_params(self.model, args.device) + show_flops_params(copy.deepcopy(self.model), args.device) except Exception as e: logging.warning('get flops and params error: {}'.format(e))