Skip to content

Commit

Permalink
Merge pull request #46 from LikeLy-Journey/pointrend-v1
Browse files Browse the repository at this point in the history
Refine pointrend
  • Loading branch information
LikeLy-Journey authored May 3, 2020
2 parents ad44e77 + fe779a2 commit 4bc605e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 12 deletions.
9 changes: 4 additions & 5 deletions configs/cityscapes_pointrend_deeplabv3_plus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@ DATASET:
MEAN: [0.5, 0.5, 0.5]
STD: [0.5, 0.5, 0.5]
TRAIN:
EPOCHS: 400
EPOCHS: 200
BATCH_SIZE: 2
CROP_SIZE: 768
CROP_SIZE: 769
TEST:
BATCH_SIZE: 2
CROP_SIZE: (1024, 2048)
CROP_SIZE: (1025, 2049)
# TEST_MODEL_PATH: trained_models/deeplabv3_plus_xception_segmentron.pth

SOLVER:
LR: 0.01
LR: 0.02

MODEL:
MODEL_NAME: "PointRend"
BACKBONE: "xception65"
BN_EPS_FOR_ENCODER: 1e-3
DEEPLABV3_PLUS:
ENABLE_DECODER: False

41 changes: 35 additions & 6 deletions segmentron/models/pointrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ def forward(self, x):
class PointHead(nn.Module):
def __init__(self, in_c=275, num_classes=19, k=3, beta=0.75):
super().__init__()
self.mlp = nn.Conv1d(in_c, num_classes, 1)
self.mlp = nn.Sequential(
nn.Conv1d(in_c, 256, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(True),
nn.Conv1d(256, 256, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(True),
nn.Conv1d(256, 256, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(True),
nn.Conv1d(256, num_classes, 1)
)
self.k = k
self.beta = beta

Expand All @@ -47,7 +55,10 @@ def forward(self, x, res2, out):
if not self.training:
return self.inference(x, res2, out)

points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)
# out = F.interpolate(out, size=x.shape[-2:], mode="bilinear", align_corners=True)
# res2 = F.interpolate(res2, size=out.shape[-2:], mode="bilinear", align_corners=True)
N = x.shape[-1] // 16
points = sampling_points(out, N * N, self.k, self.beta)

coarse = point_sample(out, points, align_corners=False)
fine = point_sample(res2, points, align_corners=False)
Expand All @@ -66,8 +77,9 @@ def inference(self, x, res2, out):
"""
num_points = 8096

while out.shape[-1] != x.shape[-1]:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)
while out.shape[-1] * 2 < x.shape[-1]:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
# res2 = F.interpolate(res2, size=out.shape[-2:], mode="bilinear", align_corners=True)

points_idx, points = sampling_points(out, num_points, training=self.training)

Expand All @@ -83,7 +95,24 @@ def inference(self, x, res2, out):
out = (out.reshape(B, C, -1)
.scatter_(2, points_idx, rend)
.view(B, C, H, W))


out = F.interpolate(out, size=x.shape[-2:], mode="bilinear", align_corners=False)
# res2 = F.interpolate(res2, size=out.shape[-2:], mode="bilinear", align_corners=True)

points_idx, points = sampling_points(out, num_points, training=self.training)

coarse = point_sample(out, points, align_corners=False)
fine = point_sample(res2, points, align_corners=False)

feature_representation = torch.cat([coarse, fine], dim=1)

rend = self.mlp(feature_representation)

B, C, H, W = out.shape
points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
out = (out.reshape(B, C, -1)
.scatter_(2, points_idx, rend)
.view(B, C, H, W))
return {"fine": out}


Expand All @@ -106,7 +135,7 @@ def point_sample(input, point_coords, **kwargs):
if point_coords.dim() == 3:
add_dim = True
point_coords = point_coords.unsqueeze(2)
output = F.grid_sample(input, 2.0 * point_coords - 1.0)#, **kwargs)
output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
Expand Down
1 change: 1 addition & 0 deletions segmentron/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False):
print(line)


@torch.no_grad()
def show_flops_params(model, device, input_shape=[1, 3, 1024, 2048]):
#summary(model, tuple(input_shape[1:]), device=device)
input = torch.randn(*input_shape).to(torch.device(device))
Expand Down
4 changes: 3 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import copy
import datetime
import os
import sys
Expand Down Expand Up @@ -61,10 +62,11 @@ def __init__(self, args):

# create network
self.model = get_segmentation_model().to(self.device)

# print params and flops
if get_rank() == 0:
try:
show_flops_params(self.model, args.device)
show_flops_params(copy.deepcopy(self.model), args.device)
except Exception as e:
logging.warning('get flops and params error: {}'.format(e))

Expand Down

0 comments on commit 4bc605e

Please sign in to comment.