From 5c310f0d7b77634290649b713ecb9581a81ab9b4 Mon Sep 17 00:00:00 2001 From: Quinn Damerell Date: Thu, 26 Dec 2024 19:37:08 -0800 Subject: [PATCH] Starting work! --- .vscode/settings.json | 1 + homeway/homeway/sentry.py | 66 +++--- homeway/homeway_linuxhost/linuxhost.py | 8 + homeway/homeway_linuxhost/sage/__init__.py | 1 + homeway/homeway_linuxhost/sage/fabric.py | 206 ++++++++++++++++++ homeway/homeway_linuxhost/sage/sagehandler.py | 102 +++++++++ homeway/homeway_linuxhost/sage/sagehost.py | 183 ++++++++++++++++ 7 files changed, 534 insertions(+), 33 deletions(-) create mode 100644 homeway/homeway_linuxhost/sage/__init__.py create mode 100644 homeway/homeway_linuxhost/sage/fabric.py create mode 100644 homeway/homeway_linuxhost/sage/sagehandler.py create mode 100644 homeway/homeway_linuxhost/sage/sagehost.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 0018456..e397c29 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -76,6 +76,7 @@ "Roboto", "routable", "rwix", + "sagehandler", "serverauth", "servercon", "serverdiscovery", diff --git a/homeway/homeway/sentry.py b/homeway/homeway/sentry.py index 4a2a494..4867f4f 100644 --- a/homeway/homeway/sentry.py +++ b/homeway/homeway/sentry.py @@ -35,39 +35,39 @@ def Setup(versionString:str, distType:str, isDevMode:bool = False): Sentry.IsDevMode = isDevMode # Only setup sentry if we aren't in dev mode. - if Sentry.IsDevMode is False: - try: - # We don't want sentry to capture error logs, which is it's default. - # We do want the logging for breadcrumbs, so we will leave it enabled. - sentry_logging = LoggingIntegration( - level=logging.INFO, # Capture info and above as breadcrumbs - event_level=logging.FATAL # Only send FATAL errors and above. - ) - - # Setup and init Sentry with our private Sentry server. - sentry_sdk.init( - dsn= "https://0f277df18f036d44f9ca11e653485da1@oe-sentry.octoeverywhere.com/5", - integrations= [ - sentry_logging, - ThreadingIntegration(propagate_hub=True), - ], - # This is the recommended format - release= f"homeway-plugin@{versionString}", - dist= distType, - environment= "dev" if isDevMode else "production", - before_send= Sentry._beforeSendFilter, - enable_tracing= True, - # This means we will send 100% of errors, maybe we want to reduce this in the future? - sample_rate= 1.0, - traces_sample_rate= 0.01, - profiles_sample_rate= 0.01, - ) - except Exception as e: - if Sentry._Logger is not None: - Sentry._Logger.error("Failed to init Sentry: "+str(e)) - - # Set that sentry is ready to use. - Sentry.IsSentrySetup = True + # if Sentry.IsDevMode is False: + # try: + # # We don't want sentry to capture error logs, which is it's default. + # # We do want the logging for breadcrumbs, so we will leave it enabled. + # sentry_logging = LoggingIntegration( + # level=logging.INFO, # Capture info and above as breadcrumbs + # event_level=logging.FATAL # Only send FATAL errors and above. + # ) + + # # Setup and init Sentry with our private Sentry server. + # sentry_sdk.init( + # dsn= "https://0f277df18f036d44f9ca11e653485da1@oe-sentry.octoeverywhere.com/5", + # integrations= [ + # sentry_logging, + # ThreadingIntegration(propagate_hub=True), + # ], + # # This is the recommended format + # release= f"homeway-plugin@{versionString}", + # dist= distType, + # environment= "dev" if isDevMode else "production", + # before_send= Sentry._beforeSendFilter, + # enable_tracing= True, + # # This means we will send 100% of errors, maybe we want to reduce this in the future? + # sample_rate= 1.0, + # traces_sample_rate= 0.01, + # profiles_sample_rate= 0.01, + # ) + # except Exception as e: + # if Sentry._Logger is not None: + # Sentry._Logger.error("Failed to init Sentry: "+str(e)) + + # # Set that sentry is ready to use. + # Sentry.IsSentrySetup = True @staticmethod diff --git a/homeway/homeway_linuxhost/linuxhost.py b/homeway/homeway_linuxhost/linuxhost.py index f82848a..c363024 100644 --- a/homeway/homeway_linuxhost/linuxhost.py +++ b/homeway/homeway_linuxhost/linuxhost.py @@ -26,6 +26,7 @@ from .ha.eventhandler import EventHandler from .ha.serverinfo import ServerInfo from .ha.serverdiscovery import ServerDiscovery +from .sage.sagehost import SageHost # This file is the main host for the linux service. @@ -36,6 +37,7 @@ def __init__(self, addonDataRootDir:str, logsDir:str, addonType:AddonTypes, devC self.Secrets = None self.WebServer = None self.HaEventHandler = None + self.Sage:SageHost = None # Indicates if we are running as the Home Assistant addon, or standalone docker or cli. self.AddonType = addonType @@ -187,6 +189,9 @@ def RunBlocking(self, storageDir, versionFileDir, devConfig_CanBeNone): configManager.SetHaConnection(haConnection) configManager.UpdateConfigIfNeeded() + # Setup the sage sub system, it won't be started until the primary connection is established. + self.Sage = SageHost(self.Logger) + # Now start the main runner! pluginConnectUrl = HostCommon.GetPluginConnectionUrl() if devLocalHomewayServerAddress_CanBeNone is not None: @@ -269,6 +274,9 @@ def OnPrimaryConnectionEstablished(self, apiKey, connectedAccounts): # Set the current API key to the event handler self.HaEventHandler.SetHomewayApiKey(apiKey) + # Once we have the API key, we can start the Sage system. + self.Sage.Start(self.GetPluginId(), apiKey) + # Set the current API key to the custom file server CustomFileServer.Get().UpdateAddonConfig(self.GetPluginId(), apiKey) diff --git a/homeway/homeway_linuxhost/sage/__init__.py b/homeway/homeway_linuxhost/sage/__init__.py new file mode 100644 index 0000000..564091d --- /dev/null +++ b/homeway/homeway_linuxhost/sage/__init__.py @@ -0,0 +1 @@ +# Need to make this a module diff --git a/homeway/homeway_linuxhost/sage/fabric.py b/homeway/homeway_linuxhost/sage/fabric.py new file mode 100644 index 0000000..1461025 --- /dev/null +++ b/homeway/homeway_linuxhost/sage/fabric.py @@ -0,0 +1,206 @@ +import time +import json +import logging +import threading +import octoflatbuffers + +from homeway.sentry import Sentry +from homeway.websocketimpl import Client + +from homeway.Proto import SageFiber + + +# Connects to Home Assistant and manages the connection. +class Fabric: + + # For debugging, it's too chatty to enable always. + c_LogWsMessages = False + + def __init__(self, logger:logging.Logger, pluginId:str, apiKey:str) -> None: + self.Logger = logger + #self.EventHandler = eventHandler + self.HaVersionString = None + + # The current websocket connection and Id + self.ConId = 0 + self.BackoffCounter = 0 + self.Ws = None + + # We need to send a message id with each message. + self.MsgIdLock = threading.Lock() + self.MsgId = 1 + + # Indicates if the connection is connection and authed. + self.IsConnected = False + + # If set, when the websocket is connected, we should send the HA restart command. + self.IssueRestartOnConnect = False + + self.PendingContext = None + + + def Start(self) -> None: + t = threading.Thread(target=self._ConnectionThread) + t.daemon = True + t.start() + + + # Called when the websocket is up and authed. + def _OnConnected(self) -> None: + self.Logger.info(f"{self._getLogTag()} Successfully authed and connected!") + self.IsConnected = True + + + + # Runs the main connection we maintain with Home Assistant. + def _ConnectionThread(self): + while True: + + # Reset the state vars + self.IsConnected = False + self.Ws = None + self.MsgId = 1 + + # If this isn't the first connection, sleep a bit before trying again. + if self.ConId != 0: + self.BackoffCounter += 1 + self.BackoffCounter = min(self.BackoffCounter, 12) + self.Logger.error(f"{self._getLogTag()} sleeping before trying the HA connection again.") + time.sleep(5 * self.BackoffCounter) + self.ConId += 1 + + try: + + # Setup our handlers. + + # This is called when the socket is opened. + def Opened(ws:Client): + self.Logger.info(f"{self._getLogTag()} Websocket opened") + + # Called when the websocket is closed. + def Closed(ws:Client): + self.Logger.info(f"{self._getLogTag()} Websocket closed") + + # Start the web socket connection. + # If we got auth from the env var, we running in the add on and use this address. + uri = "wss://homeway.io/sage-fabric-websocket" + self.Logger.info(f"{self._getLogTag()} Starting connection to [{uri}]") + self.Ws = Client(uri, onWsOpen=Opened, onWsData=self._OnData, onWsClose=Closed) + + # Run until success or failure. + self.Ws.RunUntilClosed() + + self.Logger.info(f"{self._getLogTag()} Loop restarting.") + + except Exception as e: + Sentry.Exception("ConnectionThread exception.", e) + + + def _OnData(self, ws:Client, buffer:bytes, msgType): + try: + # Parse the message + # sageFiber = SageFiber.SageFiber() + # sageFiber.Init(buffer, 0) + # text = sageFiber.Text() + + if self.PendingContext is not None: + self.PendingContext.Result = buffer.decode() + self.PendingContext.Event.set() + + + # jsonStr = buffer.decode() + # jsonObj = json.loads(jsonStr) + # if self.Logger.isEnabledFor(logging.DEBUG) and Connection.c_LogWsMessages: + # jsonFormatted = json.dumps(jsonObj, indent=2) + # self.Logger.debug(f"{self._getLogTag()} WS Message \r\n{jsonFormatted}\r\n") + + except Exception as e: + Sentry.Exception("ConnectionThread exception.", e) + self.Close() + + + def Listen(self, audio:bytes) -> str: + + try: + # builder = octoflatbuffers.Builder(len(audio) + 500) + + # audioOffset = builder.CreateByteVector(audio) + + # SageFiber.Start(builder) + # SageFiber.AddData(builder, audioOffset) + # streamMsgOffset = SageFiber.End(builder) + # SageFiber.fin + + # # Finalize the message. We use the size prefixed + # builder.FinishSizePrefixed(streamMsgOffset) + # builder.Output() + + # Instead of using Output, which will create a copy of the buffer that's trimmed, we return the fully built buffer + # with the header offset set and size. Flatbuffers are built backwards, so there's usually space in the front were we can add data + # without creating a new buffer! + # Note that the buffer is a bytearray + # buffer = builder.Bytes + # msgStartOffsetBytes = builder.Head() + # msgSize = len(buffer) - msgStartOffsetBytes + #return builder.Output() + self.Ws.Send(audio, 0, len(audio)) + + self.PendingContext = Context() + self.PendingContext.Event.wait(5) + text = self.PendingContext.Result + self.PendingContext = None + return text + except Exception as e: + self.Logger.error(str(e)) + return "" + + + # def SendMsg(self, msg:dict, ignoreConnectionState:bool = False) -> bool: + # # Check the connection state. + # if ignoreConnectionState is False: + # if self.IsConnected is False: + # self.Logger.error(f"{self._getLogTag()} message tired to be sent while we weren't authed.") + # return False + + # # Capture and check the websocket. + # ws = self.Ws + # if ws is None: + # self.Logger.error(f"{self._getLogTag()} message tired to be sent while we weren't connected.") + # return False + + # try: + # # Add the id field to all messages that are post auth. + # if self.IsConnected: + # with self.MsgIdLock: + # msg["id"] = self.MsgId + # self.MsgId += 1 + + # # Dump the message + # jsonStr = json.dumps(msg) + # if self.Logger.isEnabledFor(logging.DEBUG) and Connection.c_LogWsMessages: + # self.Logger.debug(f"{self._getLogTag()} Sending Ws Message {jsonStr}") + + # # Since we must encode the data, which will create a copy, we might as well just send the buffer as normal, + # # without adding the extra space for the header. We can add the header here or in the WS lib, it's the same amount of work. + # ws.Send(jsonStr.encode(), isData=False) + # return True + # except Exception as e: + # Sentry.Exception("SendMsg exception.", e) + # return False + + + # Closes the connection if it's open. + def Close(self) -> None: + ws = self.Ws + if ws is not None: + ws.Close() + + + def _getLogTag(self) -> str: + return f"HaCon [{self.ConId}]" + + +class Context: + def __init__(self): + self.Event = threading.Event() + self.Result = None diff --git a/homeway/homeway_linuxhost/sage/sagehandler.py b/homeway/homeway_linuxhost/sage/sagehandler.py new file mode 100644 index 0000000..2e1e458 --- /dev/null +++ b/homeway/homeway_linuxhost/sage/sagehandler.py @@ -0,0 +1,102 @@ +import logging +import time +import math + +from wyoming.asr import Transcribe, Transcript +from wyoming.tts import Synthesize +from wyoming.audio import AudioChunk, AudioStop +from wyoming.event import Event +from wyoming.handle import Handled +from wyoming.info import Describe, Info +from wyoming.server import AsyncEventHandler +from wyoming.audio import AudioChunk, AudioStart, AudioStop + +from homeway.httpsessions import HttpSessions +from .fabric import Fabric + + +class SageHandler(AsyncEventHandler): + + def __init__(self, info: Info, logger:logging.Logger, fabric: Fabric, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.Logger = logger + self.InfoEvent = info.event() + self.Fabric = fabric + self.IncomingAudioBuffer = bytearray() + + + # The main event handler for all wyoming events. + # Returning False will disconnect the client. + async def handle_event(self, event: Event) -> bool: + self.Logger.debug(f"Wyoming event: {event.type}") + + # Fired when the server is first connected to and the client wants to know what models are available. + if Describe.is_type(event.type): + await self.write_event(self.InfoEvent) + return True + + if AudioStart.is_type(event.type): + self.IncomingAudioBuffer = bytearray() + return True + if AudioChunk.is_type(event.type): + e = AudioChunk.from_event(event) + self.IncomingAudioBuffer.extend(AudioChunk.from_event(event).audio) + return True + if AudioStop.is_type(event.type): + start = time.time() + text = self.Fabric.Listen(self.IncomingAudioBuffer) + self.Logger.warn(f"Sage WS Listen End - {text} - time: {time.time() - start}") + await self.write_event(Transcript(text=text).event()) + return True + + if Transcribe.is_type(event.type): + transcribe = Transcribe.from_event(event) + return True + + # Fired when there's a input phrase from the user that the client wants to run the model on. + if Transcript.is_type(event.type): + transcript = Transcript.from_event(event) + self.Logger.info(f"Transcript: {transcript.text}") + await self.write_event(Handled("Are the office lights on?").event()) + return True + + # Fired when the client wants to synthesize a voice for the given text. + if Synthesize.is_type(event.type): + transcript = Synthesize.from_event(event) + + # Ensure all of the text is joined on one line. + # text = " ".join(transcript.text.strip().splitlines()) + + self.Logger.debug(f"Sage - Synthesize Start - {text}") + + start = time.time() + url = "https://homeway.io/api/sage/speak" + response = HttpSessions.GetSession(url).post(url, json={"Text": text}, timeout=120) + + # Compute the audio values. + data = response.content + rate = 24000 + width = 2 + channels = 1 + bytesPerSample = width * channels + bytesPerChunk = bytesPerSample * 1024 + chunks = int(math.ceil(len(data) / bytesPerChunk)) + + # Start the response. + await self.write_event(AudioStart(rate=rate, width=width, channels=channels).event()) + + # Write the audio chunks. + for i in range(chunks): + offset = i * bytesPerChunk + chunk = data[offset : offset + bytesPerChunk] + await self.write_event(AudioChunk(audio=chunk, rate=rate, width=width, channels=channels).event()) + + # Write the end event. + await self.write_event(AudioStop().event()) + self.Logger.warn(f"Sage Synthesize End - {text} - time: {time.time() - start}") + return True + + + # For all other events, return True. + # Returning False will disconnect the client. + return True \ No newline at end of file diff --git a/homeway/homeway_linuxhost/sage/sagehost.py b/homeway/homeway_linuxhost/sage/sagehost.py new file mode 100644 index 0000000..47b39de --- /dev/null +++ b/homeway/homeway_linuxhost/sage/sagehost.py @@ -0,0 +1,183 @@ +import asyncio +import logging +import threading +from functools import partial + +from wyoming.info import AsrModel, AsrProgram, Attribution, Info, TtsProgram, TtsVoice, TtsVoiceSpeaker, HandleProgram,HandleModel, IntentProgram, IntentModel +from wyoming.server import AsyncServer + +from homeway.sentry import Sentry + +from .sagehandler import SageHandler +from .fabric import Fabric + +# The main root host for Sage +class SageHost: + + # TODO - This should be dynamic to support multiple instances, but it can't change. + c_ServerPort = 8765 + + def __init__(self, logger:logging.Logger): + self.Logger = logger + self.PluginId:str = None + self.ApiKey:str = None + self.Fabric:Fabric = None + + + # Once the api key is known, we can start. + def Start(self, pluginId:str, apiKey:str): + self.PluginId = pluginId + self.ApiKey = apiKey + + # Start the fabric connection with Homeway + self.Fabric = Fabric(self.Logger, self.PluginId, self.ApiKey) + self.Fabric.Start() + + # Start an independent thread to run asyncio. + threading.Thread(target=self._run).start() + + + def _run(self): + # A main protector for the asyncio loop. + while True: + try: + asyncio.run(self._ServerThread()) + except Exception as e: + Sentry.Exception("SageHost Asyncio Error", e) + + + # The main asyncio loop for the server. + async def _ServerThread(self): + + info = self._GetInfo() + + self.Logger.info(f"Starting wyoming server on port {SageHost.c_ServerPort}") + server = AsyncServer.from_uri(f"tcp://0.0.0.0:{SageHost.c_ServerPort}") + model_lock = asyncio.Lock() + + # Run! + await server.run( + partial( + SageHandler, + info, + self.Logger, + self.Fabric, + ) + ) + + + def _GetInfo(self) -> Info: + models = [ + AsrModel( + name="Hw Test", + description="Some model?", + attribution=Attribution( + name="Homeway", + url="https://homeway.io/", + ), + installed=True, + languages=["en"], + version="0.0.1" + ) + ] + + voices = [ + TtsVoice( + name="homeway-voice", + description="test voice - deepgram long name test", + attribution=Attribution( + name="Homeway", + url="https://homeway.io/", + ), + installed=True, + version="0.0.1", + languages=[ + "en" + ], + speakers=[ + TtsVoiceSpeaker("name-speaker") + ] + ) + ] + + info = Info( + asr=[ + AsrProgram( + name="homeway-voice-speech-render", + description="Test", + attribution=Attribution( + name="Homeway", + url="https://homeway.io/", + ), + installed=True, + version="0.0.1", + models=models, + + ) + ], + tts=[ + TtsProgram( + name="homeway-text-to-speech", + description="test", + attribution=Attribution( + name="Homeway", + url="https://homeway.io/", + ), + installed=True, + voices=voices, + version="0.0.1", + ) + ], + handle= [ + HandleProgram( + name="homeway-handle", + description="test", + attribution=Attribution( + name="Homeway", + url="https://homeway.io/", + ), + installed=True, + version="0.0.1", + models=[ + HandleModel( + name="homeway-handle-model", + description="test", + attribution=Attribution( + name="Homeway", + url="https://homeway.io/", + ), + installed=True, + version="0.0.1", + languages=["en"], + ) + ] + ) + ], + intent= + [ + IntentProgram( + name="homeway-intent", + description="test", + attribution=Attribution( + name="Homeway", + url="https://homeway.io/", + ), + installed=True, + version="0.0.1", + models=[ + IntentModel( + name="homeway-intent-model", + description="test", + attribution=Attribution( + name="Homeway", + url="https://homeway.io/", + ), + installed=True, + version="0.0.1", + languages=["en"], + ) + ] + ) + ] + ) + return info