-
Notifications
You must be signed in to change notification settings - Fork 0
/
erl_net.py
138 lines (113 loc) · 5.67 KB
/
erl_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
TEN = torch.Tensor
class QNetBase(nn.Module): # nn.Module is a standard PyTorch Network
def __init__(self, state_dim: int, action_dim: int):
super().__init__()
self.explore_rate = 0.125
self.state_dim = state_dim
self.action_dim = action_dim
self.net = None # build_mlp(dims=[state_dim + action_dim, *dims, 1])
self.state_avg = nn.Parameter(torch.zeros((state_dim,)), requires_grad=False)
self.state_std = nn.Parameter(torch.ones((state_dim,)), requires_grad=False)
self.value_avg = nn.Parameter(torch.zeros((1,)), requires_grad=False)
self.value_std = nn.Parameter(torch.ones((1,)), requires_grad=False)
def state_norm(self, state: TEN) -> TEN:
return (state - self.state_avg) / self.state_std
def value_re_norm(self, value: TEN) -> TEN:
return value * self.value_std + self.value_avg
class QNetTwin(QNetBase): # Double DQN
def __init__(self, dims: [int], state_dim: int, action_dim: int):
super().__init__(state_dim=state_dim, action_dim=action_dim)
self.net_state = build_mlp(dims=[state_dim, *dims])
self.net_val1 = build_mlp(dims=[dims[-1], action_dim]) # Q value 1
self.net_val2 = build_mlp(dims=[dims[-1], action_dim]) # Q value 2
self.soft_max = nn.Softmax(dim=1)
layer_init_with_orthogonal(self.net_val1[-1], std=0.1)
layer_init_with_orthogonal(self.net_val2[-1], std=0.1)
def forward(self, state):
state = self.state_norm(state)
s_enc = self.net_state(state) # encoded state
q_val = self.net_val1(s_enc) # q value
return q_val # one group of Q values
def get_q1_q2(self, state):
state = self.state_norm(state)
s_enc = self.net_state(state) # encoded state
q_val1 = self.net_val1(s_enc) # q value 1
q_val1 = self.value_re_norm(q_val1)
q_val2 = self.net_val2(s_enc) # q value 2
q_val2 = self.value_re_norm(q_val2)
return q_val1, q_val2 # two groups of Q values
def get_action(self, state):
state = self.state_norm(state)
s_enc = self.net_state(state) # encoded state
q_val = self.net_val1(s_enc) # q value
if self.explore_rate < torch.rand(1):
action = q_val.argmax(dim=1, keepdim=True)
else:
# a_prob = self.soft_max(q_val)
# action = torch.multinomial(a_prob, num_samples=1)
action = torch.randint(self.action_dim, size=(state.shape[0], 1))
return action
class QNetTwinDuel(QNetBase): # D3QN: Dueling Double DQN
def __init__(self, dims: [int], state_dim: int, action_dim: int):
super().__init__(state_dim=state_dim, action_dim=action_dim)
self.net_state = build_mlp(dims=[state_dim, *dims])
self.net_adv1 = build_mlp(dims=[dims[-1], 1]) # advantage value 1
self.net_val1 = build_mlp(dims=[dims[-1], action_dim]) # Q value 1
self.net_adv2 = build_mlp(dims=[dims[-1], 1]) # advantage value 2
self.net_val2 = build_mlp(dims=[dims[-1], action_dim]) # Q value 2
self.soft_max = nn.Softmax(dim=1)
layer_init_with_orthogonal(self.net_adv1[-1], std=0.1)
layer_init_with_orthogonal(self.net_val1[-1], std=0.1)
layer_init_with_orthogonal(self.net_adv2[-1], std=0.1)
layer_init_with_orthogonal(self.net_val2[-1], std=0.1)
def forward(self, state):
state = self.state_norm(state)
s_enc = self.net_state(state) # encoded state
q_val = self.net_val1(s_enc) # q value
q_adv = self.net_adv1(s_enc) # advantage value
value = q_val - q_val.mean(dim=1, keepdim=True) + q_adv # one dueling Q value
value = self.value_re_norm(value)
return value
def get_q1_q2(self, state):
state = self.state_norm(state)
s_enc = self.net_state(state) # encoded state
q_val1 = self.net_val1(s_enc) # q value 1
q_adv1 = self.net_adv1(s_enc) # advantage value 1
q_duel1 = q_val1 - q_val1.mean(dim=1, keepdim=True) + q_adv1
q_duel1 = self.value_re_norm(q_duel1)
q_val2 = self.net_val2(s_enc) # q value 2
q_adv2 = self.net_adv2(s_enc) # advantage value 2
q_duel2 = q_val2 - q_val2.mean(dim=1, keepdim=True) + q_adv2
q_duel2 = self.value_re_norm(q_duel2)
return q_duel1, q_duel2 # two dueling Q values
def get_action(self, state):
state = self.state_norm(state)
s_enc = self.net_state(state) # encoded state
q_val = self.net_val1(s_enc) # q value
if self.explore_rate < torch.rand(1):
action = q_val.argmax(dim=1, keepdim=True)
else:
# a_prob = self.soft_max(q_val)
# action = torch.multinomial(a_prob, num_samples=1)
action = torch.randint(self.action_dim, size=(state.shape[0], 1))
return action
def build_mlp(dims: [int], activation: nn = None, if_raw_out: bool = True) -> nn.Sequential:
"""
build MLP (MultiLayer Perceptron)
dims: the middle dimension, `dims[-1]` is the output dimension of this network
activation: the activation function
if_remove_out_layer: if remove the activation function of the output layer.
"""
if activation is None:
activation = nn.ReLU
net_list = []
for i in range(len(dims) - 1):
net_list.extend([nn.Linear(dims[i], dims[i + 1]), activation()])
if if_raw_out:
del net_list[-1] # delete the activation function of the output layer to keep raw output
return nn.Sequential(*net_list)
def layer_init_with_orthogonal(layer, std=1.0, bias_const=1e-6):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)