-
Notifications
You must be signed in to change notification settings - Fork 0
/
LSTM_test.py
365 lines (302 loc) · 11.8 KB
/
LSTM_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#!/usr/bin/env python
# -*- coding: utf-8 -*-
##
# @file d-dash.py
# @author Kyeong Soo (Joseph) Kim <[email protected]>
# @date 2020-05-15
#
# @brief Baseline (simplied) implementation of D-DASH [1], a framework that
# combines deep learning and reinforcement learning techniques to
# optimize the quality of experience (QoE) of DASH, where the
# policy-network is implemented based on feedforward neural network
# (FNN) but without the target network and the replay memory.
# The current implementation is based on PyTorch reinforcement learning
# (DQN) tutorial [2].
#
# @remark [1] M. Gadaleta, F. Chiariotti, M. Rossi, and A. Zanella, “D-dash: A
# deep Q-learning framework for DASH video streaming,” IEEE Trans. on
# Cogn. Commun. Netw., vol. 3, no. 4, pp. 703–718, Dec. 2017.
# [2] PyTorch reinforcement (DQN) tutorial. Available online:
# https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
# import copy # for target network
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
from dataclasses import dataclass
import torch.nn as nn
import torch
import copy
# global variables
# - DQL
CH_HISTORY = 1 # number of channel capacity history samples
BATCH_SIZE = 50
EPS_START = 0.8
EPS_END = 0.0
EPS_DECAY = 200
LEARNING_RATE = 1e-4
# - FFN
N_I = 3 + CH_HISTORY # input dimension (= state dimension)
N_H1 = 128
N_H2 = 256
N_O = 4
# - D-DASH
BETA = 2
GAMMA = 50
DELTA = 0.001
B_MAX = 20
B_THR = 10
T = 2 # segment duration
TARGET_UPDATE = 20
LAMBDA = 0.9
# RNN parameters
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
plt.ion() # turn interactive mode on
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define neural network
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.LSTM( # define LSTM class
input_size=N_I, # 图片每行的数据像素点
hidden_size=16, # rnn hidden unit
num_layers=1, # 有几层 RNN layers
batch_first=True, # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
)
self.fc = nn.Linear(16, 4) # 输出层
def forward(self, x):
# x shape (batch, time_step, input_size)
# r_out shape (batch, time_step, output_size)
# h_n shape (n_layers, batch, hidden_size) LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
# h_c shape (n_layers, batch, hidden_size)
r_out, (h_n, h_c) = self.rnn(x, None) # None 表示 hidden state 会用全0的 state
# 选取最后一个时间点的 r_out 输出
# 这里 r_out[:, -1, :] 的值也是 h_n 的值
# out = self.out(r_out[:, -1, :])
out = self.fc(r_out)
return out
@dataclass
class State:
"""
$s_t = (q_{t-1}, F_{t-1}(q_{t-1}), B_t, \bm{C}_t)$, which is a modified
version of the state defined in [1].
"""
sg_quality: int
sg_size: float
buffer: float
ch_history: np.ndarray
def tensor(self):
return torch.tensor(
np.concatenate(
(
np.array([
self.sg_quality,
self.sg_size,
self.buffer]),
self.ch_history
),
axis=None
),
dtype=torch.float32
)
@dataclass
class Experience:
"""$e_t = (s_t, q_t, r_t, s_{t+1})$ in [1]"""
state: State
action: int
reward: float
next_state: State
class ReplayMemory(object):
"""Replay memory based on a circular buffer (with overlapping)"""
def __init__(self, capacity):
self.capacity = capacity
self.memory = [None] * self.capacity
self.position = 0
self.num_elements = 0
def push(self, experience):
# if len(self.memory) < self.capacity:
# self.memory.append(None)
self.memory[self.position] = experience
self.position = (self.position + 1) % self.capacity
if self.num_elements < self.capacity:
self.num_elements += 1
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def get_num_elements(self):
return self.num_elements
class ActionSelector(object):
"""
Select an action based on the exploration policy.
"""
def __init__(self, num_actions):
self.steps_done = 0
self.num_actions = num_actions
def reset(self):
self.steps_done = 0
def increse_step_number(self):
self.steps_done += 1
def action(self, state):
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
math.exp(-1. * self.steps_done / EPS_DECAY)
# self.steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
net_output = policy_net_lstm(state.tensor().view(-1, 1, N_I))
return int(torch.argmax(net_output[:, -1, :]))
else:
return random.randrange(self.num_actions)
# policy-network based on FNN with 2 hidden layers
policy_net = torch.nn.Sequential(
torch.nn.Linear(N_I, N_H1),
torch.nn.ReLU(),
torch.nn.Linear(N_H1, N_H2),
torch.nn.ReLU(),
torch.nn.Linear(N_H2, N_O),
torch.nn.Sigmoid()
).to(device)
optimizer = torch.optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
# policy_network based on RNN with 1 hidden layer
policy_net_lstm = RNN()
optimizer_lstm = torch.optim.Adam(policy_net_lstm.parameters(), lr=LEARNING_RATE)
mse_loss = torch.nn.MSELoss(reduction='mean')
# TODO: Implement target network
target_net = copy.deepcopy(policy_net_lstm)
target_net.load_state_dict(policy_net_lstm.state_dict())
target_net.eval()
def simulate_dash(sss, bws):
# initialize parameters
num_segments = sss.shape[0] # number of segments
num_qualities = sss.shape[1] # number of quality levels
# initialize replay memory and action_selector
memory = ReplayMemory(1000)
selector = ActionSelector(num_qualities)
##########
# training
##########
num_episodes = 30
mean_sqs = np.empty(num_episodes) # mean segment qualities
mean_rewards = np.empty(num_episodes) # mean rewards
for i_episode in range(num_episodes):
# TODO: use different video traces per episode
sqs = np.empty(num_segments - CH_HISTORY)
rewards = np.empty(num_segments - CH_HISTORY)
# initialize the state
sg_quality = random.randrange(num_qualities) # random action
state = State(
sg_quality=sg_quality,
sg_size=sss[CH_HISTORY - 1, sg_quality],
buffer=T,
ch_history=bws[0:CH_HISTORY]
)
for t in range(CH_HISTORY, num_segments):
sg_quality = selector.action(state)
sqs[t - CH_HISTORY] = sg_quality
# update the state
tau = sss[t, sg_quality] / bws[t]
buffer_next = T - max(0, state.buffer - tau)
next_state = State(
sg_quality=sg_quality,
sg_size=sss[t, sg_quality],
buffer=buffer_next,
ch_history=bws[t - CH_HISTORY + 1:t + 1]
)
# calculate reward (i.e., (4) in [1]).
downloading_time = next_state.sg_size / next_state.ch_history[-1]
rebuffering = max(0, downloading_time - state.buffer)
rewards[t - CH_HISTORY] = next_state.sg_quality \
- BETA * abs(next_state.sg_quality - state.sg_quality) \
- GAMMA * rebuffering - DELTA * max(0, B_THR - next_state.buffer) ** 2
# store the experience in the replay memory
experience = Experience(
state=state,
action=sg_quality,
reward=rewards[t - CH_HISTORY],
next_state=next_state
)
memory.push(experience)
# move to the next state
state = next_state
#############################
# optimize the policy network
#############################
if memory.get_num_elements() < BATCH_SIZE:
continue
experiences = memory.sample(BATCH_SIZE)
state_batch = torch.stack([experiences[i].state.tensor()
for i in range(BATCH_SIZE)])
next_state_batch = torch.stack([experiences[i].next_state.tensor()
for i in range(BATCH_SIZE)])
action_batch = torch.tensor([experiences[i].action
for i in range(BATCH_SIZE)])
reward_batch = torch.tensor([experiences[i].reward
for i in range(BATCH_SIZE)])
# $Q(s_t, q_t|\bm{w}_t)$ in (13) in [1]
# 1. policy_net generates a batch of Q(...) for all q values.
# 2. columns of actions taken are selected using 'action_batch'.
# state_action_values = policy_net_lstm(state_batch.view(-1, BATCH_SIZE, 5))
state_Q_values = torch.squeeze(policy_net_lstm(state_batch.view(-1, BATCH_SIZE, N_I)))
state_action_values = state_Q_values.gather(1, action_batch.view(BATCH_SIZE, -1))
# $\max_{q}\hat{Q}(s_{t+1},q|\bar{\bm{w}}_t$ in (13) in [1]
# TODO: Replace policy_net with target_net.
target_values = torch.squeeze(target_net(next_state_batch.view(-1, BATCH_SIZE, N_I)))
next_state_values = target_values.max(1)[0].detach()
# expected Q values
expected_state_action_values = reward_batch + (LAMBDA * next_state_values)
# loss fuction, i.e., (14) in [1]
loss = mse_loss(state_action_values,
expected_state_action_values.unsqueeze(1))
# optimize the model
optimizer_lstm.zero_grad()
loss.backward()
for param in policy_net_lstm.parameters():
param.grad.data.clamp_(-1, 1)
optimizer_lstm.step()
# TODO: Implement target network
# # update the target network
if t % TARGET_UPDATE == 0:
target_net.load_state_dict(policy_net_lstm.state_dict())
# processing after each episode
selector.increse_step_number()
mean_sqs[i_episode] = sqs.mean()
mean_rewards[i_episode] = rewards.mean()
print("Mean Segment Quality[{0:2d}]: {1:E}".format(i_episode, mean_sqs[i_episode]))
print("Mean Reward[{0:2d}]: {1:E}".format(i_episode, mean_rewards[i_episode]))
return (mean_sqs, mean_rewards)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-V",
"--video_trace",
help="video trace file name; default is 'bigbuckbunny.npy'",
default='bigbuckbunny.npy',
type=str)
parser.add_argument(
"-C",
"--channel_bandwidths",
help="channel bandwidths file name; default is 'bandwidths.npy'",
default='bandwidths.npy',
type=str)
args = parser.parse_args()
video_trace = args.video_trace
channel_bandwidths = args.channel_bandwidths
# read data
sss = np.load(video_trace) # segment sizes [bit]
bws = np.load(channel_bandwidths) # channel bandwdiths [bit/s]
# simulate D-DASH
mean_sqs, mean_rewards = simulate_dash(sss, bws)
# plot results
fig, axs = plt.subplots(nrows=2, sharex=True)
axs[0].plot(mean_rewards)
axs[0].set_ylabel("Reward")
axs[1].plot(mean_sqs)
axs[1].set_ylabel("Video Quality")
axs[1].set_xlabel("Video Episode")
plt.show()
input("Press ENTER to continue...")
plt.close('all')