-
Notifications
You must be signed in to change notification settings - Fork 0
/
scheduler.py
executable file
·71 lines (53 loc) · 2.46 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class CircularLRBeta:
def __init__(
self, optimizer, lr_max, lr_divider, cut_point, step_size, momentum=None
):
self.lr_max = lr_max
self.lr_divider = lr_divider
self.cut_point = step_size // cut_point
self.step_size = step_size
self.iteration = 0
self.cycle_step = int(step_size * (1 - cut_point / 100) / 2)
self.momentum = momentum
self.optimizer = optimizer
def get_lr(self):
if self.iteration > 2 * self.cycle_step:
cut = (self.iteration - 2 * self.cycle_step) / (
self.step_size - 2 * self.cycle_step
)
lr = self.lr_max * (1 + (cut * (1 - 100) / 100)) / self.lr_divider
elif self.iteration > self.cycle_step:
cut = 1 - (self.iteration - self.cycle_step) / self.cycle_step
lr = self.lr_max * (1 + cut * (self.lr_divider - 1)) / self.lr_divider
else:
cut = self.iteration / self.cycle_step
lr = self.lr_max * (1 + cut * (self.lr_divider - 1)) / self.lr_divider
return lr
def get_momentum(self):
if self.iteration > 2 * self.cycle_step:
momentum = self.momentum[0]
elif self.iteration > self.cycle_step:
cut = 1 - (self.iteration - self.cycle_step) / self.cycle_step
momentum = self.momentum[0] + cut * (self.momentum[1] - self.momentum[0])
else:
cut = self.iteration / self.cycle_step
momentum = self.momentum[0] + cut * (self.momentum[1] - self.momentum[0])
return momentum
def step(self):
lr = self.get_lr()
if self.momentum is not None:
momentum = self.get_momentum()
self.iteration += 1
if self.iteration == self.step_size:
self.iteration = 0
'''get_lr()로 learning rate를 업데이트(새로운 learning rate를 계산) 하고 init에서 받은 optimizer의
파라미터 중에 'lr'에 해당하는 값에 업데이트된 lr을 집어넣는다.
'''
for group in self.optimizer.param_groups:
group['lr'] = lr
'''여기서는 파이토치의 라이브러리로 구현되어 있는 CyclicLR과는 다르게 momentum이 None이 아니라면
momentum에 해당하는 파라미터까지 업데이트 해주고 있다.
'''
if self.momentum is not None:
group['betas'] = (momentum, group['betas'][1])
return lr