forked from ikostrikov/pytorch-a2c-ppo-acktr-gail
-
Notifications
You must be signed in to change notification settings - Fork 0
/
storage.py
116 lines (101 loc) · 5.47 KB
/
storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
class RolloutStorage(object):
def __init__(self, num_steps, num_processes, obs_shape, action_space, state_size):
self.observations = torch.zeros(num_steps + 1, num_processes, *obs_shape)
self.states = torch.zeros(num_steps + 1, num_processes, state_size)
self.rewards = torch.zeros(num_steps, num_processes, 1)
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
self.returns = torch.zeros(num_steps + 1, num_processes, 1)
self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
if action_space.__class__.__name__ == 'Discrete':
action_shape = 1
else:
action_shape = action_space.shape[0]
self.actions = torch.zeros(num_steps, num_processes, action_shape)
if action_space.__class__.__name__ == 'Discrete':
self.actions = self.actions.long()
self.masks = torch.ones(num_steps + 1, num_processes, 1)
def cuda(self):
self.observations = self.observations.cuda()
self.states = self.states.cuda()
self.rewards = self.rewards.cuda()
self.value_preds = self.value_preds.cuda()
self.returns = self.returns.cuda()
self.action_log_probs = self.action_log_probs.cuda()
self.actions = self.actions.cuda()
self.masks = self.masks.cuda()
def insert(self, step, current_obs, state, action, action_log_prob, value_pred, reward, mask):
self.observations[step + 1].copy_(current_obs)
self.states[step + 1].copy_(state)
self.actions[step].copy_(action)
self.action_log_probs[step].copy_(action_log_prob)
self.value_preds[step].copy_(value_pred)
self.rewards[step].copy_(reward)
self.masks[step + 1].copy_(mask)
def after_update(self):
self.observations[0].copy_(self.observations[-1])
self.states[0].copy_(self.states[-1])
self.masks[0].copy_(self.masks[-1])
def compute_returns(self, next_value, use_gae, gamma, tau):
if use_gae:
self.value_preds[-1] = next_value
gae = 0
for step in reversed(range(self.rewards.size(0))):
delta = self.rewards[step] + gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
gae = delta + gamma * tau * self.masks[step + 1] * gae
self.returns[step] = gae + self.value_preds[step]
else:
self.returns[-1] = next_value
for step in reversed(range(self.rewards.size(0))):
self.returns[step] = self.returns[step + 1] * \
gamma * self.masks[step + 1] + self.rewards[step]
def feed_forward_generator(self, advantages, num_mini_batch):
num_steps, num_processes = self.rewards.size()[0:2]
batch_size = num_processes * num_steps
mini_batch_size = batch_size // num_mini_batch
sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
for indices in sampler:
indices = torch.LongTensor(indices)
if advantages.is_cuda:
indices = indices.cuda()
observations_batch = self.observations[:-1].view(-1,
*self.observations.size()[2:])[indices]
states_batch = self.states[:-1].view(-1, self.states.size(-1))[indices]
actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
adv_targ = advantages.view(-1, 1)[indices]
yield observations_batch, states_batch, actions_batch, \
return_batch, masks_batch, old_action_log_probs_batch, adv_targ
def recurrent_generator(self, advantages, num_mini_batch):
num_processes = self.rewards.size(1)
num_envs_per_batch = num_processes // num_mini_batch
perm = torch.randperm(num_processes)
for start_ind in range(0, num_processes, num_envs_per_batch):
observations_batch = []
states_batch = []
actions_batch = []
return_batch = []
masks_batch = []
old_action_log_probs_batch = []
adv_targ = []
for offset in range(num_envs_per_batch):
ind = perm[start_ind + offset]
observations_batch.append(self.observations[:-1, ind])
states_batch.append(self.states[0:1, ind])
actions_batch.append(self.actions[:, ind])
return_batch.append(self.returns[:-1, ind])
masks_batch.append(self.masks[:-1, ind])
old_action_log_probs_batch.append(self.action_log_probs[:, ind])
adv_targ.append(advantages[:, ind])
observations_batch = torch.cat(observations_batch, 0)
states_batch = torch.cat(states_batch, 0)
actions_batch = torch.cat(actions_batch, 0)
return_batch = torch.cat(return_batch, 0)
masks_batch = torch.cat(masks_batch, 0)
old_action_log_probs_batch = torch.cat(old_action_log_probs_batch, 0)
adv_targ = torch.cat(adv_targ, 0)
yield observations_batch, states_batch, actions_batch, \
return_batch, masks_batch, old_action_log_probs_batch, adv_targ