Skip to content

Commit

Permalink
Support multiple images for InternVL models (#1099)
Browse files Browse the repository at this point in the history
Ticket: CVS-155384
  • Loading branch information
yatarkan authored Oct 29, 2024
1 parent 5ecf1e6 commit 5f1c8ae
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "visual_language/image_embedder.hpp"
#include "visual_language/inputs_embedder.hpp"

#include "visual_language/clip.hpp"
#include "visual_language/vision_encoder.hpp"
Expand Down Expand Up @@ -905,56 +905,60 @@ class InputsEmbedderInternVLChat : public InputsEmbedder::IInputsEmbedder {
IInputsEmbedder(vlm_config, model_dir, device, device_config) { }

virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images) override {
if (images.empty()) {
ov::Tensor input_ids = get_encoded_input_ids(prompt);
return m_embedding.infer(input_ids);
} else {
OPENVINO_ASSERT(1 == images.size(), "Only a single image allowed");
EncodedImage encoded_image = m_vision_encoder.encode(images.at(0));
ov::Tensor image_embeds = encoded_image.resized_source;

std::string image_start_token = m_vlm_config.image_start_token;
std::string image_context_token = m_vlm_config.image_context_token;
std::string image_end_token = m_vlm_config.image_end_token;
std::string image_start_token = m_vlm_config.image_start_token;
std::string image_context_token = m_vlm_config.image_context_token;
std::string image_end_token = m_vlm_config.image_end_token;

std::vector<ov::Tensor> single_images = to_single_image_tensors(images);

std::string formatted_prompt;
std::vector<ov::Tensor> image_embeds;
image_embeds.reserve(single_images.size());

for (const auto& image : single_images) {
EncodedImage encoded_image = m_vision_encoder.encode(image);
ov::Tensor single_image_embeds = encoded_image.resized_source;

const size_t num_patches = image_embeds.get_shape().at(0);
const size_t num_image_tokens = image_embeds.get_shape().at(1);
const size_t num_patches = single_image_embeds.get_shape().at(0);
const size_t num_image_tokens = single_image_embeds.get_shape().at(1);

std::string concated_image_tokens;
concated_image_tokens += image_start_token;
formatted_prompt += image_start_token;
for (int i = 0; i < num_patches * num_image_tokens; ++i) {
concated_image_tokens += image_context_token;
formatted_prompt += image_context_token;
}
concated_image_tokens += image_end_token;
formatted_prompt += image_end_token + "\n";

std::string formatted_prompt = concated_image_tokens + "\n" + prompt;

ov::Tensor input_ids = get_encoded_input_ids(formatted_prompt);
ov::Tensor text_embeds = m_embedding.infer(input_ids);
image_embeds.push_back(std::move(single_image_embeds));
}
formatted_prompt += prompt;

ov::Tensor encoded_image_context_token = m_tokenizer.encode(image_context_token, ov::genai::add_special_tokens(false)).input_ids;
int64_t image_context_token_id = encoded_image_context_token.data<int64_t>()[encoded_image_context_token.get_size() - 1];
ov::Tensor input_ids = get_encoded_input_ids(formatted_prompt);
ov::Tensor text_embeds = m_embedding.infer(input_ids);

return merge_text_and_image_embeddings_internvl(input_ids, text_embeds, image_embeds, image_context_token_id);
if (images.empty()) {
return text_embeds;
}

ov::Tensor encoded_image_context_token = m_tokenizer.encode(image_context_token, ov::genai::add_special_tokens(false)).input_ids;
int64_t image_context_token_id = encoded_image_context_token.data<int64_t>()[encoded_image_context_token.get_size() - 1];

return merge_text_and_image_embeddings_internvl(input_ids, text_embeds, image_embeds, image_context_token_id);
}

protected:
ov::Tensor merge_text_and_image_embeddings_internvl(
const ov::Tensor& input_ids,
const ov::Tensor& text_embeds,
const ov::Tensor& image_embeds,
const std::vector<ov::Tensor>& image_embeds,
int64_t image_context_token_id
) {
auto text_embeds_shape = text_embeds.get_shape();
auto image_embeds_shape = image_embeds.get_shape();
size_t batch_size = text_embeds_shape.at(0);
size_t seq_len = text_embeds_shape.at(1);
size_t embed_dim = text_embeds_shape.at(2);

ov::Tensor merged_embeds(text_embeds.get_element_type(), text_embeds_shape);

const float* image_embeds_data = image_embeds.data<float>();
const float* text_embeds_data = text_embeds.data<float>();
const int64_t* input_ids_data = input_ids.data<int64_t>();
float* merged_embeds_data = merged_embeds.data<float>();
Expand All @@ -972,15 +976,27 @@ class InputsEmbedderInternVLChat : public InputsEmbedder::IInputsEmbedder {

OPENVINO_ASSERT(image_context_tokens_count > 0, "input_ids does not contain image context token ids");

size_t vision_idx = 0;
size_t image_idx = 0;
size_t image_context_token_idx = 0;
for (size_t i = 0; i < batch_size; ++i) {
for (size_t j = 0; j < seq_len; ++j) {
size_t flat_idx = i * seq_len + j;
size_t offset = flat_idx * embed_dim;

if (image_context_tokens_mask[flat_idx]) {
std::copy_n(image_embeds_data + vision_idx * embed_dim, embed_dim, merged_embeds_data + offset);
++vision_idx;
const ov::Tensor& single_image_embeds = image_embeds[image_idx];
const size_t num_all_image_tokens = single_image_embeds.get_shape().at(0) * single_image_embeds.get_shape().at(1); // num_patches * num_image_tokens
const float* image_embeds_data = single_image_embeds.data<float>();
std::copy_n(image_embeds_data + image_context_token_idx * embed_dim,
embed_dim,
merged_embeds_data + offset);

++image_context_token_idx;

if (image_context_token_idx == num_all_image_tokens) {
++image_idx;
image_context_token_idx = 0;
}
} else {
std::copy_n(text_embeds_data + offset, embed_dim, merged_embeds_data + offset);
}
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "openvino/genai/tokenizer.hpp"

#include "visual_language/vlm_config.hpp"
#include "visual_language/image_embedder.hpp"
#include "visual_language/inputs_embedder.hpp"
#include "visual_language/embedding_model.hpp"

#include "sampler.hpp"
Expand Down

0 comments on commit 5f1c8ae

Please sign in to comment.