diff --git a/samples/cpp/visual_language_chat/visual_language_chat.cpp b/samples/cpp/visual_language_chat/visual_language_chat.cpp index 95342402cb..06f8f4e696 100644 --- a/samples/cpp/visual_language_chat/visual_language_chat.cpp +++ b/samples/cpp/visual_language_chat/visual_language_chat.cpp @@ -3,8 +3,11 @@ #include "load_image.hpp" #include +#include #include +namespace fs = std::filesystem; + bool print_subword(std::string&& subword) { return !(std::cout << subword << std::flush); } @@ -13,7 +16,40 @@ int main(int argc, char* argv[]) try { if (3 != argc) { throw std::runtime_error(std::string{"Usage "} + argv[0] + " "); } - ov::Tensor image = utils::load_image(argv[2]); + + // multinomial or beam_search can be used as well + ov::genai::GenerationConfig generation_config = ov::genai::greedy(); + // ov::genai::GenerationConfig generation_config = ov::genai::multinomial(); + // ov::genai::GenerationConfig generation_config = ov::genai::beam_search(); + + ov::AnyMap properies; + properies.insert(ov::genai::generation_config(generation_config)); + + // streamer could be used with greedy and multinomial + // if num_return_sequences > 1 in case of multinomial, the streamer will use the output from the first sequence + if (generation_config.is_greedy_decoding() or generation_config.is_multinomial()) { + properies.insert(ov::genai::streamer(print_subword)); + } + + std::vector images; + std::string input_path = argv[2]; + if (!input_path.empty() && fs::exists(input_path)) { + if (fs::is_directory(input_path)) { + for (const auto& dir_entry : fs::directory_iterator(input_path)) { + ov::Tensor image = utils::load_image(dir_entry.path()); + images.push_back(std::move(image)); + } + } else if (fs::is_regular_file(input_path)) { + ov::Tensor image = utils::load_image(input_path); + images.push_back(std::move(image)); + } + } + + if (images.empty()) + throw std::runtime_error("No one image found by path " + input_path); + else + properies.insert(images.size() == 1 ? ov::genai::image(images.at(0)) : ov::genai::images(images)); + std::string device = "CPU"; // GPU can be used as well ov::AnyMap enable_compile_cache; if ("GPU" == device) { @@ -26,16 +62,21 @@ int main(int argc, char* argv[]) try { pipe.start_chat(); std::cout << "question:\n"; + std::getline(std::cin, prompt); - pipe.generate( - prompt, - ov::genai::image(image), - ov::genai::streamer(print_subword) - ); + auto resuls = pipe.generate(prompt, properies); + if (generation_config.is_beam_search()) { + std::cout << resuls.texts.at(0) << std::endl; + } + properies.erase(images.size() == 1 ? "image" : "images"); + std::cout << "\n----------\n" "question:\n"; while (std::getline(std::cin, prompt)) { - pipe.generate(prompt, ov::genai::streamer(print_subword)); + resuls = pipe.generate(prompt, properies); + if (generation_config.is_beam_search()) { + std::cout << resuls.texts.at(0) << std::endl; + } std::cout << "\n----------\n" "question:\n"; } diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 5ae604c725..2e631f6201 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -230,6 +230,22 @@ Sampler::GroupBeamSearcher::GroupBeamSearcher(SequenceGroup::Ptr sequence_group, } } + +std::vector Sampler::GroupBeamSearcher::get_beam_idxs() { + std::vector next_beams; + + for (Group& group : m_groups) { + if (!group.done) { + for (Beam& beam : group.ongoing) { + next_beams.push_back(beam.m_global_beam_idx); + } + } + } + + return next_beams; +} + + void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output) { assert(m_parameters.num_beams % m_parameters.num_beam_groups == 0 && "number of beams should be divisible by number of groups"); @@ -581,6 +597,14 @@ void register_new_token(const Token& sampled_token_id, } }; +std::vector Sampler::get_beam_idxs(uint64_t request_id) { + auto beam_searcher = m_beam_search_info.find(request_id); + if (m_beam_search_info.find(request_id) == m_beam_search_info.end()) { + return { 0 }; + } + return beam_searcher->second.get_beam_idxs(); +} + std::list create_n_forked_sequences(SequenceGroup::Ptr sequence_group, LogitProcessor& logit_processor, diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 8188b35573..ca73cbb92d 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -65,6 +65,7 @@ class Sampler { SamplerOutput sample(std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false); void set_seed(size_t seed) { rng_engine.seed(seed); } void clear_beam_search_info(uint64_t request_id); + std::vector get_beam_idxs(uint64_t request_id); }; class Sampler::GroupBeamSearcher { @@ -109,5 +110,6 @@ class Sampler::GroupBeamSearcher { void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output); void finalize(SamplerOutput& sampler_output); + std::vector get_beam_idxs(); }; } diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index a75b5a5bb8..b82677f689 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -298,6 +298,135 @@ ov::Tensor merge_text_and_image_embeddings_llava( return merged_embeds; } + +EncodedGenerationResult get_lm_encoded_results( + ov::InferRequest& language, + ov::InferRequest& embedding, + ov::Tensor inputs_embeds, + const VLMConfig m_vlm_config, + const std::shared_ptr streamer_ptr, + Sampler& sampler, + std::vector requests +) { + SequenceGroup::Ptr request = requests.back(); + GenerationHandle generation = std::make_shared(request->get_generation_stream(), request->get_sampling_parameters()); + + language.set_tensor("inputs_embeds", inputs_embeds); + + size_t history_len = language.get_tensor("attention_mask").get_shape().at(1); + language.get_tensor("attention_mask").set_shape({1, history_len + inputs_embeds.get_shape()[1]}); + std::fill_n(language.get_tensor("attention_mask").data(), language.get_tensor("attention_mask").get_size(), 1); + + language.get_tensor("position_ids").set_shape({1, inputs_embeds.get_shape().at(1)}); + std::iota(language.get_tensor("position_ids").data(), language.get_tensor("position_ids").data() + language.get_tensor("position_ids").get_size(), history_len); + + language.get_tensor("beam_idx").set_shape({ BATCH_SIZE }); + language.get_tensor("beam_idx").data()[0] = 0; + + language.infer(); + + int64_t sequence_len = language.get_tensor("logits").get_shape().at(1); + request->schedule_tokens(sequence_len); + + SamplerOutput sampler_output = sampler.sample(requests, language.get_tensor("logits")); + + // logits include image and prompt embedings + if (m_vlm_config.model_type == VLMModelType::LLAVA) { + request->update_processed_tokens_num(request->get_prompt_len()); + } + + language.get_tensor("inputs_embeds").set_shape({BATCH_SIZE, 1, m_vlm_config.hidden_size}); + language.get_tensor("position_ids").set_shape({ BATCH_SIZE, 1 }); + + while (!request->has_finished()) { + request->schedule_tokens(1); + size_t num_sequences = request->num_running_seqs(); + size_t total_num_tokens = request->get_num_scheduled_tokens() * num_sequences; + + ov::Tensor + input_ids(ov::element::i64, {total_num_tokens, 1}), + position_ids(ov::element::i64, {total_num_tokens, 1}), + beam_idx(ov::element::i32, { total_num_tokens }); + + int64_t + * input_ids_data = input_ids.data(), + * position_ids_data = position_ids.data(); + + size_t num_scheduled_tokens = request->get_num_scheduled_tokens(); + size_t group_position_id = request->get_num_processed_tokens(); + for (Sequence::Ptr& sequence : request->get_running_sequences()) { + for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) { + // compute token for current sequence + input_ids_data[token_id] = position_id < request->get_prompt_len() ? + request->get_prompt_ids()[position_id] : + sequence->get_generated_ids()[position_id - request->get_prompt_len()]; + + position_ids_data[token_id] = position_id; + } + // apply strides to shift to a next sequence + input_ids_data += num_scheduled_tokens; + position_ids_data += num_scheduled_tokens; + } + + embedding.set_input_tensor(input_ids); + + embedding.infer(); + const ov::Tensor& embed_prompt_tensor = embedding.get_output_tensor(); + float* embed_data = embed_prompt_tensor.data(); + for (auto idx = 0; idx < embed_prompt_tensor.get_size(); idx++) { + embed_data[idx] = embed_data[idx] * m_vlm_config.scale_emb; + } + + language.set_tensor("inputs_embeds", embed_prompt_tensor); + + language.get_tensor("attention_mask").set_shape({ total_num_tokens, language.get_tensor("attention_mask").get_shape()[1] + 1 }); + std::fill_n(language.get_tensor("attention_mask").data(), language.get_tensor("attention_mask").get_size(), 1); + + language.set_tensor("position_ids", position_ids); + + std::vector beam_idxs = sampler.get_beam_idxs(request->get_request_id()); + int32_t *beam_idx_data = beam_idx.data(); + if (total_num_tokens > beam_idxs.size()) { + std::fill_n(beam_idx_data, total_num_tokens, 0); + } else { + copy(beam_idxs.begin(), beam_idxs.end(), beam_idx_data); + } + language.set_tensor("beam_idx", beam_idx); + + language.infer(); + + if (streamer_ptr) { + // first sequence + int64_t out_token = request.get()->operator[](0)->get_generated_ids().back(); + if (streamer_ptr->put(out_token)) { + break; + } + } + + sampler_output = sampler.sample(requests, language.get_tensor("logits")); + } + + if (streamer_ptr) { + streamer_ptr->end(); + } + + EncodedGenerationResult result; + result.m_request_id = 1; + std::vector generation_outputs = generation->read_all(); + std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) { + return r1.score > r2.score; + }); + + auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size()); + for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { + const auto& generation_output = generation_outputs[generation_output_idx]; + result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); + result.m_scores.push_back(generation_output.score); + } + result.m_status = generation->get_status(); + + return result; +} } class ov::genai::VLMPipeline::VLMPipelineImpl { @@ -382,40 +511,26 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { const GenerationConfig& generation_config, const StreamerVariant& streamer ) { - ov::Tensor inputs_embeds; + // inputs_embeds, tokenized_input; + std::pair processed_input; if (m_vlm_config.model_type == VLMModelType::MINICPM) { - inputs_embeds = get_inputs_embeds_minicpm(prompt, rgbs); + processed_input = get_inputs_embeds_minicpm(prompt, rgbs); } else if (m_vlm_config.model_type == VLMModelType::LLAVA) { - inputs_embeds = get_inputs_embeds_llava(prompt, rgbs); + processed_input = get_inputs_embeds_llava(prompt, rgbs); } - m_language.set_tensor("inputs_embeds", inputs_embeds); - size_t history_len = m_language.get_tensor("attention_mask").get_shape().at(1); - m_language.get_tensor("attention_mask").set_shape({1, history_len + inputs_embeds.get_shape()[1]}); - std::fill_n(m_language.get_tensor("attention_mask").data(), m_language.get_tensor("attention_mask").get_size(), 1); - - m_language.get_tensor("position_ids").set_shape({1, inputs_embeds.get_shape().at(1)}); - std::iota(m_language.get_tensor("position_ids").data(), m_language.get_tensor("position_ids").data() + m_language.get_tensor("position_ids").get_size(), history_len); - - m_language.get_tensor("beam_idx").set_shape({ BATCH_SIZE }); - m_language.get_tensor("beam_idx").data()[0] = 0; - - m_language.infer(); - - ov::Shape logits_shape = m_language.get_tensor("logits").get_shape(); - auto attention_size = m_language.get_tensor("attention_mask").get_size(); + Sampler sampler = Sampler(m_tokenizer); + std::vector requests; + // request_id, input_ids, generation_config, block_size, enable_prefix_caching + // now we have one prompt as input, so we need one request + SequenceGroup::Ptr sequence_group = std::make_shared(0, processed_input.second, generation_config, 1, false); + size_t inputs_embeds_size = processed_input.first.get_shape()[1]; + size_t tokenized_prompt_size = processed_input.second.get_size(); + size_t num_processed_tokens = inputs_embeds_size <= tokenized_prompt_size ? tokenized_prompt_size - inputs_embeds_size : 0; + sequence_group->update_processed_tokens_num(num_processed_tokens); + sequence_group->set_sequence_group_ptr(sequence_group); + requests.push_back(sequence_group); - int64_t sequence_len = m_language.get_tensor("logits").get_shape().at(1) - 1; - size_t vocab_size = m_language.get_tensor("logits").get_shape().back(); - float* logits = m_language.get_tensor("logits").data() + sequence_len * vocab_size; - int64_t out_token = std::max_element(logits, logits + vocab_size) - logits; - - m_language.get_tensor("inputs_embeds").set_shape({BATCH_SIZE, 1, m_vlm_config.hidden_size}); - m_language.get_tensor("position_ids").set_shape({ BATCH_SIZE, 1 }); - - m_embedding.get_input_tensor().set_shape({ 1, 1 }); - - int64_t eos_token_id = m_tokenizer.get_eos_token_id(); std::shared_ptr streamer_ptr = std::visit(overloaded{ [&m_tokenizer = m_tokenizer]( const std::function& callback @@ -429,40 +544,20 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { return std::shared_ptr{nullptr}; }, }, streamer); - std::vector generated; - while (true) { //(out_token != eos_token_id) - m_embedding.get_input_tensor().data()[0] = out_token; - m_embedding.infer(); - const ov::Tensor& embed_prompt_tensor = m_embedding.get_output_tensor(); - float* embed_data = embed_prompt_tensor.data(); - for (auto idx = 0; idx < embed_prompt_tensor.get_size(); idx++) { - embed_data[idx] = embed_data[idx] * m_vlm_config.scale_emb; - } - m_language.set_tensor("inputs_embeds", embed_prompt_tensor); - m_language.get_tensor("attention_mask").set_shape({ BATCH_SIZE, m_language.get_tensor("attention_mask").get_shape()[1] + 1 }); - std::fill_n(m_language.get_tensor("attention_mask").data(), m_language.get_tensor("attention_mask").get_size(), 1); - m_language.get_tensor("position_ids").data()[0] = int64_t(m_language.get_tensor("attention_mask").get_size() - 1); - - m_language.infer(); - - generated.push_back(out_token); - if (streamer_ptr && streamer_ptr->put(out_token)) { - break; - } - logits = m_language.get_tensor("logits").data(); - - out_token = std::max_element(logits, logits + vocab_size) - logits; - if (out_token == eos_token_id) { - break; - } + if ((!(generation_config.is_greedy_decoding() || generation_config.is_multinomial())) && streamer_ptr) { + OPENVINO_THROW("Currently streaming is possible only for greedy or multinomial decoding"); } - if (streamer_ptr) { - streamer_ptr->end(); + EncodedGenerationResult encoded_result = get_lm_encoded_results(m_language, m_embedding, processed_input.first, m_vlm_config, streamer_ptr, sampler, requests); + + DecodedResults decoded; + for (size_t idx = 0; idx < encoded_result.m_generation_ids.size(); ++idx) { + decoded.texts.push_back(m_tokenizer.decode(encoded_result.m_generation_ids.at(idx))); + decoded.scores.push_back(encoded_result.m_scores.at(idx)); } - std::string decoded_results = m_tokenizer.decode(generated); + std::string decoded_results = decoded.texts.at(0); if (m_is_chat_conversation) { // Tail of chat template is missing in KV cache. // Find the tail to concatenate it with the next input prompt. @@ -474,7 +569,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { } m_language.get_tensor("attention_mask").set_shape({1, 0}); } - return {{std::move(decoded_results)}}; + return decoded; } DecodedResults generate( @@ -554,7 +649,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { int64_t image_token_index = 32000; // TODO Consider getting from m_vlm_config.image_token_index or config.json - return merge_text_and_image_embeddings_llava(input_ids, text_embeds, image_embeds, image_token_index); + return std::pair{merge_text_and_image_embeddings_llava(input_ids, text_embeds, image_embeds, image_token_index), input_ids}; } } @@ -607,6 +702,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { } images_prompt += prompt; ov::Tensor encoded_input; + ov::Tensor new_chat_tokens; if (m_is_chat_conversation) { // KV cache in model already contains prompts and answers from previous iterations. // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns @@ -619,7 +715,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { m_history.push_back({{"role", "user"}, {"content", images_prompt}}); constexpr bool add_generation_prompt = true; std::string new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); - ov::Tensor new_chat_tokens = m_tokenizer.encode(new_templated_chat_history).input_ids; + new_chat_tokens = m_tokenizer.encode(new_templated_chat_history).input_ids; if (0 == m_language.get_tensor("attention_mask").get_shape().at(1)) { encoded_input = new_chat_tokens; } else { @@ -688,7 +784,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { } } - return inputs_embeds; + return std::pair{inputs_embeds, m_is_chat_conversation ? new_chat_tokens : encoded_input}; } ov::Tensor resample(VLMPipeline::VLMPipelineImpl& pipe, const ov::Tensor& encoded_image, const std::vector& target_sizes) { diff --git a/src/cpp/src/vlm_pipeline.cpp b/src/cpp/src/vlm_pipeline.cpp new file mode 100644 index 0000000000..332acb3928 --- /dev/null +++ b/src/cpp/src/vlm_pipeline.cpp @@ -0,0 +1,747 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "openvino/genai/vlm_pipeline.hpp" +#include "openvino/genai/tokenizer.hpp" +#include "sampler.hpp" +#include "clip.hpp" +#include +#include "../src/text_callback_streamer.hpp" +#include "utils.hpp" +#include +#include + +using namespace ov::genai; + +namespace { +template struct overloaded : Ts... {using Ts::operator()...;}; +template overloaded(Ts...) -> overloaded; + +constexpr size_t BATCH_SIZE = 1; + +ov::Tensor process_prompt(ov::InferRequest& embedding, const ov::Tensor& prompt, float scale_emb) { + embedding.set_input_tensor(prompt); + embedding.infer(); + + const ov::Tensor& embed_output_tensor = embedding.get_output_tensor(); + + ov::Shape out_shape = embed_output_tensor.get_shape(); + float* data = embed_output_tensor.data(); + + //embedding * scale_emb + for (size_t idx = 0; idx < embed_output_tensor.get_size(); idx++) { + data[idx] = data[idx] * scale_emb; + } + return embed_output_tensor; +} + +ov::Tensor concatenate_last_dim(const ov::Tensor& first, const ov::Tensor& second) { + size_t res_d_0 = first.get_shape().at(0); + size_t res_d_1 = first.get_shape().at(1); + OPENVINO_ASSERT(second.get_shape().at(0) == res_d_0); + OPENVINO_ASSERT(second.get_shape().at(1) == res_d_1); + size_t res_d_2 = first.get_shape().at(2) + second.get_shape().at(2); + ov::Tensor res{first.get_element_type(), {res_d_0, res_d_1, res_d_2}}; + float* first_data = first.data(); + float* second_data = second.data(); + float* res_data = res.data(); + for (size_t i = 0; i < res_d_0; ++i) { + for (size_t j = 0; j < res_d_1; ++j) { + size_t k = 0; + for (; k < first.get_shape().at(2); ++k) { + res_data[i * res_d_1 * res_d_2 + j * res_d_2 + k] + = first_data[i * res_d_1 * first.get_shape().at(2) + j * first.get_shape().at(2) + k]; + } + for (size_t l = 0; l < second.get_shape().at(2); ++l, ++k) { + res_data[i * res_d_1 * res_d_2 + j * res_d_2 + k] + = second_data[i * res_d_1 * second.get_shape().at(2) + j * second.get_shape().at(2) + l]; + } + } + } + return res; +} + +ov::Tensor concatenate_mid_dim(const ov::Tensor& first, const ov::Tensor& second) { + size_t res_d_0 = first.get_shape().at(0); + size_t res_d_2 = first.get_shape().at(2); + OPENVINO_ASSERT(second.get_shape().at(0) == res_d_0); + OPENVINO_ASSERT(second.get_shape().at(2) == res_d_2); + size_t res_d_1 = first.get_shape().at(1) + second.get_shape().at(1); + ov::Tensor res{first.get_element_type(), {res_d_0, res_d_1, res_d_2}}; + float* first_data = first.data(); + float* second_data = second.data(); + float* res_data = res.data(); + for (size_t i = 0; i < res_d_0; ++i) { + size_t j = 0; + for (; j < first.get_shape().at(1); ++j) { + std::copy_n( + first_data + i * first.get_shape().at(1) * res_d_2 + j * res_d_2, + res_d_2, + res_data + i * res_d_1 * res_d_2 + j * res_d_2 + ); + } + for (size_t k = 0; k < second.get_shape().at(1); ++k, ++j) { + std::copy_n( + second_data + i * second.get_shape().at(1) * res_d_2 + k * res_d_2, + res_d_2, + res_data + i * res_d_1 * res_d_2 + j * res_d_2 + ); + } + } + return res; +} + +/// embed_dim: output dimension for each position +/// pos: a list of positions to be encoded: size (H, W) +/// out: (H, W, D) +ov::Tensor get_1d_sincos_pos_embed_from_grid_new(size_t embed_dim, const ov::Tensor& pos) { + OPENVINO_ASSERT(embed_dim % 2 == 0); + ov::Shape pos_shape = pos.get_shape(); + size_t H = pos_shape[0]; + size_t W = pos_shape[1]; + + std::vector omega(embed_dim / 2); + for (size_t i = 0; i < omega.size(); ++i) { + omega[i] = 1.0f / std::pow(10000.0f, float(i) / (embed_dim / 2)); + } + + std::vector out_shape = {H, W, embed_dim}; + ov::Tensor emb(ov::element::f32, out_shape); + + float* pos_data = pos.data(); + float* emb_data = emb.data(); + + size_t counter = 0; + for (size_t h = 0; h < H; ++h) { + for (size_t w = 0; w < W; ++w) { + for (size_t d = 0; d < embed_dim / 2; ++d) { + // Correctly access the 2D position grid + float value = omega[d] * pos_data[h * W + w]; + // There should be sinf() and cosf(), but they don't exist on default Ubuntu20 gcc. + emb_data[h * W * embed_dim + w * embed_dim + d] = std::sin(double(value)); + emb_data[h * W * embed_dim + w * embed_dim + d + (embed_dim / 2)] = std::cos(double(value)); + } + } + } + return emb; +} + +ov::Tensor get_2d_sincos_pos_embed_from_grid(size_t embed_dim, const ov::Tensor& grid) { + OPENVINO_ASSERT(embed_dim % 2 == 0); + ov::Shape grid_shape = grid.get_shape(); + float* grid_data = grid.data(); + ov::Shape plane_shape{grid_shape.at(1), grid_shape.at(2)}; + ov::Tensor emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, ov::Tensor{ + ov::element::f32, + plane_shape, + grid_data + }); // (H, W, D/2) + ov::Tensor emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, ov::Tensor{ + ov::element::f32, + plane_shape, + grid_data + plane_shape.at(0) * plane_shape.at(1) + }); // (H, W, D/2) + return concatenate_last_dim(emb_h, emb_w); +} + +/// image_size: image_size or (image_height, image_width) +/// return: +/// pos_embed: [image_height, image_width, embed_dim] +ov::Tensor get_2d_sincos_pos_embed(size_t embed_dim, const ImageSize& image_size) { + size_t grid_h_size = image_size.height, grid_w_size = image_size.width; + ov::Tensor grid(ov::element::f32, {2, grid_h_size, grid_w_size}); + float* data = grid.data(); + for (size_t y = 0; y < grid_h_size; ++y) { + std::iota(data, data + grid_w_size, 0.0f); + data += grid_w_size; + } + for (float y = 0.0f; y < grid_h_size; ++y) { + std::fill(data, data + grid_w_size, y); + data += grid_w_size; + } + return get_2d_sincos_pos_embed_from_grid(embed_dim, grid); +} + +void adjust_pos_cache( + const std::vector& target_sizes, + size_t hidden_size, + ov::Tensor& pos_embed_cache +) { + size_t max_h = std::max_element(target_sizes.begin(), target_sizes.end(), [](const ImageSize& left, const ImageSize& right) { + return left.height < right.height; + })->height; + size_t max_w = std::max_element(target_sizes.begin(), target_sizes.end(), [](const ImageSize& left, const ImageSize& right) { + return left.width < right.width; + })->width; + size_t allocated_height, allocated_width; + if (pos_embed_cache) { + const ov::Shape& allocated_shape = pos_embed_cache.get_shape(); + allocated_height = allocated_shape.at(0); + allocated_width = allocated_shape.at(1); + } else { + allocated_height = allocated_width = 70; + } + if (max_h > allocated_height || max_w > allocated_width) { + allocated_height = std::max(max_h, allocated_height); + allocated_width = std::max(max_w, allocated_width); + pos_embed_cache = get_2d_sincos_pos_embed( + hidden_size, {allocated_height, allocated_width} + ); + } +} + +ov::Tensor resample(VLMPipeline& pipe, const ov::Tensor& encoded_image, const std::vector& target_sizes) { + size_t bs = encoded_image.get_shape().at(0); + std::vector patch_len{target_sizes.size()}; + std::transform(target_sizes.begin(), target_sizes.end(), patch_len.begin(), [](const ImageSize& height_width) { + return height_width.height * height_width.width; + }); + adjust_pos_cache( + target_sizes, + pipe.m_vlm_config.hidden_size, + pipe.m_pos_embed_cache + ); + size_t max_patch_len = *std::max_element(patch_len.begin(), patch_len.end()); + ov::Tensor key_padding_mask(ov::element::boolean, {bs, max_patch_len}); + bool* mask_data = key_padding_mask.data(); + size_t embed_len = pipe.m_pos_embed_cache.get_shape().at(2); + ov::Tensor pos_embed(ov::element::f32, {max_patch_len, bs, embed_len}); // BLD => L * B * D + float* pos_embed_data = pos_embed.data(); + float* cache_data = pipe.m_pos_embed_cache.data(); + size_t _d0 = pipe.m_pos_embed_cache.get_shape().at(0); + size_t _d1 = pipe.m_pos_embed_cache.get_shape().at(1); + for (size_t i = 0; i < bs; ++i) { + size_t target_h = target_sizes.at(i).height; + size_t target_w = target_sizes.at(i).width; + for (size_t h_idx = 0; h_idx < target_h; ++h_idx) { + for (size_t w_idx = 0; w_idx < target_w; ++w_idx) { + std::copy_n( + cache_data + (h_idx * _d1 + w_idx) * embed_len, + embed_len, + pos_embed_data + (h_idx * target_w + w_idx) * bs * embed_len + i * embed_len + ); + } + } + for (size_t flat = target_h * target_w; flat < max_patch_len; ++flat) { + std::fill_n(pos_embed_data + flat * bs * embed_len + i * embed_len, embed_len, 0.0f); + } + std::fill_n(mask_data + i * max_patch_len, patch_len[i], false); + std::fill_n(mask_data + i * max_patch_len + patch_len[i], max_patch_len - patch_len[i], true); + } + pipe.m_resampler.set_tensor("x", encoded_image); // [N, H*W, old_hidden_size] + pipe.m_resampler.set_tensor("pos_embed", pos_embed); // [H*W, N, new_hidden_size] + pipe.m_resampler.set_tensor("key_padding_mask", key_padding_mask); // [N, H*W] + pipe.m_resampler.infer(); + return pipe.m_resampler.get_output_tensor(); // [N, query_num, new_hidden_size] +} + +ov::Tensor merge_text_and_image_embeddings_llava( + const ov::Tensor& input_ids, + const ov::Tensor& text_embeds, + const ov::Tensor& image_embeds, + int64_t image_token_index +) { + auto text_embeds_shape = text_embeds.get_shape(); + auto image_embeds_shape = image_embeds.get_shape(); + + OPENVINO_ASSERT( + text_embeds_shape[2] == image_embeds_shape[2], + "Incompatible shapes between text_embeds and image_embeds" + ); + + size_t text_embeds_seq_length = text_embeds_shape[1]; + size_t hidden_size = text_embeds_shape[2]; + size_t image_embeds_seq_length = image_embeds_shape[1]; + + size_t merged_seq_length = text_embeds_seq_length + (image_embeds_seq_length - 1); + + ov::Tensor merged_embeds(text_embeds.get_element_type(), {BATCH_SIZE, merged_seq_length, hidden_size}); + + const int64_t* input_ids_data = input_ids.data(); + const float* text_embeds_data = text_embeds.data(); + const float* image_embeds_data = image_embeds.data(); + float* merged_data = merged_embeds.data(); + + + size_t merged_idx = 0; + for (size_t s = 0; s < text_embeds_seq_length; ++s) { + if (input_ids_data[s] == image_token_index) { + for (size_t i = 0; i < image_embeds_seq_length; ++i) { + std::copy_n(image_embeds_data + i * hidden_size, + hidden_size, + merged_data + merged_idx * hidden_size); + merged_idx++; + } + } else { + std::copy_n(text_embeds_data + s * hidden_size, + hidden_size, + merged_data + merged_idx * hidden_size); + merged_idx++; + } + } + + return merged_embeds; +} + +EncodedGenerationResult get_lm_encoded_results( + ov::InferRequest& language, + ov::InferRequest& embedding, + ov::Tensor inputs_embeds, + const VLMConfig m_vlm_config, + const std::shared_ptr streamer_ptr, + Sampler& sampler, + std::vector requests +) { + SequenceGroup::Ptr request = requests.back(); + GenerationHandle generation = std::make_shared(request->get_generation_stream(), request->get_sampling_parameters()); + + language.set_tensor("inputs_embeds", inputs_embeds); + + size_t history_len = language.get_tensor("attention_mask").get_shape().at(1); + language.get_tensor("attention_mask").set_shape({1, history_len + inputs_embeds.get_shape()[1]}); + std::fill_n(language.get_tensor("attention_mask").data(), language.get_tensor("attention_mask").get_size(), 1); + + language.get_tensor("position_ids").set_shape({1, inputs_embeds.get_shape().at(1)}); + std::iota(language.get_tensor("position_ids").data(), language.get_tensor("position_ids").data() + language.get_tensor("position_ids").get_size(), history_len); + + language.get_tensor("beam_idx").set_shape({ BATCH_SIZE }); + language.get_tensor("beam_idx").data()[0] = 0; + + language.infer(); + + int64_t sequence_len = language.get_tensor("logits").get_shape().at(1); + request->schedule_tokens(sequence_len); + + SamplerOutput sampler_output = sampler.sample(requests, language.get_tensor("logits")); + + // logits include image and prompt embedings + if (m_vlm_config.model_type == VLMModelType::LLAVA) { + request->update_processed_tokens_num(request->get_prompt_len()); + } + + language.get_tensor("inputs_embeds").set_shape({BATCH_SIZE, 1, m_vlm_config.hidden_size}); + language.get_tensor("position_ids").set_shape({ BATCH_SIZE, 1 }); + + while (!request->has_finished()) { + request->schedule_tokens(1); + size_t num_sequences = request->num_running_seqs(); + size_t total_num_tokens = request->get_num_scheduled_tokens() * num_sequences; + + ov::Tensor + input_ids(ov::element::i64, {total_num_tokens, 1}), + position_ids(ov::element::i64, {total_num_tokens, 1}), + beam_idx(ov::element::i32, { total_num_tokens }); + + int64_t + * input_ids_data = input_ids.data(), + * position_ids_data = position_ids.data(); + + size_t num_scheduled_tokens = request->get_num_scheduled_tokens(); + size_t group_position_id = request->get_num_processed_tokens(); + for (Sequence::Ptr& sequence : request->get_running_sequences()) { + for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) { + // compute token for current sequence + input_ids_data[token_id] = position_id < request->get_prompt_len() ? + request->get_prompt_ids()[position_id] : + sequence->get_generated_ids()[position_id - request->get_prompt_len()]; + + position_ids_data[token_id] = position_id; + } + // apply strides to shift to a next sequence + input_ids_data += num_scheduled_tokens; + position_ids_data += num_scheduled_tokens; + } + + embedding.set_input_tensor(input_ids); + + embedding.infer(); + const ov::Tensor& embed_prompt_tensor = embedding.get_output_tensor(); + float* embed_data = embed_prompt_tensor.data(); + for (auto idx = 0; idx < embed_prompt_tensor.get_size(); idx++) { + embed_data[idx] = embed_data[idx] * m_vlm_config.scale_emb; + } + + language.set_tensor("inputs_embeds", embed_prompt_tensor); + + language.get_tensor("attention_mask").set_shape({ total_num_tokens, language.get_tensor("attention_mask").get_shape()[1] + 1 }); + std::fill_n(language.get_tensor("attention_mask").data(), language.get_tensor("attention_mask").get_size(), 1); + + language.set_tensor("position_ids", position_ids); + + std::vector beam_idxs = sampler.get_beam_idxs(request->get_request_id()); + int32_t *beam_idx_data = beam_idx.data(); + if (total_num_tokens > beam_idxs.size()) { + std::fill_n(beam_idx_data, total_num_tokens, 0); + } else { + copy(beam_idxs.begin(), beam_idxs.end(), beam_idx_data); + } + language.set_tensor("beam_idx", beam_idx); + + language.infer(); + + if (streamer_ptr) { + // first sequence + int64_t out_token = request.get()->operator[](0)->get_generated_ids().back(); + if (streamer_ptr->put(out_token)) { + break; + } + } + + sampler_output = sampler.sample(requests, language.get_tensor("logits")); + } + + if (streamer_ptr) { + streamer_ptr->end(); + } + + EncodedGenerationResult result; + result.m_request_id = 1; + std::vector generation_outputs = generation->read_all(); + std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) { + return r1.score > r2.score; + }); + + auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size()); + for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { + const auto& generation_output = generation_outputs[generation_output_idx]; + result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); + result.m_scores.push_back(generation_output.score); + } + result.m_status = generation->get_status(); + + return result; +} +} // anonymous + + +class ov::genai::VLMPipeline::VLMPipelineImpl { +}; + +VLMPipeline::VLMPipeline( + const std::filesystem::path& model_dir, + const std::string& device, + const ov::AnyMap device_config +) : + m_vlm_config{ + utils::from_config_json_if_exists( + model_dir, "config.json" + ) + }, + m_tokenizer{Tokenizer(model_dir.string(), device_config)}, + m_vision_encoder(model_dir, m_vlm_config.model_type, device, device_config, ov::Core{}), + m_is_chat_conversation{false} { + if (m_vlm_config.model_type == VLMModelType::MINICPM) { + m_resampler = ov::Core{}.compile_model( + model_dir / "resampler.xml", device, device_config + ).create_infer_request(); + + m_embedding = ov::Core{}.compile_model( + model_dir / "embed_tokens.xml", device, device_config + ).create_infer_request(); + + m_language = ov::Core{}.compile_model( + model_dir / "language_model.xml", device, device_config + ).create_infer_request(); + + m_pos_embed_cache = get_2d_sincos_pos_embed(m_vlm_config.hidden_size, {70, 70}); + } else if (m_vlm_config.model_type == VLMModelType::LLAVA) { + m_language = ov::Core{}.compile_model( + model_dir / "openvino_language_model.xml", device, device_config + ).create_infer_request(); + + // Reusing the same m_embedding for llava text_embeddings model + m_embedding = ov::Core{}.compile_model( + model_dir / "openvino_text_embeddings_model.xml", device, device_config + ).create_infer_request(); + } + + m_language.get_tensor("attention_mask").set_shape({1, 0}); +} + +ov::genai::VLMPipeline::~VLMPipeline() = default; + +DecodedResults VLMPipeline::generate( + const std::string& prompt, + const std::vector& rgbs, + const GenerationConfig& generation_config, + const StreamerVariant& streamer +) { + + // inputs_embeds, tokenized_input; + std::pair processed_input; + if (m_vlm_config.model_type == VLMModelType::MINICPM) { + processed_input = get_inputs_embeds_minicpm(prompt, rgbs); + } else if (m_vlm_config.model_type == VLMModelType::LLAVA) { + processed_input = get_inputs_embeds_llava(prompt, rgbs); + } + + Sampler sampler = Sampler(m_tokenizer); + std::vector requests; + // request_id, input_ids, generation_config, block_size, enable_prefix_caching + // now we have one prompt as input, so we need one request + SequenceGroup::Ptr sequence_group = std::make_shared(0, processed_input.second, generation_config, 1, false); + size_t inputs_embeds_size = processed_input.first.get_shape()[1]; + size_t tokenized_prompt_size = processed_input.second.get_size(); + size_t num_processed_tokens = inputs_embeds_size <= tokenized_prompt_size ? tokenized_prompt_size - inputs_embeds_size : 0; + sequence_group->update_processed_tokens_num(num_processed_tokens); + sequence_group->set_sequence_group_ptr(sequence_group); + requests.push_back(sequence_group); + + std::shared_ptr streamer_ptr = std::visit(overloaded{ + [&m_tokenizer = m_tokenizer]( + const std::function& callback + ) -> std::shared_ptr { + return std::make_shared(m_tokenizer, callback); + }, + [](const std::shared_ptr& ptr) { + return ptr; + }, + [](std::monostate) { + return std::shared_ptr{nullptr}; + }, + }, streamer); + + if ((!(generation_config.is_greedy_decoding() || generation_config.is_multinomial())) && streamer_ptr) { + OPENVINO_THROW("Currently streaming is possible only for greedy or multinomial decoding"); + } + + EncodedGenerationResult encoded_result = get_lm_encoded_results(m_language, m_embedding, processed_input.first, m_vlm_config, streamer_ptr, sampler, requests); + + DecodedResults decoded; + for (size_t idx = 0; idx < encoded_result.m_generation_ids.size(); ++idx) { + decoded.texts.push_back(m_tokenizer.decode(encoded_result.m_generation_ids.at(idx))); + decoded.scores.push_back(encoded_result.m_scores.at(idx)); + } + + std::string decoded_results = decoded.texts.at(0); + if (m_is_chat_conversation) { + // Tail of chat template is missing in KV cache. + // Find the tail to concatenate it with the next input prompt. + m_templated_chat_history.append(decoded_results); + m_history.push_back({{"role", "assistant"}, {"content", decoded_results}}); + } else { + for (auto& variable : m_language.query_state()) { + variable.reset(); + } + m_language.get_tensor("attention_mask").set_shape({1, 0}); + } + return decoded; +} + +DecodedResults VLMPipeline::generate( + const std::string& prompt, + const ov::AnyMap& config_map +) { + auto image = config_map.find(ov::genai::image.name()); + auto images = config_map.find(ov::genai::images.name()); + OPENVINO_ASSERT( + config_map.end() == image || config_map.end() == images, + "Only one property can be set: image of images." + ); + std::vector rgbs; + if (config_map.end() != image) { + rgbs = {image->second.as()}; + } if (config_map.end() != images) { + rgbs = images->second.as>(); + } + ov::genai::OptionalGenerationConfig config_arg = utils::get_config_from_map(config_map); + GenerationConfig config = (config_arg.has_value()) ? *config_arg : get_generation_config(); + config.update_generation_config(config_map); + + // If eos_token_id was not provided, take value + if (config.eos_token_id == -1) + config.set_eos_token_id(m_tokenizer.get_eos_token_id()); + + return generate( + prompt, + rgbs, + config, + utils::get_streamer_from_map(config_map) + ); +} + +void VLMPipeline::start_chat(const std::string& system_message) { + m_is_chat_conversation = true; + bool have_state = 0 != m_language.get_tensor("attention_mask").get_size(); + if (have_state) { + // Resetting state may be slow. + for (ov::VariableState& variable : m_language.query_state()) { + variable.reset(); + } + // Since if is already introduced, move all resetting here. + m_language.get_tensor("attention_mask").set_shape({1, 0}); + m_history.clear(); + m_templated_chat_history.clear(); + } + if (system_message.empty()) { + return; + } + m_history = {{{"role", "system"}, {"content", system_message}}}; + constexpr bool add_generation_prompt = false; + m_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); +} + +void VLMPipeline::set_chat_template(const std::string& new_template) { + m_tokenizer.set_chat_template(new_template); +} + +GenerationConfig VLMPipeline::get_generation_config() const { + return m_generation_config; +} + +void VLMPipeline::set_generation_config(const GenerationConfig& new_config) { + m_generation_config = new_config; +} + +std::pair VLMPipeline::get_inputs_embeds_llava(const std::string& prompt, const std::vector& images) { + std::string image_token = ""; // TODO Consider getting from vlm_config or json + std::string formatted_prompt = "USER: " + (images.empty() ? prompt : image_token + "\n" + prompt) + " ASSISTANT:"; + ov::Tensor input_ids = m_tokenizer.encode(formatted_prompt).input_ids; + if (images.empty()) { + return std::pair{process_prompt(m_embedding, input_ids, m_vlm_config.scale_emb), input_ids}; + } else { + OPENVINO_ASSERT(1 == images.size(), "Only a single image allowed"); + EncodedImage encoded_image = m_vision_encoder.encode(images.at(0)); + ov::Tensor image_embeds = encoded_image.resized_source; + + ov::Tensor text_embeds = process_prompt(m_embedding, input_ids, m_vlm_config.scale_emb); + + int64_t image_token_index = 32000; // TODO Consider getting from m_vlm_config.image_token_index or config.json + + return std::pair{merge_text_and_image_embeddings_llava(input_ids, text_embeds, image_embeds, image_token_index), input_ids}; + } +} + +std::pair VLMPipeline::get_inputs_embeds_minicpm(const std::string& prompt, const std::vector& images) { + std::string images_prompt; + std::vector embeds; + for (const ov::Tensor& rgb : images) { + ov::Tensor reshaped = rgb; + ov::Shape rgb_shape = rgb.get_shape(); + switch (rgb_shape.size()) { + case 3: + reshaped.set_shape({1, rgb_shape.at(0), rgb_shape.at(1), rgb_shape.at(2)}); + break; + case 4: break; + default: OPENVINO_THROW("Input image must have [NHWC] or [HWC] layout"); + } + ov::Shape reshaped_shape = reshaped.get_shape(); + for (size_t batch_idx = 0; batch_idx < reshaped_shape.at(0); ++batch_idx) { + ov::Tensor single_image{ + ov::element::u8, + {1, reshaped_shape.at(1), reshaped_shape.at(2), reshaped_shape.at(3)}, + reshaped.data() + batch_idx * reshaped_shape.at(1) * reshaped_shape.at(1) * reshaped_shape.at(1) + }; + EncodedImage encoded_image = m_vision_encoder.encode(single_image); + if (m_vlm_config.use_image_id) { + images_prompt += m_vlm_config.im_id_start + std::to_string(image_id) + m_vlm_config.im_id_end; + ++image_id; + } + std::string unk64; + for (size_t idx = 0; idx < m_vlm_config.query_num; ++idx) { + unk64 += m_vlm_config.unk; + } + images_prompt += m_vlm_config.im_start + unk64 + m_vlm_config.im_end; + if (encoded_image.slices) { + ov::Shape slices_shape = encoded_image.slices.get_shape(); + for (size_t row_idx = 0; row_idx < slices_shape.at(0); ++row_idx) { + for (size_t col_idx = 0; col_idx < slices_shape.at(1); ++col_idx) { + images_prompt += m_vlm_config.slice_start + unk64 + m_vlm_config.slice_end; + } + images_prompt += '\n'; + } + } + if ('\n' != *(images_prompt.end() - 1)) { + // Image wasn't sliced, add \n to the end of image anyway. + // Strangely, \n isn't placed between . + images_prompt += '\n'; + } + embeds.push_back(std::move(encoded_image)); + } + } + images_prompt += prompt; + ov::Tensor encoded_input; + ov::Tensor new_chat_tokens; + if (m_is_chat_conversation) { + // KV cache in model already contains prompts and answers from previous iterations. + // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns + // token_ids = {, ...}. So if tokenizer applies only to the new prompt, + // will be inserted on every iteration. + // So actual pipeline calculates input_ids for whole chat history + for whole chat history without the new prompt + // and takes only the difference between them. + // The chat history cannot be saved as already encoded tokens because generate call doesn't return token, but + // KV cache contains it. So we have to add it manually or get it by tokenization all chat history. + m_history.push_back({{"role", "user"}, {"content", images_prompt}}); + constexpr bool add_generation_prompt = true; + std::string new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + new_chat_tokens = m_tokenizer.encode(new_templated_chat_history).input_ids; + if (0 == m_language.get_tensor("attention_mask").get_shape().at(1)) { + encoded_input = new_chat_tokens; + } else { + TokenizedInputs prev_chat_tokens = m_tokenizer.encode( + m_templated_chat_history + ); + encoded_input = utils::subtract_chat_tokenized_inputs( + {new_chat_tokens}, prev_chat_tokens + ).input_ids; + } + m_templated_chat_history = std::move(new_templated_chat_history); + } else { + encoded_input = m_tokenizer.encode(images_prompt).input_ids; + } + m_embedding.set_input_tensor(encoded_input); + m_embedding.infer(); + ov::Tensor inputs_embeds = m_embedding.get_output_tensor(); + OPENVINO_ASSERT( + m_vlm_config.hidden_size == inputs_embeds.get_shape().at(2), + "Unexpected embedding size" + ); + ov::Tensor special_tokens = m_tokenizer.encode( + m_vlm_config.im_start + + m_vlm_config.im_end + + m_vlm_config.slice_start + + m_vlm_config.slice_end + ).input_ids; + OPENVINO_ASSERT( + 4 == special_tokens.get_shape().at(1), + "Every special token must be represented with a single int." + ); + int64_t im_start_id = special_tokens.data()[0]; + int64_t im_end_id = special_tokens.data()[1]; + int64_t slice_start_id = special_tokens.data()[2]; + int64_t slice_end_id = special_tokens.data()[3]; + int64_t im_start_pos = 0, slice_start_pos = 0; + int64_t* begin = encoded_input.data(); + int64_t* ids = begin; + size_t encoded_input_size = encoded_input.get_size(); + int64_t* end = ids + encoded_input_size; + float* inputs_embeds_data = inputs_embeds.data(); + for (const EncodedImage& encoded_image : embeds) { + const ov::Tensor& resampled_source = resample(*this, encoded_image.resized_source, {encoded_image.resized_source_size}); + float* emb = resampled_source.data(); + ids = std::find(ids, end, im_start_id); + OPENVINO_ASSERT(end != ids); + ++ids; + std::copy_n(emb, resampled_source.get_size(), inputs_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size); + ids += m_vlm_config.query_num; + if (encoded_image.slices) { + size_t token_idx = 0; + const ov::Shape& slices_shape = encoded_image.slices.get_shape(); + for (size_t i = 0; i < slices_shape.at(0); ++i) { + for (size_t ja = 0; ja < slices_shape.at(1); ++ja) { + size_t d2 = slices_shape.at(2); + size_t d3 = slices_shape.at(3); + ov::Tensor encoded_view{ov::element::f32, {1, d2, d3}, encoded_image.slices.data() + (i * slices_shape.at(1) + ja) * d2 * d3}; + const ov::Tensor& vision_embed_tensor_i_j = resample(*this, encoded_view, {encoded_image.slices_size}); + ids = std::find(ids, end, slice_start_id); + OPENVINO_ASSERT(end != ids); + ++ids; + std::copy_n(vision_embed_tensor_i_j.data(), vision_embed_tensor_i_j.get_size(), inputs_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size); + ids += m_vlm_config.query_num; + } + } + } + } + + return std::pair{inputs_embeds, m_is_chat_conversation ? new_chat_tokens : encoded_input}; +} diff --git a/src/cpp/src/vlm_sampling.hpp b/src/cpp/src/vlm_sampling.hpp deleted file mode 100644 index b0a7d2341f..0000000000 --- a/src/cpp/src/vlm_sampling.hpp +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include -#include -#include -#include - -struct TokenIdScore { - int id; - float score; - - TokenIdScore() = default; - TokenIdScore(int id, float score) : id(id), score(score) {} - - bool operator<(const TokenIdScore& other) const { return score < other.score; } - bool operator>(const TokenIdScore& other) const { return score > other.score; } - - friend std::ostream& operator<<(std::ostream& os, const TokenIdScore& self) { - return os << "TokenIdScore(id=" << self.id << ", score=" << self.score << ")"; - } -}; - -void sampling_softmax_inplace(TokenIdScore* first, TokenIdScore* last) { - float max_score = std::max_element(first, last)->score; - float sum = 0.f; - for (TokenIdScore* p = first; p != last; p++) { - float s = std::exp(p->score - max_score); - p->score = s; - sum += s; - } - float inv_sum = 1.f / sum; - for (TokenIdScore* p = first; p != last; p++) { - p->score *= inv_sum; - } -} - -void sampling_top_k(TokenIdScore* first, TokenIdScore* kth, TokenIdScore* last) { - std::nth_element(first, kth, last, std::greater()); -} - -TokenIdScore* sampling_top_p(TokenIdScore* first, TokenIdScore* last, float top_p) { - // fast top_p in expected O(n) time complexity - sampling_softmax_inplace(first, last); - - while (first + 1 < last) { - const float pivot_score = (last - 1)->score; // use mid score? - TokenIdScore* mid = - std::partition(first, last - 1, [pivot_score](const TokenIdScore& x) { return x.score > pivot_score; }); - std::swap(*mid, *(last - 1)); - - const float prefix_sum = - std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore& x) { return sum + x.score; }); - if (prefix_sum >= top_p) { - last = mid; - } - else if (prefix_sum + mid->score < top_p) { - first = mid + 1; - top_p -= prefix_sum + mid->score; - } - else { - return mid + 1; - } - } - return last; -} - -void sampling_repetition_penalty(float* first, float* last, const std::vector& input_ids, - float penalty) { - if (penalty < 0) { - std::cout << "penalty must be a positive float, but got " << penalty; - return; - } - const float inv_penalty = 1.f / penalty; - const ptrdiff_t vocab_size = last - first; - std::vector occurrence(vocab_size, false); - for (const int id : input_ids) { - if (!occurrence[id]) { - first[id] *= (first[id] > 0) ? inv_penalty : penalty; - } - occurrence[id] = true; - } -} - -void sampling_temperature(float* first, float* last, float temp) { - const float inv_temp = 1.f / temp; - for (float* it = first; it != last; it++) { - *it *= inv_temp; - } -} - - -