Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

worker+runner: Dynamically start live inference runners #275

Merged
merged 17 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions runner/app/live/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from streamer.zeromq import ZeroMQStreamer


async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish_url: str, pipeline: str, params: dict):
async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish_url: str, pipeline: str, input_timeout: int, params: dict):
if stream_protocol == "trickle":
handler = TrickleStreamer(subscribe_url, publish_url, pipeline, **(params or {}))
handler = TrickleStreamer(subscribe_url, publish_url, pipeline, input_timeout, params or {})
elif stream_protocol == "zeromq":
handler = ZeroMQStreamer(subscribe_url, publish_url, pipeline, **(params or {}))
handler = ZeroMQStreamer(subscribe_url, publish_url, pipeline, input_timeout, params or {})
else:
raise ValueError(f"Unsupported protocol: {stream_protocol}")

Expand All @@ -34,14 +34,14 @@ async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish
logging.error(f"Stack trace:\n{traceback.format_exc()}")
raise e

await block_until_signal([signal.SIGINT, signal.SIGTERM])
try:
await asyncio.wait(
[block_until_signal([signal.SIGINT, signal.SIGTERM]), handler.wait()],
return_when=asyncio.FIRST_COMPLETED
)
finally:
await runner.cleanup()
await handler.stop()
except Exception as e:
logging.error(f"Error stopping room handler: {e}")
logging.error(f"Stack trace:\n{traceback.format_exc()}")
raise e


async def block_until_signal(sigs: List[signal.Signals]):
Expand Down Expand Up @@ -81,6 +81,12 @@ def signal_handler(sig, _):
parser.add_argument(
"--publish-url", type=str, required=True, help="URL to publish output frames (trickle). For zeromq this is the output socket address"
)
parser.add_argument(
"--input-timeout",
type=int,
default=60,
help="Timeout in seconds to wait after input frames stop before shutting down. Set to 0 to disable."
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
Expand All @@ -103,7 +109,7 @@ def signal_handler(sig, _):

try:
asyncio.run(
main(args.http_port, args.stream_protocol, args.subscribe_url, args.publish_url, args.pipeline, params)
main(args.http_port, args.stream_protocol, args.subscribe_url, args.publish_url, args.pipeline, args.input_timeout, params)
)
except Exception as e:
logging.error(f"Fatal error in main: {e}")
Expand Down
2 changes: 1 addition & 1 deletion runner/app/live/params_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def handle_params_update(request):
raise ValueError(f"Unknown content type: {request.content_type}")

handler = cast(PipelineStreamer, request.app["handler"])
handler.update_params(**params)
handler.update_params(params)

return web.Response(text="Params updated successfully")
except Exception as e:
Expand Down
19 changes: 17 additions & 2 deletions runner/app/live/streamer/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,28 @@


class PipelineStreamer(ABC):
def __init__(self, pipeline: str, **params):
def __init__(self, pipeline: str, input_timeout: int, params: dict):
self.pipeline = pipeline
self.params = params
self.process = None
self.last_params_time = 0.0
self.restart_count = 0
self.input_timeout = input_timeout # 0 means disabled
self.done_future = None

def start(self):
self.done_future = asyncio.get_running_loop().create_future()
self._start_process()

async def wait(self):
if not self.done_future:
raise RuntimeError("Streamer not started")
return await self.done_future

async def stop(self):
await self._stop_process()
if self.done_future and not self.done_future.done():
self.done_future.set_result(None)

def _start_process(self):
if self.process:
Expand Down Expand Up @@ -68,7 +78,7 @@ async def _restart(self):
logging.error(f"Stack trace:\n{traceback.format_exc()}")
os._exit(1)

def update_params(self, **params):
def update_params(self, params: dict):
self.params = params
self.last_params_time = time.time()
if self.process:
Expand All @@ -92,6 +102,11 @@ async def monitor_loop(self, done: Event):
time_since_last_params = current_time - self.last_params_time
time_since_reload = min(time_since_last_params, time_since_start)

if self.input_timeout > 0 and time_since_last_input >= self.input_timeout:
logging.info(f"Input stream stopped for {self.input_timeout} seconds. Shutting down...")
await self.stop()
return

gone_stale = (
time_since_last_output > time_since_last_input
and time_since_last_output > 60
Expand Down
5 changes: 3 additions & 2 deletions runner/app/live/streamer/trickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ def __init__(
subscribe_url: str,
publish_url: str,
pipeline: str,
**params,
input_timeout: int,
params: dict,
):
super().__init__(pipeline, **params)
super().__init__(pipeline, input_timeout, params)
self.subscribe_url = subscribe_url
self.publish_url = publish_url
self.subscribe_queue = queue.Queue[bytearray]()
Expand Down
5 changes: 3 additions & 2 deletions runner/app/live/streamer/zeromq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ def __init__(
input_address: str,
output_address: str,
pipeline: str,
**params,
input_timeout: int,
params: dict,
):
super().__init__(pipeline, **params)
super().__init__(pipeline, input_timeout, params)
self.input_address = input_address
self.output_address = output_address

Expand Down
3 changes: 2 additions & 1 deletion runner/app/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
class Pipeline(ABC):
@abstractmethod
def __init__(self, model_id: str, model_dir: str):
self.model_id: str # type hint so we can use the field in routes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this comment mean?

# type hint so we can use the field in routes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids a typing error on all the route/ implementations when they used the model_id field on the abstract Pipeline class (currently tons of typing errors on this proj, I'll try to clear them up as I make changes).

I'm changing this comment to this which might be clearer?

# declare the field here so the type hint is available when using this abstract class

raise NotImplementedError("Pipeline should implement an __init__ method")

@abstractmethod
def __call__(self, inputs: Any) -> Any:
def __call__(self, **kwargs) -> Any:
raise NotImplementedError("Pipeline should implement a __call__ method")
32 changes: 19 additions & 13 deletions runner/app/pipelines/live_video_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,23 @@ def __init__(self, model_id: str):
self.monitor_thread = None
self.log_thread = None

def __call__(
self, **kwargs

def __call__( # type: ignore
self, *, subscribe_url: str, publish_url: str, params: dict, **kwargs
):
if self.process:
raise RuntimeError("Pipeline already running")

try:
if not self.process:
self.start_process(
pipeline=self.model_id, # we use the model_id as the pipeline name for now
http_port=8888,
subscribe_url=kwargs["subscribe_url"],
publish_url=kwargs["publish_url"],
initial_params=json.dumps(kwargs["params"]),
# TODO: set torch device from self.torch_device
)
logger.info(f"Starting stream, subscribe={kwargs['subscribe_url']} publish={kwargs['publish_url']}")
return
logger.info(f"Starting stream, subscribe={subscribe_url} publish={publish_url}")
self.start_process(
pipeline=self.model_id, # we use the model_id as the pipeline name for now
http_port=8888,
subscribe_url=subscribe_url,
publish_url=publish_url,
initial_params=json.dumps(params),
# TODO: set torch device from self.torch_device
)
except Exception as e:
raise InferenceError(original_exception=e)

Expand Down Expand Up @@ -82,6 +84,10 @@ def monitor_process(self):
logger.error(
f"infer.py process failed with return code {return_code}. Error: {stderr}"
)
else:
# If process exited cleanly (return code 0) and exit the main process
logger.info("infer.py process exited cleanly, shutting down...")
sys.exit(0)
break

logger.info("infer.py process is running...")
Expand Down
17 changes: 7 additions & 10 deletions runner/app/routes/live_video_to_video.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
import random
from typing import Annotated, Dict, Tuple, Union
from typing import Annotated, Any, Dict, Tuple, Union

import torch
import traceback
Expand Down Expand Up @@ -48,18 +47,18 @@ class LiveVideoToVideoParams(BaseModel):
model_id: Annotated[
str,
Field(
default="", description="Hugging Face model ID used for image generation."
default="", description="Name of the pipeline to run in the live video to video job. Notice that this is named model_id for consistency with other routes, but it does not refer to a Hugging Face model ID. The exact model(s) depends on the pipeline implementation and might be configurable via the `params` argument."
),
]
params: Annotated[
Dict,
Field(
default={},
description="Initial parameters for the model."
description="Initial parameters for the pipeline."
),
]

RESPONSES = {
RESPONSES: dict[int | str, dict[str, Any]]= {
status.HTTP_200_OK: {
"content": {
"application/json": {
Expand All @@ -78,9 +77,9 @@ class LiveVideoToVideoParams(BaseModel):
"/live-video-to-video",
response_model=LiveVideoToVideoResponse,
responses=RESPONSES,
description="Apply video-like transformations to a provided image.",
description="Apply transformations to a live video streamed to the returned endpoints.",
operation_id="genLiveVideoToVideo",
summary="Video To Video",
summary="Live Video To Video",
tags=["generate"],
openapi_extra={"x-speakeasy-name-override": "liveVideoToVideo"},
)
Expand Down Expand Up @@ -113,10 +112,8 @@ async def live_video_to_video(
),
)

seed = random.randint(0, 2**32 - 1)
kwargs = {k: v for k, v in params.model_dump().items()}
try:
pipeline(**kwargs)
pipeline(**params.model_dump())
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ def json_str_to_np_array(

def handle_pipeline_exception(
e: object,
default_error_message: Union[str, Dict[str, object]] = "Pipeline error.",
default_error_message: Union[str, Dict[str, object], None] = "Pipeline error.",
default_status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
custom_error_config: Dict[str, Tuple[str, int]] = None,
custom_error_config: Dict[str, Tuple[str | None, int]] | None = None,
) -> JSONResponse:
"""Handles pipeline exceptions by returning a JSON response with the appropriate
error message and status code.
Expand Down
13 changes: 9 additions & 4 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,9 @@ paths:
post:
tags:
- generate
summary: Video To Video
description: Apply video-like transformations to a provided image.
summary: Live Video To Video
description: Apply transformations to a live video streamed to the returned
endpoints.
operationId: genLiveVideoToVideo
requestBody:
content:
Expand Down Expand Up @@ -927,12 +928,16 @@ components:
model_id:
type: string
title: Model Id
description: Hugging Face model ID used for image generation.
description: Name of the pipeline to run in the live video to video job.
Notice that this is named model_id for consistency with other routes,
but it does not refer to a Hugging Face model ID. The exact model depends
on the pipeline implementation and might be configurable via the `params`
argument.
default: ''
params:
type: object
title: Params
description: Initial parameters for the model.
description: Initial parameters for the pipeline.
default: {}
type: object
required:
Expand Down
13 changes: 9 additions & 4 deletions runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,9 @@ paths:
post:
tags:
- generate
summary: Video To Video
description: Apply video-like transformations to a provided image.
summary: Live Video To Video
description: Apply transformations to a live video streamed to the returned
endpoints.
operationId: genLiveVideoToVideo
requestBody:
content:
Expand Down Expand Up @@ -944,12 +945,16 @@ components:
model_id:
type: string
title: Model Id
description: Hugging Face model ID used for image generation.
description: Name of the pipeline to run in the live video to video job.
Notice that this is named model_id for consistency with other routes,
but it does not refer to a Hugging Face model ID. The exact model depends
on the pipeline implementation and might be configurable via the `params`
argument.
default: ''
params:
type: object
title: Params
description: Initial parameters for the model.
description: Initial parameters for the pipeline.
default: {}
type: object
required:
Expand Down
1 change: 1 addition & 0 deletions runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ sentencepiece== 0.2.0
protobuf==5.27.2
bitsandbytes==0.43.3
psutil==6.0.0
PyYAML==6.0.2
Loading
Loading