From 5031a509e417c3c35021fa7f68972d107aac87a6 Mon Sep 17 00:00:00 2001 From: Victor Elias Date: Fri, 8 Nov 2024 17:36:45 +0000 Subject: [PATCH] live/infer: Allow running with zeromq protocol (#267) * live/infer: Create protocol arg * live/inver: Instantiate a different streamer based on protocol * live/infer: Cleanup args/docs --- runner/app/live/infer.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/runner/app/live/infer.py b/runner/app/live/infer.py index 21b4303f..cb9428ba 100644 --- a/runner/app/live/infer.py +++ b/runner/app/live/infer.py @@ -14,10 +14,17 @@ from params_api import start_http_server from streamer.trickle import TrickleStreamer +from streamer.zeromq import ZeroMQStreamer -async def main(http_port: int, subscribe_url: str, publish_url: str, pipeline: str, params: dict): - handler = TrickleStreamer(subscribe_url, publish_url, pipeline, **(params or {})) +async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish_url: str, pipeline: str, params: dict): + if stream_protocol == "trickle": + handler = TrickleStreamer(subscribe_url, publish_url, pipeline, **(params or {})) + elif stream_protocol == "zeromq": + handler = ZeroMQStreamer(subscribe_url, publish_url, pipeline, **(params or {})) + else: + raise ValueError(f"Unsupported protocol: {stream_protocol}") + runner = None try: handler.start() @@ -55,18 +62,6 @@ def signal_handler(sig, _): parser.add_argument( "--http-port", type=int, default=8888, help="Port for the HTTP server" ) - parser.add_argument( - "--input-address", - type=str, - default="tcp://localhost:5555", - help="Address for the input socket", - ) - parser.add_argument( - "--output-address", - type=str, - default="tcp://localhost:5556", - help="Address for the output socket", - ) parser.add_argument( "--pipeline", type=str, default="streamdiffusion", help="Pipeline to use" ) @@ -74,10 +69,17 @@ def signal_handler(sig, _): "--initial-params", type=str, default="{}", help="Initial parameters for the pipeline" ) parser.add_argument( - "--subscribe-url", type=str, required=True, help="url to pull incoming streams" + "--stream-protocol", + type=str, + choices=["trickle", "zeromq"], + default="trickle", + help="Protocol to use for streaming frames in and out. One of: trickle, zeromq" + ) + parser.add_argument( + "--subscribe-url", type=str, required=True, help="URL to subscribe for the input frames (trickle). For zeromq this is the input socket address" ) parser.add_argument( - "--publish-url", type=str, required=True, help="url to push outgoing streams" + "--publish-url", type=str, required=True, help="URL to publish output frames (trickle). For zeromq this is the output socket address" ) parser.add_argument( "-v", "--verbose", @@ -101,7 +103,7 @@ def signal_handler(sig, _): try: asyncio.run( - main(args.http_port, args.subscribe_url, args.publish_url, args.pipeline, params) + main(args.http_port, args.stream_protocol, args.subscribe_url, args.publish_url, args.pipeline, params) ) except Exception as e: logging.error(f"Fatal error in main: {e}")