Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into rafal/comfyui
Browse files Browse the repository at this point in the history
  • Loading branch information
leszko committed Nov 12, 2024
2 parents e1ffcfa + f371196 commit 345cc0e
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 49 deletions.
36 changes: 19 additions & 17 deletions runner/app/live/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@

from params_api import start_http_server
from streamer.trickle import TrickleStreamer
from streamer.zeromq import ZeroMQStreamer


async def main(http_port: int, subscribe_url: str, publish_url: str, pipeline: str, params: dict):
handler = TrickleStreamer(subscribe_url, publish_url, pipeline, **(params or {}))
async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish_url: str, pipeline: str, params: dict):
if stream_protocol == "trickle":
handler = TrickleStreamer(subscribe_url, publish_url, pipeline, **(params or {}))
elif stream_protocol == "zeromq":
handler = ZeroMQStreamer(subscribe_url, publish_url, pipeline, **(params or {}))
else:
raise ValueError(f"Unsupported protocol: {stream_protocol}")

runner = None
try:
handler.start()
Expand Down Expand Up @@ -55,29 +62,24 @@ def signal_handler(sig, _):
parser.add_argument(
"--http-port", type=int, default=8888, help="Port for the HTTP server"
)
parser.add_argument(
"--input-address",
type=str,
default="tcp://localhost:5555",
help="Address for the input socket",
)
parser.add_argument(
"--output-address",
type=str,
default="tcp://localhost:5556",
help="Address for the output socket",
)
parser.add_argument(
"--pipeline", type=str, default="streamdiffusion", help="Pipeline to use"
)
parser.add_argument(
"--initial-params", type=str, default="{}", help="Initial parameters for the pipeline"
)
parser.add_argument(
"--subscribe-url", type=str, required=True, help="url to pull incoming streams"
"--stream-protocol",
type=str,
choices=["trickle", "zeromq"],
default="trickle",
help="Protocol to use for streaming frames in and out. One of: trickle, zeromq"
)
parser.add_argument(
"--subscribe-url", type=str, required=True, help="URL to subscribe for the input frames (trickle). For zeromq this is the input socket address"
)
parser.add_argument(
"--publish-url", type=str, required=True, help="url to push outgoing streams"
"--publish-url", type=str, required=True, help="URL to publish output frames (trickle). For zeromq this is the output socket address"
)
parser.add_argument(
"-v", "--verbose",
Expand All @@ -101,7 +103,7 @@ def signal_handler(sig, _):

try:
asyncio.run(
main(args.http_port, args.subscribe_url, args.publish_url, args.pipeline, params)
main(args.http_port, args.stream_protocol, args.subscribe_url, args.publish_url, args.pipeline, params)
)
except Exception as e:
logging.error(f"Fatal error in main: {e}")
Expand Down
8 changes: 6 additions & 2 deletions runner/app/live/pipelines/streamdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Config:
model_id: str = "KBlueLeaf/kohaku-v2.1"
lora_dict: Optional[Dict[str, float]] = None
use_lcm_lora: bool = True
lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
num_inference_steps: int = 50
t_index_list: Optional[List[int]] = None
t_index_ratio_list: Optional[List[float]] = [0.75, 0.9, 0.975]
Expand All @@ -25,6 +26,8 @@ class Config:
enable_similar_image_filter: bool = False
seed: int = 2
guidance_scale: float = 1.2
do_add_noise: bool = False
similar_image_filter_threshold: float = 0.98

def __init__(self, **data):
super().__init__(**data)
Expand Down Expand Up @@ -68,17 +71,18 @@ def update_params(self, **params):
model_id_or_path=new_params.model_id,
lora_dict=new_params.lora_dict,
use_lcm_lora=new_params.use_lcm_lora,
lcm_lora_id=new_params.lcm_lora_id,
t_index_list=new_params.t_index_list,
frame_buffer_size=1,
width=512,
height=512,
warmup=10,
acceleration=new_params.acceleration,
do_add_noise=False,
do_add_noise=new_params.do_add_noise,
mode="img2img",
# output_type="pt",
enable_similar_image_filter=new_params.enable_similar_image_filter,
similar_image_filter_threshold=0.98,
similar_image_filter_threshold=new_params.similar_image_filter_threshold,
use_denoising_batch=new_params.use_denoising_batch,
seed=new_params.seed,
)
Expand Down
1 change: 0 additions & 1 deletion runner/app/live/trickle/jpeg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __enter__(self):
return self

def __exit__(self, exec_type, exec_val, exec_tb):
logging.info("JOSH closing jpeg parser via exit")
self.close()

def close(self):
Expand Down
3 changes: 1 addition & 2 deletions runner/app/live/trickle/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

# target framerate
FRAMERATE=segmenter.FRAMERATE
GOP_SECS=segmenter.GOP_SECS

# TODO make this better configurable
GPU=segmenter.GPU
Expand All @@ -32,7 +31,7 @@ async def run_subscribe(subscribe_url: str, image_callback):

async def subscribe(subscribe_url, out_pipe):
subscriber = TrickleSubscriber(url=subscribe_url)
logging.info(f"JOSH - launching subscribe loop for {subscribe_url}")
logging.info(f"launching subscribe loop for {subscribe_url}")
while True:
segment = None
try:
Expand Down
10 changes: 5 additions & 5 deletions runner/app/live/trickle/segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def ffmpeg_cmd(out_pattern):
cmd = [
'ffmpeg',
'-loglevel', 'warning',
'-use_wallclock_as_timestamps', '1',
'-f', 'image2pipe',
'-framerate', f"{FRAMERATE}",
'-i', 'pipe:0', # stdin
'-c:v', 'h264_nvenc',
'-bf', '0', # disable bframes for webrtc
'-g', f'{GOP_SECS*FRAMERATE}',
'-force_key_frames', f'expr:gte(t,n_forced*{GOP_SECS})',
'-preset', 'p1',
'-tune', 'ull',
'-f', 'segment',
Expand All @@ -56,19 +56,19 @@ def ffmpeg_cmd(out_pattern):
cmd = [
'ffmpeg',
'-loglevel', 'warning',
'-use_wallclock_as_timestamps', '1',
'-f', 'image2pipe',
'-framerate', f"{FRAMERATE}",
'-i', 'pipe:0', # stdin
'-c:v', 'libx264',
'-bf', '0', # disable bframes for webrtc
'-g', f'{GOP_SECS*FRAMERATE}',
'-force_key_frames', f'expr:gte(t,n_forced*{GOP_SECS})',
'-preset', 'superfast',
'-tune', 'zerolatency',
'-f', 'segment',
out_pattern
]

logging.info(f"JOSH - ffmpeg (output) {cmd}")
logging.info(f"ffmpeg (output) {cmd}")
return cmd


Expand Down
2 changes: 1 addition & 1 deletion runner/app/live/trickle/trickle_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, url: str, mime_type: str):
self.idx = 0 # Start index for POSTs
self.next_writer = None
self.lock = asyncio.Lock() # Lock to manage concurrent access
self.session = aiohttp.ClientSession()
self.session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(verify_ssl=False))

async def __aenter__(self):
"""Enter context manager."""
Expand Down
41 changes: 22 additions & 19 deletions runner/app/live/trickle/trickle_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,9 @@ def __init__(self, url: str):
self.idx = -1 # Start with -1 for 'latest' index
self.pending_get = None # Pre-initialized GET request
self.lock = asyncio.Lock() # Lock to manage concurrent access
self.session = aiohttp.ClientSession()
self.session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(verify_ssl=False))
self.errored = False

async def get_index(self, resp):
"""Extract the index from the response headers."""
if resp is None:
return -1
idx_str = resp.headers.get("Lp-Trickle-Idx")
try:
idx = int(idx_str)
except (TypeError, ValueError):
return -1
return idx

async def preconnect(self):
"""Preconnect to the server by making a GET request to fetch the next segment."""
url = f"{self.base_url}/{self.idx}"
Expand Down Expand Up @@ -58,18 +47,23 @@ async def next(self):
self.pending_get = await self.preconnect()

# Extract the current connection to use for reading
conn = self.pending_get
resp = self.pending_get
self.pending_get = None

# Extract and set the next index from the response headers
idx = await self.get_index(conn)
if idx != -1:
segment = Segment(resp)

if segment.eos():
return None

idx = segment.seq()
if idx >= 0:
self.idx = idx + 1

# Set up the next connection in the background
asyncio.create_task(self._preconnect_next_segment())

return Segment(conn)
return segment

async def _preconnect_next_segment(self):
"""Preconnect to the next segment in the background."""
Expand All @@ -80,14 +74,23 @@ async def _preconnect_next_segment(self):
next_conn = await self.preconnect()
if next_conn:
self.pending_get = next_conn
next_idx = await self.get_index(next_conn)
if next_idx != -1:
self.idx = next_idx + 1

class Segment:
def __init__(self, response):
self.response = response

def seq(self):
"""Extract the sequence number from the response headers."""
seq_str = self.response.headers.get('Lp-Trickle-Seq')
try:
seq = int(seq_str)
except (TypeError, ValueError):
return -1
return seq

def eos(self):
return self.response.headers.get('Lp-Trickle-Closed') != None

async def read(self, chunk_size=32 * 1024):
"""Read the next chunk of the segment."""
if not self.response:
Expand Down
2 changes: 0 additions & 2 deletions runner/app/pipelines/live_video_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def __call__(
if not self.process:
self.start_process(
pipeline=self.model_id, # we use the model_id as the pipeline name for now
input_address="tcp://localhost:5555",
output_address="tcp://localhost:5556",
http_port=8888,
subscribe_url=kwargs["subscribe_url"],
publish_url=kwargs["publish_url"],
Expand Down

0 comments on commit 345cc0e

Please sign in to comment.