-
Notifications
You must be signed in to change notification settings - Fork 45
/
memory.py
88 lines (72 loc) · 2.73 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import horovod.torch as hvd
__all__ = ['Memory', 'DGCSGDMemory']
# code modified from https://github.com/sands-lab/grace/blob/master/grace_dl/torch/memory/dgc.py
class Memory:
@staticmethod
def initialize(*args, **kwargs):
pass
@staticmethod
def compensate(tensor, *args, **kwargs):
return tensor
@staticmethod
def update(*args, **kwargs):
pass
@staticmethod
def state_dict():
return None
@staticmethod
def load_state_dict(state_dict):
pass
class DGCSGDMemory(Memory):
""" Memory for momentum correction in DGC for momentum SGD optimizer"""
def __init__(self, momentum=0.9, nesterov=False,
gradient_clipping=None, momentum_masking=True):
self.gradient_clipping = gradient_clipping
self.momentum_masking = momentum_masking
self.momentum = momentum
self.nesterov = nesterov
self.momentums = {}
self.velocities = {}
def initialize(self, named_parameters):
if hvd.rank() == 0:
print("=> initializing dgc sgd memory")
for name, param in named_parameters:
self.momentums[name] = torch.zeros_like(param.data)
self.velocities[name] = torch.zeros_like(param.data)
def compensate(self, grad, name, accumulate=True):
"""Update the velocities with the momentums."""
if self.gradient_clipping is not None:
grad = self.gradient_clipping(grad)
mmt = self.momentums[name]
if accumulate:
vec = self.velocities[name]
if self.nesterov:
mmt.add_(grad).mul_(self.momentum)
vec.add_(mmt).add_(grad)
else:
mmt.mul_(self.momentum).add_(grad)
vec.add_(mmt)
return vec
else:
if self.nesterov:
mmt.add_(grad).mul_(self.momentum)
return mmt.add(grad)
else:
mmt.mul_(self.momentum).add_(grad)
return mmt.clone() # TODO: save this clone
def update(self, name, ctx):
"""Update the momentums."""
indices = ctx[0]
if self.momentum_masking:
self.momentums[name].view(-1).index_fill_(0, indices, 0)
self.velocities[name].view(-1).index_fill_(0, indices, 0)
def state_dict(self):
return dict(momentums=self.momentums, velocities=self.velocities)
def load_state_dict(self, state_dict):
momentums = state_dict['momentums']
velocities = state_dict['velocities']
for name in self.momentums.keys():
if name in momentums:
self.momentums[name] = momentums[name]
self.velocities[name] = velocities[name]