Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mfarre authored and drbh committed Nov 18, 2024
1 parent f7cf45d commit cee1dea
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 48 deletions.
47 changes: 31 additions & 16 deletions server/text_generation_server/models/custom_modeling/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@

from typing import Dict, Optional, Tuple, List

import os
import tempfile
import requests
import torch
import torch.utils.checkpoint
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
from contextlib import contextmanager
from qwen_vl_utils import process_vision_info


Expand Down Expand Up @@ -561,20 +565,31 @@ def forward(
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

class QwenVideoProcessor:
"""Utility class to handle video processing specifically for Qwen models"""

@staticmethod
def prepare_video_inputs(messages: List[Dict]) -> Tuple[Dict, Optional[torch.Tensor]]:
"""
Process messages containing video inputs for Qwen models
Returns a tuple of (processed_messages, video_pixels)
"""
# Use Qwen's built-in video processing
vision_info = process_vision_info(messages)
@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to temporary file and cleans it up after use."""
temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url)))

if vision_info is not None:
_, video_inputs = vision_info
return video_inputs[0] if video_inputs else None

return None
try:
with open(temp_path, 'wb') as tmp_file:
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
yield temp_path
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)

def process_qwen_video(chunk_video: str):
"""Process video for Qwen2VL model"""
vision_info = [{
"type": "video",
"video": chunk_video,
"max_pixels": 360 * 420,
"fps": 1.0
}]
return process_vision_info(vision_info)
35 changes: 3 additions & 32 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import os
import torch
import tempfile
import requests
from PIL import Image
from io import BytesIO

from contextlib import contextmanager

from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict
Expand Down Expand Up @@ -196,8 +192,6 @@ def batch_tokenized_inputs(
images.append([image])
elif chunk_type == "video":
if config.model_type == "qwen2_vl":
# For now, treat video URLs as special tokens
# This will be processed in the text replacement section below
pass
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
Expand All @@ -222,13 +216,11 @@ def batch_tokenized_inputs(
)
image_id += 1
elif chunk_type == "video" and config.model_type == "qwen2_vl":
# Download and process video in a temporary context
with cls.temp_video_download(chunk.video) as local_path:
# Now the video is available at local_path for processing
full_text += f"<video>{local_path}</video>"
from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
text, _ = process_qwen_video(chunk.video)
full_text += text

full_text = image_text_replacement_fixup(config, full_text)

batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)

Expand Down Expand Up @@ -277,27 +269,6 @@ def from_pb_processor(
batch.image_sizes = None
batch.image_grid_thw = None
return batch

@staticmethod
@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to a temporary file and cleans it up after use."""
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(url)[1], delete=False) as tmp_file:
try:
# Download video
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
tmp_file.flush()
yield tmp_file.name
finally:
# Clean up temp file
try:
os.unlink(tmp_file.name)
except OSError:
pass

class VlmCausalLM(FlashCausalLM):
def __init__(
Expand Down

0 comments on commit cee1dea

Please sign in to comment.