Skip to content

Commit

Permalink
Implement Auto Flush
Browse files Browse the repository at this point in the history
  • Loading branch information
davidvonthenen committed Jun 12, 2024
1 parent 3177017 commit 13a0833
Show file tree
Hide file tree
Showing 7 changed files with 885 additions and 52 deletions.
230 changes: 217 additions & 13 deletions deepgram/clients/live/v1/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
from typing import Dict, Union, Optional, cast, Any
from datetime import datetime

import websockets
from websockets.client import WebSocketClientProtocol
Expand All @@ -28,6 +29,7 @@
from .options import LiveOptions

ONE_SECOND = 1
HALF_SECOND = 0.5
DEEPGRAM_INTERVAL = 5
PING_INTERVAL = 20

Expand All @@ -49,8 +51,12 @@ class AsyncLiveClient: # pylint: disable=too-many-instance-attributes

_socket: WebSocketClientProtocol
_event_handlers: Dict[LiveTranscriptionEvents, list]
_listen_thread: asyncio.Task
_keep_alive_thread: asyncio.Task

_last_datagram: Optional[datetime] = None

_listen_thread: Union[asyncio.Task, None]
_keep_alive_thread: Union[asyncio.Task, None]
_flush_thread: Union[asyncio.Task, None]

_kwargs: Optional[Dict] = None
_addons: Optional[Dict] = None
Expand All @@ -67,7 +73,16 @@ def __init__(self, config: DeepgramClientOptions):

self._config = config
self._endpoint = "v1/listen"

self._listen_thread = None
self._keep_alive_thread = None
self._flush_thread = None

# exit
self._exit_event = asyncio.Event()

# auto flush
self._flush_event = asyncio.Event()
self._event_handlers = {
event: [] for event in LiveTranscriptionEvents.__members__.values()
}
Expand Down Expand Up @@ -112,7 +127,7 @@ async def start(

if isinstance(options, LiveOptions):
self._logger.info("LiveOptions switching class -> dict")
self._options = cast(Dict[str, str], options.to_dict())
self._options = options.to_dict()
elif options is not None:
self._options = options
else:
Expand Down Expand Up @@ -146,12 +161,19 @@ async def start(
self._listen_thread = asyncio.create_task(self._listening())

# keepalive thread
if self._config.options.get("keepalive") == "true":
if self._config.is_keep_alive_enabled():
self._logger.notice("keepalive is enabled")
self._keep_alive_thread = asyncio.create_task(self._keep_alive())
else:
self._logger.notice("keepalive is disabled")

# flush thread
if self._config.is_auto_flush_enabled():
self._logger.notice("autoflush is enabled")
self._flush_thread = asyncio.create_task(self._flush())
else:
self._logger.notice("autoflush is disabled")

# push open event
await self._emit(
LiveTranscriptionEvents(LiveTranscriptionEvents.Open),
Expand Down Expand Up @@ -186,7 +208,7 @@ def on(self, event: LiveTranscriptionEvents, handler) -> None:
"""
Registers event handlers for specific events.
"""
self._logger.info("event fired: %s", event)
self._logger.info("event subscribed: %s", event)
if event in LiveTranscriptionEvents.__members__.values() and callable(handler):
self._event_handlers[event].append(handler)

Expand All @@ -195,13 +217,14 @@ async def _emit(self, event: LiveTranscriptionEvents, *args, **kwargs) -> None:
"""
Emits events to the registered event handlers.
"""
self._logger.debug("callback handlers for: %s", event)
for handler in self._event_handlers[event]:
if asyncio.iscoroutinefunction(handler):
await handler(self, *args, **kwargs)
else:
asyncio.create_task(handler(self, *args, **kwargs))

# pylint: disable=too-many-return-statements,too-many-statements,too-many-locals
# pylint: disable=too-many-return-statements,too-many-statements,too-many-locals,too-many-branches
async def _listening(self) -> None:
"""
Listens for messages from the WebSocket connection.
Expand Down Expand Up @@ -244,6 +267,13 @@ async def _listening(self) -> None:
message
)
self._logger.verbose("LiveResultResponse: %s", msg_result)

# auto flush
if self._config.is_inspecting_messages():
inspect_res = await self._inspect(msg_result)
if not inspect_res:
self._logger.error("inspect_res failed")

await self._emit(
LiveTranscriptionEvents(LiveTranscriptionEvents.Transcript),
result=msg_result,
Expand Down Expand Up @@ -426,8 +456,7 @@ async def _keep_alive(self) -> None:

# deepgram keepalive
if counter % DEEPGRAM_INTERVAL == 0:
self._logger.verbose("Sending KeepAlive...")
await self.send(json.dumps({"type": "KeepAlive"}))
await self.keep_alive()

except websockets.exceptions.ConnectionClosedOK as e:
self._logger.notice(f"_keep_alive({e.code}) exiting gracefully")
Expand Down Expand Up @@ -514,6 +543,132 @@ async def _keep_alive(self) -> None:
raise
return

## pylint: disable=too-many-return-statements,too-many-statements
async def _flush(self) -> None:
self._logger.debug("AsyncLiveClient._flush ENTER")

delta_in_ms_str = self._config.options.get("auto_flush_reply_delta")
if delta_in_ms_str is None:
self._logger.error("auto_flush_reply_delta is None")
self._logger.debug("AsyncLiveClient._flush LEAVE")
return
delta_in_ms = float(delta_in_ms_str)

while True:
try:
await asyncio.sleep(HALF_SECOND)

if self._exit_event.is_set():
self._logger.notice("_flush exiting gracefully")
self._logger.debug("AsyncLiveClient._flush LEAVE")
return

if self._socket is None:
self._logger.notice("socket is None, exiting flush")
self._logger.debug("AsyncLiveClient._flush LEAVE")
return

if self._last_datagram is None:
self._logger.debug("AutoFlush last_datagram is None")
continue

delta = datetime.now() - self._last_datagram
diff_in_ms = delta.total_seconds() * 1000
self._logger.debug("AutoFlush delta: %f", diff_in_ms)
if diff_in_ms < delta_in_ms:
self._logger.debug("AutoFlush delta is less than threshold")
continue

self._last_datagram = None
await self.finalize()

except websockets.exceptions.ConnectionClosedOK as e:
self._logger.notice(f"_flush({e.code}) exiting gracefully")
self._logger.debug("AsyncLiveClient._flush LEAVE")
return

except websockets.exceptions.ConnectionClosed as e:
if e.code == 1000:
self._logger.notice(f"_flush({e.code}) exiting gracefully")
self._logger.debug("AsyncLiveClient._flush LEAVE")
return

self._logger.error(
"ConnectionClosed in AsyncLiveClient._flush with code %s: %s",
e.code,
e.reason,
)
cc_error: ErrorResponse = ErrorResponse(
"ConnectionClosed in AsyncLiveClient._flush",
f"{e}",
"ConnectionClosed",
)
await self._emit(
LiveTranscriptionEvents(LiveTranscriptionEvents.Error),
error=cc_error,
**dict(cast(Dict[Any, Any], self._kwargs)),
)

# signal exit and close
await self._signal_exit()

self._logger.debug("AsyncLiveClient._flush LEAVE")

if self._config.options.get("termination_exception") == "true":
raise
return

except websockets.exceptions.WebSocketException as e:
self._logger.error(
"WebSocketException in AsyncLiveClient._flush: %s", e
)
ws_error: ErrorResponse = ErrorResponse(
"WebSocketException in AsyncLiveClient._flush",
f"{e}",
"Exception",
)
await self._emit(
LiveTranscriptionEvents(LiveTranscriptionEvents.Error),
error=ws_error,
**dict(cast(Dict[Any, Any], self._kwargs)),
)

# signal exit and close
await self._signal_exit()

self._logger.debug("AsyncLiveClient._flush LEAVE")

if self._config.options.get("termination_exception") == "true":
raise
return

except Exception as e: # pylint: disable=broad-except
self._logger.error("Exception in AsyncLiveClient._flush: %s", e)
e_error: ErrorResponse = ErrorResponse(
"Exception in AsyncLiveClient._flush",
f"{e}",
"Exception",
)
self._logger.error("Exception in AsyncLiveClient._flush: %s", str(e))
await self._emit(
LiveTranscriptionEvents(LiveTranscriptionEvents.Error),
error=e_error,
**dict(cast(Dict[Any, Any], self._kwargs)),
)

# signal exit and close
await self._signal_exit()

self._logger.debug("AsyncLiveClient._flush LEAVE")

if self._config.options.get("termination_exception") == "true":
raise
return

# pylint: enable=too-many-return-statements

# pylint: disable=too-many-return-statements

async def send(self, data: Union[str, bytes]) -> bool:
"""
Sends data over the WebSocket connection.
Expand Down Expand Up @@ -570,6 +725,31 @@ async def send(self, data: Union[str, bytes]) -> bool:

# pylint: enable=too-many-return-statements

async def keep_alive(self) -> bool:
"""
Sends a KeepAlive message
"""
self._logger.spam("AsyncLiveClient.keep_alive ENTER")

if self._exit_event.is_set():
self._logger.notice("keep_alive exiting gracefully")
self._logger.debug("AsyncLiveClient.keep_alive LEAVE")
return False

if self._socket is not None:
self._logger.notice("Sending KeepAlive...")
ret = await self.send(json.dumps({"type": "KeepAlive"}))

if not ret:
self._logger.error("keep_alive failed")
self._logger.spam("AsyncLiveClient.keep_alive LEAVE")
return False

self._logger.notice("keep_alive succeeded")
self._logger.spam("AsyncLiveClient.keep_alive LEAVE")

return True

async def finalize(self) -> bool:
"""
Finalizes the Transcript connection by flushing it
Expand All @@ -582,7 +762,7 @@ async def finalize(self) -> bool:
return False

if self._socket is not None:
self._logger.notice("sending Finalize...")
self._logger.notice("Sending Finalize...")
ret = await self.send(json.dumps({"type": "Finalize"}))

if not ret:
Expand All @@ -609,13 +789,20 @@ async def finish(self) -> bool:
try:
# Before cancelling, check if the tasks were created
tasks = []
if self._config.options.get("keepalive") == "true":
if self._keep_alive_thread is not None:
self._keep_alive_thread.cancel()
tasks.append(self._keep_alive_thread)
if self._keep_alive_thread is not None:
self._keep_alive_thread.cancel()
tasks.append(self._keep_alive_thread)
self._logger.notice("processing _keep_alive_thread cancel...")

if self._flush_thread is not None:
self._flush_thread.cancel()
tasks.append(self._flush_thread)
self._logger.notice("processing _flush_thread cancel...")

if self._listen_thread is not None:
self._listen_thread.cancel()
tasks.append(self._listen_thread)
self._logger.notice("processing _listen_thread cancel...")

# Use asyncio.gather to wait for tasks to be cancelled
await asyncio.gather(*filter(None, tasks), return_exceptions=True)
Expand Down Expand Up @@ -673,3 +860,20 @@ async def _signal_exit(self) -> None:
self._logger.error("socket.wait_closed failed: %s", e)

self._socket = None # type: ignore

async def _inspect(self, msg_result: LiveResultResponse) -> bool:
sentence = msg_result.channel.alternatives[0].transcript
if len(sentence) == 0:
return True

if msg_result.is_final:
self._logger.debug("AutoFlush is_final received")
self._last_datagram = None
else:
self._last_datagram = datetime.now()
self._logger.debug(
"AutoFlush interim received: %s",
str(self._last_datagram),
)

return True
Loading

0 comments on commit 13a0833

Please sign in to comment.