-
Notifications
You must be signed in to change notification settings - Fork 13
/
exposure_augment.py
60 lines (49 loc) · 1.72 KB
/
exposure_augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import math
import os
import PIL.Image as Image
import numpy as np
import torch
import torchvision.transforms as vtrans
import tqdm
def main(fip, fod):
max_overex_rate = 0.25
steps = 20
num_gen = 4
im = Image.open(fip)
im = vtrans.ToTensor()(im)
im_max = torch.flatten(torch.max(im, dim=0, keepdim=True).values)
mag = 1. / torch.topk(im_max, math.floor(len(im_max) * max_overex_rate + 1)).values
mag = mag[range(0, len(mag), int(len(mag) * (1. / steps)))]
mag_diff = torch.diff(mag, 1)
mag = mag[:-1]
top_mag_diff = torch.topk(mag_diff, num_gen).values
min_gain = top_mag_diff[top_mag_diff > 0][-1]
min_mag = mag[0]
max_mag = mag[mag_diff > min_gain][-1]
fn, ext = os.path.basename(fip).split('.')
bar.set_description(f'{fn}: {min_gain}')
ma = np.arange(1, min_mag - min_gain, min_gain * 2)
if len(ma) > num_gen:
mags = np.append(np.linspace(1, min_mag - min_gain, num_gen),
np.linspace(min_mag, max_mag, num_gen))
elif len(ma) == num_gen:
mags = np.append(ma, np.linspace(min_mag, max_mag, num_gen))
else:
mags = np.linspace(1, max_mag, num_gen * 2)
im = Image.open(fip)
im_raw = vtrans.ToTensor()(im)
for i, mag in enumerate(mags):
im = im_raw * mag
im.clamp_max_(1.)
fop = os.path.join(fod, f'{fn}_{i}.{ext}')
if not os.path.exists(fop):
vtrans.ToPILImage()(im).save(fop)
if __name__ == '__main__':
# one needs to download it online
fid = './data/LOL/train/images'
fod = './data/LOL/train/images_aug'
os.makedirs(fod, exist_ok=True)
bar = tqdm.tqdm(os.listdir(fid))
for fn in bar:
fip = os.path.join(fid, fn)
main(fip, fod)