-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathac_agent.py
320 lines (263 loc) · 14.5 KB
/
ac_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import torch
import os
import numpy as np
from spirl.rl.components.agent import BaseAgent
from spirl.utils.general_utils import ParamDict, map_dict, AttrDict
from spirl.utils.pytorch_utils import ten2ar, avg_grad_norm, TensorModule, check_shape, map2torch, map2np
from spirl.rl.utils.mpi import sync_networks
class ACAgent(BaseAgent):
"""Implements actor-critic agent. (does not implement update function, this should be handled by RL algo agent)"""
def __init__(self, config):
BaseAgent.__init__(self, config)
self._hp = self._default_hparams().overwrite(config)
self.policy = self._hp.policy(self._hp.policy_params)
if self.policy.has_trainable_params:
self.policy_opt = self._get_optimizer(self._hp.optimizer, self.policy, self._hp.policy_lr)
def _default_hparams(self):
default_dict = ParamDict({
'policy': None, # policy class
'policy_params': None, # parameters for the policy class
'policy_lr': 3e-4, # learning rate for policy update
})
return super()._default_hparams().overwrite(default_dict)
def _act(self, obs):
# TODO implement non-sampling validation mode
obs = map2torch(self._obs_normalizer(obs), self._hp.device)
if len(obs.shape) == 1: # we need batched inputs for policy
policy_output = self._remove_batch(self.policy(obs[None]))
if 'dist' in policy_output:
del policy_output['dist']
return map2np(policy_output)
return map2np(self.policy(obs))
def _act_rand(self, obs):
policy_output = self.policy.sample_rand(map2torch(obs, self.policy.device))
if 'dist' in policy_output:
del policy_output['dist']
return map2np(policy_output)
def state_dict(self, *args, **kwargs):
d = super().state_dict()
if self.policy.has_trainable_params:
d['policy_opt'] = self.policy_opt.state_dict()
return d
def load_state_dict(self, state_dict, *args, **kwargs):
self.policy_opt.load_state_dict(state_dict.pop('policy_opt'))
super().load_state_dict(state_dict, *args, **kwargs)
def visualize(self, logger, rollout_storage, step):
super().visualize(logger, rollout_storage, step)
self.policy.visualize(logger, rollout_storage, step)
def reset(self):
self.policy.reset()
def sync_networks(self):
if self.policy.has_trainable_params:
sync_networks(self.policy)
def _preprocess_experience(self, experience_batch):
"""Optionally pre-process experience before it is used for policy training."""
return experience_batch
class SACAgent(ACAgent):
"""Implements SAC algorithm."""
def __init__(self, config):
ACAgent.__init__(self, config)
self._hp = self._default_hparams().overwrite(config)
# build critics and target networks, copy weights of critics to target networks
self.critics = torch.nn.ModuleList([self._hp.critic(self._hp.critic_params) for _ in range(2)])
self.critic_targets = torch.nn.ModuleList([self._hp.critic(self._hp.critic_params) for _ in range(2)])
[self._copy_to_target_network(target, source) for target, source in zip(self.critics, self.critic_targets)]
# build optimizers for critics
self.critic_opts = [self._get_optimizer(self._hp.optimizer, critic, self._hp.critic_lr) for critic in self.critics]
# define entropy multiplier alpha
self._log_alpha = TensorModule(torch.zeros(1, requires_grad=True, device=self._hp.device))
self.alpha_opt = self._get_optimizer(self._hp.optimizer, self._log_alpha, self._hp.alpha_lr)
self._target_entropy = self._hp.target_entropy if self._hp.target_entropy is not None \
else -1 * self._hp.policy_params.action_dim
# build replay buffer
self.replay_buffer = self._hp.replay(self._hp.replay_params)
self._update_steps = 0 # counts the number of alpha updates for optional variable schedules
def _default_hparams(self):
default_dict = ParamDict({
'critic': None, # critic class
'critic_params': None, # parameters for the critic class
'replay': None, # replay buffer class
'replay_params': None, # parameters for replay buffer
'critic_lr': 3e-4, # learning rate for critic update
'alpha_lr': 3e-4, # learning rate for alpha coefficient update
'fixed_alpha': None, # optionally fixed value for alpha
'reward_scale': 1.0, # SAC reward scale
'clip_q_target': False, # if True, clips Q target
'target_entropy': None, # target value for automatic entropy tuning, if None uses -action_dim
})
return super()._default_hparams().overwrite(default_dict)
def update(self, experience_batch):
"""Updates actor and critics."""
# push experience batch into replay buffer
self.add_experience(experience_batch)
for _ in range(self._hp.update_iterations):
# sample batch and normalize
experience_batch = self._sample_experience()
experience_batch = self._normalize_batch(experience_batch)
experience_batch = map2torch(experience_batch, self._hp.device)
experience_batch = self._preprocess_experience(experience_batch)
policy_output = self.policy(experience_batch.observation)
# update alpha
alpha_loss = self._update_alpha(experience_batch, policy_output)
# print(f"alpha_loss = {alpha_loss}")
# compute policy loss
policy_loss = self._compute_policy_loss(experience_batch, policy_output)
# compute target Q value
with torch.no_grad():
policy_output_next = self.policy(experience_batch.observation_next)
value_next = self._compute_next_value(experience_batch, policy_output_next)
q_target = experience_batch.reward * self._hp.reward_scale + \
(1 - experience_batch.done) * self._hp.discount_factor * value_next
if self._hp.clip_q_target:
q_target = self._clip_q_target(q_target)
q_target = q_target.detach()
check_shape(q_target, [self._hp.batch_size])
# compute critic loss
critic_losses, qs = self._compute_critic_loss(experience_batch, q_target)
# update policy network on policy loss
self._perform_update(policy_loss, self.policy_opt, self.policy)
# update critic networks
[self._perform_update(critic_loss, critic_opt, critic)
for critic_loss, critic_opt, critic in zip(critic_losses, self.critic_opts, self.critics)]
# update target networks
[self._soft_update_target_network(critic_target, critic)
for critic_target, critic in zip(self.critic_targets, self.critics)]
# logging
info = AttrDict( # losses
policy_loss=policy_loss,
alpha_loss=alpha_loss,
critic_loss_1=critic_losses[0],
critic_loss_2=critic_losses[1],
)
if self._update_steps % 100 == 0:
info.update(AttrDict( # gradient norms
policy_grad_norm=avg_grad_norm(self.policy),
critic_1_grad_norm=avg_grad_norm(self.critics[0]),
critic_2_grad_norm=avg_grad_norm(self.critics[1]),
))
info.update(AttrDict( # misc
alpha=self.alpha,
pi_log_prob=policy_output.log_prob.mean(),
policy_entropy=policy_output.entropy.mean(),
q_target=q_target.mean(),
q_1=qs[0].mean(),
q_2=qs[1].mean(),
))
info.update(self._aux_info(experience_batch, policy_output))
info = map_dict(ten2ar, info)
self._update_steps += 1
return info
def add_experience(self, experience_batch):
"""Adds experience to replay buffer."""
if not experience_batch:
return # pass if experience_batch is empty
self.replay_buffer.append(experience_batch)
self._obs_normalizer.update(experience_batch.observation)
def _sample_experience(self):
return self.replay_buffer.sample(n_samples=self._hp.batch_size)
def _normalize_batch(self, experience_batch):
"""Optionally apply observation normalization."""
experience_batch.observation = self._obs_normalizer(experience_batch.observation)
experience_batch.observation_next = self._obs_normalizer(experience_batch.observation_next)
return experience_batch
def _run_policy(self, obs):
"""Allows child classes to post-process policy outputs."""
return self.policy(obs)
def _update_alpha(self, experience_batch, policy_output):
if self._hp.fixed_alpha is not None:
return 0.
alpha_loss = self._compute_alpha_loss(policy_output)
self._perform_update(alpha_loss, self.alpha_opt, self._log_alpha)
return alpha_loss
def _compute_alpha_loss(self, policy_output):
return -1 * (self.alpha * (self._target_entropy + policy_output.action).detach()).mean()
def _compute_policy_loss(self, experience_batch, policy_output):
# print(f"self.critics = {self.critics}")
q_est = torch.min(*[critic(experience_batch.observation, self._prep_action(policy_output.action)).q
for critic in self.critics])
policy_loss = -1 * q_est + self.alpha * policy_output.log_prob[:, None]
check_shape(policy_loss, [self._hp.batch_size, 1])
return policy_loss.mean()
def _compute_next_value(self, experience_batch, policy_output):
q_next = torch.min(*[critic_target(experience_batch.observation_next, self._prep_action(policy_output.action)).q
for critic_target in self.critic_targets])
next_val = (q_next - self.alpha * policy_output.log_prob[:, None])
check_shape(next_val, [self._hp.batch_size, 1])
return next_val.squeeze(-1)
def _compute_critic_loss(self, experience_batch, q_target):
qs = self._compute_q_estimates(experience_batch)
check_shape(qs[0], [self._hp.batch_size])
critic_losses = [0.5 * (q - q_target).pow(2).mean() for q in qs]
return critic_losses, qs
def _compute_q_estimates(self, experience_batch):
return [critic(experience_batch.observation, self._prep_action(experience_batch.action.detach())).q.squeeze(-1)
for critic in self.critics] # no gradient propagation into policy here!
def _prep_action(self, action):
"""Preprocessing of action in case of discrete action space."""
if len(action.shape) == 1: action = action[:, None] # unsqueeze for single-dim action spaces
return action #.float()
def _clip_q_target(self, q_target):
clip = 1 / (1 - self._hp.discount_factor)
return torch.clamp(q_target, -clip, clip)
def _aux_info(self, experience_batch, policy_output):
"""Optionally add auxiliary info about policy outputs etc."""
return AttrDict()
def sync_networks(self):
super().sync_networks()
[sync_networks(critic) for critic in self.critics]
sync_networks(self._log_alpha)
def state_dict(self, *args, **kwargs):
d = super().state_dict()
d['critic_opts'] = [o.state_dict() for o in self.critic_opts]
d['alpha_opt'] = self.alpha_opt.state_dict()
return d
def load_state_dict(self, state_dict, *args, **kwargs):
[o.load_state_dict(d) for o, d in zip(self.critic_opts, state_dict.pop('critic_opts'))]
self.alpha_opt.load_state_dict(state_dict.pop('alpha_opt'))
super().load_state_dict(state_dict, *args, **kwargs)
def save_state(self, save_dir):
"""Saves compressed replay buffer to disk."""
self.replay_buffer.save(os.path.join(save_dir, 'replay'))
def load_state(self, save_dir):
"""Loads replay buffer from disk."""
self.replay_buffer.load(os.path.join(save_dir, 'replay'))
@property
def alpha(self):
return self._log_alpha().exp()
@property
def schedule_steps(self):
return self._update_steps
class CodebookBasedSACAgent(SACAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.get_codebook = self._hp.codebook
def act(self, obs):
obs = map2torch(self._obs_normalizer(obs), self._hp.device)
if len(obs.shape) == 1:
output = self.policy.net._compute_output_dist(obs[None])
else:
output = self.policy.net._compute_output_dist(obs)
# output = output.softmax(dim = -1)
action_dist = torch.distributions.Categorical(output)
code_idx = action_dist.sample()
one_hot = torch.zeros((output.shape))
one_hot[:, code_idx] = 1
policy_output = self.get_codebook()[code_idx] # choose code
return AttrDict(action=policy_output, idx=code_idx, log_prob=output) # warning: log_prob
def _compute_policy_loss(self, experience_batch, policy_output):
"""Computes loss for policy update."""
q_est = torch.min(*[critic(experience_batch.observation).q.gather(1, self._prep_action(policy_output.idx).type(torch.int64).to("cuda:0"))
for critic in self.critics])
policy_loss = -1 * q_est + self.alpha * policy_output.log_prob.to("cuda:0")
check_shape(policy_loss, [self._hp.batch_size, 1])
return policy_loss.mean()
def _compute_next_value(self, experience_batch, policy_output):
"""Computes value of next state for target value computation."""
q_next = torch.min(*[critic_target(experience_batch.observation_next).q.gather(1,self._prep_action(policy_output.idx).type(torch.int64).to("cuda:0"))
for critic_target in self.critic_targets])
next_val = q_next - self.alpha * policy_output.log_prob.to("cuda:0")
check_shape(next_val, [self._hp.batch_size, 1])
return next_val.squeeze(-1)
def _compute_q_estimates(self, experience_batch):
return [critic(experience_batch.observation).q.squeeze(-1).gather(1,self._prep_action(experience_batch.idx).type(torch.int64).to("cuda:0").detach()).squeeze()
for critic in self.critics] # no gradient propagation into policy here!