Skip to content

Commit

Permalink
live/infer: Allow running with zeromq protocol (#267)
Browse files Browse the repository at this point in the history
* live/infer: Create protocol arg

* live/inver: Instantiate a different streamer based on protocol

* live/infer: Cleanup args/docs
  • Loading branch information
victorges authored Nov 8, 2024
1 parent 232cf6e commit 5031a50
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions runner/app/live/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@

from params_api import start_http_server
from streamer.trickle import TrickleStreamer
from streamer.zeromq import ZeroMQStreamer


async def main(http_port: int, subscribe_url: str, publish_url: str, pipeline: str, params: dict):
handler = TrickleStreamer(subscribe_url, publish_url, pipeline, **(params or {}))
async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish_url: str, pipeline: str, params: dict):
if stream_protocol == "trickle":
handler = TrickleStreamer(subscribe_url, publish_url, pipeline, **(params or {}))
elif stream_protocol == "zeromq":
handler = ZeroMQStreamer(subscribe_url, publish_url, pipeline, **(params or {}))
else:
raise ValueError(f"Unsupported protocol: {stream_protocol}")

runner = None
try:
handler.start()
Expand Down Expand Up @@ -55,29 +62,24 @@ def signal_handler(sig, _):
parser.add_argument(
"--http-port", type=int, default=8888, help="Port for the HTTP server"
)
parser.add_argument(
"--input-address",
type=str,
default="tcp://localhost:5555",
help="Address for the input socket",
)
parser.add_argument(
"--output-address",
type=str,
default="tcp://localhost:5556",
help="Address for the output socket",
)
parser.add_argument(
"--pipeline", type=str, default="streamdiffusion", help="Pipeline to use"
)
parser.add_argument(
"--initial-params", type=str, default="{}", help="Initial parameters for the pipeline"
)
parser.add_argument(
"--subscribe-url", type=str, required=True, help="url to pull incoming streams"
"--stream-protocol",
type=str,
choices=["trickle", "zeromq"],
default="trickle",
help="Protocol to use for streaming frames in and out. One of: trickle, zeromq"
)
parser.add_argument(
"--subscribe-url", type=str, required=True, help="URL to subscribe for the input frames (trickle). For zeromq this is the input socket address"
)
parser.add_argument(
"--publish-url", type=str, required=True, help="url to push outgoing streams"
"--publish-url", type=str, required=True, help="URL to publish output frames (trickle). For zeromq this is the output socket address"
)
parser.add_argument(
"-v", "--verbose",
Expand All @@ -101,7 +103,7 @@ def signal_handler(sig, _):

try:
asyncio.run(
main(args.http_port, args.subscribe_url, args.publish_url, args.pipeline, params)
main(args.http_port, args.stream_protocol, args.subscribe_url, args.publish_url, args.pipeline, params)
)
except Exception as e:
logging.error(f"Fatal error in main: {e}")
Expand Down

0 comments on commit 5031a50

Please sign in to comment.