-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathES.py
140 lines (123 loc) · 4.08 KB
/
ES.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import numpy as np
class CEM:
"""
Cross-entropy methods. Adapted to PyTorch
"""
def __init__(self, num_params,
mu_init=None,
batch_size=256,
sigma_init=1e-3,
clip=0.5,
pop_size=256,
damp=1e-3,
damp_limit=1e-5,
parents=None,
elitism=True,
device=torch.device('cuda')
):
# misc
self.num_params = num_params
self.batch_size = batch_size
self.device = device
# distribution parameters
if mu_init is None:
self.mu = torch.zeros([self.batch_size, self.num_params], device=device)
else:
self.mu = mu_init.clone()
self.sigma = sigma_init
self.damp = damp
self.damp_limit = damp_limit
self.tau = 0.95
self.cov = self.sigma * torch.ones([self.batch_size, self.num_params], device=device)
self.clip = clip
# elite stuff
self.elitism = elitism
self.elite = torch.sqrt(torch.tensor(self.sigma, device=device)) * torch.rand(self.batch_size, self.num_params, device=device)
self.elite_score = None
# sampling stuff
self.pop_size = pop_size
if parents is None or parents <= 0:
self.parents = pop_size // 2
else:
self.parents = parents
self.weights = torch.FloatTensor([np.log((self.parents + 1) / i)
for i in range(1, self.parents + 1)]).to(device)
self.weights /= self.weights.sum()
def ask(self, pop_size):
"""
Returns a list of candidates parameters
"""
epsilon = torch.randn(self.batch_size, pop_size, self.num_params, device=self.device)
inds = self.mu.unsqueeze(1) + (epsilon * torch.sqrt(self.cov).unsqueeze(1)).clamp(-self.clip, self.clip)
if self.elitism:
inds[:, -1] = self.elite
return inds
def tell(self, solutions, scores):
"""
Updates the distribution
returns the best solution
"""
scores = scores.clone().squeeze()
scores *= -1
if len(scores.shape) == 1:
scores = scores[None, :]
_, idx_sorted = torch.sort(scores, dim=1)
old_mu = self.mu.clone()
self.damp = self.damp * self.tau + (1 - self.tau) * self.damp_limit
idx_sorted = idx_sorted[:, :self.parents]
top_solutions = torch.gather(solutions, 1, idx_sorted.unsqueeze(2).expand(*idx_sorted.shape, solutions.shape[-1]))
self.mu = self.weights @ top_solutions
z = top_solutions - old_mu.unsqueeze(1)
self.cov = 1 / self.parents * self.weights @ (
z * z) + self.damp * torch.ones([self.batch_size, self.num_params], device=self.device)
self.elite = top_solutions[:, 0, :]
# self.elite_score = scores[:, idx_sorted[0]]
return top_solutions[:, 0, :]
def get_distrib_params(self):
"""
Returns the parameters of the distrubtion:
the mean and sigma
"""
return self.mu.clone(), self.cov.clone()
class Searcher():
def __init__(self,
action_dim,
max_action,
batch_size=256,
sigma_init=1e-3,
clip=0.5,
pop_size=25,
damp=0.1,
damp_limit=0.05,
parents=5,
device=torch.device('cuda')):
self.sigma_init = sigma_init
self.clip=clip
self.pop_size = pop_size
self.damp = damp
self.damp_limit = damp_limit
self.parents = parents
self.action_dim = action_dim
self.batch_size = batch_size
self.max_action = max_action
self.device = device
def search(self, state, action_init, critic, batch_size=None, n_iter=2, action_bound=True):
if batch_size is None:
batch_size = self.batch_size
cem = CEM(self.action_dim, action_init, batch_size, self.sigma_init, self.clip, self.pop_size, self.damp, self.damp_limit, self.parents, device=self.device)
with torch.no_grad():
for iter in range(n_iter):
actions = cem.ask(self.pop_size)
if action_bound:
actions = actions.clamp(-self.max_action, self.max_action)
actions_temp = actions.clone().view( self.pop_size * batch_size, -1)
Qs = critic(state.unsqueeze(1).repeat(1, self.pop_size, 1).view(self.pop_size * batch_size,-1), actions_temp).view(batch_size,self.pop_size)
best_action = cem.tell(actions, Qs)
if iter == n_iter - 1:
best_Q = critic(state, best_action)
ori_Q = critic(state, action_init)
action_index = (best_Q < ori_Q).squeeze()
best_action[action_index] = action_init[action_index]
# best_Q = torch.max(ori_Q, best_Q)
return best_action