-
-
Notifications
You must be signed in to change notification settings - Fork 173
/
websocket.py
395 lines (360 loc) · 15.5 KB
/
websocket.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
from __future__ import annotations
import json
import asyncio
import logging
from time import time
from contextlib import suppress
from typing import Any, Literal, TYPE_CHECKING
import aiohttp
from translate import _
from exceptions import MinerException, WebsocketClosed
from constants import PING_INTERVAL, PING_TIMEOUT, MAX_WEBSOCKETS, WS_TOPICS_LIMIT
from utils import (
CHARS_ASCII,
task_wrapper,
create_nonce,
json_minify,
format_traceback,
AwaitableValue,
ExponentialBackoff,
)
if TYPE_CHECKING:
from collections import abc
from twitch import Twitch
from gui import WebsocketStatus
from constants import JsonType, WebsocketTopic
WSMsgType = aiohttp.WSMsgType
logger = logging.getLogger("TwitchDrops")
ws_logger = logging.getLogger("TwitchDrops.websocket")
class Websocket:
def __init__(self, pool: WebsocketPool, index: int):
self._pool: WebsocketPool = pool
self._twitch: Twitch = pool._twitch
self._ws_gui: WebsocketStatus = self._twitch.gui.websockets
self._state_lock = asyncio.Lock()
# websocket index
self._idx: int = index
# current websocket connection
self._ws: AwaitableValue[aiohttp.ClientWebSocketResponse] = AwaitableValue()
# set when the websocket needs to be closed or reconnect
self._closed = asyncio.Event()
self._reconnect_requested = asyncio.Event()
# set when the topics changed
self._topics_changed = asyncio.Event()
# ping timestamps
self._next_ping: float = time()
self._max_pong: float = self._next_ping + PING_TIMEOUT.total_seconds()
# main task, responsible for receiving messages, sending them, and websocket ping
self._handle_task: asyncio.Task[None] | None = None
# topics stuff
self.topics: dict[str, WebsocketTopic] = {}
self._submitted: set[WebsocketTopic] = set()
# notify GUI
self.set_status(_("gui", "websocket", "disconnected"))
@property
def connected(self) -> bool:
return self._ws.has_value()
def wait_until_connected(self):
return self._ws.wait()
def set_status(self, status: str | None = None, refresh_topics: bool = False):
self._twitch.gui.websockets.update(
self._idx, status=status, topics=(len(self.topics) if refresh_topics else None)
)
def request_reconnect(self):
# reset our ping interval, so we send a PING after reconnect right away
self._next_ping = time()
self._reconnect_requested.set()
async def start(self):
async with self._state_lock:
self.start_nowait()
await self.wait_until_connected()
def start_nowait(self):
if self._handle_task is None or self._handle_task.done():
self._handle_task = asyncio.create_task(self._handle())
async def stop(self, *, remove: bool = False):
async with self._state_lock:
if self._closed.is_set():
return
self._closed.set()
ws = self._ws.get_with_default(None)
if ws is not None:
self.set_status(_("gui", "websocket", "disconnecting"))
await ws.close()
if self._handle_task is not None:
with suppress(asyncio.TimeoutError, asyncio.CancelledError):
await asyncio.wait_for(self._handle_task, timeout=2)
self._handle_task = None
if remove:
self.topics.clear()
self._topics_changed.set()
self._twitch.gui.websockets.remove(self._idx)
def stop_nowait(self, *, remove: bool = False):
# weird syntax but that's what we get for using a decorator for this
# return type of 'task_wrapper' is a coro, so we need to instance it for the task
asyncio.create_task(task_wrapper(self.stop)(remove=remove))
async def _backoff_connect(
self, ws_url: str, **kwargs
) -> abc.AsyncGenerator[aiohttp.ClientWebSocketResponse, None]:
session = await self._twitch.get_session()
backoff = ExponentialBackoff(**kwargs)
if self._twitch.settings.proxy:
proxy = self._twitch.settings.proxy
else:
proxy = None
for delay in backoff:
try:
async with session.ws_connect(ws_url, proxy=proxy) as websocket:
yield websocket
backoff.reset()
except (
asyncio.TimeoutError,
aiohttp.ClientResponseError,
aiohttp.ClientConnectionError,
):
ws_logger.info(
f"Websocket[{self._idx}] connection problem (sleep: {round(delay)}s)"
)
await asyncio.sleep(delay)
except RuntimeError:
ws_logger.warning(
f"Websocket[{self._idx}] exiting backoff connect loop "
"because session is closed (RuntimeError)"
)
break
@task_wrapper(critical=True)
async def _handle(self):
# ensure we're logged in before connecting
self.set_status(_("gui", "websocket", "initializing"))
await self._twitch.wait_until_login()
self.set_status(_("gui", "websocket", "connecting"))
ws_logger.info(f"Websocket[{self._idx}] connecting...")
self._closed.clear()
# Connect/Reconnect loop
async for websocket in self._backoff_connect(
"wss://pubsub-edge.twitch.tv/v1", maximum=3*60 # 3 minutes maximum backoff time
):
self._ws.set(websocket)
self._reconnect_requested.clear()
# NOTE: _topics_changed doesn't start set,
# because there's no initial topics we can sub to right away
self.set_status(_("gui", "websocket", "connected"))
ws_logger.info(f"Websocket[{self._idx}] connected.")
try:
try:
while not self._reconnect_requested.is_set():
await self._handle_ping()
await self._handle_topics()
await self._handle_recv()
finally:
self._ws.clear()
self._submitted.clear()
# set _topics_changed to let the next WS connection resub to the topics
self._topics_changed.set()
# A reconnect was requested
except WebsocketClosed as exc:
if exc.received:
# server closed the connection, not us - reconnect
ws_logger.warning(
f"Websocket[{self._idx}] closed unexpectedly: {websocket.close_code}"
)
elif self._closed.is_set():
# we closed it - exit
ws_logger.info(f"Websocket[{self._idx}] stopped.")
self.set_status(_("gui", "websocket", "disconnected"))
return
except Exception:
ws_logger.exception(f"Exception in Websocket[{self._idx}]")
self.set_status(_("gui", "websocket", "reconnecting"))
ws_logger.warning(f"Websocket[{self._idx}] reconnecting...")
async def _handle_ping(self):
now = time()
if now >= self._next_ping:
self._next_ping = now + PING_INTERVAL.total_seconds()
self._max_pong = now + PING_TIMEOUT.total_seconds() # wait for a PONG for up to 10s
await self.send({"type": "PING"})
elif now >= self._max_pong:
# it's been more than 10s and there was no PONG
ws_logger.warning(f"Websocket[{self._idx}] didn't receive a PONG, reconnecting...")
self.request_reconnect()
async def _handle_topics(self):
if not self._topics_changed.is_set():
# nothing to do
return
self._topics_changed.clear()
self.set_status(refresh_topics=True)
auth_state = await self._twitch.get_auth()
current: set[WebsocketTopic] = set(self.topics.values())
# handle removed topics
removed = self._submitted.difference(current)
if removed:
topics_list = list(map(str, removed))
ws_logger.debug(f"Websocket[{self._idx}]: Removing topics: {', '.join(topics_list)}")
await self.send(
{
"type": "UNLISTEN",
"data": {
"topics": topics_list,
"auth_token": auth_state.access_token,
}
}
)
self._submitted.difference_update(removed)
# handle added topics
added = current.difference(self._submitted)
if added:
topics_list = list(map(str, added))
ws_logger.debug(f"Websocket[{self._idx}]: Adding topics: {', '.join(topics_list)}")
await self.send(
{
"type": "LISTEN",
"data": {
"topics": topics_list,
"auth_token": auth_state.access_token,
}
}
)
self._submitted.update(added)
async def _gather_recv(self, messages: list[JsonType], timeout: float = 0.5):
"""
Gather incoming messages over the timeout specified.
Note that there's no return value - this modifies `messages` in-place.
"""
ws = self._ws.get_with_default(None)
assert ws is not None
while True:
raw_message: aiohttp.WSMessage = await ws.receive(timeout=timeout)
ws_logger.debug(f"Websocket[{self._idx}] received: {raw_message}")
if raw_message.type is WSMsgType.TEXT:
message: JsonType = json.loads(raw_message.data)
messages.append(message)
elif raw_message.type is WSMsgType.CLOSE:
raise WebsocketClosed(received=True)
elif raw_message.type is WSMsgType.CLOSED:
raise WebsocketClosed(received=False)
elif raw_message.type is WSMsgType.CLOSING:
pass # skip these
elif raw_message.type is WSMsgType.ERROR:
ws_logger.error(
f"Websocket[{self._idx}] error: {format_traceback(raw_message.data)}"
)
raise WebsocketClosed()
else:
ws_logger.error(f"Websocket[{self._idx}] error: Unknown message: {raw_message}")
def _handle_message(self, message):
# request the assigned topic to process the response
topic = self.topics.get(message["data"]["topic"])
if topic is not None:
# use a task to not block the websocket
asyncio.create_task(topic(json.loads(message["data"]["message"])))
async def _handle_recv(self):
"""
Handle receiving messages from the websocket.
"""
# listen over 0.5s for incoming messages
messages: list[JsonType] = []
with suppress(asyncio.TimeoutError):
await self._gather_recv(messages, timeout=0.5)
# process them
for message in messages:
msg_type = message["type"]
if msg_type == "MESSAGE":
self._handle_message(message)
elif msg_type == "PONG":
# move the timestamp to something much later
self._max_pong = self._next_ping
elif msg_type == "RESPONSE":
# no special handling for these (for now)
pass
elif msg_type == "RECONNECT":
# We've received a reconnect request
ws_logger.warning(f"Websocket[{self._idx}] requested reconnect.")
self.request_reconnect()
else:
ws_logger.warning(f"Websocket[{self._idx}] received unknown payload: {message}")
def add_topics(self, topics_set: set[WebsocketTopic]):
changed: bool = False
while topics_set and len(self.topics) < WS_TOPICS_LIMIT:
topic = topics_set.pop()
self.topics[str(topic)] = topic
changed = True
if changed:
self._topics_changed.set()
def remove_topics(self, topics_set: set[str]):
existing = topics_set.intersection(self.topics.keys())
if not existing:
# nothing to remove from here
return
topics_set.difference_update(existing)
for topic in existing:
del self.topics[topic]
self._topics_changed.set()
async def send(self, message: JsonType):
ws = self._ws.get_with_default(None)
assert ws is not None
if message["type"] != "PING":
message["nonce"] = create_nonce(CHARS_ASCII, 30)
await ws.send_json(message, dumps=json_minify)
ws_logger.debug(f"Websocket[{self._idx}] sent: {message}")
class WebsocketPool:
def __init__(self, twitch: Twitch):
self._twitch: Twitch = twitch
self._running = asyncio.Event()
self.websockets: list[Websocket] = []
@property
def running(self) -> bool:
return self._running.is_set()
def wait_until_connected(self) -> abc.Coroutine[Any, Any, Literal[True]]:
return self._running.wait()
async def start(self):
self._running.set()
await asyncio.gather(*(ws.start() for ws in self.websockets))
async def stop(self, *, clear_topics: bool = False):
self._running.clear()
await asyncio.gather(*(ws.stop(remove=clear_topics) for ws in self.websockets))
def add_topics(self, topics: abc.Iterable[WebsocketTopic]):
# ensure no topics end up duplicated
topics_set = set(topics)
if not topics_set:
# nothing to add
return
topics_set.difference_update(*(ws.topics.values() for ws in self.websockets))
if not topics_set:
# none left to add
return
for ws_idx in range(MAX_WEBSOCKETS):
if ws_idx < len(self.websockets):
# just read it back
ws = self.websockets[ws_idx]
else:
# create new
ws = Websocket(self, ws_idx)
if self.running:
ws.start_nowait()
self.websockets.append(ws)
# ask websocket to take any topics it can - this modifies the set in-place
ws.add_topics(topics_set)
# see if there's any leftover topics for the next websocket connection
if not topics_set:
return
# if we're here, there were leftover topics after filling up all websockets
raise MinerException("Maximum topics limit has been reached")
def remove_topics(self, topics: abc.Iterable[str]):
topics_set = set(topics)
if not topics_set:
# nothing to remove
return
for ws in self.websockets:
ws.remove_topics(topics_set)
# count up all the topics - if we happen to have more websockets connected than needed,
# stop the last one and recycle topics from it - repeat until we have enough
recycled_topics: list[WebsocketTopic] = []
while True:
count = sum(len(ws.topics) for ws in self.websockets)
if count <= (len(self.websockets) - 1) * WS_TOPICS_LIMIT:
ws = self.websockets.pop()
recycled_topics.extend(ws.topics.values())
ws.stop_nowait(remove=True)
else:
break
if recycled_topics:
self.add_topics(recycled_topics)