Skip to content

Commit

Permalink
feat: add signal handling for graceful shutdown in main execution (#86)
Browse files Browse the repository at this point in the history
* feat: enhance argument parsing and validation with environment variable support

* feat: add signal handling for graceful shutdown in main execution
  • Loading branch information
hugobloem authored Nov 23, 2024
1 parent 88ff459 commit a63a1f4
Showing 1 changed file with 43 additions and 6 deletions.
49 changes: 43 additions & 6 deletions wyoming_microsoft_tts/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import asyncio
import contextlib
import logging
import os
import re
import signal
from functools import partial
from typing import Any

Expand All @@ -14,18 +17,26 @@

_LOGGER = logging.getLogger(__name__)

stop_event = asyncio.Event()

async def main() -> None:
"""Start Wyoming Microsoft TTS server."""

def handle_stop_signal(*args):
"""Handle shutdown signal and set the stop event."""
_LOGGER.info("Received stop signal. Shutting down...")
stop_event.set()


def parse_arguments():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--service-region",
required=True,
default=os.getenv("AZURE_SERVICE_REGION"),
help="Microsoft Azure region (e.g., westus2)",
)
parser.add_argument(
"--subscription-key",
required=True,
default=os.getenv("AZURE_SUBSCRIPTION_KEY"),
help="Microsoft Azure subscription key",
)
parser.add_argument(
Expand Down Expand Up @@ -59,9 +70,30 @@ async def main() -> None:
)
#
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
args = parser.parse_args()
return parser.parse_args()


def validate_args(args):
"""Validate command-line arguments."""
if not args.service_region or not args.subscription_key:
raise ValueError(
"Both --service-region and --subscription-key must be provided either as command-line arguments or environment variables."
)
# Reinstate key validation with more flexibility to accommodate complex keys
if not re.match(r"^[A-Za-z0-9\-_]{40,}$", args.subscription_key):
_LOGGER.warning(
"The subscription key does not match the expected format but will attempt to initialize."
)


async def main() -> None:
"""Start Wyoming Microsoft TTS server."""
args = parse_arguments()
validate_args(args)

# setup logging
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
_LOGGER.debug("Arguments parsed successfully.")

# Load voice info
voices_info = get_voices(
Expand All @@ -80,7 +112,8 @@ async def main() -> None:
# Make sure default voice is in the list
if args.voice not in voices_info:
raise ValueError(
f"Voice {args.voice} not found in voices.json, please look up the correct voice name here\nhttps://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts"
f"Voice {args.voice} not found in voices.json, please look up the correct voice name here"
+ "\nhttps://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts"
)

voices_info.update(aliases_info)
Expand Down Expand Up @@ -157,5 +190,9 @@ def get_description(voice_info: dict[str, Any]):
# -----------------------------------------------------------------------------

if __name__ == "__main__":
# Set up signal handling for graceful shutdown
signal.signal(signal.SIGTERM, handle_stop_signal)
signal.signal(signal.SIGINT, handle_stop_signal)

with contextlib.suppress(KeyboardInterrupt):
asyncio.run(main())

0 comments on commit a63a1f4

Please sign in to comment.