From ca37cecce8ba5a3423c771447e8844d58890ba1c Mon Sep 17 00:00:00 2001 From: Dariusz Trawinski Date: Fri, 26 Jul 2024 08:44:18 +0200 Subject: [PATCH 1/3] change kv cache type if forcing precision type --- src/cpp/src/continuous_batching_pipeline.cpp | 2 +- src/cpp/src/device_config.hpp | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/cpp/src/continuous_batching_pipeline.cpp b/src/cpp/src/continuous_batching_pipeline.cpp index 55100f3cb4..1639aeb3eb 100644 --- a/src/cpp/src/continuous_batching_pipeline.cpp +++ b/src/cpp/src/continuous_batching_pipeline.cpp @@ -77,7 +77,7 @@ class ContinuousBatchingPipeline::Impl { // The model can be compiled for GPU as well std::shared_ptr model = core.read_model(models_path + "/openvino_model.xml"); - DeviceConfig device_config(core, scheduler_config, device); + DeviceConfig device_config(core, scheduler_config, device, plugin_config); apply_paged_attention_transformations(model, device_config); diff --git a/src/cpp/src/device_config.hpp b/src/cpp/src/device_config.hpp index f2ed5d424b..9be7bf170d 100644 --- a/src/cpp/src/device_config.hpp +++ b/src/cpp/src/device_config.hpp @@ -20,7 +20,7 @@ class DeviceConfig { std::string m_device; public: - DeviceConfig(ov::Core& core, const SchedulerConfig& scheduling_config, const std::string& device) { + DeviceConfig(ov::Core& core, const SchedulerConfig& scheduling_config, const std::string& device, const ov::AnyMap& plugin_config = {}) { m_device = device; // keep information about blocsk @@ -29,6 +29,20 @@ class DeviceConfig { if (m_device == "CPU") { auto inference_precision = core.get_property(device, ov::hint::inference_precision); m_kv_cache_type = inference_precision == ov::element::bf16 ? ov::element::bf16 : ov::element::f16; + // if user sets precision hint, kv cache type should be changed + if (plugin_config.find("INFERENCE_PRECISION_HINT") != plugin_config.end()) { + const auto& type_name = plugin_config.at("INFERENCE_PRECISION_HINT").as(); + if (type_name == "f32") { + m_kv_cache_type = ov::element::f32; + } else if (type_name == "f16") { + m_kv_cache_type = ov::element::f16; + } else if (type_name == "bf16") { + m_kv_cache_type = ov::element::bf16; + } else { + // use default f32 + m_kv_cache_type = ov::element::f32; + } + } } else if (m_device == "GPU") { OPENVINO_ASSERT("GPU is not currently supported. Please, remove this assert and fill configuration"); } else { From 5dc319f7719c9b1bbe678f9295e315991e77108e Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 5 Aug 2024 09:20:30 +0800 Subject: [PATCH 2/3] apply review comments --- src/cpp/src/device_config.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cpp/src/device_config.hpp b/src/cpp/src/device_config.hpp index 9be7bf170d..c704a326dc 100644 --- a/src/cpp/src/device_config.hpp +++ b/src/cpp/src/device_config.hpp @@ -30,8 +30,8 @@ class DeviceConfig { auto inference_precision = core.get_property(device, ov::hint::inference_precision); m_kv_cache_type = inference_precision == ov::element::bf16 ? ov::element::bf16 : ov::element::f16; // if user sets precision hint, kv cache type should be changed - if (plugin_config.find("INFERENCE_PRECISION_HINT") != plugin_config.end()) { - const auto& type_name = plugin_config.at("INFERENCE_PRECISION_HINT").as(); + if (plugin_config.find(ov::hint::inference_precision.name()) != plugin_config.end()) { + const auto& type_name = plugin_config.at(ov::hint::inference_precision.name()).as(); if (type_name == "f32") { m_kv_cache_type = ov::element::f32; } else if (type_name == "f16") { From a8f173c4e18a45b042e3a567cdfa977254ed8c3d Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 5 Aug 2024 18:00:56 +0800 Subject: [PATCH 3/3] apply review comment --- src/cpp/src/device_config.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cpp/src/device_config.hpp b/src/cpp/src/device_config.hpp index c704a326dc..cb653ac514 100644 --- a/src/cpp/src/device_config.hpp +++ b/src/cpp/src/device_config.hpp @@ -31,12 +31,12 @@ class DeviceConfig { m_kv_cache_type = inference_precision == ov::element::bf16 ? ov::element::bf16 : ov::element::f16; // if user sets precision hint, kv cache type should be changed if (plugin_config.find(ov::hint::inference_precision.name()) != plugin_config.end()) { - const auto& type_name = plugin_config.at(ov::hint::inference_precision.name()).as(); - if (type_name == "f32") { + const auto precision = plugin_config.at(ov::hint::inference_precision.name()).as(); + if (precision == ov::element::f32) { m_kv_cache_type = ov::element::f32; - } else if (type_name == "f16") { + } else if (precision == ov::element::f16) { m_kv_cache_type = ov::element::f16; - } else if (type_name == "bf16") { + } else if (precision == ov::element::bf16) { m_kv_cache_type = ov::element::bf16; } else { // use default f32