diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index 34a7f1fb12..b8fbe397d2 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -698,19 +698,30 @@ jobs: source ./ov/setupvars.sh cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ cmake --build ./build/ --config Release --target visual_language_chat py_generate_pipeline -j - - name: Download and convert a model and an image + - name: Download and convert MiniCPM-V-2_6 model and an image run: | source ./ov/setupvars.sh python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly python ./samples/cpp/visual_language_chat/export_MiniCPM-V-2_6.py ./miniCPM-V-2_6/ wget https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11 --output-document cat.jpg - - - name: Run chat chat sample + - name: Run visual_language_chat sample - MiniCPM-V-2_6 run: > source ./ov/setupvars.sh && timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./miniCPM-V-2_6/ cat.jpg <<< $'What is on the image?\nWhat is special on the image?' + - name: Download and convert LLaVa 1.5 model and an image + run: | + source ./ov/setupvars.sh + python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + optimum-cli export openvino --model llava-hf/llava-1.5-7b-hf ./llava_1_5_7b_ov/ + wget https://llava-vl.github.io/static/images/monalisa.jpg + - name: Run visual_language_chat sample - LLaVa 1.5 + run: > + source ./ov/setupvars.sh + && timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./llava_1_5_7b_ov/ monalisa.jpg + <<< $'Who drew this painting?\nWhen did the painter live?' - name: Run python chat sample run: | diff --git a/README.md b/README.md index 9a4d73802b..163768b18e 100644 --- a/README.md +++ b/README.md @@ -40,10 +40,9 @@ Continuous batching functionality is used within OpenVINO Model Server (OVMS) to # Install optimum-intel to be able to download, convert and optimize LLMs from Hugging Face # Optimum is not required to run models, only to convert and compress - pip install optimum[openvino] + pip install optimum-intel@git+https://github.com/huggingface/optimum-intel.git # (Optional) Install (TBD) to be able to download models from Model Scope - #pip install optimum[openvino] ``` ## Performing text generation diff --git a/samples/requirements.txt b/samples/requirements.txt index 4821d6dbef..df71d0cbb1 100644 --- a/samples/requirements.txt +++ b/samples/requirements.txt @@ -1,5 +1,6 @@ --extra-index-url https://download.pytorch.org/whl/cpu -optimum[openvino]==1.22.0 +optimum-intel @ git+https://github.com/huggingface/optimum-intel.git +numpy<2.0.0; sys_platform == 'darwin' einops==0.8.0 # For Qwen transformers_stream_generator==0.0.5 # For Qwen diffusers==0.30.3 diff --git a/src/cpp/include/openvino/genai/processor_config.hpp b/src/cpp/include/openvino/genai/processor_config.hpp index bef6754e14..f4fc5d33ec 100644 --- a/src/cpp/include/openvino/genai/processor_config.hpp +++ b/src/cpp/include/openvino/genai/processor_config.hpp @@ -34,6 +34,14 @@ class OPENVINO_GENAI_EXPORTS ProcessorConfig { /// Applied after norm_mean. /// llava calls it image_std. std::array norm_std{1.0f, 1.0f, 1.0f}; + + // llava specific config params + std::array image_mean{0.0f, 0.0f, 0.0f}; + std::array image_std{1.0f, 1.0f, 1.0f}; + size_t crop_size_height = 336; + size_t crop_size_width = 336; + size_t size_shortest_edge = 336; + /// @brief Default constructor ProcessorConfig() = default; /// @brief Construct ProcessorConfig from values in json_path. diff --git a/src/cpp/include/openvino/genai/vision_encoder.hpp b/src/cpp/include/openvino/genai/vision_encoder.hpp index 474216736c..902557d316 100644 --- a/src/cpp/include/openvino/genai/vision_encoder.hpp +++ b/src/cpp/include/openvino/genai/vision_encoder.hpp @@ -5,6 +5,7 @@ #include "openvino/genai/processor_config.hpp" #include +#include "vlm_model_type.hpp" namespace ov::genai { /// @brief A pair describing image size. @@ -41,8 +42,10 @@ struct EncodedImage { /// ov::InferRequest and configured by ProcessorConfig. class OPENVINO_GENAI_EXPORTS VisionEncoder { public: + /// @brief A enum denoting model type. + VLMModelType model_type; /// @brief A model for image encoding. - ov::InferRequest m_encoder; + ov::InferRequest m_vision_encoder; /// @brief A config to follow. ProcessorConfig m_processor_config; @@ -52,7 +55,7 @@ class OPENVINO_GENAI_EXPORTS VisionEncoder { explicit VisionEncoder( const ov::InferRequest& encoder, const ProcessorConfig& processor_config=ProcessorConfig{} - ) : m_encoder{encoder}, m_processor_config{processor_config} {} + ) : m_vision_encoder{encoder}, m_processor_config{processor_config} {} /// @brief Construct the encoder from model_dir. /// @param model_dir A folder containing openvino_embedding.xml and @@ -63,6 +66,7 @@ class OPENVINO_GENAI_EXPORTS VisionEncoder { /// @param core ov::Core to be used to compile the model. explicit VisionEncoder( const std::filesystem::path& model_dir, + const VLMModelType model_type, const std::string& device="CPU", const ov::AnyMap device_config={}, ov::Core core=ov::Core{} @@ -117,5 +121,14 @@ class OPENVINO_GENAI_EXPORTS VisionEncoder { image, AnyMap{std::forward(properties)...} ); } + +private: + EncodedImage encode_minicpm( + const ov::Tensor& image, const ProcessorConfig& config + ); + + EncodedImage encode_llava( + const ov::Tensor& image, const ProcessorConfig& config + ); }; } diff --git a/src/cpp/include/openvino/genai/vlm_config.hpp b/src/cpp/include/openvino/genai/vlm_config.hpp index dd22e422bf..46983c080a 100644 --- a/src/cpp/include/openvino/genai/vlm_config.hpp +++ b/src/cpp/include/openvino/genai/vlm_config.hpp @@ -6,12 +6,15 @@ #include "openvino/genai/visibility.hpp" #include #include +#include "vlm_model_type.hpp" namespace ov::genai { /// @brief A Configuration class passed to VLMPipeline and used to /// change VLMPipeline's behavior. Corresponds to config.json. class OPENVINO_GENAI_EXPORTS VLMConfig { public: + /// @brief A enum denoting model type. + VLMModelType model_type; /// @brief A size of a single embedding returned by a resampler. /// Used to initialize positional embeddings for resampler input. size_t hidden_size = 2304; diff --git a/src/cpp/include/openvino/genai/vlm_model_type.hpp b/src/cpp/include/openvino/genai/vlm_model_type.hpp new file mode 100644 index 0000000000..0f811a116a --- /dev/null +++ b/src/cpp/include/openvino/genai/vlm_model_type.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "openvino/genai/visibility.hpp" +#include + +namespace ov::genai { + +enum class OPENVINO_GENAI_EXPORTS VLMModelType { + MINICPM, + LLAVA, +}; + +inline VLMModelType to_vlm_model_type(const std::string& value) { + static const std::unordered_map model_types_map = { + {"minicpmv", VLMModelType::MINICPM}, + {"llava", VLMModelType::LLAVA} + }; + + auto it = model_types_map.find(value); + if (it != model_types_map.end()) { + return it->second; + } + OPENVINO_THROW("Unsupported '", value, "' VLM model type"); +} +} \ No newline at end of file diff --git a/src/cpp/include/openvino/genai/vlm_pipeline.hpp b/src/cpp/include/openvino/genai/vlm_pipeline.hpp index 38595f1b96..0eb0b5a646 100644 --- a/src/cpp/include/openvino/genai/vlm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/vlm_pipeline.hpp @@ -139,6 +139,9 @@ class OPENVINO_GENAI_EXPORTS VLMPipeline { private: class VLMPipelineImpl; std::unique_ptr m_pimpl; + + ov::Tensor get_inputs_embeds_minicpm(const std::string& prompt, const std::vector& images); + ov::Tensor get_inputs_embeds_llava(const std::string& prompt, const std::vector& images); }; /* diff --git a/src/cpp/src/processor_config.cpp b/src/cpp/src/processor_config.cpp index 33673f7e79..cea7f98fd4 100644 --- a/src/cpp/src/processor_config.cpp +++ b/src/cpp/src/processor_config.cpp @@ -10,7 +10,7 @@ ov::genai::ProcessorConfig::ProcessorConfig(const std::filesystem::path& json_pa OPENVINO_ASSERT(stream.is_open(), "Failed to open '" + json_path.string() + "' with processor config"); nlohmann::json parsed = nlohmann::json::parse(stream); using ov::genai::utils::read_json_param; - read_json_param(parsed, "patch_size", patch_size); + read_json_param(parsed, "patch_size", patch_size); // For llava - stored in config.json vision_config read_json_param(parsed, "scale_resolution", scale_resolution); read_json_param(parsed, "max_slice_nums", max_slice_nums); if (parsed.contains("norm_mean")) { @@ -19,4 +19,20 @@ ov::genai::ProcessorConfig::ProcessorConfig(const std::filesystem::path& json_pa if (parsed.contains("norm_std")) { norm_std = parsed.at("norm_std").get>(); } + + // Setting llava config params + if (parsed.contains("image_mean")) { + image_mean = parsed.at("image_mean").get>(); + } + if (parsed.contains("image_std")) { + image_std = parsed.at("image_std").get>(); + } + + if (parsed.contains("crop_size")) { + crop_size_height = parsed.at("crop_size").at("height"); + crop_size_width = parsed.at("crop_size").at("width"); + } + if (parsed.contains("size")) { + size_shortest_edge = parsed.at("size").at("shortest_edge"); + } } diff --git a/src/cpp/src/vision_encoder.cpp b/src/cpp/src/vision_encoder.cpp index 8e8612697c..6c926e0ed8 100644 --- a/src/cpp/src/vision_encoder.cpp +++ b/src/cpp/src/vision_encoder.cpp @@ -362,29 +362,117 @@ ProcessorConfig from_any_map( read_anymap_param(config_map, "norm_std", extracted_config.norm_std); return extracted_config; } + + +ov::Tensor preprocess_image_llava(const ov::Tensor& image, const ProcessorConfig& config) { + bool do_resize = true; + bool do_center_crop = true; + + // ov::Tensor to clip_image_u8 + clip_image_u8 input_image{ + int(image.get_shape().at(3)), + int(image.get_shape().at(2)), + {image.data(), image.data() + image.get_size()} + }; + + // Resize + clip_image_u8 resized_image; + if (do_resize) { + int target_size = config.size_shortest_edge; + float scale = static_cast(target_size) / std::min(input_image.nx, input_image.ny); + int new_width = static_cast(input_image.nx * scale); + int new_height = static_cast(input_image.ny * scale); + bicubic_resize(input_image, resized_image, new_width, new_height); + } else { + resized_image = input_image; + } + + // Center crop + clip_image_u8 cropped_image; + if (do_center_crop) { + int crop_height = config.crop_size_height; + int crop_width = config.crop_size_width; + int start_x = (resized_image.nx - crop_width) / 2; + int start_y = (resized_image.ny - crop_height) / 2; + + cropped_image.nx = crop_width; + cropped_image.ny = crop_height; + cropped_image.buf.resize(3 * crop_width * crop_height); + + for (int y = 0; y < crop_height; ++y) { + for (int x = 0; x < crop_width; ++x) { + for (int c = 0; c < 3; ++c) { + cropped_image.buf[(y * crop_width + x) * 3 + c] = + resized_image.buf[((start_y + y) * resized_image.nx + (start_x + x)) * 3 + c]; + } + } + } + } else { + cropped_image = resized_image; + } + + // Normalize + clip_ctx ctx; + std::copy(config.image_mean.begin(), config.image_mean.end(), ctx.image_mean); + std::copy(config.image_std.begin(), config.image_std.end(), ctx.image_std); + + clip_image_f32 normalized_image = clip_image_preprocess(ctx, cropped_image); + + // Convert clip_image_f32 to ov::Tensor + ov::Tensor result( + ov::element::f32, + {1, 3, size_t(normalized_image.ny), size_t(normalized_image.nx)}, + (void*)(normalized_image.buf.data()) + ); + + return result; +} } -VisionEncoder::VisionEncoder(const std::filesystem::path& model_dir, const std::string& device, const ov::AnyMap device_config, ov::Core core) : - VisionEncoder{ - core.compile_model( - model_dir / "image_encoder.xml", device, device_config - ).create_infer_request(), - ov::genai::utils::from_config_json_if_exists( +VisionEncoder::VisionEncoder(const std::filesystem::path& model_dir, const VLMModelType model_type, const std::string& device, const ov::AnyMap device_config, ov::Core core) : + model_type(model_type) { + if (model_type == VLMModelType::MINICPM) { + m_vision_encoder = core.compile_model(model_dir / "image_encoder.xml", device, device_config).create_infer_request(); + } else if (model_type == VLMModelType::LLAVA) { + // Vision embeddings model is merged with multi modal projector at model export stage by optimum-intel + m_vision_encoder = core.compile_model(model_dir / "openvino_vision_embeddings_model.xml", device, device_config).create_infer_request(); + } + m_processor_config = ov::genai::utils::from_config_json_if_exists( model_dir, "preprocessor_config.json" - ) - } {} + ); +} EncodedImage VisionEncoder::encode(const ov::Tensor& image, const ProcessorConfig& config) { + if (model_type == VLMModelType::MINICPM) { + return encode_minicpm(image, config); + } else if (model_type == VLMModelType::LLAVA) { + return encode_llava(image, config); + } +} + +EncodedImage VisionEncoder::encode(const ov::Tensor& image, const ov::AnyMap& config_map) { + return encode(image, from_any_map( + config_map, m_processor_config + )); +} + +EncodedImage VisionEncoder::encode_minicpm(const ov::Tensor& image, const ProcessorConfig& config) { clip_ctx ctx_clip; ctx_clip.patch_size = m_processor_config.patch_size; ctx_clip.image_size = m_processor_config.image_size; std::copy(config.norm_mean.begin(), config.norm_mean.end(), ctx_clip.image_mean); std::copy(config.norm_std.begin(), config.norm_std.end(), ctx_clip.image_std); - return llava_image_embed_make_with_bytes_slice(ctx_clip, image, m_encoder, config.max_slice_nums, config.scale_resolution, config.patch_size, 0 == config.max_slice_nums); + return llava_image_embed_make_with_bytes_slice(ctx_clip, image, m_vision_encoder, config.max_slice_nums, config.scale_resolution, config.patch_size, 0 == config.max_slice_nums); } -EncodedImage VisionEncoder::encode(const ov::Tensor& image, const ov::AnyMap& config_map) { - return encode(image, from_any_map( - config_map, m_processor_config - )); +EncodedImage VisionEncoder::encode_llava(const ov::Tensor& image, const ProcessorConfig& config) { + ov::Tensor preprocessed_image = preprocess_image_llava(image, config); + + m_vision_encoder.set_tensor("pixel_values", preprocessed_image); + m_vision_encoder.infer(); + + ov::Tensor image_features = m_vision_encoder.get_output_tensor(); + ImageSize resized_source_size{config.crop_size_height / config.patch_size, config.crop_size_width / config.patch_size}; + + return {image_features, resized_source_size}; } diff --git a/src/cpp/src/vlm_config.cpp b/src/cpp/src/vlm_config.cpp index 36d997ecbe..8d7585f2bb 100644 --- a/src/cpp/src/vlm_config.cpp +++ b/src/cpp/src/vlm_config.cpp @@ -10,6 +10,7 @@ ov::genai::VLMConfig::VLMConfig(const std::filesystem::path& json_path) { OPENVINO_ASSERT(stream.is_open(), "Failed to open '" + json_path.string() + "' with processor config"); nlohmann::json parsed = nlohmann::json::parse(stream); using ov::genai::utils::read_json_param; + model_type = to_vlm_model_type(parsed.at("model_type")); read_json_param(parsed, "hidden_size", hidden_size); read_json_param(parsed, "scale_emb", scale_emb); read_json_param(parsed, "query_num", query_num); diff --git a/src/cpp/src/vlm_pipeline.cpp b/src/cpp/src/vlm_pipeline.cpp index 3bdc3d9ae9..0678f2b074 100644 --- a/src/cpp/src/vlm_pipeline.cpp +++ b/src/cpp/src/vlm_pipeline.cpp @@ -294,6 +294,54 @@ ov::Tensor resample(VLMPipeline& pipe, const ov::Tensor& encoded_image, const st pipe.m_resampler.infer(); return pipe.m_resampler.get_output_tensor(); // [N, query_num, new_hidden_size] } + +ov::Tensor merge_text_and_image_embeddings_llava( + const ov::Tensor& input_ids, + const ov::Tensor& text_embeds, + const ov::Tensor& image_embeds, + int64_t image_token_index +) { + auto text_embeds_shape = text_embeds.get_shape(); + auto image_embeds_shape = image_embeds.get_shape(); + + OPENVINO_ASSERT( + text_embeds_shape[2] == image_embeds_shape[2], + "Incompatible shapes between text_embeds and image_embeds" + ); + + size_t text_embeds_seq_length = text_embeds_shape[1]; + size_t hidden_size = text_embeds_shape[2]; + size_t image_embeds_seq_length = image_embeds_shape[1]; + + size_t merged_seq_length = text_embeds_seq_length + (image_embeds_seq_length - 1); + + ov::Tensor merged_embeds(text_embeds.get_element_type(), {BATCH_SIZE, merged_seq_length, hidden_size}); + + const int64_t* input_ids_data = input_ids.data(); + const float* text_embeds_data = text_embeds.data(); + const float* image_embeds_data = image_embeds.data(); + float* merged_data = merged_embeds.data(); + + + size_t merged_idx = 0; + for (size_t s = 0; s < text_embeds_seq_length; ++s) { + if (input_ids_data[s] == image_token_index) { + for (size_t i = 0; i < image_embeds_seq_length; ++i) { + std::copy_n(image_embeds_data + i * hidden_size, + hidden_size, + merged_data + merged_idx * hidden_size); + merged_idx++; + } + } else { + std::copy_n(text_embeds_data + s * hidden_size, + hidden_size, + merged_data + merged_idx * hidden_size); + merged_idx++; + } + } + + return merged_embeds; +} } class ov::genai::VLMPipeline::VLMPipelineImpl { @@ -310,20 +358,33 @@ VLMPipeline::VLMPipeline( ) }, m_tokenizer{Tokenizer(model_dir.string(), device_config)}, - m_vision_encoder(model_dir, device, device_config, ov::Core{}), - m_resampler{ov::Core{}.compile_model( - model_dir / "resampler.xml", device, device_config - ).create_infer_request()}, - m_embedding{ov::Core{}.compile_model( - model_dir / "embed_tokens.xml", device, device_config - ).create_infer_request()}, - m_language{ov::Core{}.compile_model( - model_dir / "language_model.xml", device, device_config - ).create_infer_request()}, - m_pos_embed_cache{ - get_2d_sincos_pos_embed(m_vlm_config.hidden_size, {70, 70}) - }, + m_vision_encoder(model_dir, m_vlm_config.model_type, device, device_config, ov::Core{}), m_is_chat_conversation{false} { + if (m_vlm_config.model_type == VLMModelType::MINICPM) { + m_resampler = ov::Core{}.compile_model( + model_dir / "resampler.xml", device, device_config + ).create_infer_request(); + + m_embedding = ov::Core{}.compile_model( + model_dir / "embed_tokens.xml", device, device_config + ).create_infer_request(); + + m_language = ov::Core{}.compile_model( + model_dir / "language_model.xml", device, device_config + ).create_infer_request(); + + m_pos_embed_cache = get_2d_sincos_pos_embed(m_vlm_config.hidden_size, {70, 70}); + } else if (m_vlm_config.model_type == VLMModelType::LLAVA) { + m_language = ov::Core{}.compile_model( + model_dir / "openvino_language_model.xml", device, device_config + ).create_infer_request(); + + // Reusing the same m_embedding for llava text_embeddings model + m_embedding = ov::Core{}.compile_model( + model_dir / "openvino_text_embeddings_model.xml", device, device_config + ).create_infer_request(); + } + m_language.get_tensor("attention_mask").set_shape({1, 0}); } @@ -335,138 +396,21 @@ DecodedResults VLMPipeline::generate( const GenerationConfig& generation_config, const StreamerVariant& streamer ) { - std::string images_prompt; - std::vector embeds; - for (const ov::Tensor& rgb : rgbs) { - ov::Tensor reshaped = rgb; - ov::Shape rgb_shape = rgb.get_shape(); - switch (rgb_shape.size()) { - case 3: - reshaped.set_shape({1, rgb_shape.at(0), rgb_shape.at(1), rgb_shape.at(2)}); - break; - case 4: break; - default: OPENVINO_THROW("Input image must have [NHWC] or [HWC] layout"); - } - ov::Shape reshaped_shape = reshaped.get_shape(); - for (size_t batch_idx = 0; batch_idx < reshaped_shape.at(0); ++batch_idx) { - ov::Tensor single_image{ - ov::element::u8, - {1, reshaped_shape.at(1), reshaped_shape.at(2), reshaped_shape.at(3)}, - reshaped.data() + batch_idx * reshaped_shape.at(1) * reshaped_shape.at(1) * reshaped_shape.at(1) - }; - EncodedImage encoded_image = m_vision_encoder.encode(single_image); - if (m_vlm_config.use_image_id) { - images_prompt += m_vlm_config.im_id_start + std::to_string(image_id) + m_vlm_config.im_id_end; - ++image_id; - } - std::string unk64; - for (size_t idx = 0; idx < m_vlm_config.query_num; ++idx) { - unk64 += m_vlm_config.unk; - } - images_prompt += m_vlm_config.im_start + unk64 + m_vlm_config.im_end; - if (encoded_image.slices) { - ov::Shape slices_shape = encoded_image.slices.get_shape(); - for (size_t row_idx = 0; row_idx < slices_shape.at(0); ++row_idx) { - for (size_t col_idx = 0; col_idx < slices_shape.at(1); ++col_idx) { - images_prompt += m_vlm_config.slice_start + unk64 + m_vlm_config.slice_end; - } - images_prompt += '\n'; - } - } - if ('\n' != *(images_prompt.end() - 1)) { - // Image wasn't sliced, add \n to the end of image anyway. - // Strangely, \n isn't placed between . - images_prompt += '\n'; - } - embeds.push_back(std::move(encoded_image)); - } - } - images_prompt += prompt; - ov::Tensor encoded_input; - if (m_is_chat_conversation) { - // KV cache in model already contains prompts and answers from previous iterations. - // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns - // token_ids = {, ...}. So if tokenizer applies only to the new prompt, - // will be inserted on every iteration. - // So actual pipeline calculates input_ids for whole chat history + for whole chat history without the new prompt - // and takes only the difference between them. - // The chat history cannot be saved as already encoded tokens because generate call doesn't return token, but - // KV cache contains it. So we have to add it manually or get it by tokenization all chat history. - m_history.push_back({{"role", "user"}, {"content", images_prompt}}); - constexpr bool add_generation_prompt = true; - std::string new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); - ov::Tensor new_chat_tokens = m_tokenizer.encode(new_templated_chat_history).input_ids; - if (0 == m_language.get_tensor("attention_mask").get_shape().at(1)) { - encoded_input = new_chat_tokens; - } else { - TokenizedInputs prev_chat_tokens = m_tokenizer.encode( - m_templated_chat_history - ); - encoded_input = utils::subtract_chat_tokenized_inputs( - {new_chat_tokens}, prev_chat_tokens - ).input_ids; - } - m_templated_chat_history = std::move(new_templated_chat_history); - } else { - encoded_input = m_tokenizer.encode(images_prompt).input_ids; - } - m_embedding.set_input_tensor(encoded_input); - m_embedding.infer(); - ov::Tensor inputs_embeds = m_embedding.get_output_tensor(); - OPENVINO_ASSERT( - m_vlm_config.hidden_size == inputs_embeds.get_shape().at(2), - "Unexpected embedding size" - ); - ov::Tensor special_tokens = m_tokenizer.encode( - m_vlm_config.im_start - + m_vlm_config.im_end - + m_vlm_config.slice_start - + m_vlm_config.slice_end - ).input_ids; - OPENVINO_ASSERT( - 4 == special_tokens.get_shape().at(1), - "Every special token must be represented with a single int." - ); - int64_t im_start_id = special_tokens.data()[0]; - int64_t im_end_id = special_tokens.data()[1]; - int64_t slice_start_id = special_tokens.data()[2]; - int64_t slice_end_id = special_tokens.data()[3]; - int64_t im_start_pos = 0, slice_start_pos = 0; - int64_t* begin = encoded_input.data(); - int64_t* ids = begin; - size_t encoded_input_size = encoded_input.get_size(); - int64_t* end = ids + encoded_input_size; - float* inputs_embeds_data = inputs_embeds.data(); - for (const EncodedImage& encoded_image : embeds) { - const ov::Tensor& resampled_source = resample(*this, encoded_image.resized_source, {encoded_image.resized_source_size}); - float* emb = resampled_source.data(); - ids = std::find(ids, end, im_start_id); - OPENVINO_ASSERT(end != ids); - std::copy_n(emb, resampled_source.get_size(), inputs_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size); - ids += m_vlm_config.query_num; - if (encoded_image.slices) { - size_t token_idx = 0; - const ov::Shape& slices_shape = encoded_image.slices.get_shape(); - for (size_t i = 0; i < slices_shape.at(0); ++i) { - for (size_t ja = 0; ja < slices_shape.at(1); ++ja) { - size_t d2 = slices_shape.at(2); - size_t d3 = slices_shape.at(3); - ov::Tensor encoded_view{ov::element::f32, {1, d2, d3}, encoded_image.slices.data() + (i * slices_shape.at(1) + ja) * d2 * d3}; - const ov::Tensor& vision_embed_tensor_i_j = resample(*this, encoded_view, {encoded_image.slices_size}); - ids = std::find(ids, end, slice_start_id); - OPENVINO_ASSERT(end != ids); - std::copy_n(vision_embed_tensor_i_j.data(), vision_embed_tensor_i_j.get_size(), inputs_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size); - ids += m_vlm_config.query_num; - } - } - } + ov::Tensor inputs_embeds; + if (m_vlm_config.model_type == VLMModelType::MINICPM) { + inputs_embeds = get_inputs_embeds_minicpm(prompt, rgbs); + } else if (m_vlm_config.model_type == VLMModelType::LLAVA) { + inputs_embeds = get_inputs_embeds_llava(prompt, rgbs); } + m_language.set_tensor("inputs_embeds", inputs_embeds); size_t history_len = m_language.get_tensor("attention_mask").get_shape().at(1); m_language.get_tensor("attention_mask").set_shape({1, history_len + inputs_embeds.get_shape()[1]}); std::fill_n(m_language.get_tensor("attention_mask").data(), m_language.get_tensor("attention_mask").get_size(), 1); + m_language.get_tensor("position_ids").set_shape({1, inputs_embeds.get_shape().at(1)}); std::iota(m_language.get_tensor("position_ids").data(), m_language.get_tensor("position_ids").data() + m_language.get_tensor("position_ids").get_size(), history_len); + m_language.get_tensor("beam_idx").set_shape({ BATCH_SIZE }); m_language.get_tensor("beam_idx").data()[0] = 0; @@ -606,3 +550,153 @@ GenerationConfig VLMPipeline::get_generation_config() const { void VLMPipeline::set_generation_config(const GenerationConfig& new_config) { m_generation_config = new_config; } + +ov::Tensor VLMPipeline::get_inputs_embeds_llava(const std::string& prompt, const std::vector& images) { + std::string image_token = ""; // TODO Consider getting from vlm_config or json + std::string formatted_prompt = "USER: " + (images.empty() ? prompt : image_token + "\n" + prompt) + " ASSISTANT:"; + ov::Tensor input_ids = m_tokenizer.encode(formatted_prompt).input_ids; + if (images.empty()) { + return process_prompt(m_embedding, input_ids, m_vlm_config.scale_emb); + } else { + OPENVINO_ASSERT(1 == images.size(), "Only a single image allowed"); + EncodedImage encoded_image = m_vision_encoder.encode(images.at(0)); + ov::Tensor image_embeds = encoded_image.resized_source; + + ov::Tensor text_embeds = process_prompt(m_embedding, input_ids, m_vlm_config.scale_emb); + + int64_t image_token_index = 32000; // TODO Consider getting from m_vlm_config.image_token_index or config.json + + return merge_text_and_image_embeddings_llava(input_ids, text_embeds, image_embeds, image_token_index); + } +} + +ov::Tensor VLMPipeline::get_inputs_embeds_minicpm(const std::string& prompt, const std::vector& images) { + std::string images_prompt; + std::vector embeds; + for (const ov::Tensor& rgb : images) { + ov::Tensor reshaped = rgb; + ov::Shape rgb_shape = rgb.get_shape(); + switch (rgb_shape.size()) { + case 3: + reshaped.set_shape({1, rgb_shape.at(0), rgb_shape.at(1), rgb_shape.at(2)}); + break; + case 4: break; + default: OPENVINO_THROW("Input image must have [NHWC] or [HWC] layout"); + } + ov::Shape reshaped_shape = reshaped.get_shape(); + for (size_t batch_idx = 0; batch_idx < reshaped_shape.at(0); ++batch_idx) { + ov::Tensor single_image{ + ov::element::u8, + {1, reshaped_shape.at(1), reshaped_shape.at(2), reshaped_shape.at(3)}, + reshaped.data() + batch_idx * reshaped_shape.at(1) * reshaped_shape.at(1) * reshaped_shape.at(1) + }; + EncodedImage encoded_image = m_vision_encoder.encode(single_image); + if (m_vlm_config.use_image_id) { + images_prompt += m_vlm_config.im_id_start + std::to_string(image_id) + m_vlm_config.im_id_end; + ++image_id; + } + std::string unk64; + for (size_t idx = 0; idx < m_vlm_config.query_num; ++idx) { + unk64 += m_vlm_config.unk; + } + images_prompt += m_vlm_config.im_start + unk64 + m_vlm_config.im_end; + if (encoded_image.slices) { + ov::Shape slices_shape = encoded_image.slices.get_shape(); + for (size_t row_idx = 0; row_idx < slices_shape.at(0); ++row_idx) { + for (size_t col_idx = 0; col_idx < slices_shape.at(1); ++col_idx) { + images_prompt += m_vlm_config.slice_start + unk64 + m_vlm_config.slice_end; + } + images_prompt += '\n'; + } + } + if ('\n' != *(images_prompt.end() - 1)) { + // Image wasn't sliced, add \n to the end of image anyway. + // Strangely, \n isn't placed between . + images_prompt += '\n'; + } + embeds.push_back(std::move(encoded_image)); + } + } + images_prompt += prompt; + ov::Tensor encoded_input; + if (m_is_chat_conversation) { + // KV cache in model already contains prompts and answers from previous iterations. + // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns + // token_ids = {, ...}. So if tokenizer applies only to the new prompt, + // will be inserted on every iteration. + // So actual pipeline calculates input_ids for whole chat history + for whole chat history without the new prompt + // and takes only the difference between them. + // The chat history cannot be saved as already encoded tokens because generate call doesn't return token, but + // KV cache contains it. So we have to add it manually or get it by tokenization all chat history. + m_history.push_back({{"role", "user"}, {"content", images_prompt}}); + constexpr bool add_generation_prompt = true; + std::string new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + ov::Tensor new_chat_tokens = m_tokenizer.encode(new_templated_chat_history).input_ids; + if (0 == m_language.get_tensor("attention_mask").get_shape().at(1)) { + encoded_input = new_chat_tokens; + } else { + TokenizedInputs prev_chat_tokens = m_tokenizer.encode( + m_templated_chat_history + ); + encoded_input = utils::subtract_chat_tokenized_inputs( + {new_chat_tokens}, prev_chat_tokens + ).input_ids; + } + m_templated_chat_history = std::move(new_templated_chat_history); + } else { + encoded_input = m_tokenizer.encode(images_prompt).input_ids; + } + m_embedding.set_input_tensor(encoded_input); + m_embedding.infer(); + ov::Tensor inputs_embeds = m_embedding.get_output_tensor(); + OPENVINO_ASSERT( + m_vlm_config.hidden_size == inputs_embeds.get_shape().at(2), + "Unexpected embedding size" + ); + ov::Tensor special_tokens = m_tokenizer.encode( + m_vlm_config.im_start + + m_vlm_config.im_end + + m_vlm_config.slice_start + + m_vlm_config.slice_end + ).input_ids; + OPENVINO_ASSERT( + 4 == special_tokens.get_shape().at(1), + "Every special token must be represented with a single int." + ); + int64_t im_start_id = special_tokens.data()[0]; + int64_t im_end_id = special_tokens.data()[1]; + int64_t slice_start_id = special_tokens.data()[2]; + int64_t slice_end_id = special_tokens.data()[3]; + int64_t im_start_pos = 0, slice_start_pos = 0; + int64_t* begin = encoded_input.data(); + int64_t* ids = begin; + size_t encoded_input_size = encoded_input.get_size(); + int64_t* end = ids + encoded_input_size; + float* inputs_embeds_data = inputs_embeds.data(); + for (const EncodedImage& encoded_image : embeds) { + const ov::Tensor& resampled_source = resample(*this, encoded_image.resized_source, {encoded_image.resized_source_size}); + float* emb = resampled_source.data(); + ids = std::find(ids, end, im_start_id); + OPENVINO_ASSERT(end != ids); + std::copy_n(emb, resampled_source.get_size(), inputs_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size); + ids += m_vlm_config.query_num; + if (encoded_image.slices) { + size_t token_idx = 0; + const ov::Shape& slices_shape = encoded_image.slices.get_shape(); + for (size_t i = 0; i < slices_shape.at(0); ++i) { + for (size_t ja = 0; ja < slices_shape.at(1); ++ja) { + size_t d2 = slices_shape.at(2); + size_t d3 = slices_shape.at(3); + ov::Tensor encoded_view{ov::element::f32, {1, d2, d3}, encoded_image.slices.data() + (i * slices_shape.at(1) + ja) * d2 * d3}; + const ov::Tensor& vision_embed_tensor_i_j = resample(*this, encoded_view, {encoded_image.slices_size}); + ids = std::find(ids, end, slice_start_id); + OPENVINO_ASSERT(end != ids); + std::copy_n(vision_embed_tensor_i_j.data(), vision_embed_tensor_i_j.get_size(), inputs_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size); + ids += m_vlm_config.query_num; + } + } + } + } + + return inputs_embeds; +} diff --git a/src/docs/SUPPORTED_MODELS.md b/src/docs/SUPPORTED_MODELS.md index 1232a081dd..fb6df36950 100644 --- a/src/docs/SUPPORTED_MODELS.md +++ b/src/docs/SUPPORTED_MODELS.md @@ -167,8 +167,17 @@ The pipeline can work with other similar topologies produced by `optimum-intel` Example HuggingFace Models - MiniCPM-V-2_6 + LLaVA + LLaVA-v1.5 + + + + + MiniCPMV + MiniCPM-V-2_6
  • openbmb/MiniCPM-V-2_6
  • diff --git a/tests/python_tests/requirements.txt b/tests/python_tests/requirements.txt index 8c49f7c1e6..372d3ac950 100644 --- a/tests/python_tests/requirements.txt +++ b/tests/python_tests/requirements.txt @@ -1,5 +1,6 @@ --extra-index-url https://download.pytorch.org/whl/cpu -optimum[openvino]==1.22.0 +optimum-intel @ git+https://github.com/huggingface/optimum-intel.git +numpy<2.0.0; sys_platform == 'darwin' onnx==1.16.1 pytest llm_bench/python/who_what_benchmark