diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 1fdd75d2483..b73a78e040d 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -14,6 +14,8 @@ # limitations under the License. """PyTorch Qwen2 VL model.""" +__all__ = ['Qwen2VLForConditionalGeneration', 'process_qwen_video'] + from typing import Dict, Optional, Tuple, List import os @@ -72,6 +74,43 @@ def apply_rotary_pos_emb_vision( return output +@contextmanager +def temp_video_download(url: str) -> str: + """Downloads video to temporary file and cleans it up after use.""" + temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url))) + + try: + with open(temp_path, 'wb') as tmp_file: + with requests.get(url, stream=True) as r: + r.raise_for_status() + for chunk in r.iter_content(chunk_size=8192): + if chunk: + tmp_file.write(chunk) + yield temp_path + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + +def process_qwen_video(video_url: str): + """Process video for Qwen2VL model""" + with temp_video_download(video_url) as local_path: + messages = [ + { + "role": "user", + "content": [ + { + "type": "video", + "video": local_path, + "max_pixels": 360 * 420, + "fps": 1.0 + } + ] + } + ] + return process_vision_info(messages) + class Qwen2VLSdpaAttention(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() @@ -538,32 +577,3 @@ def forward( hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits - - @contextmanager - def temp_video_download(url: str) -> str: - """Downloads video to temporary file and cleans it up after use.""" - temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos") - os.makedirs(temp_dir, exist_ok=True) - temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url))) - - try: - with open(temp_path, 'wb') as tmp_file: - with requests.get(url, stream=True) as r: - r.raise_for_status() - for chunk in r.iter_content(chunk_size=8192): - if chunk: - tmp_file.write(chunk) - yield temp_path - finally: - if os.path.exists(temp_path): - os.unlink(temp_path) - - def process_qwen_video(chunk_video: str): - """Process video for Qwen2VL model""" - vision_info = [{ - "type": "video", - "video": chunk_video, - "max_pixels": 360 * 420, - "fps": 1.0 - }] - return process_vision_info(vision_info) \ No newline at end of file