From 67bcef14e0833decabc215d2f744dba69ac27412 Mon Sep 17 00:00:00 2001 From: Oleg Pipikin Date: Sat, 12 Oct 2024 07:20:56 +0200 Subject: [PATCH] Slice the last matmull in stateful llm pipeline (#814) Ticket: CVS-154175 Co-authored-by: Andrei Kochin --- src/cpp/src/greedy_decoding.cpp | 1 - src/cpp/src/llm_pipeline.cpp | 5 ++++- src/cpp/src/utils.cpp | 34 +++++++++++++++++++++++++++++++++ src/cpp/src/utils.hpp | 2 ++ 4 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/cpp/src/greedy_decoding.cpp b/src/cpp/src/greedy_decoding.cpp index 95a1843645..2f1ed3f89d 100644 --- a/src/cpp/src/greedy_decoding.cpp +++ b/src/cpp/src/greedy_decoding.cpp @@ -73,7 +73,6 @@ EncodedResults greedy_decoding( bool all_are_eos = std::all_of(eos_met.begin(), eos_met.end(), [](int elem) { return elem == 1; }); if (!generation_config.ignore_eos && all_are_eos) return results; - for (size_t i = 0; i < max_new_tokens - 1; ++i) { if (position_ids.has_value()) diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index ff7ceb051e..417a66edc0 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -82,12 +82,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { core.set_property(core_plugin_config); auto model = core.read_model(model_path / "openvino_model.xml"); m_adapter_controller = AdapterController(model, m_generation_config.adapters, "base_model.model.model.", device); // TODO: Make the prefix name configurable + utils::slice_matmul_statefull_model(model); m_model_runner = core.compile_model(model, device, compile_plugin_config).create_infer_request(); m_adapter_controller->apply(m_model_runner, m_generation_config.adapters); } else { auto [core_plugin_config, compile_plugin_config] = ov::genai::utils::split_core_complile_config(plugin_config); core.set_property(core_plugin_config); - m_model_runner = core.compile_model(model_path / "openvino_model.xml", device, compile_plugin_config).create_infer_request(); + auto model = core.read_model(model_path / "openvino_model.xml"); + utils::slice_matmul_statefull_model(model); + m_model_runner = core.compile_model(model, device, compile_plugin_config).create_infer_request(); } // If eos_token_id was not provided, take value diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 229c418e54..e7f58a015e 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -5,6 +5,14 @@ #include +#include "openvino/op/add.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/tanh.hpp" +#include "openvino/op/transpose.hpp" + namespace ov { namespace genai { namespace utils { @@ -225,6 +233,32 @@ ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::Token return {new_input_ids, new_attention_mask}; } + +void slice_matmul_statefull_model(std::shared_ptr model) { + ov::Node* matmul = nullptr; + auto last_node = model->output(0).get_node()->input_value(0).get_node(); + if (matmul = dynamic_cast(last_node)) { + } else if(auto add = dynamic_cast(last_node)) { + matmul = dynamic_cast(add->input_value(0).get_node()); + } else if (auto transpose = dynamic_cast(last_node)) { + matmul = dynamic_cast(transpose->input_value(0).get_node()); + } else if (auto multiply = dynamic_cast(last_node)) { + if (auto tanh = dynamic_cast(multiply->input_value(0).get_node())) { + if (auto divide = dynamic_cast(tanh->input_value(0).get_node())) { + matmul = dynamic_cast(divide->input_value(0).get_node()); + } + } + } + + if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) { + auto start = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto stop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-2}); + auto step = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + auto slice = std::make_shared(matmul->input_value(0), start, stop, step, axis); + matmul->input(0).replace_source_output(slice); + } +} } // namespace utils } // namespace genai } // namespace ov diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 2b7ff18e2d..b5228eede0 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -87,6 +87,8 @@ ProcessorConfig from_any_map( std::pair split_core_complile_config(const ov::AnyMap& plugin_config); ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& minuend, const ov::genai::TokenizedInputs& subtrahend); + +void slice_matmul_statefull_model(std::shared_ptr model); } // namespace utils } // namespace genai } // namespace ov