This repository has been archived by the owner on Nov 6, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
gridmask.py
103 lines (87 loc) · 4.08 KB
/
gridmask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
## Grid Mask
# original code taken from https://www.kaggle.com/haqishen/gridmask
# albumentations
import albumentations
from albumentations.core.transforms_interface import DualTransform
from albumentations.augmentations import functional as F
import numpy as np
class GridMask(DualTransform):
"""GridMask augmentation for image classification and object detection.
Args:
num_grid (int): number of grid in a row or column.
fill_value (int, float, lisf of int, list of float): value for dropped pixels.
rotate ((int, int) or int): range from which a random angle is picked. If rotate is a single int
an angle is picked from (-rotate, rotate). Default: (-90, 90)
mode (int):
0 - cropout a quarter of the square of each grid (left top)
1 - reserve a quarter of the square of each grid (left top)
2 - cropout 2 quarter of the square of each grid (left top & right bottom)
Targets:
image, mask
Image types:
uint8, float32
Reference:
| https://arxiv.org/abs/2001.04086
| https://github.com/akuxcw/GridMask
"""
def __init__(self, num_grid=3, fill_value=0, rotate=90, mode=0, always_apply=False, p=0.5):
super(GridMask, self).__init__(always_apply, p)
if isinstance(num_grid, int):
num_grid = (num_grid, num_grid)
if isinstance(rotate, int):
rotate = (-rotate, rotate)
self.num_grid = num_grid
self.fill_value = fill_value
self.rotate = rotate
self.mode = mode
self.masks = None
self.rand_h_max = []
self.rand_w_max = []
def init_masks(self, height, width):
if self.masks is None:
self.masks = []
n_masks = self.num_grid[1] - self.num_grid[0] + 1
for n, n_g in enumerate(range(self.num_grid[0], self.num_grid[1] + 1, 1)):
grid_h = height / n_g
grid_w = width / n_g
this_mask = np.ones((int((n_g + 1) * grid_h), int((n_g + 1) * grid_w))).astype(np.uint8)
for i in range(n_g + 1):
for j in range(n_g + 1):
this_mask[
int(i * grid_h) : int(i * grid_h + grid_h / 2),
int(j * grid_w) : int(j * grid_w + grid_w / 2)
] = self.fill_value
if self.mode == 2:
this_mask[
int(i * grid_h + grid_h / 2) : int(i * grid_h + grid_h),
int(j * grid_w + grid_w / 2) : int(j * grid_w + grid_w)
] = self.fill_value
if self.mode == 1:
this_mask = 1 - this_mask
self.masks.append(this_mask)
self.rand_h_max.append(grid_h)
self.rand_w_max.append(grid_w)
def apply(self, image, mask, rand_h, rand_w, angle, **params):
image_masked = image.copy()
h, w = image_masked.shape[:2]
mask = F.rotate(mask, angle) if self.rotate[1] > 0 else mask
mask = mask[:,:,np.newaxis] if image_masked.ndim == 3 else mask
image_masked *= mask[rand_h:rand_h+h, rand_w:rand_w+w].astype(image_masked.dtype)
return image_masked
def get_params_dependent_on_targets(self, params):
img = params['image']
height, width = img.shape[:2]
self.init_masks(height, width)
mid = np.random.randint(len(self.masks))
mask = self.masks[mid]
rand_h = np.random.randint(self.rand_h_max[mid])
rand_w = np.random.randint(self.rand_w_max[mid])
angle = np.random.randint(self.rotate[0], self.rotate[1]) if self.rotate[1] > 0 else 0
return {'mask': mask, 'rand_h': rand_h, 'rand_w': rand_w, 'angle': angle}
@property
def targets_as_params(self):
return ['image']
def get_transform_init_args_names(self):
return ('num_grid', 'fill_value', 'rotate', 'mode')
# grid mask augmentation
# grid_mask = albumentations.Compose([GridMask(num_grid=(3,6), mode=0, p=1.0)])