Skip to content

Commit

Permalink
Merge branch 'add-VLM-matching-test'
Browse files Browse the repository at this point in the history
  • Loading branch information
Wovchena committed Oct 17, 2024
2 parents d6e56a7 + f90725e commit 330f122
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 67 deletions.
31 changes: 28 additions & 3 deletions .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -702,19 +702,44 @@ jobs:
run: |
source ./ov/setupvars.sh
python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt opencv-python --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
- name: Download and convert MiniCPM-V-2_6 model and an image
run: |
python -m pip install git+https://github.com/eaidova/optimum-intel.git@ea/minicpmv
python -m pip install -U "optimum<1.23" --no-dependencies
source ./ov/setupvars.sh
optimum-cli export openvino -m openbmb/MiniCPM-V-2_6 MiniCPM-V-2_6 --trust-remote-code
wget https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11 --output-document cat.jpg
- name: Generate reference
shell: python
run: |
from optimum.intel.openvino import OVModelForVisualCausalLM
from transformers import AutoProcessor
from PIL import Image
import cv2
import numpy as np
res = 448, 448
lines = np.arange(res[0] * res[1] * 3, dtype=np.uint8) % 255
lines = lines.reshape([*res, 3])
cv2.imwrite("lines.png", lines)
lines = Image.open("lines.png").convert('RGB')
model_id = "openbmb/MiniCPM-V-2_6"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
prompt = processor.tokenizer.apply_chat_template([{"role": "user", "content": "(<image>./</image>)\nWhat is unusual on this image?"}], tokenize=False, add_generation_prompt=True)
inputs = processor([prompt], [lines], return_tensors="pt")
model = OVModelForVisualCausalLM.from_pretrained("MiniCPM-V-2_6", device="CPU", trust_remote_code=True)
result = model.generate(**inputs, max_new_tokens=200)
decoded = processor.tokenizer.batch_decode(result[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
print(decoded)
with open("ref.txt", "w") as f:
f.write(f"question:\n{decoded}\n----------\nquestion:\n")
- name: Run visual_language_chat C++ sample - MiniCPM-V-2_6
run: >
source ./ov/setupvars.sh
&& timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./MiniCPM-V-2_6/ cat.jpg
<<< $'What is on the image?\nWhat is special on the image?'
&& timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./MiniCPM-V-2_6/ lines.png
<<< $'What is unusual on this image?' | tee cpp.txt
- run: diff cpp.txt ref.txt
- name: Download and convert LLaVa 1.5 model and an image
run: |
source ./ov/setupvars.sh
Expand Down
20 changes: 2 additions & 18 deletions src/cpp/src/visual_language/clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,9 @@
#include <vector>
#include <numeric>

//#define CLIP_DEBUG_FUNCTIONS
enum projector_type {
PROJECTOR_TYPE_RESAMPLER,
PROJECTOR_TYPE_UNKNOWN,
};

struct clip_ctx {
bool has_text_encoder = false;
bool has_vision_encoder = false;
bool has_minicpmv_projector = false;

float image_mean[3];
float image_std[3];
int32_t ftype = 1;

std::vector<uint8_t> buf_compute_meta;

projector_type proj_type = PROJECTOR_TYPE_RESAMPLER;
size_t patch_size = 0;
float image_mean[3] = {0.0f, 0.0f, 0.0f};
float image_std[3] = {1.0f, 1.0f, 1.0f};
size_t image_size = 0;
};

Expand Down
16 changes: 12 additions & 4 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,13 @@ ov::Tensor pack_image_features_llava_next(
return result;
}
}

// It's not possible to pass a GPU tensor from one model to another GPU
// model on a different ov::Core instance.
ov::Core singleton_core() {
static ov::Core core;
return core;
}
}

class ov::genai::VLMPipeline::VLMPipelineImpl {
Expand Down Expand Up @@ -604,21 +611,22 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
)
},
m_tokenizer{Tokenizer(model_dir.string(), device_config)},
m_vision_encoder(model_dir, m_vlm_config.model_type, device, device_config, ov::Core{}),
m_vision_encoder(model_dir, m_vlm_config.model_type, device, device_config, singleton_core()),
m_is_chat_conversation{false},
m_image_id{0} {
ov::Core core = singleton_core();
if (m_vlm_config.model_type == VLMModelType::MINICPM) {
m_resampler = ov::Core{}.compile_model(
m_resampler = core.compile_model(
model_dir / "openvino_resampler_model.xml", device, device_config
).create_infer_request();

m_pos_embed_cache = get_2d_sincos_pos_embed(m_vlm_config.hidden_size, {70, 70});
}
m_embedding = ov::Core{}.compile_model(
m_embedding = core.compile_model(
model_dir / "openvino_text_embeddings_model.xml", device, device_config
).create_infer_request();

m_language = ov::Core{}.compile_model(
m_language = core.compile_model(
model_dir / "openvino_language_model.xml", device, device_config
).create_infer_request();

Expand Down
100 changes: 58 additions & 42 deletions src/cpp/src/visual_language/vision_encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ ov::Tensor prepare_vis_position_ids(
});
size_t position_ids_batch_elem = max_nb_patches_h * max_nb_patches_w;
ov::Tensor position_ids{ov::element::i64, {batch_size, position_ids_batch_elem}};
// throw std::runtime_error("");
int64_t* res_data = position_ids.data<int64_t>();
std::fill_n(res_data, position_ids.get_size(), 0);

Expand Down Expand Up @@ -285,66 +284,84 @@ EncodedImage llava_image_embed_make_with_bytes_slice(clip_ctx& ctx_clip, const o
std::vector<std::vector<ov::Tensor>> results;
std::vector<std::vector<ImageSize>> sizes;

// std::vector<clip_image_f32*> img_res_v; // format N x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336
std::vector<std::vector<clip_image_f32>> preprocessed{imgs.size()};
std::transform(imgs.begin(), imgs.end(), preprocessed.begin(), [&ctx_clip](const std::vector<clip_image_u8>& row) {
size_t max_h = 0, max_w = 0, n_images = 0;
std::transform(imgs.begin(), imgs.end(), preprocessed.begin(), [&ctx_clip, &max_h, &max_w, &n_images](const std::vector<clip_image_u8>& row) {
std::vector<clip_image_f32> processed_row{row.size()};
std::transform(row.begin(), row.end(), processed_row.begin(), [&ctx_clip](const clip_image_u8& raw) {
return clip_image_preprocess(ctx_clip, raw);
std::transform(row.begin(), row.end(), processed_row.begin(), [&ctx_clip, &max_h, &max_w, &n_images](const clip_image_u8& raw) {
clip_image_f32 im = clip_image_preprocess(ctx_clip, raw);
max_h = std::max(size_t(im.ny), max_h);
max_w = std::max(size_t(im.nx), max_w);
++n_images;
return im;
});
return processed_row;
});

ov::Tensor batched_images{ov::element::f32, {n_images, 3, max_h, max_w}};
float* batched_data = batched_images.data<float>();
const clip_image_f32& resized_preprocessed = preprocessed.at(0).at(0);
ImageSize resized_source_size{resized_preprocessed.ny / patch_size, resized_preprocessed.nx / patch_size};
ov::Tensor input_tensor{ov::element::f32, {1, 3, size_t(resized_preprocessed.ny), size_t(resized_preprocessed.nx)}, (void*)(resized_preprocessed.buf.data())};
ov::Tensor pixel_values = preprocess_for_encoder(input_tensor, patch_size);
std::copy(resized_preprocessed.buf.begin(), resized_preprocessed.buf.end(), batched_data);
if (1 < preprocessed.size()) {
for (size_t row = 1; row < preprocessed.size(); ++row) {
size_t n_slices = preprocessed.at(row).size();
for (size_t col = 0; col < n_slices; ++col) {
const clip_image_f32& elem = preprocessed.at(row).at(col);
std::copy(elem.buf.begin(), elem.buf.end(), batched_data + ((row - 1) * n_slices + col + 1) * 3 * max_h * max_w);
}
}
}
ov::Tensor pixel_values = preprocess_for_encoder(batched_images, patch_size);
encoder.set_tensor("pixel_values", pixel_values);
ov::Tensor patch_attention_mask{ov::element::f32, {pixel_values.get_shape().at(0), 1, resized_source_size.height * resized_source_size.width}};
std::fill_n(patch_attention_mask.data<float>(), patch_attention_mask.get_size(), 1.0f);

ov::Tensor patch_attention_mask{ov::element::f32, {pixel_values.get_shape().at(0), 1, max_h / patch_size * max_w / patch_size}};
float* attention_data = patch_attention_mask.data<float>();
std::fill_n(attention_data, patch_attention_mask.get_size(), 0.0f);
std::fill_n(attention_data, resized_preprocessed.ny / patch_size * resized_preprocessed.nx / patch_size, 1.0f);
if (1 < preprocessed.size()) {
for (size_t row = 1; row < preprocessed.size(); ++row) {
size_t n_slices = preprocessed.at(row).size();
for (size_t col = 0; col < n_slices; ++col) {
const clip_image_f32& elem = preprocessed.at(row).at(col);
std::fill_n(attention_data + ((row - 1) * n_slices + col + 1) * max_h / patch_size * max_w / patch_size, elem.ny / patch_size * elem.nx / patch_size, 1.0f);
}
}
}
encoder.set_tensor("patch_attention_mask", patch_attention_mask);
ov::Tensor position_ids = prepare_vis_position_ids(pixel_values, patch_attention_mask, {resized_source_size}, ctx_clip.patch_size, ctx_clip.image_size / ctx_clip.patch_size);

ImageSize resized_source_size{resized_preprocessed.ny / patch_size, resized_preprocessed.nx / patch_size};
std::vector<ImageSize> tgt_sizes{resized_source_size};
if (1 < preprocessed.size()) {
for (const std::vector<clip_image_f32>& row : preprocessed) {
for (const clip_image_f32& elem : row) {
tgt_sizes.push_back({elem.ny / patch_size, elem.nx / patch_size});
}
}
}
ov::Tensor position_ids = prepare_vis_position_ids(pixel_values, patch_attention_mask, tgt_sizes, patch_size, ctx_clip.image_size / patch_size);
encoder.set_tensor("position_ids", position_ids);
encoder.infer();
const ov::Tensor& output_tensor = encoder.get_output_tensor();
ov::Tensor resized_source{ov::element::f32, output_tensor.get_shape()};
output_tensor.copy_to(resized_source);

if (1 == preprocessed.size()) {
ov::Tensor resized_source{ov::element::f32, output_tensor.get_shape()};
output_tensor.copy_to(resized_source);
return {std::move(resized_source), resized_source_size};
}

ImageSize raw_size{
size_t(preprocessed.at(1).at(0).ny),
size_t(preprocessed.at(1).at(0).nx)
};
ImageSize slices_size{
raw_size.height / patch_size,
raw_size.width / patch_size
};
size_t n_patches = slices_size.height * slices_size.width,
old_hidden_size = resized_source.get_shape().at(2);
size_t old_hidden_size = output_tensor.get_shape().at(2);
const float* out = output_tensor.data<float>();
ov::Tensor resized_source{ov::element::f32, {1, resized_source_size.height * resized_source_size.width, old_hidden_size}};
std::copy_n(out, resized_source.get_size(), resized_source.data<float>());

size_t n_patches = tgt_sizes.at(1).height * tgt_sizes.at(1).width;
ov::Tensor encoded_slices{ov::element::f32, {preprocessed.size() - 1, preprocessed.at(1).size(), n_patches, old_hidden_size}};
for (size_t row = 1; row < preprocessed.size(); ++row) {
for (size_t col = 0; col < preprocessed.at(row).size(); ++col) {
clip_image_f32& elem = preprocessed.at(row).at(col);
ov::Tensor pixel_values = preprocess_for_encoder(
{ov::element::f32, {1, 3, size_t(elem.ny), size_t(elem.nx)}, elem.buf.data()},
patch_size
);
encoder.set_tensor("pixel_values", pixel_values);
ov::Tensor patch_attention_mask{ov::element::f32, {1, 1, slices_size.height * slices_size.width}};
std::fill_n(patch_attention_mask.data<float>(), patch_attention_mask.get_size(), 1.0f);
encoder.set_tensor("patch_attention_mask", patch_attention_mask);
ov::Tensor position_ids = prepare_vis_position_ids(pixel_values, patch_attention_mask, {slices_size}, ctx_clip.patch_size, ctx_clip.image_size / ctx_clip.patch_size);
encoder.set_tensor("position_ids", position_ids);
const ov::Tensor& old = encoder.get_output_tensor();
encoder.set_output_tensor({ov::element::f32, {1, n_patches, old_hidden_size}, encoded_slices.data<float>() + ((row - 1) * preprocessed.at(row).size() + col) * n_patches * old_hidden_size});
encoder.infer();
encoder.set_output_tensor(old);
for (size_t col = 0; col < preprocessed.size() - 1; ++col) {
for (size_t row = 0; row < preprocessed.at(1).size(); ++row) {
std::copy_n(out + (col * preprocessed.at(1).size() + row + 1) * n_patches * old_hidden_size, n_patches * old_hidden_size, encoded_slices.data<float>() + (col * preprocessed.at(1).size() + row) * n_patches * old_hidden_size);
}
}
return {resized_source, resized_source_size, encoded_slices, slices_size};
return {resized_source, resized_source_size, encoded_slices, tgt_sizes.at(1)};
}

ProcessorConfig from_any_map(
Expand Down Expand Up @@ -504,7 +521,6 @@ EncodedImage VisionEncoder::encode(const ov::Tensor& image, const ov::AnyMap& co

EncodedImage VisionEncoder::encode_minicpm(const ov::Tensor& image, const ProcessorConfig& config) {
clip_ctx ctx_clip;
ctx_clip.patch_size = m_processor_config.patch_size;
ctx_clip.image_size = m_processor_config.image_size;
std::copy(config.norm_mean.begin(), config.norm_mean.end(), ctx_clip.image_mean);
std::copy(config.norm_std.begin(), config.norm_std.end(), ctx_clip.image_std);
Expand Down

0 comments on commit 330f122

Please sign in to comment.