Skip to content

Commit

Permalink
Merge branch 'master' into genai
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov authored Oct 12, 2024
2 parents 77c2f23 + 67bcef1 commit 2e86233
Show file tree
Hide file tree
Showing 23 changed files with 499 additions and 181 deletions.
17 changes: 14 additions & 3 deletions .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -698,19 +698,30 @@ jobs:
source ./ov/setupvars.sh
cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/
cmake --build ./build/ --config Release --target visual_language_chat py_generate_pipeline -j
- name: Download and convert a model and an image
- name: Download and convert MiniCPM-V-2_6 model and an image
run: |
source ./ov/setupvars.sh
python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
python ./samples/cpp/visual_language_chat/export_MiniCPM-V-2_6.py ./miniCPM-V-2_6/
wget https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11 --output-document cat.jpg
- name: Run chat chat sample
- name: Run visual_language_chat sample - MiniCPM-V-2_6
run: >
source ./ov/setupvars.sh
&& timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./miniCPM-V-2_6/ cat.jpg
<<< $'What is on the image?\nWhat is special on the image?'
- name: Download and convert LLaVa 1.5 model and an image
run: |
source ./ov/setupvars.sh
python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
optimum-cli export openvino --model llava-hf/llava-1.5-7b-hf ./llava_1_5_7b_ov/
wget https://llava-vl.github.io/static/images/monalisa.jpg
- name: Run visual_language_chat sample - LLaVa 1.5
run: >
source ./ov/setupvars.sh
&& timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./llava_1_5_7b_ov/ monalisa.jpg
<<< $'Who drew this painting?\nWhen did the painter live?'
- name: Run python chat sample
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/llm_bench-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ jobs:
python ./llm_bench/python/benchmark.py -m tiny-random-qwen -d cpu -n 1 -f pt
- name: Test tiny-random-baichuan2 on Linux
run: |
python ./llm_bench/python/convert.py --model_id katuni4ka/tiny-random-baichuan2 --output_dir ./ov_models/tiny-random-baichuan2 --precision FP16
optimum-cli export openvino --model katuni4ka/tiny-random-baichuan2 --trust-remote-code --weight-format fp16 ./ov_models/tiny-random-baichuan2/pytorch/dldt/FP16
python ./llm_bench/python/benchmark.py -m ./ov_models/tiny-random-baichuan2/pytorch/dldt/FP16/ -d cpu -n 1
- name: Test tiny-stable-diffusion on Linux
run: |
python ./llm_bench/python/convert.py --model_id segmind/tiny-sd --output_dir ./ov_models/tiny-sd --precision FP16
optimum-cli export openvino --model segmind/tiny-sd --trust-remote-code --weight-format fp16 ./ov_models/tiny-sd/pytorch/dldt/FP16/
python ./llm_bench/python/benchmark.py -m ./ov_models/tiny-sd/pytorch/dldt/FP16/ -pf ./llm_bench/python/prompts/stable-diffusion.jsonl -d cpu -n 1
- name: WWB Tests
run: |
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ Continuous batching functionality is used within OpenVINO Model Server (OVMS) to

# Install optimum-intel to be able to download, convert and optimize LLMs from Hugging Face
# Optimum is not required to run models, only to convert and compress
pip install optimum[openvino]
pip install optimum-intel@git+https://github.com/huggingface/optimum-intel.git

# (Optional) Install (TBD) to be able to download models from Model Scope
#pip install optimum[openvino]
```

## Performing text generation
Expand Down
26 changes: 14 additions & 12 deletions llm_bench/python/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,14 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list,
log.warning(f"[{num}] Prompt[{prompt_index}]'s md5 {result_md5_list} "
f"is different from md5 of the {num - 1} iteration {prev_md5}")
llm_bench_utils.metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0])
if num == 1:
# if the device is CPU, throw exception
if args['devices'].lower().startswith('cpu') is True:
if not args.get("use_cb", False):
if num == 1:
# if the device is CPU, throw exception
if args['devices'].lower().startswith('cpu') is True:
assert (result_md5_list == prev_md5)
else:
# throw exception
assert (result_md5_list == prev_md5)
else:
# throw exception
assert (result_md5_list == prev_md5)
else:
llm_bench_utils.metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0])
if bench_hook is not None:
Expand Down Expand Up @@ -412,13 +413,14 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg
log.warning(f"[{num}] Prompt[{prompt_index}]'s md5 {result_md5_list} "
f"is different from md5 of the {num - 1} iteration {prev_md5}")
llm_bench_utils.metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0])
if num == 1:
# if the device is CPU, throw exception
if args['devices'].lower().startswith('cpu') is True:
if not args.get("use_cb", False):
if num == 1:
# if the device is CPU, throw exception
if args['devices'].lower().startswith('cpu') is True:
assert (result_md5_list == prev_md5)
else:
# throw exception
assert (result_md5_list == prev_md5)
else:
# throw exception
assert (result_md5_list == prev_md5)
else:
llm_bench_utils.metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0])
streamer.reset()
Expand Down
2 changes: 1 addition & 1 deletion llm_bench/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ torch
transformers>=4.40.0
diffusers>=0.22.0
#optimum is in dependency list of optimum-intel
git+https://github.com/huggingface/optimum-intel.git@f34bd61df89f57f61c282c02297980299981ee78#egg=optimum-intel
git+https://github.com/huggingface/optimum-intel.git@main#egg=optimum-intel
git+https://github.com/openvinotoolkit/nncf.git@develop#egg=nncf
packaging
psutil
Expand Down
2 changes: 1 addition & 1 deletion llm_bench/python/who_what_benchmark/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ transformers>=4.35.2
sentence-transformers>=2.2.2
openvino>=2024.3.0
openvino-telemetry
optimum-intel>=1.14
optimum-intel @ git+https://github.com/huggingface/optimum-intel.git
openvino-tokenizers
pandas>=2.0.3
numpy>=1.23.5
Expand Down
3 changes: 2 additions & 1 deletion samples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/cpu
optimum[openvino]==1.22.0
optimum-intel @ git+https://github.com/huggingface/optimum-intel.git
numpy<2.0.0; sys_platform == 'darwin'
einops==0.8.0 # For Qwen
transformers_stream_generator==0.0.5 # For Qwen
diffusers==0.30.3
Expand Down
8 changes: 8 additions & 0 deletions src/cpp/include/openvino/genai/processor_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ class OPENVINO_GENAI_EXPORTS ProcessorConfig {
/// Applied after norm_mean.
/// llava calls it image_std.
std::array<float, 3> norm_std{1.0f, 1.0f, 1.0f};

// llava specific config params
std::array<float, 3> image_mean{0.0f, 0.0f, 0.0f};
std::array<float, 3> image_std{1.0f, 1.0f, 1.0f};
size_t crop_size_height = 336;
size_t crop_size_width = 336;
size_t size_shortest_edge = 336;

/// @brief Default constructor
ProcessorConfig() = default;
/// @brief Construct ProcessorConfig from values in json_path.
Expand Down
17 changes: 15 additions & 2 deletions src/cpp/include/openvino/genai/vision_encoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "openvino/genai/processor_config.hpp"
#include <openvino/openvino.hpp>
#include "vlm_model_type.hpp"

namespace ov::genai {
/// @brief A pair describing image size.
Expand Down Expand Up @@ -41,8 +42,10 @@ struct EncodedImage {
/// ov::InferRequest and configured by ProcessorConfig.
class OPENVINO_GENAI_EXPORTS VisionEncoder {
public:
/// @brief A enum denoting model type.
VLMModelType model_type;
/// @brief A model for image encoding.
ov::InferRequest m_encoder;
ov::InferRequest m_vision_encoder;
/// @brief A config to follow.
ProcessorConfig m_processor_config;

Expand All @@ -52,7 +55,7 @@ class OPENVINO_GENAI_EXPORTS VisionEncoder {
explicit VisionEncoder(
const ov::InferRequest& encoder,
const ProcessorConfig& processor_config=ProcessorConfig{}
) : m_encoder{encoder}, m_processor_config{processor_config} {}
) : m_vision_encoder{encoder}, m_processor_config{processor_config} {}

/// @brief Construct the encoder from model_dir.
/// @param model_dir A folder containing openvino_embedding.xml and
Expand All @@ -63,6 +66,7 @@ class OPENVINO_GENAI_EXPORTS VisionEncoder {
/// @param core ov::Core to be used to compile the model.
explicit VisionEncoder(
const std::filesystem::path& model_dir,
const VLMModelType model_type,
const std::string& device="CPU",
const ov::AnyMap device_config={},
ov::Core core=ov::Core{}
Expand Down Expand Up @@ -117,5 +121,14 @@ class OPENVINO_GENAI_EXPORTS VisionEncoder {
image, AnyMap{std::forward<Properties>(properties)...}
);
}

private:
EncodedImage encode_minicpm(
const ov::Tensor& image, const ProcessorConfig& config
);

EncodedImage encode_llava(
const ov::Tensor& image, const ProcessorConfig& config
);
};
}
3 changes: 3 additions & 0 deletions src/cpp/include/openvino/genai/vlm_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
#include "openvino/genai/visibility.hpp"
#include <openvino/runtime/properties.hpp>
#include <filesystem>
#include "vlm_model_type.hpp"

namespace ov::genai {
/// @brief A Configuration class passed to VLMPipeline and used to
/// change VLMPipeline's behavior. Corresponds to config.json.
class OPENVINO_GENAI_EXPORTS VLMConfig {
public:
/// @brief A enum denoting model type.
VLMModelType model_type;
/// @brief A size of a single embedding returned by a resampler.
/// Used to initialize positional embeddings for resampler input.
size_t hidden_size = 2304;
Expand Down
31 changes: 31 additions & 0 deletions src/cpp/include/openvino/genai/vlm_model_type.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <string>
#include <unordered_map>

#include "openvino/genai/visibility.hpp"
#include <openvino/core/except.hpp>

namespace ov::genai {

enum class OPENVINO_GENAI_EXPORTS VLMModelType {
MINICPM,
LLAVA,
};

inline VLMModelType to_vlm_model_type(const std::string& value) {
static const std::unordered_map<std::string, VLMModelType> model_types_map = {
{"minicpmv", VLMModelType::MINICPM},
{"llava", VLMModelType::LLAVA}
};

auto it = model_types_map.find(value);
if (it != model_types_map.end()) {
return it->second;
}
OPENVINO_THROW("Unsupported '", value, "' VLM model type");
}
}
3 changes: 3 additions & 0 deletions src/cpp/include/openvino/genai/vlm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ class OPENVINO_GENAI_EXPORTS VLMPipeline {
private:
class VLMPipelineImpl;
std::unique_ptr<VLMPipelineImpl> m_pimpl;

ov::Tensor get_inputs_embeds_minicpm(const std::string& prompt, const std::vector<ov::Tensor>& images);
ov::Tensor get_inputs_embeds_llava(const std::string& prompt, const std::vector<ov::Tensor>& images);
};

/*
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
bool continue_generation = true;
while (has_non_finished_requests() && continue_generation) {
step();
if (streamer_ptr) {
if (streamer_ptr && generations.at(0)->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generations.at(0).get()->back();
OPENVINO_ASSERT(1 == token.size());
OPENVINO_ASSERT(1 == token.begin()->second.generated_ids.size());
Expand Down
1 change: 0 additions & 1 deletion src/cpp/src/greedy_decoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ EncodedResults greedy_decoding(
bool all_are_eos = std::all_of(eos_met.begin(), eos_met.end(), [](int elem) { return elem == 1; });
if (!generation_config.ignore_eos && all_are_eos)
return results;


for (size_t i = 0; i < max_new_tokens - 1; ++i) {
if (position_ids.has_value())
Expand Down
5 changes: 4 additions & 1 deletion src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
core.set_property(core_plugin_config);
auto model = core.read_model(model_path / "openvino_model.xml");
m_adapter_controller = AdapterController(model, m_generation_config.adapters, "base_model.model.model.", device); // TODO: Make the prefix name configurable
utils::slice_matmul_statefull_model(model);
m_model_runner = core.compile_model(model, device, compile_plugin_config).create_infer_request();
m_adapter_controller->apply(m_model_runner, m_generation_config.adapters);
} else {
auto [core_plugin_config, compile_plugin_config] = ov::genai::utils::split_core_complile_config(plugin_config);
core.set_property(core_plugin_config);
m_model_runner = core.compile_model(model_path / "openvino_model.xml", device, compile_plugin_config).create_infer_request();
auto model = core.read_model(model_path / "openvino_model.xml");
utils::slice_matmul_statefull_model(model);
m_model_runner = core.compile_model(model, device, compile_plugin_config).create_infer_request();
}

// If eos_token_id was not provided, take value
Expand Down
18 changes: 17 additions & 1 deletion src/cpp/src/processor_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ov::genai::ProcessorConfig::ProcessorConfig(const std::filesystem::path& json_pa
OPENVINO_ASSERT(stream.is_open(), "Failed to open '" + json_path.string() + "' with processor config");
nlohmann::json parsed = nlohmann::json::parse(stream);
using ov::genai::utils::read_json_param;
read_json_param(parsed, "patch_size", patch_size);
read_json_param(parsed, "patch_size", patch_size); // For llava - stored in config.json vision_config
read_json_param(parsed, "scale_resolution", scale_resolution);
read_json_param(parsed, "max_slice_nums", max_slice_nums);
if (parsed.contains("norm_mean")) {
Expand All @@ -19,4 +19,20 @@ ov::genai::ProcessorConfig::ProcessorConfig(const std::filesystem::path& json_pa
if (parsed.contains("norm_std")) {
norm_std = parsed.at("norm_std").get<std::array<float, 3>>();
}

// Setting llava config params
if (parsed.contains("image_mean")) {
image_mean = parsed.at("image_mean").get<std::array<float, 3>>();
}
if (parsed.contains("image_std")) {
image_std = parsed.at("image_std").get<std::array<float, 3>>();
}

if (parsed.contains("crop_size")) {
crop_size_height = parsed.at("crop_size").at("height");
crop_size_width = parsed.at("crop_size").at("width");
}
if (parsed.contains("size")) {
size_shortest_edge = parsed.at("size").at("shortest_edge");
}
}
34 changes: 34 additions & 0 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@

#include <fstream>

#include "openvino/op/add.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/tanh.hpp"
#include "openvino/op/transpose.hpp"

namespace ov {
namespace genai {
namespace utils {
Expand Down Expand Up @@ -225,6 +233,32 @@ ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::Token

return {new_input_ids, new_attention_mask};
}

void slice_matmul_statefull_model(std::shared_ptr<ov::Model> model) {
ov::Node* matmul = nullptr;
auto last_node = model->output(0).get_node()->input_value(0).get_node();
if (matmul = dynamic_cast<ov::op::v0::MatMul*>(last_node)) {
} else if(auto add = dynamic_cast<ov::op::v1::Add*>(last_node)) {
matmul = dynamic_cast<ov::op::v0::MatMul*>(add->input_value(0).get_node());
} else if (auto transpose = dynamic_cast<ov::op::v1::Transpose*>(last_node)) {
matmul = dynamic_cast<ov::op::v0::MatMul*>(transpose->input_value(0).get_node());
} else if (auto multiply = dynamic_cast<ov::op::v1::Multiply*>(last_node)) {
if (auto tanh = dynamic_cast<ov::op::v0::Tanh*>(multiply->input_value(0).get_node())) {
if (auto divide = dynamic_cast<ov::op::v1::Divide*>(tanh->input_value(0).get_node())) {
matmul = dynamic_cast<ov::op::v0::MatMul*>(divide->input_value(0).get_node());
}
}
}

if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto stop = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-2});
auto step = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1});
auto slice = std::make_shared<ov::op::v8::Slice>(matmul->input_value(0), start, stop, step, axis);
matmul->input(0).replace_source_output(slice);
}
}
} // namespace utils
} // namespace genai
} // namespace ov
2 changes: 2 additions & 0 deletions src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ ProcessorConfig from_any_map(
std::pair<ov::AnyMap, ov::AnyMap> split_core_complile_config(const ov::AnyMap& plugin_config);

ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& minuend, const ov::genai::TokenizedInputs& subtrahend);

void slice_matmul_statefull_model(std::shared_ptr<ov::Model> model);
} // namespace utils
} // namespace genai
} // namespace ov
Loading

0 comments on commit 2e86233

Please sign in to comment.