Skip to content

Commit

Permalink
fix: adjust video process, reduce to 1 fps and adjust tensor shape
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Nov 25, 2024
1 parent 94da944 commit 32438fc
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 132 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
ffmpeg \
libavcodec-dev \
libavfilter-dev \
libavdevice-dev \
libavdevice-dev \
libavformat-dev \
libavutil-dev \
libswscale-dev \
Expand Down
19 changes: 12 additions & 7 deletions backends/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,19 @@ impl ChunksToString for Vec<InputChunk> {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
Some(Chunk::Video(video)) => {
let encoded = STANDARD.encode(&video.as_bytes());
output.push_str(&format!("<video>(data:{};base64,{})", video.mimetype, encoded))
Some(Chunk::Video(Video {
data,
mimetype,
width,
frames: _,
})) => {
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
let encoded = STANDARD.encode(data);
output.push_str(&format!(
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
width, mimetype, encoded, mimetype
));
}
// Some(Chunk::Video(Video { data, mimetype })) => {
// let encoded = STANDARD.encode(data);
// output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
// }
// We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"),
});
Expand Down
4 changes: 3 additions & 1 deletion backends/v3/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,10 @@ impl State {
mimetype: image.mimetype,
}),
Chunk::Video(video) => client::Chunk::Video(client::Video {
data: video.frames,
data: video.data,
mimetype: video.mimetype,
width: video.width,
frames: video.num_frames,
}),
}),
})
Expand Down
8 changes: 7 additions & 1 deletion proto/v3/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,17 @@ message Image {
}

message Video {
/// Binary video data.
/// Binary video data (array of RGB data)
bytes data = 1;

/// Video MIME type.
string mimetype = 2;

/// Video width
uint32 width = 3;

/// Total number of frames
uint32 frames = 4;
}

message InputChunk {
Expand Down
233 changes: 125 additions & 108 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tracing::{instrument, Span};
use {once_cell::sync::Lazy, regex::Regex};

/// Video constants
const TARGET_WIDTH: u32 = 360;
const TARGET_HEIGHT: u32 = 420;
const TARGET_FPS: u32 = 1; // Sample at 1fps
// video processing
use ffmpeg::media::Type;
use ffmpeg::software::scaling::{context::Context, flag::Flags};
use ffmpeg_next::format::Pixel;

/// Validation
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -527,14 +526,18 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string()
}

pub fn fetch_video(input: &str) -> Result<ProcessedVideo, ValidationError> {
pub fn fetch_video(
input: &str,
target_width: u32,
target_height: u32,
) -> Result<ProcessedVideo, ValidationError> {
let (data, mimetype) =
if input.starts_with("(http://") || input.starts_with("(https://") {
let url = &input["(".len()..input.len() - 1];
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
let url = &input["<video>(".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?.to_vec();
(data, "video/mp4".to_string())
} else if input.starts_with("(data:") {
let content = &input["(data:".len()..input.len() - 1];
} else if input.starts_with("<video>(data:") {
let content = &input["<video>(data:".len()..input.len() - 1];
let tokens: Vec<&str> = content.split(';').collect();
if tokens.len() != 2 {
return Err(ValidationError::InvalidVideoContent(content.to_string()));
Expand All @@ -550,101 +553,88 @@ pub fn fetch_video(input: &str) -> Result<ProcessedVideo, ValidationError> {
return Err(ValidationError::InvalidVideoContent(input.to_string()));
};

// Initialize ffmpeg
// init ffmpeg
ffmpeg::init().map_err(|_| ValidationError::FFmpegError)?;

// Create temporary file for ffmpeg input
// create temporary file for ffmpeg input
let mut temp_file = tempfile::NamedTempFile::new().map_err(ValidationError::IoError)?;
temp_file.write_all(&data).map_err(ValidationError::IoError)?;

let input_ctx = ffmpeg::format::input(&temp_file.path())
.map_err(|_| ValidationError::FFmpegError)?;
temp_file
.write_all(&data)
.map_err(ValidationError::IoError)?;

let video_stream = input_ctx
.streams()
.best(ffmpeg::media::Type::Video)
.ok_or(ValidationError::NoVideoStream)?;
let mut ictx =
ffmpeg::format::input(&temp_file.path()).map_err(|_| ValidationError::FFmpegError)?;

// Get video metadata
let time_base = video_stream.time_base();
let fps = 1.0_f32 / f64::from(time_base.numerator()) as f32 * f64::from(time_base.denominator()) as f32;
let total_frames = video_stream.frames() as usize;
let input = ictx
.streams()
.best(Type::Video)
.ok_or(ValidationError::FFmpegError)?;
let video_stream_index = input.index();

// Setup decoder
let context_decoder = ffmpeg::codec::context::Context::from_parameters(video_stream.parameters())
let context_decoder = ffmpeg::codec::context::Context::from_parameters(input.parameters())
.map_err(|_| ValidationError::FFmpegError)?;
let mut decoder = context_decoder
.decoder()
.video()
.map_err(|_| ValidationError::FFmpegError)?;

// Calculate dimensions
let (width, height) = if decoder.width() > TARGET_WIDTH || decoder.height() > TARGET_HEIGHT {
let ratio = (TARGET_WIDTH as f32 / decoder.width() as f32)
.min(TARGET_HEIGHT as f32 / decoder.height() as f32);
let width = (decoder.width() as f32 * ratio) as u32;
let height = (decoder.height() as f32 * ratio) as u32;
(width as usize, height as usize)
} else {
(decoder.width() as usize, decoder.height() as usize)
};
let width = target_width;
let height = target_height;

// Setup scaler if needed
let mut scaler = if width != decoder.width() as usize || height != decoder.height() as usize {
Some(ffmpeg::software::scaling::Context::get(
decoder.format(),
decoder.width(),
decoder.height(),
ffmpeg::format::Pixel::RGB24,
width as u32,
height as u32,
ffmpeg::software::scaling::Flags::BILINEAR,
).map_err(|_| ValidationError::FFmpegError)?)
} else {
None
};
let mut scaler = Context::get(
decoder.format(),
decoder.width(), // original width
decoder.height(),
Pixel::RGB24,
width, // target width
height,
Flags::BILINEAR,
)
.map_err(|_| ValidationError::FFmpegError)?;

let mut frame_index = 0;
let mut captured_frame_index = 0;
let mut frames = vec![];

let mut receive_and_process_decoded_frames =
|decoder: &mut ffmpeg::decoder::Video, fps: f32| -> Result<(), ffmpeg::Error> {
let mut decoded = ffmpeg::util::frame::video::Video::empty();
while decoder.receive_frame(&mut decoded).is_ok() {
let mut rgb_frame = ffmpeg::util::frame::video::Video::empty();
scaler.run(&decoded, &mut rgb_frame)?;
if frame_index as f32 % fps == 0.0 {
captured_frame_index += 1;
frames.push(rgb_frame.data(0).to_vec());
}
frame_index += 1;
}
Ok(())
};

// Sample frames at 1fps
let mut frames = Vec::new();
let frame_interval = (fps / TARGET_FPS as f32).round() as usize;
let mut frame_count = 0;
let mut receive_frame = ffmpeg::frame::Video::empty();

while let Ok(..) = decoder.receive_frame(&mut receive_frame) {
if frame_count % frame_interval == 0 {
let mut rgb_frame = if let Some(scaler) = scaler.as_mut() {
let mut scaled = ffmpeg::frame::Video::empty();
scaler.run(&receive_frame, &mut scaled)
.map_err(|_| ValidationError::FFmpegError)?;
scaled
} else {
receive_frame.clone()
};
let mut fps = 0.0;
let mut total_frames = 0;

// Convert to RGB if needed
if rgb_frame.format() != ffmpeg::format::Pixel::RGB24 {
let mut rgb = ffmpeg::frame::Video::empty();
let mut rgb_scaler = ffmpeg::software::scaling::Context::get(
rgb_frame.format(),
width as u32,
height as u32,
ffmpeg::format::Pixel::RGB24,
width as u32,
height as u32,
ffmpeg::software::scaling::Flags::BILINEAR,
).map_err(|_| ValidationError::FFmpegError)?;
rgb_scaler.run(&rgb_frame, &mut rgb)
.map_err(|_| ValidationError::FFmpegError)?;
rgb_frame = rgb;
}
for (stream, packet) in ictx.packets() {
total_frames = stream.frames();
fps = stream.rate().numerator() as f32 / stream.rate().denominator() as f32;

// Extract RGB data
let frame_data: Vec<u8> = rgb_frame.data(0).to_vec();
frames.push(frame_data);
if stream.index() == video_stream_index {
decoder
.send_packet(&packet)
.map_err(|_| ValidationError::FFmpegError)?;
receive_and_process_decoded_frames(&mut decoder, fps)
.map_err(|_| ValidationError::FFmpegError)?;
}
frame_count += 1;
}
decoder
.send_eof()
.map_err(|_| ValidationError::FFmpegError)?;

let total_frames = total_frames.try_into().map_err(|_| {
ValidationError::InvalidVideoContent(
"Total frames is too large to be represented as usize".to_string(),
)
})?;
Ok(ProcessedVideo {
mimetype,
height,
Expand Down Expand Up @@ -743,7 +733,7 @@ fn image_tokens(
}
}

fn video_tokens(config: &Config, height: usize, width: usize, total_frames: f32) -> String {
fn video_tokens(config: &Config, height: u32, width: u32, total_frames: f32) -> String {
use Config::*;

match config {
Expand All @@ -752,11 +742,9 @@ fn video_tokens(config: &Config, height: usize, width: usize, total_frames: f32)
let min_frames = 2_f32;
let max_frames = 256_f32;
// make sure the frames are within the range and are even
let nframes = (total_frames)
.max(min_frames)
.min(max_frames);
let nframes = (total_frames).max(min_frames).min(max_frames);
let nframes = (nframes / 2.0).round() as usize * 2;
let num_tokens = nframes * height * width / 1541;
let num_tokens = nframes * height as usize * width as usize / 1541;
format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|video_pad|>".repeat(num_tokens)
Expand Down Expand Up @@ -807,10 +795,35 @@ fn prepare_input<T: TokenizerTrait>(
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let ProcessedVideo { mimetype, height, width, frames, fps: _, total_frames: _ } =
fetch_video(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Video(Video { frames: frames.clone(), mimetype }));
let video_tokens = video_tokens(config, height, width, frames.len() as f32);
let ProcessedVideo {
mimetype,
height,
width,
frames,
fps,
total_frames,
} = match config {
Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) => {
let default_target_width = 224;
let default_target_height = 224;
fetch_video(
&inputs[chunk_start..chunk_end],
default_target_width,
default_target_height,
)?
}
Qwen2Vl(_) => fetch_video(&inputs[chunk_start..chunk_end], 360, 420)?,
_ => {
unreachable!("Video tokens are not supported for this model configuration")
}
};
input_chunks.push(Chunk::Video(Video {
data: frames.iter().flatten().cloned().collect(),
mimetype,
width,
num_frames: (total_frames as f32 / fps) as u32,
}));
let video_tokens = video_tokens(config, height, width, total_frames as f32);
tokenizer_query.push_str(&video_tokens);
start = chunk_end;
}
Expand Down Expand Up @@ -865,23 +878,19 @@ pub struct Image {

pub struct ProcessedVideo {
mimetype: String,
height: usize,
width: usize,
height: u32,
width: u32,
frames: Vec<Vec<u8>>, // RGB frames
fps: f32,
total_frames: usize,
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Video {
pub frames: Vec<Vec<u8>>,
pub data: Vec<u8>,
pub mimetype: String,
}

impl Video {
pub fn as_bytes(&self) -> Vec<u8> {
self.frames.iter().flatten().cloned().collect()
}
pub width: u32,
pub num_frames: u32,
}

#[derive(Debug, Clone, Eq, PartialEq)]
Expand All @@ -907,10 +916,18 @@ impl ChunksToString for Vec<Chunk> {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
Chunk::Video(Video { frames, mimetype }) => {
// For backward compatibility, we'll just use the first frame's data
let encoded = STANDARD.encode(frames.first().unwrap_or(&Vec::new()));
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
Chunk::Video(Video {
data,
mimetype,
width,
num_frames: _,
}) => {
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
let encoded = STANDARD.encode(data);
output.push_str(&format!(
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
width, mimetype, encoded, mimetype
));
}
});
output
Expand Down
Loading

0 comments on commit 32438fc

Please sign in to comment.