-
Notifications
You must be signed in to change notification settings - Fork 2
/
memory.py
98 lines (80 loc) · 3.53 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from collections import namedtuple
import numpy as np
import random
State = namedtuple('State', ('obs', 'description', 'inventory'))
Transition = namedtuple('Transition', ('state', 'act', 'reward', 'next_state', 'next_acts', 'done', 'acts'))
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, transition):
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = transition
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class PrioritizedReplayMemory(object):
def __init__(self, capacity, alpha):
self.capacity = capacity
self.alpha= alpha
self.memory = []
self.priorities = []
self.position = 0
def push(self, transition, priority):
if len(self.memory) < self.capacity:
self.memory.append(None)
self.priorities.append(None)
self.memory[self.position] = transition
self.priorities[self.position] = priority
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
priorities = np.array(self.priorities)
priorities = np.power(priorities + 1e-5, self.alpha)
p = priorities / np.sum(priorities)
idxs = np.random.choice(np.arange(len(p)), size=batch_size, p=p)
return [self.memory[i] for i in idxs]
def update(self, idxs, priorities):
for i, priority in zip(idxs, priorities):
self.priorities[i] = priority
def __len__(self):
return len(self.memory)
class ABReplayMemory(object):
def __init__(self, capacity, priority_fraction):
self.priority_fraction = priority_fraction
self.alpha_capacity = int(capacity * priority_fraction)
self.beta_capacity = capacity - self.alpha_capacity
self.alpha_memory, self.beta_memory = [], []
self.alpha_position, self.beta_position = 0, 0
def clear_alpha(self):
self.alpha_memory = []
self.alpha_position = 0
def push(self, transition, is_prior=False):
"""Saves a transition."""
if self.priority_fraction == 0.0:
is_prior = False
if is_prior:
if len(self.alpha_memory) < self.alpha_capacity:
self.alpha_memory.append(None)
self.alpha_memory[self.alpha_position] = transition
self.alpha_position = (self.alpha_position + 1) % self.alpha_capacity
else:
if len(self.beta_memory) < self.beta_capacity:
self.beta_memory.append(None)
self.beta_memory[self.beta_position] = transition
self.beta_position = (self.beta_position + 1) % self.beta_capacity
def sample(self, batch_size):
if self.priority_fraction == 0.0:
from_beta = min(batch_size, len(self.beta_memory))
res = random.sample(self.beta_memory, from_beta)
else:
from_alpha = min(int(self.priority_fraction * batch_size), len(self.alpha_memory))
from_beta = min(batch_size - int(self.priority_fraction * batch_size), len(self.beta_memory))
res = random.sample(self.alpha_memory, from_alpha) + random.sample(self.beta_memory, from_beta)
random.shuffle(res)
return res
def __len__(self):
return len(self.alpha_memory) + len(self.beta_memory)