diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index 097cca34dd..6f6a11c1e9 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -704,11 +704,17 @@ jobs: python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly python ./samples/cpp/visual_language_chat/export_MiniCPM-V-2_6.py ./miniCPM-V-2_6/ - wget https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11 --output-document cat.jpg - - name: Run visual_language_chat sample - MiniCPM-V-2_6 + mkdir cat_img + wget https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11 --output-document cat_img/cat.jpg + - name: Run visual_language_chat sample with file as input - MiniCPM-V-2_6 run: > source ./ov/setupvars.sh - && timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./miniCPM-V-2_6/ cat.jpg + && timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./miniCPM-V-2_6/ cat_img/cat.jpg + <<< $'What is on the image?\nWhat is special on the image?' + - name: Run visual_language_chat sample with dir as input - MiniCPM-V-2_6 + run: > + source ./ov/setupvars.sh + && timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./miniCPM-V-2_6/ cat_img <<< $'What is on the image?\nWhat is special on the image?' - name: Download and convert LLaVa 1.5 model and an image run: | @@ -729,7 +735,7 @@ jobs: source ./ov/setupvars.sh export PYTHONPATH=./build/:$PYTHONPATH printf 'What is on the image?\nWhat is special on the image?\n' > ./input.txt - timeout 120s python ./samples/python/visual_language_chat/visual_language_chat.py ./miniCPM-V-2_6/ cat.jpg < input.txt > ./pred.txt + timeout 120s python ./samples/python/visual_language_chat/visual_language_chat.py ./miniCPM-V-2_6/ cat_img/cat.jpg < input.txt > ./pred.txt cpp-continuous-batching-ubuntu: runs-on: ubuntu-20.04-8-cores diff --git a/samples/cpp/visual_language_chat/load_image.cpp b/samples/cpp/visual_language_chat/load_image.cpp index 855f7567bf..554583a10d 100644 --- a/samples/cpp/visual_language_chat/load_image.cpp +++ b/samples/cpp/visual_language_chat/load_image.cpp @@ -6,6 +6,28 @@ #include "stb_image.h" #include "load_image.hpp" +namespace fs = std::filesystem; + +std::vector utils::load_images(const std::filesystem::path& input_path) { + std::vector images; + if (!input_path.empty() && fs::exists(input_path)) { + if (fs::is_directory(input_path)) { + for (const auto& dir_entry : fs::directory_iterator(input_path)) { + ov::Tensor image = utils::load_image(dir_entry.path()); + images.push_back(std::move(image)); + } + } else if (fs::is_regular_file(input_path)) { + ov::Tensor image = utils::load_image(input_path); + images.push_back(std::move(image)); + } + } + + if (images.empty()) + throw std::runtime_error(std::string{"No one image found by path "} + input_path.string()); + + return images; +} + ov::Tensor utils::load_image(const std::filesystem::path& image_path) { int x = 0, y = 0, channels_in_file = 0; constexpr int desired_channels = 3; diff --git a/samples/cpp/visual_language_chat/load_image.hpp b/samples/cpp/visual_language_chat/load_image.hpp index f66dd2caf2..d0dcc271cd 100644 --- a/samples/cpp/visual_language_chat/load_image.hpp +++ b/samples/cpp/visual_language_chat/load_image.hpp @@ -9,4 +9,5 @@ namespace utils { ov::Tensor load_image(const std::filesystem::path& image_path); +std::vector load_images(const std::filesystem::path& image_path); } diff --git a/samples/cpp/visual_language_chat/visual_language_chat.cpp b/samples/cpp/visual_language_chat/visual_language_chat.cpp index 06f8f4e696..c924217764 100644 --- a/samples/cpp/visual_language_chat/visual_language_chat.cpp +++ b/samples/cpp/visual_language_chat/visual_language_chat.cpp @@ -14,41 +14,10 @@ bool print_subword(std::string&& subword) { int main(int argc, char* argv[]) try { if (3 != argc) { - throw std::runtime_error(std::string{"Usage "} + argv[0] + " "); + throw std::runtime_error(std::string{"Usage "} + argv[0] + " "); } - // multinomial or beam_search can be used as well - ov::genai::GenerationConfig generation_config = ov::genai::greedy(); - // ov::genai::GenerationConfig generation_config = ov::genai::multinomial(); - // ov::genai::GenerationConfig generation_config = ov::genai::beam_search(); - - ov::AnyMap properies; - properies.insert(ov::genai::generation_config(generation_config)); - - // streamer could be used with greedy and multinomial - // if num_return_sequences > 1 in case of multinomial, the streamer will use the output from the first sequence - if (generation_config.is_greedy_decoding() or generation_config.is_multinomial()) { - properies.insert(ov::genai::streamer(print_subword)); - } - - std::vector images; - std::string input_path = argv[2]; - if (!input_path.empty() && fs::exists(input_path)) { - if (fs::is_directory(input_path)) { - for (const auto& dir_entry : fs::directory_iterator(input_path)) { - ov::Tensor image = utils::load_image(dir_entry.path()); - images.push_back(std::move(image)); - } - } else if (fs::is_regular_file(input_path)) { - ov::Tensor image = utils::load_image(input_path); - images.push_back(std::move(image)); - } - } - - if (images.empty()) - throw std::runtime_error("No one image found by path " + input_path); - else - properies.insert(images.size() == 1 ? ov::genai::image(images.at(0)) : ov::genai::images(images)); + std::vector images = utils::load_images(argv[2]); std::string device = "CPU"; // GPU can be used as well ov::AnyMap enable_compile_cache; @@ -64,19 +33,16 @@ int main(int argc, char* argv[]) try { std::cout << "question:\n"; std::getline(std::cin, prompt); - auto resuls = pipe.generate(prompt, properies); - if (generation_config.is_beam_search()) { - std::cout << resuls.texts.at(0) << std::endl; - } - properies.erase(images.size() == 1 ? "image" : "images"); - + auto resuls = pipe.generate(prompt, + images.size() == 1 ? ov::genai::image(images.at(0)) : ov::genai::images(images), + ov::genai::generation_config(ov::genai::greedy()), + ov::genai::streamer(print_subword)); std::cout << "\n----------\n" "question:\n"; while (std::getline(std::cin, prompt)) { - resuls = pipe.generate(prompt, properies); - if (generation_config.is_beam_search()) { - std::cout << resuls.texts.at(0) << std::endl; - } + resuls = pipe.generate(prompt, + ov::genai::generation_config(ov::genai::greedy()), + ov::genai::streamer(print_subword)); std::cout << "\n----------\n" "question:\n"; } diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index b82677f689..f076b5bcc5 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -3,7 +3,7 @@ #include "openvino/genai/visual_language/pipeline.hpp" #include "openvino/genai/tokenizer.hpp" -#include "vlm_sampling.hpp" +#include "sampler.hpp" #include "clip.hpp" #include "text_callback_streamer.hpp" #include "utils.hpp" @@ -21,64 +21,6 @@ template overloaded(Ts...) -> overloaded; constexpr size_t BATCH_SIZE = 1; -struct Args { - bool do_sample = false; - int top_k = 0; - float top_p = 0.7f; - float temp = 0.95f; - float repeat_penalty = 1.0f; -}; - -int64_t get_out_token_id(const std::vector& input_ids, float* logits, size_t vocab_size, Args args) { - int64_t out_token; - - // logits pre-process - if (args.repeat_penalty != 1.f) { - sampling_repetition_penalty(logits, logits + vocab_size, input_ids, args.repeat_penalty); - } - - if (args.do_sample) - { - if (args.temp > 0) { - sampling_temperature(logits, logits + vocab_size, args.temp); - } - - std::vector token_scores(vocab_size); - for (int i = 0; i < vocab_size; i++) { - token_scores[i] = TokenIdScore(i, logits[i]); - } - - // top_k sampling - if (0 < args.top_k && args.top_k < (int)token_scores.size()) { - sampling_top_k(token_scores.data(), token_scores.data() + args.top_k, - token_scores.data() + token_scores.size()); - token_scores.resize(args.top_k); - } - - // top_p sampling - if (0.f < args.top_p && args.top_p < 1.f) { - auto pos = sampling_top_p(token_scores.data(), token_scores.data() + token_scores.size(), args.top_p); - token_scores.resize(pos - token_scores.data()); - } - - // sample next token - sampling_softmax_inplace(token_scores.data(), token_scores.data() + token_scores.size()); - for (size_t i = 0; i < token_scores.size(); i++) { - logits[i] = token_scores[i].score; - } - - thread_local std::random_device rd; - thread_local std::mt19937 gen(rd()); - - std::discrete_distribution<> dist(logits, logits + token_scores.size()); - out_token = token_scores[dist(gen)].id; - } - else { - out_token = std::max_element(logits, logits + vocab_size) - logits; - } - - return out_token; -} ov::Tensor process_prompt(ov::InferRequest& embedding, const ov::Tensor& prompt, float scale_emb) { embedding.set_input_tensor(prompt); @@ -302,9 +244,9 @@ ov::Tensor merge_text_and_image_embeddings_llava( EncodedGenerationResult get_lm_encoded_results( ov::InferRequest& language, ov::InferRequest& embedding, - ov::Tensor inputs_embeds, - const VLMConfig m_vlm_config, - const std::shared_ptr streamer_ptr, + const ov::Tensor& inputs_embeds, + const VLMConfig& m_vlm_config, + const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector requests ) { @@ -545,9 +487,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { }, }, streamer); - if ((!(generation_config.is_greedy_decoding() || generation_config.is_multinomial())) && streamer_ptr) { - OPENVINO_THROW("Currently streaming is possible only for greedy or multinomial decoding"); - } + OPENVINO_ASSERT((generation_config.is_greedy_decoding() || generation_config.is_multinomial() || !streamer_ptr), + "Currently streaming is possible only for greedy or multinomial decoding"); EncodedGenerationResult encoded_result = get_lm_encoded_results(m_language, m_embedding, processed_input.first, m_vlm_config, streamer_ptr, sampler, requests); @@ -591,6 +532,11 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::genai::OptionalGenerationConfig config_arg = utils::get_config_from_map(config_map); GenerationConfig config = (config_arg.has_value()) ? *config_arg : get_generation_config(); config.update_generation_config(config_map); + + // If eos_token_id was not provided, take value + if (config.eos_token_id == -1) + config.set_eos_token_id(m_tokenizer.get_eos_token_id()); + return generate( prompt, rgbs, @@ -634,12 +580,12 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { m_generation_config = new_config; } - ov::Tensor get_inputs_embeds_llava(const std::string& prompt, const std::vector& images) { + std::pair get_inputs_embeds_llava(const std::string& prompt, const std::vector& images) { std::string image_token = ""; // TODO Consider getting from vlm_config or json std::string formatted_prompt = "USER: " + (images.empty() ? prompt : image_token + "\n" + prompt) + " ASSISTANT:"; ov::Tensor input_ids = m_tokenizer.encode(formatted_prompt).input_ids; if (images.empty()) { - return process_prompt(m_embedding, input_ids, m_vlm_config.scale_emb); + return std::pair{process_prompt(m_embedding, input_ids, m_vlm_config.scale_emb), input_ids}; } else { OPENVINO_ASSERT(1 == images.size(), "Only a single image allowed"); EncodedImage encoded_image = m_vision_encoder.encode(images.at(0)); @@ -653,7 +599,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { } } - ov::Tensor get_inputs_embeds_minicpm(const std::string& prompt, const std::vector& images) { + std::pair get_inputs_embeds_minicpm(const std::string& prompt, const std::vector& images) { std::string images_prompt; std::vector embeds; for (const ov::Tensor& rgb : images) { diff --git a/src/cpp/src/vlm_pipeline.cpp b/src/cpp/src/vlm_pipeline.cpp deleted file mode 100644 index 332acb3928..0000000000 --- a/src/cpp/src/vlm_pipeline.cpp +++ /dev/null @@ -1,747 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include "openvino/genai/vlm_pipeline.hpp" -#include "openvino/genai/tokenizer.hpp" -#include "sampler.hpp" -#include "clip.hpp" -#include -#include "../src/text_callback_streamer.hpp" -#include "utils.hpp" -#include -#include - -using namespace ov::genai; - -namespace { -template struct overloaded : Ts... {using Ts::operator()...;}; -template overloaded(Ts...) -> overloaded; - -constexpr size_t BATCH_SIZE = 1; - -ov::Tensor process_prompt(ov::InferRequest& embedding, const ov::Tensor& prompt, float scale_emb) { - embedding.set_input_tensor(prompt); - embedding.infer(); - - const ov::Tensor& embed_output_tensor = embedding.get_output_tensor(); - - ov::Shape out_shape = embed_output_tensor.get_shape(); - float* data = embed_output_tensor.data(); - - //embedding * scale_emb - for (size_t idx = 0; idx < embed_output_tensor.get_size(); idx++) { - data[idx] = data[idx] * scale_emb; - } - return embed_output_tensor; -} - -ov::Tensor concatenate_last_dim(const ov::Tensor& first, const ov::Tensor& second) { - size_t res_d_0 = first.get_shape().at(0); - size_t res_d_1 = first.get_shape().at(1); - OPENVINO_ASSERT(second.get_shape().at(0) == res_d_0); - OPENVINO_ASSERT(second.get_shape().at(1) == res_d_1); - size_t res_d_2 = first.get_shape().at(2) + second.get_shape().at(2); - ov::Tensor res{first.get_element_type(), {res_d_0, res_d_1, res_d_2}}; - float* first_data = first.data(); - float* second_data = second.data(); - float* res_data = res.data(); - for (size_t i = 0; i < res_d_0; ++i) { - for (size_t j = 0; j < res_d_1; ++j) { - size_t k = 0; - for (; k < first.get_shape().at(2); ++k) { - res_data[i * res_d_1 * res_d_2 + j * res_d_2 + k] - = first_data[i * res_d_1 * first.get_shape().at(2) + j * first.get_shape().at(2) + k]; - } - for (size_t l = 0; l < second.get_shape().at(2); ++l, ++k) { - res_data[i * res_d_1 * res_d_2 + j * res_d_2 + k] - = second_data[i * res_d_1 * second.get_shape().at(2) + j * second.get_shape().at(2) + l]; - } - } - } - return res; -} - -ov::Tensor concatenate_mid_dim(const ov::Tensor& first, const ov::Tensor& second) { - size_t res_d_0 = first.get_shape().at(0); - size_t res_d_2 = first.get_shape().at(2); - OPENVINO_ASSERT(second.get_shape().at(0) == res_d_0); - OPENVINO_ASSERT(second.get_shape().at(2) == res_d_2); - size_t res_d_1 = first.get_shape().at(1) + second.get_shape().at(1); - ov::Tensor res{first.get_element_type(), {res_d_0, res_d_1, res_d_2}}; - float* first_data = first.data(); - float* second_data = second.data(); - float* res_data = res.data(); - for (size_t i = 0; i < res_d_0; ++i) { - size_t j = 0; - for (; j < first.get_shape().at(1); ++j) { - std::copy_n( - first_data + i * first.get_shape().at(1) * res_d_2 + j * res_d_2, - res_d_2, - res_data + i * res_d_1 * res_d_2 + j * res_d_2 - ); - } - for (size_t k = 0; k < second.get_shape().at(1); ++k, ++j) { - std::copy_n( - second_data + i * second.get_shape().at(1) * res_d_2 + k * res_d_2, - res_d_2, - res_data + i * res_d_1 * res_d_2 + j * res_d_2 - ); - } - } - return res; -} - -/// embed_dim: output dimension for each position -/// pos: a list of positions to be encoded: size (H, W) -/// out: (H, W, D) -ov::Tensor get_1d_sincos_pos_embed_from_grid_new(size_t embed_dim, const ov::Tensor& pos) { - OPENVINO_ASSERT(embed_dim % 2 == 0); - ov::Shape pos_shape = pos.get_shape(); - size_t H = pos_shape[0]; - size_t W = pos_shape[1]; - - std::vector omega(embed_dim / 2); - for (size_t i = 0; i < omega.size(); ++i) { - omega[i] = 1.0f / std::pow(10000.0f, float(i) / (embed_dim / 2)); - } - - std::vector out_shape = {H, W, embed_dim}; - ov::Tensor emb(ov::element::f32, out_shape); - - float* pos_data = pos.data(); - float* emb_data = emb.data(); - - size_t counter = 0; - for (size_t h = 0; h < H; ++h) { - for (size_t w = 0; w < W; ++w) { - for (size_t d = 0; d < embed_dim / 2; ++d) { - // Correctly access the 2D position grid - float value = omega[d] * pos_data[h * W + w]; - // There should be sinf() and cosf(), but they don't exist on default Ubuntu20 gcc. - emb_data[h * W * embed_dim + w * embed_dim + d] = std::sin(double(value)); - emb_data[h * W * embed_dim + w * embed_dim + d + (embed_dim / 2)] = std::cos(double(value)); - } - } - } - return emb; -} - -ov::Tensor get_2d_sincos_pos_embed_from_grid(size_t embed_dim, const ov::Tensor& grid) { - OPENVINO_ASSERT(embed_dim % 2 == 0); - ov::Shape grid_shape = grid.get_shape(); - float* grid_data = grid.data(); - ov::Shape plane_shape{grid_shape.at(1), grid_shape.at(2)}; - ov::Tensor emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, ov::Tensor{ - ov::element::f32, - plane_shape, - grid_data - }); // (H, W, D/2) - ov::Tensor emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, ov::Tensor{ - ov::element::f32, - plane_shape, - grid_data + plane_shape.at(0) * plane_shape.at(1) - }); // (H, W, D/2) - return concatenate_last_dim(emb_h, emb_w); -} - -/// image_size: image_size or (image_height, image_width) -/// return: -/// pos_embed: [image_height, image_width, embed_dim] -ov::Tensor get_2d_sincos_pos_embed(size_t embed_dim, const ImageSize& image_size) { - size_t grid_h_size = image_size.height, grid_w_size = image_size.width; - ov::Tensor grid(ov::element::f32, {2, grid_h_size, grid_w_size}); - float* data = grid.data(); - for (size_t y = 0; y < grid_h_size; ++y) { - std::iota(data, data + grid_w_size, 0.0f); - data += grid_w_size; - } - for (float y = 0.0f; y < grid_h_size; ++y) { - std::fill(data, data + grid_w_size, y); - data += grid_w_size; - } - return get_2d_sincos_pos_embed_from_grid(embed_dim, grid); -} - -void adjust_pos_cache( - const std::vector& target_sizes, - size_t hidden_size, - ov::Tensor& pos_embed_cache -) { - size_t max_h = std::max_element(target_sizes.begin(), target_sizes.end(), [](const ImageSize& left, const ImageSize& right) { - return left.height < right.height; - })->height; - size_t max_w = std::max_element(target_sizes.begin(), target_sizes.end(), [](const ImageSize& left, const ImageSize& right) { - return left.width < right.width; - })->width; - size_t allocated_height, allocated_width; - if (pos_embed_cache) { - const ov::Shape& allocated_shape = pos_embed_cache.get_shape(); - allocated_height = allocated_shape.at(0); - allocated_width = allocated_shape.at(1); - } else { - allocated_height = allocated_width = 70; - } - if (max_h > allocated_height || max_w > allocated_width) { - allocated_height = std::max(max_h, allocated_height); - allocated_width = std::max(max_w, allocated_width); - pos_embed_cache = get_2d_sincos_pos_embed( - hidden_size, {allocated_height, allocated_width} - ); - } -} - -ov::Tensor resample(VLMPipeline& pipe, const ov::Tensor& encoded_image, const std::vector& target_sizes) { - size_t bs = encoded_image.get_shape().at(0); - std::vector patch_len{target_sizes.size()}; - std::transform(target_sizes.begin(), target_sizes.end(), patch_len.begin(), [](const ImageSize& height_width) { - return height_width.height * height_width.width; - }); - adjust_pos_cache( - target_sizes, - pipe.m_vlm_config.hidden_size, - pipe.m_pos_embed_cache - ); - size_t max_patch_len = *std::max_element(patch_len.begin(), patch_len.end()); - ov::Tensor key_padding_mask(ov::element::boolean, {bs, max_patch_len}); - bool* mask_data = key_padding_mask.data(); - size_t embed_len = pipe.m_pos_embed_cache.get_shape().at(2); - ov::Tensor pos_embed(ov::element::f32, {max_patch_len, bs, embed_len}); // BLD => L * B * D - float* pos_embed_data = pos_embed.data(); - float* cache_data = pipe.m_pos_embed_cache.data(); - size_t _d0 = pipe.m_pos_embed_cache.get_shape().at(0); - size_t _d1 = pipe.m_pos_embed_cache.get_shape().at(1); - for (size_t i = 0; i < bs; ++i) { - size_t target_h = target_sizes.at(i).height; - size_t target_w = target_sizes.at(i).width; - for (size_t h_idx = 0; h_idx < target_h; ++h_idx) { - for (size_t w_idx = 0; w_idx < target_w; ++w_idx) { - std::copy_n( - cache_data + (h_idx * _d1 + w_idx) * embed_len, - embed_len, - pos_embed_data + (h_idx * target_w + w_idx) * bs * embed_len + i * embed_len - ); - } - } - for (size_t flat = target_h * target_w; flat < max_patch_len; ++flat) { - std::fill_n(pos_embed_data + flat * bs * embed_len + i * embed_len, embed_len, 0.0f); - } - std::fill_n(mask_data + i * max_patch_len, patch_len[i], false); - std::fill_n(mask_data + i * max_patch_len + patch_len[i], max_patch_len - patch_len[i], true); - } - pipe.m_resampler.set_tensor("x", encoded_image); // [N, H*W, old_hidden_size] - pipe.m_resampler.set_tensor("pos_embed", pos_embed); // [H*W, N, new_hidden_size] - pipe.m_resampler.set_tensor("key_padding_mask", key_padding_mask); // [N, H*W] - pipe.m_resampler.infer(); - return pipe.m_resampler.get_output_tensor(); // [N, query_num, new_hidden_size] -} - -ov::Tensor merge_text_and_image_embeddings_llava( - const ov::Tensor& input_ids, - const ov::Tensor& text_embeds, - const ov::Tensor& image_embeds, - int64_t image_token_index -) { - auto text_embeds_shape = text_embeds.get_shape(); - auto image_embeds_shape = image_embeds.get_shape(); - - OPENVINO_ASSERT( - text_embeds_shape[2] == image_embeds_shape[2], - "Incompatible shapes between text_embeds and image_embeds" - ); - - size_t text_embeds_seq_length = text_embeds_shape[1]; - size_t hidden_size = text_embeds_shape[2]; - size_t image_embeds_seq_length = image_embeds_shape[1]; - - size_t merged_seq_length = text_embeds_seq_length + (image_embeds_seq_length - 1); - - ov::Tensor merged_embeds(text_embeds.get_element_type(), {BATCH_SIZE, merged_seq_length, hidden_size}); - - const int64_t* input_ids_data = input_ids.data(); - const float* text_embeds_data = text_embeds.data(); - const float* image_embeds_data = image_embeds.data(); - float* merged_data = merged_embeds.data(); - - - size_t merged_idx = 0; - for (size_t s = 0; s < text_embeds_seq_length; ++s) { - if (input_ids_data[s] == image_token_index) { - for (size_t i = 0; i < image_embeds_seq_length; ++i) { - std::copy_n(image_embeds_data + i * hidden_size, - hidden_size, - merged_data + merged_idx * hidden_size); - merged_idx++; - } - } else { - std::copy_n(text_embeds_data + s * hidden_size, - hidden_size, - merged_data + merged_idx * hidden_size); - merged_idx++; - } - } - - return merged_embeds; -} - -EncodedGenerationResult get_lm_encoded_results( - ov::InferRequest& language, - ov::InferRequest& embedding, - ov::Tensor inputs_embeds, - const VLMConfig m_vlm_config, - const std::shared_ptr streamer_ptr, - Sampler& sampler, - std::vector requests -) { - SequenceGroup::Ptr request = requests.back(); - GenerationHandle generation = std::make_shared(request->get_generation_stream(), request->get_sampling_parameters()); - - language.set_tensor("inputs_embeds", inputs_embeds); - - size_t history_len = language.get_tensor("attention_mask").get_shape().at(1); - language.get_tensor("attention_mask").set_shape({1, history_len + inputs_embeds.get_shape()[1]}); - std::fill_n(language.get_tensor("attention_mask").data(), language.get_tensor("attention_mask").get_size(), 1); - - language.get_tensor("position_ids").set_shape({1, inputs_embeds.get_shape().at(1)}); - std::iota(language.get_tensor("position_ids").data(), language.get_tensor("position_ids").data() + language.get_tensor("position_ids").get_size(), history_len); - - language.get_tensor("beam_idx").set_shape({ BATCH_SIZE }); - language.get_tensor("beam_idx").data()[0] = 0; - - language.infer(); - - int64_t sequence_len = language.get_tensor("logits").get_shape().at(1); - request->schedule_tokens(sequence_len); - - SamplerOutput sampler_output = sampler.sample(requests, language.get_tensor("logits")); - - // logits include image and prompt embedings - if (m_vlm_config.model_type == VLMModelType::LLAVA) { - request->update_processed_tokens_num(request->get_prompt_len()); - } - - language.get_tensor("inputs_embeds").set_shape({BATCH_SIZE, 1, m_vlm_config.hidden_size}); - language.get_tensor("position_ids").set_shape({ BATCH_SIZE, 1 }); - - while (!request->has_finished()) { - request->schedule_tokens(1); - size_t num_sequences = request->num_running_seqs(); - size_t total_num_tokens = request->get_num_scheduled_tokens() * num_sequences; - - ov::Tensor - input_ids(ov::element::i64, {total_num_tokens, 1}), - position_ids(ov::element::i64, {total_num_tokens, 1}), - beam_idx(ov::element::i32, { total_num_tokens }); - - int64_t - * input_ids_data = input_ids.data(), - * position_ids_data = position_ids.data(); - - size_t num_scheduled_tokens = request->get_num_scheduled_tokens(); - size_t group_position_id = request->get_num_processed_tokens(); - for (Sequence::Ptr& sequence : request->get_running_sequences()) { - for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) { - // compute token for current sequence - input_ids_data[token_id] = position_id < request->get_prompt_len() ? - request->get_prompt_ids()[position_id] : - sequence->get_generated_ids()[position_id - request->get_prompt_len()]; - - position_ids_data[token_id] = position_id; - } - // apply strides to shift to a next sequence - input_ids_data += num_scheduled_tokens; - position_ids_data += num_scheduled_tokens; - } - - embedding.set_input_tensor(input_ids); - - embedding.infer(); - const ov::Tensor& embed_prompt_tensor = embedding.get_output_tensor(); - float* embed_data = embed_prompt_tensor.data(); - for (auto idx = 0; idx < embed_prompt_tensor.get_size(); idx++) { - embed_data[idx] = embed_data[idx] * m_vlm_config.scale_emb; - } - - language.set_tensor("inputs_embeds", embed_prompt_tensor); - - language.get_tensor("attention_mask").set_shape({ total_num_tokens, language.get_tensor("attention_mask").get_shape()[1] + 1 }); - std::fill_n(language.get_tensor("attention_mask").data(), language.get_tensor("attention_mask").get_size(), 1); - - language.set_tensor("position_ids", position_ids); - - std::vector beam_idxs = sampler.get_beam_idxs(request->get_request_id()); - int32_t *beam_idx_data = beam_idx.data(); - if (total_num_tokens > beam_idxs.size()) { - std::fill_n(beam_idx_data, total_num_tokens, 0); - } else { - copy(beam_idxs.begin(), beam_idxs.end(), beam_idx_data); - } - language.set_tensor("beam_idx", beam_idx); - - language.infer(); - - if (streamer_ptr) { - // first sequence - int64_t out_token = request.get()->operator[](0)->get_generated_ids().back(); - if (streamer_ptr->put(out_token)) { - break; - } - } - - sampler_output = sampler.sample(requests, language.get_tensor("logits")); - } - - if (streamer_ptr) { - streamer_ptr->end(); - } - - EncodedGenerationResult result; - result.m_request_id = 1; - std::vector generation_outputs = generation->read_all(); - std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) { - return r1.score > r2.score; - }); - - auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size()); - for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { - const auto& generation_output = generation_outputs[generation_output_idx]; - result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); - result.m_scores.push_back(generation_output.score); - } - result.m_status = generation->get_status(); - - return result; -} -} // anonymous - - -class ov::genai::VLMPipeline::VLMPipelineImpl { -}; - -VLMPipeline::VLMPipeline( - const std::filesystem::path& model_dir, - const std::string& device, - const ov::AnyMap device_config -) : - m_vlm_config{ - utils::from_config_json_if_exists( - model_dir, "config.json" - ) - }, - m_tokenizer{Tokenizer(model_dir.string(), device_config)}, - m_vision_encoder(model_dir, m_vlm_config.model_type, device, device_config, ov::Core{}), - m_is_chat_conversation{false} { - if (m_vlm_config.model_type == VLMModelType::MINICPM) { - m_resampler = ov::Core{}.compile_model( - model_dir / "resampler.xml", device, device_config - ).create_infer_request(); - - m_embedding = ov::Core{}.compile_model( - model_dir / "embed_tokens.xml", device, device_config - ).create_infer_request(); - - m_language = ov::Core{}.compile_model( - model_dir / "language_model.xml", device, device_config - ).create_infer_request(); - - m_pos_embed_cache = get_2d_sincos_pos_embed(m_vlm_config.hidden_size, {70, 70}); - } else if (m_vlm_config.model_type == VLMModelType::LLAVA) { - m_language = ov::Core{}.compile_model( - model_dir / "openvino_language_model.xml", device, device_config - ).create_infer_request(); - - // Reusing the same m_embedding for llava text_embeddings model - m_embedding = ov::Core{}.compile_model( - model_dir / "openvino_text_embeddings_model.xml", device, device_config - ).create_infer_request(); - } - - m_language.get_tensor("attention_mask").set_shape({1, 0}); -} - -ov::genai::VLMPipeline::~VLMPipeline() = default; - -DecodedResults VLMPipeline::generate( - const std::string& prompt, - const std::vector& rgbs, - const GenerationConfig& generation_config, - const StreamerVariant& streamer -) { - - // inputs_embeds, tokenized_input; - std::pair processed_input; - if (m_vlm_config.model_type == VLMModelType::MINICPM) { - processed_input = get_inputs_embeds_minicpm(prompt, rgbs); - } else if (m_vlm_config.model_type == VLMModelType::LLAVA) { - processed_input = get_inputs_embeds_llava(prompt, rgbs); - } - - Sampler sampler = Sampler(m_tokenizer); - std::vector requests; - // request_id, input_ids, generation_config, block_size, enable_prefix_caching - // now we have one prompt as input, so we need one request - SequenceGroup::Ptr sequence_group = std::make_shared(0, processed_input.second, generation_config, 1, false); - size_t inputs_embeds_size = processed_input.first.get_shape()[1]; - size_t tokenized_prompt_size = processed_input.second.get_size(); - size_t num_processed_tokens = inputs_embeds_size <= tokenized_prompt_size ? tokenized_prompt_size - inputs_embeds_size : 0; - sequence_group->update_processed_tokens_num(num_processed_tokens); - sequence_group->set_sequence_group_ptr(sequence_group); - requests.push_back(sequence_group); - - std::shared_ptr streamer_ptr = std::visit(overloaded{ - [&m_tokenizer = m_tokenizer]( - const std::function& callback - ) -> std::shared_ptr { - return std::make_shared(m_tokenizer, callback); - }, - [](const std::shared_ptr& ptr) { - return ptr; - }, - [](std::monostate) { - return std::shared_ptr{nullptr}; - }, - }, streamer); - - if ((!(generation_config.is_greedy_decoding() || generation_config.is_multinomial())) && streamer_ptr) { - OPENVINO_THROW("Currently streaming is possible only for greedy or multinomial decoding"); - } - - EncodedGenerationResult encoded_result = get_lm_encoded_results(m_language, m_embedding, processed_input.first, m_vlm_config, streamer_ptr, sampler, requests); - - DecodedResults decoded; - for (size_t idx = 0; idx < encoded_result.m_generation_ids.size(); ++idx) { - decoded.texts.push_back(m_tokenizer.decode(encoded_result.m_generation_ids.at(idx))); - decoded.scores.push_back(encoded_result.m_scores.at(idx)); - } - - std::string decoded_results = decoded.texts.at(0); - if (m_is_chat_conversation) { - // Tail of chat template is missing in KV cache. - // Find the tail to concatenate it with the next input prompt. - m_templated_chat_history.append(decoded_results); - m_history.push_back({{"role", "assistant"}, {"content", decoded_results}}); - } else { - for (auto& variable : m_language.query_state()) { - variable.reset(); - } - m_language.get_tensor("attention_mask").set_shape({1, 0}); - } - return decoded; -} - -DecodedResults VLMPipeline::generate( - const std::string& prompt, - const ov::AnyMap& config_map -) { - auto image = config_map.find(ov::genai::image.name()); - auto images = config_map.find(ov::genai::images.name()); - OPENVINO_ASSERT( - config_map.end() == image || config_map.end() == images, - "Only one property can be set: image of images." - ); - std::vector rgbs; - if (config_map.end() != image) { - rgbs = {image->second.as()}; - } if (config_map.end() != images) { - rgbs = images->second.as>(); - } - ov::genai::OptionalGenerationConfig config_arg = utils::get_config_from_map(config_map); - GenerationConfig config = (config_arg.has_value()) ? *config_arg : get_generation_config(); - config.update_generation_config(config_map); - - // If eos_token_id was not provided, take value - if (config.eos_token_id == -1) - config.set_eos_token_id(m_tokenizer.get_eos_token_id()); - - return generate( - prompt, - rgbs, - config, - utils::get_streamer_from_map(config_map) - ); -} - -void VLMPipeline::start_chat(const std::string& system_message) { - m_is_chat_conversation = true; - bool have_state = 0 != m_language.get_tensor("attention_mask").get_size(); - if (have_state) { - // Resetting state may be slow. - for (ov::VariableState& variable : m_language.query_state()) { - variable.reset(); - } - // Since if is already introduced, move all resetting here. - m_language.get_tensor("attention_mask").set_shape({1, 0}); - m_history.clear(); - m_templated_chat_history.clear(); - } - if (system_message.empty()) { - return; - } - m_history = {{{"role", "system"}, {"content", system_message}}}; - constexpr bool add_generation_prompt = false; - m_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); -} - -void VLMPipeline::set_chat_template(const std::string& new_template) { - m_tokenizer.set_chat_template(new_template); -} - -GenerationConfig VLMPipeline::get_generation_config() const { - return m_generation_config; -} - -void VLMPipeline::set_generation_config(const GenerationConfig& new_config) { - m_generation_config = new_config; -} - -std::pair VLMPipeline::get_inputs_embeds_llava(const std::string& prompt, const std::vector& images) { - std::string image_token = ""; // TODO Consider getting from vlm_config or json - std::string formatted_prompt = "USER: " + (images.empty() ? prompt : image_token + "\n" + prompt) + " ASSISTANT:"; - ov::Tensor input_ids = m_tokenizer.encode(formatted_prompt).input_ids; - if (images.empty()) { - return std::pair{process_prompt(m_embedding, input_ids, m_vlm_config.scale_emb), input_ids}; - } else { - OPENVINO_ASSERT(1 == images.size(), "Only a single image allowed"); - EncodedImage encoded_image = m_vision_encoder.encode(images.at(0)); - ov::Tensor image_embeds = encoded_image.resized_source; - - ov::Tensor text_embeds = process_prompt(m_embedding, input_ids, m_vlm_config.scale_emb); - - int64_t image_token_index = 32000; // TODO Consider getting from m_vlm_config.image_token_index or config.json - - return std::pair{merge_text_and_image_embeddings_llava(input_ids, text_embeds, image_embeds, image_token_index), input_ids}; - } -} - -std::pair VLMPipeline::get_inputs_embeds_minicpm(const std::string& prompt, const std::vector& images) { - std::string images_prompt; - std::vector embeds; - for (const ov::Tensor& rgb : images) { - ov::Tensor reshaped = rgb; - ov::Shape rgb_shape = rgb.get_shape(); - switch (rgb_shape.size()) { - case 3: - reshaped.set_shape({1, rgb_shape.at(0), rgb_shape.at(1), rgb_shape.at(2)}); - break; - case 4: break; - default: OPENVINO_THROW("Input image must have [NHWC] or [HWC] layout"); - } - ov::Shape reshaped_shape = reshaped.get_shape(); - for (size_t batch_idx = 0; batch_idx < reshaped_shape.at(0); ++batch_idx) { - ov::Tensor single_image{ - ov::element::u8, - {1, reshaped_shape.at(1), reshaped_shape.at(2), reshaped_shape.at(3)}, - reshaped.data() + batch_idx * reshaped_shape.at(1) * reshaped_shape.at(1) * reshaped_shape.at(1) - }; - EncodedImage encoded_image = m_vision_encoder.encode(single_image); - if (m_vlm_config.use_image_id) { - images_prompt += m_vlm_config.im_id_start + std::to_string(image_id) + m_vlm_config.im_id_end; - ++image_id; - } - std::string unk64; - for (size_t idx = 0; idx < m_vlm_config.query_num; ++idx) { - unk64 += m_vlm_config.unk; - } - images_prompt += m_vlm_config.im_start + unk64 + m_vlm_config.im_end; - if (encoded_image.slices) { - ov::Shape slices_shape = encoded_image.slices.get_shape(); - for (size_t row_idx = 0; row_idx < slices_shape.at(0); ++row_idx) { - for (size_t col_idx = 0; col_idx < slices_shape.at(1); ++col_idx) { - images_prompt += m_vlm_config.slice_start + unk64 + m_vlm_config.slice_end; - } - images_prompt += '\n'; - } - } - if ('\n' != *(images_prompt.end() - 1)) { - // Image wasn't sliced, add \n to the end of image anyway. - // Strangely, \n isn't placed between . - images_prompt += '\n'; - } - embeds.push_back(std::move(encoded_image)); - } - } - images_prompt += prompt; - ov::Tensor encoded_input; - ov::Tensor new_chat_tokens; - if (m_is_chat_conversation) { - // KV cache in model already contains prompts and answers from previous iterations. - // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns - // token_ids = {, ...}. So if tokenizer applies only to the new prompt, - // will be inserted on every iteration. - // So actual pipeline calculates input_ids for whole chat history + for whole chat history without the new prompt - // and takes only the difference between them. - // The chat history cannot be saved as already encoded tokens because generate call doesn't return token, but - // KV cache contains it. So we have to add it manually or get it by tokenization all chat history. - m_history.push_back({{"role", "user"}, {"content", images_prompt}}); - constexpr bool add_generation_prompt = true; - std::string new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); - new_chat_tokens = m_tokenizer.encode(new_templated_chat_history).input_ids; - if (0 == m_language.get_tensor("attention_mask").get_shape().at(1)) { - encoded_input = new_chat_tokens; - } else { - TokenizedInputs prev_chat_tokens = m_tokenizer.encode( - m_templated_chat_history - ); - encoded_input = utils::subtract_chat_tokenized_inputs( - {new_chat_tokens}, prev_chat_tokens - ).input_ids; - } - m_templated_chat_history = std::move(new_templated_chat_history); - } else { - encoded_input = m_tokenizer.encode(images_prompt).input_ids; - } - m_embedding.set_input_tensor(encoded_input); - m_embedding.infer(); - ov::Tensor inputs_embeds = m_embedding.get_output_tensor(); - OPENVINO_ASSERT( - m_vlm_config.hidden_size == inputs_embeds.get_shape().at(2), - "Unexpected embedding size" - ); - ov::Tensor special_tokens = m_tokenizer.encode( - m_vlm_config.im_start - + m_vlm_config.im_end - + m_vlm_config.slice_start - + m_vlm_config.slice_end - ).input_ids; - OPENVINO_ASSERT( - 4 == special_tokens.get_shape().at(1), - "Every special token must be represented with a single int." - ); - int64_t im_start_id = special_tokens.data()[0]; - int64_t im_end_id = special_tokens.data()[1]; - int64_t slice_start_id = special_tokens.data()[2]; - int64_t slice_end_id = special_tokens.data()[3]; - int64_t im_start_pos = 0, slice_start_pos = 0; - int64_t* begin = encoded_input.data(); - int64_t* ids = begin; - size_t encoded_input_size = encoded_input.get_size(); - int64_t* end = ids + encoded_input_size; - float* inputs_embeds_data = inputs_embeds.data(); - for (const EncodedImage& encoded_image : embeds) { - const ov::Tensor& resampled_source = resample(*this, encoded_image.resized_source, {encoded_image.resized_source_size}); - float* emb = resampled_source.data(); - ids = std::find(ids, end, im_start_id); - OPENVINO_ASSERT(end != ids); - ++ids; - std::copy_n(emb, resampled_source.get_size(), inputs_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size); - ids += m_vlm_config.query_num; - if (encoded_image.slices) { - size_t token_idx = 0; - const ov::Shape& slices_shape = encoded_image.slices.get_shape(); - for (size_t i = 0; i < slices_shape.at(0); ++i) { - for (size_t ja = 0; ja < slices_shape.at(1); ++ja) { - size_t d2 = slices_shape.at(2); - size_t d3 = slices_shape.at(3); - ov::Tensor encoded_view{ov::element::f32, {1, d2, d3}, encoded_image.slices.data() + (i * slices_shape.at(1) + ja) * d2 * d3}; - const ov::Tensor& vision_embed_tensor_i_j = resample(*this, encoded_view, {encoded_image.slices_size}); - ids = std::find(ids, end, slice_start_id); - OPENVINO_ASSERT(end != ids); - ++ids; - std::copy_n(vision_embed_tensor_i_j.data(), vision_embed_tensor_i_j.get_size(), inputs_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size); - ids += m_vlm_config.query_num; - } - } - } - } - - return std::pair{inputs_embeds, m_is_chat_conversation ? new_chat_tokens : encoded_input}; -}