Skip to content

Commit

Permalink
feat:initial implementation to add support for LTX-Video model
Browse files Browse the repository at this point in the history
  • Loading branch information
RUFFY-369 committed Jan 8, 2025
1 parent 5237d5c commit 4d2275a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
35 changes: 33 additions & 2 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import PIL
import torch
from diffusers import StableVideoDiffusionPipeline
from diffusers import LTXImageToVideoPipeline, StableVideoDiffusionPipeline
from huggingface_hub import file_download
from PIL import ImageFile

Expand All @@ -22,6 +22,8 @@

class ImageToVideoPipeline(Pipeline):
def __init__(self, model_id: str):
self.pipeline_name = ""

self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

Expand All @@ -41,7 +43,28 @@ def __init__(self, model_id: str):
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)
try:
if any(substring in model_id.lower() for substring in ("ltx-video", "ltx")):
logger.info("Loading LTXImageToVideoPipeline for model_id: %s", model_id)
self.pipeline_name = "LTXImageToVideoPipeline"
self.ldm = LTXImageToVideoPipeline.from_pretrained(model_id, **kwargs)
else:
logger.info("Loading StableVideoDiffusionPipeline for model_id: %s", model_id)
self.pipeline_name = "StableVideoDiffusionPipeline"
self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)
except Exception as loading_error:
logger.error("Failed to load %s : %s." %(self.pipeline_name,loading_error))
# Trying to load the LTXImageToVideoPipeline if the StableVideoDiffusionPipeline fails to load and there is a chance that model name doesn't match the if condition for LTX-Video
# (for future extra models support)
try:
logger.info("Trying LTXImageToVideoPipeline for model_id: %s", model_id)
self.pipeline_name = "LTXImageToVideoPipeline"
self.ldm = LTXImageToVideoPipeline.from_pretrained(model_id, **kwargs)
except Exception as loading_error:
logger.error("Failed to load both LTXImageToVideoPipeline and StableVideoDiffusionPipeline: %s. Please ensure the model ID is compatible.", loading_error)
raise loading_error


self.ldm.to(get_torch_device())

sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
Expand Down Expand Up @@ -113,6 +136,14 @@ def __call__(
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)

if self.pipeline_name == "LTXImageToVideoPipeline":
del kwargs["fps"]
del kwargs["motion_bucket_id"]
del kwargs["noise_aug_strength"]
elif self.pipeline_name == "StableVideoDiffusionPipeline":
del kwargs["prompt"]
del kwargs["negative_prompt"]

if "decode_chunk_size" not in kwargs:
# Decrease decode_chunk_size to reduce memory usage.
kwargs["decode_chunk_size"] = 4
Expand Down
19 changes: 19 additions & 0 deletions runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ async def image_to_video(
UploadFile,
File(description="Uploaded image to generate a video from."),
],
prompt: Annotated[
str,
Form(description="Text prompt(s) to guide video generation for prompt accepting models.")
] = "",
negative_prompt: Annotated[
str,
Form(
description=(
"Text prompt(s) to guide what to exclude from video generation for prompt accepting models. "
"Ignored if guidance_scale < 1."
)
),
] = "",
model_id: Annotated[
str, Form(description="Hugging Face model ID used for video generation.")
] = "",
Expand Down Expand Up @@ -123,6 +136,9 @@ async def image_to_video(
)
),
] = 25, # NOTE: Hardcoded due to varying pipeline values.
num_frames: Annotated[
int, Form(description="The number of video frames to generate.")
] = 25, # NOTE: Added `25` as default value to consider for `stable-video-diffusion-img2vid-xt` model having smaller default value than LTX-V in its pipeline.
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
Expand Down Expand Up @@ -159,6 +175,9 @@ async def image_to_video(
try:
batch_frames, has_nsfw_concept = pipeline(
image=Image.open(image.file).convert("RGB"),
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=num_frames,
height=height,
width=width,
fps=fps,
Expand Down
1 change: 1 addition & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function download_all_models() {

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models
huggingface-cli download Lightricks/LTX-Video --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models

# Download image-to-text models.
huggingface-cli download Salesforce/blip-image-captioning-large --include "*.safetensors" "*.json" --cache-dir models
Expand Down

0 comments on commit 4d2275a

Please sign in to comment.