+
Llama 2 |
diff --git a/tests/python_tests/README.md b/tests/python_tests/README.md
new file mode 100644
index 0000000000..e5381708de
--- /dev/null
+++ b/tests/python_tests/README.md
@@ -0,0 +1,47 @@
+# OpenVINO⢠GenAI Tests
+
+This tests aim to validate support for vanilla and continuous batching GenAI APIs.
+
+## Setup environemnt
+
+In order to run tests first of all build or install OpenVINO GenAI library, follow instructions [GenAI Library README](../../src/README.md).
+
+Then install requirements for tests:
+```sh
+pip install -r tests/python_tests/requirements.txt
+```
+
+## Run Tests
+
+```sh
+python -m pytest tests/python_tests/ -m precommit
+```
+
+During the test downloaded HuggingFace (HF) models will be saved into the current directory. If you wish to place them somewhere else you can specify `GENAI_MODELS_PATH_PREFIX` environenment variable, e.g.
+```sh
+GENAI_MODELS_PATH_PREFIX=$HOME/test_models python -m pytest tests/python_tests/ -m precommit
+```
+
+If you have built GenAI library by yourself instead of using wheel please set `PYTHONPATH` so that test could find library, e.g.
+```sh
+PYTHONPATH=$PYTHONPATH:.../openvino.genai/build-Release/ python -m pytest tests/python_tests/ -m precommit
+```
+
+## Customise tests run
+
+Tests have `precommit` and `nightly` set of models. `precommit` contains lightweight models which can be quickly inferred, `nightly` models are heavier and required more time for interence. If you wish to run specific tests only for nightly models, you can use `-k` option, for example to run only multibatch and chat tests:
+```sh
+python -m pytest tests/python_tests/ -m nightly -k "test_multibatch and test_chat"
+```
+
+If you wish to run all tests except beam search do the following:
+```sh
+python -m pytest tests/python_tests/ -m precommit -k "not test_beam_search"
+```
+
+Argument `--model_ids` can be used to run tests selectively only for specific models. HF model ids should be separated by space, e.g:
+```sh
+python -m pytest tests/python_tests/ -m nightly -k "test_multibatch" --model_ids "TinyLlama/TinyLlama-1.1B-Chat-v1.0 Qwen/Qwen2-0.5B-Instruct"
+```
+
+List of currently supported `nightly` and `precommit` models can be found in tests/python_tests/ov_genai_test_utils.py:get_models_list
diff --git a/tests/python_tests/conftest.py b/tests/python_tests/conftest.py
index 66212468af..f98f47ecf3 100644
--- a/tests/python_tests/conftest.py
+++ b/tests/python_tests/conftest.py
@@ -14,6 +14,11 @@ def pytest_make_parametrize_id(config, val, argname):
return f'{argname}={val}'
return None
-def pytest_configure(config):
+def pytest_addoption(parser):
+ parser.addoption("--model_ids", help="Select models to run")
+
+def pytest_configure(config: pytest.Config):
marker = 'precommit' if config.getoption('-m') == 'precommit' else 'nightly'
pytest.run_marker = marker
+ pytest.selected_model_ids = config.getoption('--model_ids', default=None)
+
diff --git a/tests/python_tests/ov_genai_test_utils.py b/tests/python_tests/ov_genai_test_utils.py
index 7bceb29458..7560486d42 100644
--- a/tests/python_tests/ov_genai_test_utils.py
+++ b/tests/python_tests/ov_genai_test_utils.py
@@ -49,7 +49,10 @@ def get_models_list():
model_ids = precommit_models
else:
model_ids = nightly_models
-
+
+ if pytest.selected_model_ids:
+ model_ids = [model_id for model_id in model_ids if model_id in pytest.selected_model_ids.split(' ')]
+ # pytest.set_trace()
prefix = pathlib.Path(os.getenv('GENAI_MODELS_PATH_PREFIX', ''))
return [(model_id, prefix / model_id.split('/')[1]) for model_id in model_ids]
diff --git a/tests/python_tests/test_chat_generate_api.py b/tests/python_tests/test_chat_generate_api.py
index 94de8f6cc2..5a73d481d3 100644
--- a/tests/python_tests/test_chat_generate_api.py
+++ b/tests/python_tests/test_chat_generate_api.py
@@ -33,6 +33,7 @@
@pytest.mark.parametrize("generation_config", configs)
@pytest.mark.parametrize("model_descr", get_chat_models_list())
@pytest.mark.precommit
+@pytest.mark.nightly
def test_chat_compare_with_HF(model_descr, generation_config: Dict):
device = 'CPU'
chat_history_hf = []
@@ -69,6 +70,7 @@ def test_chat_compare_with_HF(model_descr, generation_config: Dict):
@pytest.mark.parametrize("generation_config", configs)
@pytest.mark.parametrize("model_descr", get_chat_models_list())
@pytest.mark.precommit
+@pytest.mark.nightly
def test_chat_compare_text_history_with_HF(model_descr, generation_config: Dict):
# compares with HF when history in ov_genai is save as a text
device = 'CPU'
@@ -104,6 +106,7 @@ def test_chat_compare_text_history_with_HF(model_descr, generation_config: Dict)
@pytest.mark.parametrize("generation_config", configs)
@pytest.mark.parametrize("model_descr", get_chat_models_list())
@pytest.mark.precommit
+@pytest.mark.nightly
def test_chat_compare_statefull_vs_text_history(model_descr, generation_config: Dict):
# Check that when history is stored in KV cache results are the same as when history stored in a text.
device ='CPU'
@@ -144,6 +147,7 @@ def test_chat_compare_statefull_vs_text_history(model_descr, generation_config:
{'role': 'user', 'content': 'What was my first question?'},
]
@pytest.mark.precommit
+@pytest.mark.nightly
@pytest.mark.parametrize('chat_config', get_chat_templates())
def test_apply_chat_template(model_tmp_path, chat_config: Tuple[str, Dict]):
tokenizer_config = chat_config[1]
diff --git a/tests/python_tests/test_generate_api.py b/tests/python_tests/test_generate_api.py
index 40bc121293..b4e275eef2 100644
--- a/tests/python_tests/test_generate_api.py
+++ b/tests/python_tests/test_generate_api.py
@@ -151,6 +151,7 @@ def hf_ov_genai_tensors_comparison(
@pytest.mark.parametrize("generation_config,prompt", test_cases)
@pytest.mark.parametrize("model_descr", get_models_list())
@pytest.mark.precommit
+@pytest.mark.nightly
def test_decoding(model_descr, generation_config, prompt):
run_hf_ov_genai_comparison(read_model(model_descr), generation_config, prompt)
@@ -168,6 +169,7 @@ def test_decoding(model_descr, generation_config, prompt):
condition=sys.platform in ["linux", "win32"]
)
@pytest.mark.precommit
+@pytest.mark.nightly
def test_ov_tensors(model_descr, inputs):
hf_ov_genai_tensors_comparison(read_model(model_descr), dict(max_new_tokens=20), *inputs)
@@ -182,6 +184,7 @@ def test_ov_tensors(model_descr, inputs):
@pytest.mark.parametrize("model_descr", get_models_list())
@pytest.mark.parametrize("prompt", prompts)
@pytest.mark.precommit
+@pytest.mark.nightly
@pytest.mark.xfail(
raises=TypeError,
reason="pybind was unable to find ov::Tensor from openvino yet",
@@ -217,6 +220,7 @@ def test_genai_tokenizer_encode(model_descr, prompt):
@pytest.mark.parametrize("model_descr", get_models_list())
@pytest.mark.parametrize("encoded_prompt", encoded_prompts)
@pytest.mark.precommit
+@pytest.mark.nightly
@pytest.mark.xfail(
raises=TypeError,
reason="pybind was unable to find ov::Tensor from openvino yet",
@@ -252,6 +256,7 @@ def test_genai_tokenizer_decode(model_descr, encoded_prompt):
@pytest.mark.parametrize("prompts", batched_prompts)
@pytest.mark.parametrize("model_descr", get_models_list())
@pytest.mark.precommit
+@pytest.mark.nightly
def test_multibatch(model_descr, generation_config, prompts):
run_hf_ov_genai_comparison_batched(read_model(model_descr), generation_config, prompts)
@@ -264,6 +269,7 @@ def test_multibatch(model_descr, generation_config, prompts):
@pytest.mark.parametrize("prompt", prompts)
@pytest.mark.parametrize("model_descr", get_models_list())
@pytest.mark.precommit
+@pytest.mark.nightly
def test_beam_search_decoding(model_descr, num_beam_groups, group_size,
max_new_tokens, diversity_penalty, prompt):
generation_config = dict(
@@ -281,6 +287,7 @@ def test_beam_search_decoding(model_descr, num_beam_groups, group_size,
@pytest.mark.parametrize("max_new_tokens", [10, 80])
@pytest.mark.parametrize("model_descr", get_models_list())
@pytest.mark.precommit
+@pytest.mark.nightly
def test_stop_criteria(model_descr, stop_criteria, prompt, max_new_tokens):
# todo: with EARLY stop_criteria looks like HF return unvalid out with sentence
# while genai ends sentence with
@@ -323,6 +330,7 @@ def user_defined_callback(subword):
@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
+@pytest.mark.nightly
def test_callback_one_string(callback):
pipe = read_model(get_models_list()[0])[4]
generation_config = pipe.get_generation_config()
@@ -332,6 +340,7 @@ def test_callback_one_string(callback):
@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
+@pytest.mark.nightly
def test_callback_batch_fail(callback):
pipe = read_model(get_models_list()[0])[4]
with pytest.raises(RuntimeError):
@@ -340,12 +349,14 @@ def test_callback_batch_fail(callback):
@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
+@pytest.mark.nightly
def test_callback_kwargs_one_string(callback):
pipe = read_model(get_models_list()[0])[4]
pipe.generate('table is made of', max_new_tokens=10, streamer=callback)
@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
+@pytest.mark.nightly
@pytest.mark.parametrize("model_descr", get_models_list())
def test_callback_decoding_metallama(model_descr, callback):
# On metallam this prompt generates output which can shorten after adding new tokens.
@@ -359,6 +370,7 @@ def test_callback_decoding_metallama(model_descr, callback):
@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
+@pytest.mark.nightly
def test_callback_kwargs_batch_fail(callback):
pipe = read_model(get_models_list()[0])[4]
with pytest.raises(RuntimeError):
@@ -380,6 +392,7 @@ def end(self):
@pytest.mark.precommit
+@pytest.mark.nightly
def test_streamer_one_string():
pipe = read_model(get_models_list()[0])[4]
generation_config = pipe.get_generation_config()
@@ -389,6 +402,7 @@ def test_streamer_one_string():
@pytest.mark.precommit
+@pytest.mark.nightly
def test_streamer_batch_fail():
pipe = read_model(get_models_list()[0])[4]
printer = Printer(pipe.get_tokenizer())
@@ -397,6 +411,7 @@ def test_streamer_batch_fail():
@pytest.mark.precommit
+@pytest.mark.nightly
def test_streamer_kwargs_one_string():
pipe = read_model(get_models_list()[0])[4]
printer = Printer(pipe.get_tokenizer())
@@ -404,6 +419,7 @@ def test_streamer_kwargs_one_string():
@pytest.mark.precommit
+@pytest.mark.nightly
def test_streamer_kwargs_batch_fail():
pipe = read_model(get_models_list()[0])[4]
printer = Printer(pipe.get_tokenizer())
@@ -412,6 +428,7 @@ def test_streamer_kwargs_batch_fail():
@pytest.mark.precommit
+@pytest.mark.nightly
@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
def test_operator_with_callback_one_string(callback):
pipe = read_model(get_models_list()[0])[4]
@@ -421,6 +438,7 @@ def test_operator_with_callback_one_string(callback):
@pytest.mark.precommit
+@pytest.mark.nightly
@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
def test_operator_with_callback_batch_fail(callback):
pipe = read_model(get_models_list()[0])[4]
@@ -429,6 +447,7 @@ def test_operator_with_callback_batch_fail(callback):
@pytest.mark.precommit
+@pytest.mark.nightly
def test_operator_with_streamer_kwargs_one_string():
pipe = read_model(get_models_list()[0])[4]
printer = Printer(pipe.get_tokenizer())
@@ -436,6 +455,7 @@ def test_operator_with_streamer_kwargs_one_string():
@pytest.mark.precommit
+@pytest.mark.nightly
def test_operator_with_streamer_kwargs_batch_fail():
pipe = read_model(get_models_list()[0])[4]
printer = Printer(pipe.get_tokenizer())
@@ -444,6 +464,7 @@ def test_operator_with_streamer_kwargs_batch_fail():
@pytest.mark.precommit
+@pytest.mark.nightly
def test_load_special_tokens_ids_1(model_tmp_path):
# test when there is an available config.json
config_json = {
@@ -458,6 +479,7 @@ def test_load_special_tokens_ids_1(model_tmp_path):
@pytest.mark.precommit
+@pytest.mark.nightly
def test_load_special_tokens_str_2(model_tmp_path):
# test with special_tokens_map
special_tokens_map_json = {
@@ -472,6 +494,7 @@ def test_load_special_tokens_str_2(model_tmp_path):
@pytest.mark.precommit
+@pytest.mark.nightly
def test_load_special_tokens_3_(model_tmp_path):
# special_tokens_map is not available
# but tokenize_config.json exists
@@ -498,6 +521,7 @@ def test_load_special_tokens_3_(model_tmp_path):
@pytest.mark.precommit
+@pytest.mark.nightly
def test_load_special_tokens_3(model_tmp_path):
# both config.json is availabel and tokenizer_config.json available
# check that it does not read int values from tokenizer_config.json if they are in config.json
@@ -532,6 +556,7 @@ def test_load_special_tokens_3(model_tmp_path):
@pytest.mark.precommit
+@pytest.mark.nightly
@pytest.mark.xfail(
raises=AssertionError,
reason="CVS-143410 ov tokenizer should be aligned with hf",
@@ -575,6 +600,7 @@ def test_load_special_tokens_4(model_tmp_path):
]
@pytest.mark.parametrize("generation_config", invalid_configs)
@pytest.mark.precommit
+@pytest.mark.nightly
def test_invalid_configs(model_tmp_path, generation_config):
model_id, temp_path = model_tmp_path
config_json = {}
@@ -584,6 +610,7 @@ def test_invalid_configs(model_tmp_path, generation_config):
@pytest.mark.precommit
+@pytest.mark.nightly
def test_valid_configs(model_tmp_path):
model_id, temp_path = model_tmp_path
pipe = load_pipe([({"eos_token_id": 37}, "config.json")], temp_path)
@@ -602,6 +629,7 @@ def test_valid_configs(model_tmp_path):
dict(top_k=0, do_sample=True, eos_token_id=42, max_new_tokens=20), # invalid top_k
]
@pytest.mark.precommit
+@pytest.mark.nightly
@pytest.mark.parametrize("generation_config", invalid_py_configs)
def test_python_generation_config_validation(model_tmp_path, generation_config):
model_id, temp_path = model_tmp_path
@@ -615,6 +643,7 @@ def test_python_generation_config_validation(model_tmp_path, generation_config):
@pytest.mark.precommit
+@pytest.mark.nightly
def test_unicode_pybind_decoding_1():
# On this model this prompt generates unfinished utf string.
# Test that pybind will not fail.
@@ -626,6 +655,7 @@ def test_unicode_pybind_decoding_1():
@pytest.mark.precommit
+@pytest.mark.nightly
def test_unicode_pybind_decoding_2():
# On this model this prompt generates unfinished utf string.
# Test that pybind will not fail.
@@ -636,6 +666,7 @@ def test_unicode_pybind_decoding_2():
@pytest.mark.precommit
+@pytest.mark.nightly
def test_unicode_pybind_decoding_3():
# On this model this prompt generates unfinished utf-8 string
# and streams it. Test that pybind will not fail while we pass string to python.
@@ -648,6 +679,7 @@ def test_unicode_pybind_decoding_3():
@pytest.mark.skip(reason="probably both models ov + hf doesn't fit to memory")
@pytest.mark.precommit
+@pytest.mark.nightly
@pytest.mark.skipif(sys.platform.startswith("win"), reason="not enough space for this model on Win")
def test_left_pad():
# test left pad tokenizer post processing implementation
From 944321854d77c14cf02a0ff1d32b89ba4e7a1f62 Mon Sep 17 00:00:00 2001
From: Damian Kalinowski
Date: Wed, 24 Jul 2024 08:37:34 +0200
Subject: [PATCH 11/19] Add infer request queue for tokenizers and allow for
optional plugin_config in tokenizer (#651)
This improves performance of CB lib when tested within OVMS.
---
.../genai/continuous_batching_pipeline.hpp | 3 +-
src/cpp/include/openvino/genai/tokenizer.hpp | 2 +-
src/cpp/src/circular_buffer_queue.hpp | 100 ++++++++++++++++++
src/cpp/src/continuous_batching_pipeline.cpp | 9 +-
src/cpp/src/tokenizer.cpp | 98 ++++++++++-------
src/python/py_generate_pipeline.cpp | 12 +--
tests/python_tests/common.py | 2 +-
tests/python_tests/ov_genai_test_utils.py | 2 +-
tests/python_tests/test_sampling.py | 2 +-
9 files changed, 179 insertions(+), 51 deletions(-)
create mode 100644 src/cpp/src/circular_buffer_queue.hpp
diff --git a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp
index be9a5fd8c1..f5f8c53309 100644
--- a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp
+++ b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp
@@ -30,7 +30,8 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
ContinuousBatchingPipeline(const std::string& models_path,
const SchedulerConfig& scheduler_config,
const std::string& device = "CPU",
- const ov::AnyMap& plugin_config = {});
+ const ov::AnyMap& llm_plugin_config = {},
+ const ov::AnyMap& tokenizer_plugin_config = {});
/**
* @brief Constructs a ContinuousBatchingPipeline when ov::genai::Tokenizer is initialized manually using file from the different dirs.
diff --git a/src/cpp/include/openvino/genai/tokenizer.hpp b/src/cpp/include/openvino/genai/tokenizer.hpp
index 5a1e181e21..425c30128b 100644
--- a/src/cpp/include/openvino/genai/tokenizer.hpp
+++ b/src/cpp/include/openvino/genai/tokenizer.hpp
@@ -29,7 +29,7 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
* @brief ov::genai::Tokenizer constructor.
* @param tokenizer_path openvino_tokenizer.xml and openvino_detokenizer.xml should be located in the tokenizer_path
*/
- Tokenizer(const std::string& tokenizer_path);
+ Tokenizer(const std::string& tokenizer_path, const ov::AnyMap& plugin_config = {});
/**
* @brief encode a single prompt
diff --git a/src/cpp/src/circular_buffer_queue.hpp b/src/cpp/src/circular_buffer_queue.hpp
new file mode 100644
index 0000000000..086854e68e
--- /dev/null
+++ b/src/cpp/src/circular_buffer_queue.hpp
@@ -0,0 +1,100 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+namespace ov::genai {
+
+// From OVMS:
+// https://github.com/openvinotoolkit/model_server/blob/d73e85cbb8ac1d761754cb2064a00551a9ffc655/src/queue.hpp#L34
+template
+class CircularBufferQueue
+{
+ int m_front_idx;
+ std::atomic m_back_idx;
+ std::vector m_values;
+ std::queue> m_promises;
+ std::vector m_data;
+ std::mutex m_front_mut;
+ std::mutex m_queue_mutex;
+
+public:
+
+ CircularBufferQueue(size_t length, const std::function& create_fn) :
+ m_values(length),
+ m_front_idx{0},
+ m_back_idx{0} {
+ std::iota(m_values.begin(), m_values.end(), 0);
+ m_data.reserve(length);
+ for (size_t i = 0; i < length; i++) {
+ m_data.emplace_back(std::move(create_fn()));
+ }
+ }
+
+ CircularBufferQueue(const CircularBufferQueue&) = delete;
+ CircularBufferQueue(const CircularBufferQueue&&) = delete;
+ CircularBufferQueue& operator=(const CircularBufferQueue&) = delete;
+
+ T& get(int value) {
+ return m_data[value];
+ }
+
+ std::future get_idle() {
+ int value;
+ std::promise idle_promise;
+ std::future idle_future = idle_promise.get_future();
+ std::unique_lock lk(m_front_mut);
+ if (m_values[m_front_idx] < 0) {
+ std::unique_lock queueLock(m_queue_mutex);
+ m_promises.push(std::move(idle_promise));
+ } else {
+ value = m_values[m_front_idx];
+ m_values[m_front_idx] = -1;
+ m_front_idx = (m_front_idx + 1) % m_values.size();
+ lk.unlock();
+ idle_promise.set_value(value);
+ }
+ return idle_future;
+ }
+
+ void return_to(int value) {
+ std::unique_lock lk(m_queue_mutex);
+ if (m_promises.size()) {
+ std::promise promise = std::move(m_promises.front());
+ m_promises.pop();
+ lk.unlock();
+ promise.set_value(value);
+ return;
+ }
+ int old_back = m_back_idx.load();
+ while (!m_back_idx.compare_exchange_weak(
+ old_back,
+ (old_back + 1) % m_values.size(),
+ std::memory_order_relaxed)) {
+ }
+ m_values[old_back] = value;
+ }
+};
+
+template
+class CircularBufferQueueElementGuard {
+ CircularBufferQueue* m_queue;
+ int m_value;
+public:
+ CircularBufferQueueElementGuard(CircularBufferQueue* queue) : m_queue(queue) {
+ m_value = m_queue->get_idle().get(); // blocking until we get the element
+ }
+
+ T& get() {
+ return m_queue->get(m_value);
+ }
+
+ ~CircularBufferQueueElementGuard() {
+ m_queue->return_to(m_value);
+ }
+};
+
+}
diff --git a/src/cpp/src/continuous_batching_pipeline.cpp b/src/cpp/src/continuous_batching_pipeline.cpp
index ddfebc5926..55100f3cb4 100644
--- a/src/cpp/src/continuous_batching_pipeline.cpp
+++ b/src/cpp/src/continuous_batching_pipeline.cpp
@@ -105,8 +105,8 @@ class ContinuousBatchingPipeline::Impl {
// read default generation config
}
- Impl(const std::string& models_path, const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& plugin_config)
- : Impl{models_path, Tokenizer(models_path), scheduler_config, device, plugin_config} {}
+ Impl(const std::string& models_path, const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& llm_plugin_config, const ov::AnyMap& tokenizer_plugin_config)
+ : Impl{models_path, Tokenizer(models_path, tokenizer_plugin_config), scheduler_config, device, llm_plugin_config} {}
ov::genai::GenerationConfig get_config() const {
return m_generation_config;
@@ -282,8 +282,9 @@ class ContinuousBatchingPipeline::Impl {
ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::string& models_path,
const SchedulerConfig& scheduler_config,
const std::string& device,
- const ov::AnyMap& plugin_config ) {
- m_impl = std::make_shared(models_path, scheduler_config, device, plugin_config);
+ const ov::AnyMap& llm_plugin_config,
+ const ov::AnyMap& tokenizer_plugin_config) {
+ m_impl = std::make_shared(models_path, scheduler_config, device, llm_plugin_config, tokenizer_plugin_config);
}
ContinuousBatchingPipeline::ContinuousBatchingPipeline(
diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp
index ac6b925dcb..b1e36033ee 100644
--- a/src/cpp/src/tokenizer.cpp
+++ b/src/cpp/src/tokenizer.cpp
@@ -7,7 +7,9 @@
#include
#include
#include "tokenizers_path.hpp"
+#include "circular_buffer_queue.hpp"
#include
+#include
namespace {
@@ -55,10 +57,12 @@ namespace genai {
class Tokenizer::TokenizerImpl {
public:
- ov::InferRequest m_tokenizer_request;
- ov::InferRequest m_detokenizer_request;
- std::mutex m_tokenizer_mutex;
- std::mutex m_detokenizer_mutex;
+ ov::CompiledModel m_tokenizer;
+ ov::CompiledModel m_detokenizer;
+
+ std::unique_ptr> m_ireq_queue_tokenizer;
+ std::unique_ptr> m_ireq_queue_detokenizer;
+
int64_t m_pad_token_id = -1;
int64_t m_bos_token_id = -1;
int64_t m_eos_token_id = -1;
@@ -71,7 +75,7 @@ class Tokenizer::TokenizerImpl {
TokenizerImpl() = default;
- TokenizerImpl(std::filesystem::path tokenizer_path)
+ TokenizerImpl(std::filesystem::path tokenizer_path, const ov::AnyMap& plugin_config)
: m_chat_template{chat_template_from_tokenizer_json_if_exists(tokenizer_path)} {
ov::Core core;
@@ -92,10 +96,23 @@ class Tokenizer::TokenizerImpl {
read_tokenizer_config_if_necessary(tokenizer_path);
auto device = "CPU"; // currently openvino_tokenizer supports only CPU
- m_tokenizer_request = core.compile_model(tokenizer_path / "openvino_tokenizer.xml",
- device).create_infer_request();
- m_detokenizer_request = core.compile_model(tokenizer_path / "openvino_detokenizer.xml",
- device).create_infer_request();
+ m_tokenizer = core.compile_model(tokenizer_path / "openvino_tokenizer.xml",
+ device, plugin_config);
+ m_detokenizer = core.compile_model(tokenizer_path / "openvino_detokenizer.xml",
+ device, plugin_config);
+
+
+ const size_t INFER_REQUEST_QUEUE_SIZE = m_tokenizer.get_property(ov::optimal_number_of_infer_requests);
+ m_ireq_queue_tokenizer = std::make_unique>(
+ INFER_REQUEST_QUEUE_SIZE,
+ [this]() -> ov::InferRequest {
+ return std::move(this->m_tokenizer.create_infer_request());
+ });
+ m_ireq_queue_detokenizer = std::make_unique>(
+ INFER_REQUEST_QUEUE_SIZE,
+ [this]() -> ov::InferRequest {
+ return std::move(this->m_detokenizer.create_infer_request());
+ });
// Get special token ids by inference if they are not defined.
infer_special_tokens_if_necessary();
@@ -231,29 +248,35 @@ class Tokenizer::TokenizerImpl {
}
TokenizedInputs encode(std::string prompt) {
+ CircularBufferQueueElementGuard infer_request_guard(this->m_ireq_queue_tokenizer.get());
size_t batch_size = 1;
- std::unique_lock lock(m_tokenizer_mutex);
- m_tokenizer_request.set_input_tensor(ov::Tensor{ov::element::string, {batch_size}, &prompt});
- m_tokenizer_request.infer();
- return get_copied_results();
+ infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::string, {batch_size}, &prompt});
+ infer_request_guard.get().start_async();
+ infer_request_guard.get().wait();
+ return get_copied_results(
+ infer_request_guard.get().get_tensor("input_ids"),
+ infer_request_guard.get().get_tensor("attention_mask")
+ );
}
TokenizedInputs encode(std::vector& prompts) {
TokenizedInputs unpadded;
{
- std::unique_lock lock(m_tokenizer_mutex);
- m_tokenizer_request.set_input_tensor(ov::Tensor{ov::element::string, {prompts.size()}, prompts.data()});
- auto size_ = m_tokenizer_request.get_input_tensor().get_shape();
- m_tokenizer_request.infer();
-
- unpadded = get_copied_results();
+ CircularBufferQueueElementGuard infer_request_guard(this->m_ireq_queue_tokenizer.get());
+ infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::string, {prompts.size()}, prompts.data()});
+ auto size_ = infer_request_guard.get().get_input_tensor().get_shape();
+ infer_request_guard.get().start_async();
+ infer_request_guard.get().wait();
+
+ unpadded = get_copied_results(
+ infer_request_guard.get().get_tensor("input_ids"),
+ infer_request_guard.get().get_tensor("attention_mask")
+ );
}
return pad_left(unpadded.input_ids, unpadded.attention_mask);
}
- TokenizedInputs get_copied_results() {
- auto input_ids = m_tokenizer_request.get_tensor("input_ids");
- auto attention_mask = m_tokenizer_request.get_tensor("attention_mask");
+ TokenizedInputs get_copied_results(ov::Tensor input_ids, ov::Tensor attention_mask) {
ov::Tensor input_ids_ = ov::Tensor(input_ids.get_element_type(), input_ids.get_shape());
ov::Tensor attention_mask_ = ov::Tensor(attention_mask.get_element_type(), attention_mask.get_shape());
input_ids.copy_to(input_ids_);
@@ -263,22 +286,24 @@ class Tokenizer::TokenizerImpl {
}
std::string decode(std::vector tokens) {
+ CircularBufferQueueElementGuard infer_request_guard(this->m_ireq_queue_detokenizer.get());
size_t batch_size = 1;
- std::unique_lock lock(m_detokenizer_mutex);
- m_detokenizer_request.set_input_tensor(ov::Tensor{ov::element::i64, {batch_size, tokens.size()}, tokens.data()});
- m_detokenizer_request.infer();
- return m_detokenizer_request.get_output_tensor().data()[0];
+ infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::i64, {batch_size, tokens.size()}, tokens.data()});
+ infer_request_guard.get().start_async();
+ infer_request_guard.get().wait();
+ return infer_request_guard.get().get_output_tensor().data()[0];
}
std::vector decode(ov::Tensor tokens) {
OPENVINO_ASSERT(tokens.get_element_type() == ov::element::i64, "tokens tensor element type should be an i64");
OPENVINO_ASSERT(tokens.get_shape().size() == 2, "tokens tensor should of rank 2 with shape [batch_size, seq_len]");
- std::unique_lock lock(m_detokenizer_mutex);
- m_detokenizer_request.set_input_tensor(tokens);
- m_detokenizer_request.infer();
+ CircularBufferQueueElementGuard infer_request_guard(this->m_ireq_queue_detokenizer.get());
+ infer_request_guard.get().set_input_tensor(tokens);
+ infer_request_guard.get().start_async();
+ infer_request_guard.get().wait();
- auto res = m_detokenizer_request.get_output_tensor();
+ auto res = infer_request_guard.get().get_output_tensor();
auto res_data = res.data();
return std::vector(res_data, res_data + res.get_shape()[0]);
}
@@ -299,10 +324,11 @@ class Tokenizer::TokenizerImpl {
std::fill(tokens_data + i * max_len + line_len, tokens_data + (i + 1) * max_len, m_pad_token_id);
}
- std::unique_lock lock(m_detokenizer_mutex);
- m_detokenizer_request.set_input_tensor(tokens);
- m_detokenizer_request.infer();
- auto res = m_detokenizer_request.get_output_tensor();
+ CircularBufferQueueElementGuard infer_request_guard(this->m_ireq_queue_detokenizer.get());
+ infer_request_guard.get().set_input_tensor(tokens);
+ infer_request_guard.get().start_async();
+ infer_request_guard.get().wait();
+ auto res = infer_request_guard.get().get_output_tensor();
auto res_data = res.data();
return std::vector(res_data, res_data + res.get_shape()[0]);
}
@@ -411,9 +437,9 @@ class Tokenizer::TokenizerImpl {
};
-Tokenizer::Tokenizer(const std::string& tokenizer_path) {
+Tokenizer::Tokenizer(const std::string& tokenizer_path, const ov::AnyMap& plugin_config) {
ScopedVar env_manager(tokenizers_relative_to_genai().string());
- m_pimpl = std::make_shared(tokenizer_path);
+ m_pimpl = std::make_shared(tokenizer_path, plugin_config);
}
TokenizedInputs Tokenizer::encode(const std::string prompt) {
diff --git a/src/python/py_generate_pipeline.cpp b/src/python/py_generate_pipeline.cpp
index d7b2aab29c..8a1a226bc1 100644
--- a/src/python/py_generate_pipeline.cpp
+++ b/src/python/py_generate_pipeline.cpp
@@ -436,10 +436,10 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
R"(openvino_genai.Tokenizer object is used to initialize Tokenizer
if it's located in a different path than the main model.)")
- .def(py::init([](const std::string& tokenizer_path) {
+ .def(py::init([](const std::string& tokenizer_path, const std::map& plugin_config) {
ScopedVar env_manager(ov_tokenizers_module_path());
- return std::make_unique(tokenizer_path);
- }), py::arg("tokenizer_path"))
+ return std::make_unique(tokenizer_path, properties_to_any_map(plugin_config));
+ }), py::arg("tokenizer_path"), py::arg("plugin_config") = ov::AnyMap({}))
.def("encode", [](Tokenizer& tok, std::vector& prompts) { return tok.encode(prompts); },
py::arg("prompts"),
@@ -596,10 +596,10 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
.def_readwrite("max_num_seqs", &SchedulerConfig::max_num_seqs);
py::class_(m, "ContinuousBatchingPipeline")
- .def(py::init([](const std::string& model_path, const SchedulerConfig& scheduler_config, const std::string& device, const std::map& plugin_config) {
+ .def(py::init([](const std::string& model_path, const SchedulerConfig& scheduler_config, const std::string& device, const std::map& llm_plugin_config, const std::map& tokenizer_plugin_config) {
ScopedVar env_manager(ov_tokenizers_module_path());
- return std::make_unique(model_path, scheduler_config, device, properties_to_any_map(plugin_config));
- }), py::arg("model_path"), py::arg("scheduler_config"), py::arg("device") = "CPU", py::arg("plugin_config") = ov::AnyMap({}))
+ return std::make_unique(model_path, scheduler_config, device, properties_to_any_map(llm_plugin_config), properties_to_any_map(tokenizer_plugin_config));
+ }), py::arg("model_path"), py::arg("scheduler_config"), py::arg("device") = "CPU", py::arg("llm_plugin_config") = ov::AnyMap({}), py::arg("tokenizer_plugin_config") = ov::AnyMap({}))
.def(py::init([](const std::string& model_path, const ov::genai::Tokenizer& tokenizer, const SchedulerConfig& scheduler_config, const std::string& device, const std::map& plugin_config) {
ScopedVar env_manager(ov_tokenizers_module_path());
return std::make_unique(model_path, tokenizer, scheduler_config, device, properties_to_any_map(plugin_config));
diff --git a/tests/python_tests/common.py b/tests/python_tests/common.py
index 95046a463a..0a94558274 100644
--- a/tests/python_tests/common.py
+++ b/tests/python_tests/common.py
@@ -273,7 +273,7 @@ def run_continuous_batching(
prompts: List[str],
generation_configs : List[GenerationConfig]
) -> List[GenerationResult]:
- pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), scheduler_config, "CPU", {})
+ pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), scheduler_config, "CPU", {}, {})
output = pipe.generate(prompts, generation_configs)
del pipe
shutil.rmtree(model_path)
diff --git a/tests/python_tests/ov_genai_test_utils.py b/tests/python_tests/ov_genai_test_utils.py
index 7560486d42..bf76df534d 100644
--- a/tests/python_tests/ov_genai_test_utils.py
+++ b/tests/python_tests/ov_genai_test_utils.py
@@ -208,7 +208,7 @@ def load_tok(configs: List[Tuple], temp_path):
for config_json, config_name in configs:
with (temp_path / config_name).open('w') as f:
json.dump(config_json, f)
- return ov_genai.Tokenizer(str(temp_path))
+ return ov_genai.Tokenizer(str(temp_path), {})
def load_pipe(configs: List[Tuple], temp_path):
diff --git a/tests/python_tests/test_sampling.py b/tests/python_tests/test_sampling.py
index 27596359bf..9b34cd2f5b 100644
--- a/tests/python_tests/test_sampling.py
+++ b/tests/python_tests/test_sampling.py
@@ -306,7 +306,7 @@ def test_post_oom_health(tmp_path):
model_path : Path = tmp_path / model_id
save_ov_model_from_optimum(model, hf_tokenizer, model_path)
- pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), Tokenizer(model_path.absolute().as_posix()), scheduler_config)
+ pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), Tokenizer(model_path.absolute().as_posix(), {}), scheduler_config, "CPU", {})
# First run should return incomplete response
output = pipe.generate(["What is OpenVINO?"], generation_configs)
assert(len(output))
From 04012f473c0eac190701926366e9b05704b80196 Mon Sep 17 00:00:00 2001
From: Alexander Suvorov
Date: Wed, 24 Jul 2024 09:40:37 +0200
Subject: [PATCH 12/19] Skip test_preemption_with_multinomial_n_seq (#667)
Random sampling
---
tests/python_tests/test_preemption.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/tests/python_tests/test_preemption.py b/tests/python_tests/test_preemption.py
index 8c9bda1d33..cce74136eb 100644
--- a/tests/python_tests/test_preemption.py
+++ b/tests/python_tests/test_preemption.py
@@ -161,6 +161,7 @@ def test_preemption_with_multinomial(tmp_path, dynamic_split_fuse):
@pytest.mark.parametrize("dynamic_split_fuse", [True, False])
@pytest.mark.precommit
+@pytest.mark.skip(reason="Random sampling results are non deterministic due to: discrete_distribution impl depends on platform, model inference results may depend on CPU. Test passes on CI but fails locally.")
def test_preemption_with_multinomial_n_seq(tmp_path, dynamic_split_fuse):
generation_configs = multinomial_params_n_seq.generation_config
for config in generation_configs:
From cc5e2356d64b709f765fda5563113b7802855db4 Mon Sep 17 00:00:00 2001
From: Sylwia Kuros
Date: Wed, 24 Jul 2024 12:19:54 +0200
Subject: [PATCH 13/19] Set torchvision to < 0.19.0 (#668)
Using torchvision with version 0.19.0 causes the following issue:
```
Traceback (most recent call last):
File "C:\Program Files\Python310\lib\site-packages\transformers\utils\import_utils.py", line 1567, in _get_module
return importlib.import_module("." + module_name, self.__name__)
File "C:\Program Files\Python310\lib\importlib\__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "", line 1050, in _gcd_import
File "", line 1027, in _find_and_load
File "", line 1006, in _find_and_load_unlocked
File "", line 688, in _load_unlocked
File "", line 883, in exec_module
File "", line 241, in _call_with_frames_removed
File "C:\Program Files\Python310\lib\site-packages\transformers\models\auto\image_processing_auto.py", line 27, in
from ...image_processing_utils import BaseImageProcessor, ImageProcessingMixin
File "C:\Program Files\Python310\lib\site-packages\transformers\image_processing_utils.py", line 21, in
from .image_transforms import center_crop, normalize, rescale
File "C:\Program Files\Python310\lib\site-packages\transformers\image_transforms.py", line 22, in
from .image_utils import (
File "C:\Program Files\Python310\lib\site-packages\transformers\image_utils.py", line 58, in
from torchvision.transforms import InterpolationMode
File "C:\Program Files\Python310\lib\site-packages\torchvision\__init__.py", line 10, in
from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip
File "C:\Program Files\Python310\lib\site-packages\torchvision\_meta_registrations.py", line 163, in
@torch.library.register_fake("torchvision::nms")
AttributeError: module 'torch.library' has no attribute 'register_fake'
```
---
llm_bench/python/requirements.txt | 1 +
1 file changed, 1 insertion(+)
diff --git a/llm_bench/python/requirements.txt b/llm_bench/python/requirements.txt
index ed80a66deb..d83cd5a376 100644
--- a/llm_bench/python/requirements.txt
+++ b/llm_bench/python/requirements.txt
@@ -7,6 +7,7 @@ openvino_genai
auto-gptq>=0.5.1 # for gptq
pillow
torch
+torchvision<0.19.0
transformers>=4.40.0
diffusers>=0.22.0
#optimum is in dependency list of optimum-intel
From 42dd04900cded77671ae1fa9d50f888180ace73f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mi=C5=82osz=20=C5=BBeglarski?=
Date: Wed, 24 Jul 2024 12:14:35 +0200
Subject: [PATCH 14/19] [Continuous batching] In the event of OOM, return
tokens generated so far for the request (#661)
---
src/cpp/src/sequence_group.hpp | 71 ++++++++++++-----------------
tests/python_tests/test_sampling.py | 11 +++--
2 files changed, 36 insertions(+), 46 deletions(-)
diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp
index 3df1820cfb..88b86b4484 100644
--- a/src/cpp/src/sequence_group.hpp
+++ b/src/cpp/src/sequence_group.hpp
@@ -425,59 +425,46 @@ class SequenceGroup {
return m_generation_stream->get_status() == GenerationStatus::DROPPED_BY_HANDLE;
}
- void notify_handle() {
+ void push_outputs() {
+ GenerationOutputs outputs;
+ for (auto& sequence: m_sequences) {
+ GenerationOutput output;
+ output.generated_token_ids = sequence->get_generated_ids();
+ output.score = sequence->get_beam_search_score(m_sampling_params);
+ outputs.emplace(sequence->get_grouped_id(), output);
+ }
+ m_generation_stream->push(outputs);
+ }
+
+ void push_partial_outputs() {
+ GenerationOutputs outputs;
+ // TODO: support streamimg for n seqs
+ for (auto& sequence : m_sequences) {
+ // todo: check seq.is_finished() to generate without several
+ // or is it ok to use padding?
+ const auto last_gen_token = sequence->get_last_generation_output();
+ outputs.emplace(sequence->get_grouped_id(), last_gen_token);
+ }
+ m_generation_stream->push(outputs);
+ }
+ void notify_handle() {
if (out_of_memory()) {
set_generation_status(GenerationStatus::IGNORED);
} else if (has_finished()) {
set_generation_status(GenerationStatus::FINISHED);
}
-
- GenerationOutputs outputs;
-
// For beam search streaming is not available, so we notify only upon finishing
if(m_sampling_params.is_beam_search()) {
- if (has_finished()) {
- std::vector finished_sequences = get_finished_sequences();
-
- OPENVINO_ASSERT(finished_sequences.size() == num_total_seqs() && has_finished());
- for (auto& sequence: finished_sequences) {
- GenerationOutput output;
- output.generated_token_ids = sequence->get_generated_ids();
- output.score = sequence->get_beam_search_score(m_sampling_params);
- outputs.emplace(sequence->get_grouped_id(), output);
- }
-
- if (outputs.size()) {
- m_generation_stream->push(outputs);
- }
+ if (has_finished() || out_of_memory()) {
+ push_outputs();
}
- // For greedy or multinomial sampling we decide whever to stream partial results depending on the user parameter
} else if (m_sampling_params.is_greedy_decoding() || m_sampling_params.is_multinomial()) {
// TO DO: Now we always stream for greedy search for the sake of benchmarking
- if (num_total_seqs() == 1 /* m_sampling_params.stream */) {
- // TODO: support streamimg for n seqs
- for (auto& sequence : m_sequences) {
- // todo: check seq.is_finished() to generate without several
- // or is it ok to use padding?
- const auto last_gen_token = sequence->get_last_generation_output();
- outputs.emplace(sequence->get_grouped_id(), last_gen_token);
- }
- m_generation_stream->push(outputs);
- } else if (has_finished()) {
- std::vector finished_sequences = get_finished_sequences();
-
- OPENVINO_ASSERT(finished_sequences.size() == num_total_seqs() && has_finished());
- for (auto& sequence: finished_sequences) {
- GenerationOutput output;
- output.generated_token_ids = sequence->get_generated_ids();
- output.score = sequence->get_cumulative_log_probs();
- outputs.emplace(sequence->get_grouped_id(), output);
- }
-
- if (outputs.size()) {
- m_generation_stream->push(outputs);
- }
+ if (num_total_seqs() == 1) {
+ push_partial_outputs();
+ } else if (has_finished() || out_of_memory()) {
+ push_outputs();
}
}
}
diff --git a/tests/python_tests/test_sampling.py b/tests/python_tests/test_sampling.py
index 9b34cd2f5b..741c89db78 100644
--- a/tests/python_tests/test_sampling.py
+++ b/tests/python_tests/test_sampling.py
@@ -291,8 +291,9 @@ def test_individual_generation_configs_random(tmp_path, test_struct: RandomSampl
@pytest.mark.precommit
-def test_post_oom_health(tmp_path):
- generation_config = get_greedy()
+@pytest.mark.parametrize("sampling_config", [get_greedy(), get_beam_search(), get_multinomial_all_parameters()])
+def test_post_oom_health(tmp_path, sampling_config):
+ generation_config = sampling_config
generation_config.ignore_eos = True
generation_config.max_new_tokens = 1000000
@@ -309,9 +310,11 @@ def test_post_oom_health(tmp_path):
pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), Tokenizer(model_path.absolute().as_posix(), {}), scheduler_config, "CPU", {})
# First run should return incomplete response
output = pipe.generate(["What is OpenVINO?"], generation_configs)
- assert(len(output))
+ assert (len(output))
+ assert(len(output[0].m_generation_ids))
# Same for the second run, here we want to make sure the cleanup works and we have free blocks after recent OOM
output = pipe.generate(["What is OpenVINO?"], generation_configs)
- assert(len(output))
+ assert (len(output))
+ assert(len(output[0].m_generation_ids))
del pipe
shutil.rmtree(model_path)
\ No newline at end of file
From 97595208b02dd479bf159305bec00b5cf1a9999f Mon Sep 17 00:00:00 2001
From: Zlobin Vladimir
Date: Thu, 25 Jul 2024 14:30:39 +0400
Subject: [PATCH 15/19] Bump version only (#684)
---
CMakeLists.txt | 2 +-
pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 27ed56b453..f45ab24279 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -18,7 +18,7 @@ elseif(NOT GENERATOR_IS_MULTI_CONFIG_VAR AND NOT DEFINED CMAKE_BUILD_TYPE)
endif()
project(OpenVINOGenAI
- VERSION 2024.3.0.0
+ VERSION 2024.4.0.0
DESCRIPTION "OpenVINO GenAI"
HOMEPAGE_URL "https://github.com/openvinotoolkit/openvino.genai"
LANGUAGES CXX)
diff --git a/pyproject.toml b/pyproject.toml
index 7cfa564ef9..af55c3f684 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "openvino_genai"
-version = "2024.3.0.0"
+version = "2024.4.0.0"
description = "Python bindings for https://github.com/openvinotoolkit/openvino.genai"
requires-python = ">=3.8"
readme = {file = "src/README.md", content-type="text/markdown"}
From f42e63d706c4a51a9f470d19b5677f1b3d498c35 Mon Sep 17 00:00:00 2001
From: Zlobin Vladimir
Date: Thu, 25 Jul 2024 17:31:07 +0400
Subject: [PATCH 16/19] Fix merge conflicts resolution (#685)
---
CMakeLists.txt | 18 +-----------------
thirdparty/openvino_tokenizers | 2 +-
2 files changed, 2 insertions(+), 18 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index f45ab24279..e080b4a97a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -57,33 +57,17 @@ if(ENABLE_PYTHON)
endif()
endif()
-if(ENABLE_PYTHON)
- # the following two calls are required for cross-compilation
- if(OpenVINODeveloperPackage_DIR)
- ov_find_python3(REQUIRED)
- ov_detect_python_module_extension()
- else()
- if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
- find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module)
- else()
- find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
- endif()
- endif()
-endif()
-
add_subdirectory(thirdparty)
add_subdirectory(src)
add_subdirectory(samples)
add_subdirectory(tests/cpp)
-install(FILES LICENSE DESTINATION docs/licensing COMPONENT licensing_genai RENAME LICENSE-GENAI)
-install(FILES third-party-programs.txt DESTINATION docs/licensing COMPONENT licensing_genai RENAME third-party-programs-genai.txt)
install(FILES LICENSE DESTINATION docs/licensing COMPONENT licensing_genai RENAME LICENSE-GENAI)
install(FILES third-party-programs.txt DESTINATION docs/licensing COMPONENT licensing_genai RENAME third-party-programs-genai.txt)
set(CPACK_ARCHIVE_COMPONENT_INSTALL ON)
set(CPACK_INCLUDE_TOPLEVEL_DIRECTORY OFF)
# Workaround https://gitlab.kitware.com/cmake/cmake/-/issues/2614
-set(CPACK_COMPONENTS_ALL core_genai core_genai_dev cpp_samples_genai licensing_genai openvino_tokenizers openvino_tokenizers_licenses)
+set(CPACK_COMPONENTS_ALL core_genai core_genai_dev cpp_samples_genai licensing_genai openvino_tokenizers openvino_tokenizers_docs)
if(ENABLE_PYTHON)
list(APPEND CPACK_COMPONENTS_ALL pygenai_${Python3_VERSION_MAJOR}_${Python3_VERSION_MINOR})
endif()
diff --git a/thirdparty/openvino_tokenizers b/thirdparty/openvino_tokenizers
index 04795c1b78..fb0157c30a 160000
--- a/thirdparty/openvino_tokenizers
+++ b/thirdparty/openvino_tokenizers
@@ -1 +1 @@
-Subproject commit 04795c1b78c61e3294d1744c78a8ebb5e129256c
+Subproject commit fb0157c30a8a7f6538471fe622b8b52a3800278a
From 14f9c2b1b935d805e7bcb270791880a6cfdbc657 Mon Sep 17 00:00:00 2001
From: Nikita Malinin
Date: Thu, 25 Jul 2024 17:25:24 +0200
Subject: [PATCH 17/19] Partial revert of #616 (#687)
Reverts broken `data-aware` changes from #616
---
llm_bench/python/utils/nncf_utils.py | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/llm_bench/python/utils/nncf_utils.py b/llm_bench/python/utils/nncf_utils.py
index 25ef8aff18..01d0dd95b3 100644
--- a/llm_bench/python/utils/nncf_utils.py
+++ b/llm_bench/python/utils/nncf_utils.py
@@ -38,7 +38,7 @@ def get_compressed_path(output_dir: str, base_precision, option: str):
INT4_MODEL_CONFIGURATION = {
- "dolly-v2-3b": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 1.0, "scale": True},
+ "dolly-v2-3b": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 0.8},
"gpt-j-6b": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 64},
"opt-6.7b": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 64, "ratio": 0.8},
"red-pajama-incite-7b-instruct": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128},
@@ -69,13 +69,11 @@ def get_compressed_path(output_dir: str, base_precision, option: str):
"mistral-7b-v0.1": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128, "ratio": 0.9},
"llama-7b": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128, "ratio": 0.7},
"opt-2.7b": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128, "ratio": 0.7},
- "red-pajama-incite-chat-3b-v1": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 1.0, "scale": True},
+ "red-pajama-incite-chat-3b-v1": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 0.8},
"vicuna-7b-v1.5": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 1.0},
"stablelm-tuned-alpha-3b": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 0.8},
- "gpt-2": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 0.5, "scale": True},
"longchat-b7": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 0.9},
"starcoder2-3b": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 0.9},
"tiny-llama-1.1b-chat": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 0.8},
- "stablelm-7b": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 0.6, "scale": True},
"phi-2": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": 0.9},
}
From f2010de9fbcf69ff44b465535c3ff9efeb749f7e Mon Sep 17 00:00:00 2001
From: Sylwia Kuros
Date: Fri, 26 Jul 2024 08:47:09 +0200
Subject: [PATCH 18/19] Update requirements.txt
---
llm_bench/python/requirements.txt | 1 -
1 file changed, 1 deletion(-)
diff --git a/llm_bench/python/requirements.txt b/llm_bench/python/requirements.txt
index d83cd5a376..ed80a66deb 100644
--- a/llm_bench/python/requirements.txt
+++ b/llm_bench/python/requirements.txt
@@ -7,7 +7,6 @@ openvino_genai
auto-gptq>=0.5.1 # for gptq
pillow
torch
-torchvision<0.19.0
transformers>=4.40.0
diffusers>=0.22.0
#optimum is in dependency list of optimum-intel
From 4bd1a26a08cca1895475add911bc53d8eff34a6c Mon Sep 17 00:00:00 2001
From: Anastasiia Pnevskaia
Date: Fri, 26 Jul 2024 08:51:58 +0200
Subject: [PATCH 19/19] Prefix caching. (#639)
Implementation of prefix caching.
Ticket: CVS-138669
---
.../openvino/genai/scheduler_config.hpp | 8 +
src/cpp/src/block_manager.hpp | 258 +++++++++++++++++-
src/cpp/src/scheduler.hpp | 28 +-
src/cpp/src/sequence_group.hpp | 21 ++
src/python/py_generate_pipeline.cpp | 5 +-
tests/cpp/CMakeLists.txt | 5 +-
tests/cpp/block_manager.cpp | 31 ++-
tests/cpp/evictor.cpp | 54 ++++
tests/cpp/scheduler.cpp | 68 +++++
9 files changed, 443 insertions(+), 35 deletions(-)
create mode 100644 tests/cpp/evictor.cpp
diff --git a/src/cpp/include/openvino/genai/scheduler_config.hpp b/src/cpp/include/openvino/genai/scheduler_config.hpp
index 787060d07e..d9bf7a7b41 100644
--- a/src/cpp/include/openvino/genai/scheduler_config.hpp
+++ b/src/cpp/include/openvino/genai/scheduler_config.hpp
@@ -30,5 +30,13 @@ struct SchedulerConfig {
// max number of scheduled sequences (you can think of it as "max batch size")
std::size_t max_num_seqs = 256;
+
+ // Enable caching of KV-blocks.
+ // When turned on all previously calculated KV-caches are kept in memory for future usages.
+ // KV-caches can be rewritten if KV-cache limit is reached, but blocks are not released.
+ // This results in more RAM usage, maximum RAM usage is determined by cache_size or num_kv_blocks parameters.
+ // When turend off only KV-cache required for batch calculation is kept in memory and
+ // when a sequence has finished genegartion its cache is released.
+ bool enable_prefix_caching = false;
};
}
diff --git a/src/cpp/src/block_manager.hpp b/src/cpp/src/block_manager.hpp
index ab60b7f5ff..3b1a663235 100644
--- a/src/cpp/src/block_manager.hpp
+++ b/src/cpp/src/block_manager.hpp
@@ -6,6 +6,7 @@
#include
#include
#include |