From 88cffa009f4da80b5dd46b1a075b25fbdc6fc771 Mon Sep 17 00:00:00 2001 From: Miquel Farre Date: Thu, 14 Nov 2024 12:30:59 +0000 Subject: [PATCH] refactoring --- .../models/custom_modeling/qwen2_vl.py | 47 ++++++++++++------- .../models/vlm_causal_lm.py | 35 ++------------ 2 files changed, 34 insertions(+), 48 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 75d4b5667ef..1fdd75d2483 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -16,10 +16,14 @@ from typing import Dict, Optional, Tuple, List +import os +import tempfile +import requests import torch import torch.utils.checkpoint from torch import nn from text_generation_server.utils.import_utils import SYSTEM +from contextlib import contextmanager from qwen_vl_utils import process_vision_info @@ -535,20 +539,31 @@ def forward( logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits -class QwenVideoProcessor: - """Utility class to handle video processing specifically for Qwen models""" - - @staticmethod - def prepare_video_inputs(messages: List[Dict]) -> Tuple[Dict, Optional[torch.Tensor]]: - """ - Process messages containing video inputs for Qwen models - Returns a tuple of (processed_messages, video_pixels) - """ - # Use Qwen's built-in video processing - vision_info = process_vision_info(messages) + @contextmanager + def temp_video_download(url: str) -> str: + """Downloads video to temporary file and cleans it up after use.""" + temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url))) - if vision_info is not None: - _, video_inputs = vision_info - return video_inputs[0] if video_inputs else None - - return None \ No newline at end of file + try: + with open(temp_path, 'wb') as tmp_file: + with requests.get(url, stream=True) as r: + r.raise_for_status() + for chunk in r.iter_content(chunk_size=8192): + if chunk: + tmp_file.write(chunk) + yield temp_path + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + def process_qwen_video(chunk_video: str): + """Process video for Qwen2VL model""" + vision_info = [{ + "type": "video", + "video": chunk_video, + "max_pixels": 360 * 420, + "fps": 1.0 + }] + return process_vision_info(vision_info) \ No newline at end of file diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 5f788085a95..c297966a9a8 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,11 +1,7 @@ -import os import torch -import tempfile -import requests from PIL import Image from io import BytesIO -from contextlib import contextmanager from opentelemetry import trace from typing import Iterable, Optional, Tuple, List, Type, Dict @@ -196,8 +192,6 @@ def batch_tokenized_inputs( images.append([image]) elif chunk_type == "video": if config.model_type == "qwen2_vl": - # For now, treat video URLs as special tokens - # This will be processed in the text replacement section below pass else: raise RuntimeError(f"Invalid chunk type {chunk_type}") @@ -222,13 +216,11 @@ def batch_tokenized_inputs( ) image_id += 1 elif chunk_type == "video" and config.model_type == "qwen2_vl": - # Download and process video in a temporary context - with cls.temp_video_download(chunk.video) as local_path: - # Now the video is available at local_path for processing - full_text += f"" + from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video + text, _ = process_qwen_video(chunk.video) + full_text += text full_text = image_text_replacement_fixup(config, full_text) - batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) @@ -277,27 +269,6 @@ def from_pb_processor( batch.image_sizes = None batch.image_grid_thw = None return batch - - @staticmethod - @contextmanager - def temp_video_download(url: str) -> str: - """Downloads video to a temporary file and cleans it up after use.""" - with tempfile.NamedTemporaryFile(suffix=os.path.splitext(url)[1], delete=False) as tmp_file: - try: - # Download video - with requests.get(url, stream=True) as r: - r.raise_for_status() - for chunk in r.iter_content(chunk_size=8192): - if chunk: - tmp_file.write(chunk) - tmp_file.flush() - yield tmp_file.name - finally: - # Clean up temp file - try: - os.unlink(tmp_file.name) - except OSError: - pass class VlmCausalLM(FlashCausalLM): def __init__(