Skip to content

Commit

Permalink
a bit update of bisenetv2
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed May 29, 2020
1 parent 7409bbd commit 81a8885
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
11 changes: 7 additions & 4 deletions bisenetv2/bisenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, in_chan, out_chan, exp_ratio=6):
in_chan, mid_chan, kernel_size=3, stride=1,
padding=1, groups=in_chan, bias=False),
nn.BatchNorm2d(mid_chan),
nn.ReLU(inplace=True), # not shown in paper
)
self.conv2 = nn.Sequential(
nn.Conv2d(
Expand Down Expand Up @@ -136,6 +137,7 @@ def __init__(self, in_chan, out_chan, exp_ratio=6):
mid_chan, mid_chan, kernel_size=3, stride=1,
padding=1, groups=mid_chan, bias=False),
nn.BatchNorm2d(mid_chan),
nn.ReLU(inplace=True), # not shown in paper
)
self.conv2 = nn.Sequential(
nn.Conv2d(
Expand Down Expand Up @@ -238,6 +240,7 @@ def __init__(self):
128, 128, kernel_size=3, stride=1,
padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True), # not shown in paper
)

def forward(self, x_d, x_s):
Expand Down Expand Up @@ -286,11 +289,11 @@ def __init__(self, n_classes):
self.bga = BGALayer()

## TODO: what is the number of mid chan ?
self.head = SegmentHead(128, 256, n_classes)
self.aux2 = SegmentHead(16, 32, n_classes)
self.aux3 = SegmentHead(32, 64, n_classes)
self.head = SegmentHead(128, 1024, n_classes)
self.aux2 = SegmentHead(16, 128, n_classes)
self.aux3 = SegmentHead(32, 128, n_classes)
self.aux4 = SegmentHead(64, 128, n_classes)
self.aux5_4 = SegmentHead(128, 256, n_classes)
self.aux5_4 = SegmentHead(128, 128, n_classes)

self.init_weights()

Expand Down
10 changes: 5 additions & 5 deletions bisenetv2/cityscapes_cv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def __init__(self):
self.trans_func = T.Compose([
T.RandomResizedCrop([0.375, 1.], [512, 1024]),
T.RandomHorizontalFlip(),
# T.ColorJitter(
# brightness=0.4,
# contrast=0.4,
# saturation=0.4
# ),
T.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4
),
])

def __call__(self, im_lb):
Expand Down
4 changes: 2 additions & 2 deletions bisenetv2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
# torch.backends.cudnn.benchmark = True
# torch.multiprocessing.set_sharing_strategy('file_system')

lr_start = 2.5e-2
lr_start = 5e-2
warmup_iters = 1000
max_iter = 300000 + warmup_iters
max_iter = 150000 + warmup_iters
ims_per_gpu = 8


Expand Down

0 comments on commit 81a8885

Please sign in to comment.