Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mfarre committed Nov 14, 2024
1 parent 88cffa0 commit b780f00
Showing 1 changed file with 39 additions and 29 deletions.
68 changes: 39 additions & 29 deletions server/text_generation_server/models/custom_modeling/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
"""PyTorch Qwen2 VL model."""

__all__ = ['Qwen2VLForConditionalGeneration', 'process_qwen_video']

from typing import Dict, Optional, Tuple, List

import os
Expand Down Expand Up @@ -72,6 +74,43 @@ def apply_rotary_pos_emb_vision(
return output


@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to temporary file and cleans it up after use."""
temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url)))

try:
with open(temp_path, 'wb') as tmp_file:
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
yield temp_path
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)

def process_qwen_video(video_url: str):
"""Process video for Qwen2VL model"""
with temp_video_download(video_url) as local_path:
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"video": local_path,
"max_pixels": 360 * 420,
"fps": 1.0
}
]
}
]
return process_vision_info(messages)

class Qwen2VLSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
Expand Down Expand Up @@ -538,32 +577,3 @@ def forward(
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to temporary file and cleans it up after use."""
temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url)))

try:
with open(temp_path, 'wb') as tmp_file:
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
yield temp_path
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)

def process_qwen_video(chunk_video: str):
"""Process video for Qwen2VL model"""
vision_info = [{
"type": "video",
"video": chunk_video,
"max_pixels": 360 * 420,
"fps": 1.0
}]
return process_vision_info(vision_info)

0 comments on commit b780f00

Please sign in to comment.