Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a choice of how to end streaming from callback: STOP or CANCEL #1476

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <unordered_map>

#include "openvino/genai/generation_config.hpp"
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/visibility.hpp"

namespace ov::genai {
Expand All @@ -30,6 +31,9 @@ struct EncodedGenerationResult {

// Status of generation
GenerationStatus m_status = GenerationStatus::RUNNING;

// Status of streaming
CallbacWorkStatus m_streaming_status = ov::genai::CallbacWorkStatus::UNDEF;
};

enum class GenerationFinishReason {
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ov {
namespace genai {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<CallbacWorkStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
using StringInputs = std::variant<std::string, std::vector<std::string>>;
Expand Down
16 changes: 16 additions & 0 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,28 @@
#pragma once

#include "openvino/genai/tokenizer.hpp"
#include <variant>

namespace ov {
namespace genai {

enum class CallbacWorkStatus {
UNDEF = 0, // Streaming is not run
RUNNING = 1, // Continue to run of inference
STOP = 2, // Stop generate, keep hitory as is, KV state include last prompt and generated tokens at the end
CANCEL = 3 // Stop generate, drop last prompt and all generated tokens from history, KV state include history exept last step
};

using CallbackTypeVariant = std::variant<bool, CallbacWorkStatus>;

/**
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
*
* @param m_tokenizer tokenizer
*/
class OPENVINO_GENAI_EXPORTS StreamerBase {
protected:
CallbacWorkStatus streaming_finish_status = CallbacWorkStatus::UNDEF;
public:
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
Expand All @@ -22,6 +34,10 @@ class OPENVINO_GENAI_EXPORTS StreamerBase {
/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;

virtual CallbacWorkStatus get_finish_streaming_reason() {
return streaming_finish_status;
}

virtual ~StreamerBase();
};

Expand Down
6 changes: 6 additions & 0 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
},
[this](const std::function<CallbacWorkStatus(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);

Expand Down Expand Up @@ -346,6 +349,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
result.m_generation_ids.resize(num_outputs);
result.m_scores.resize(num_outputs);

if (streamer_ptr)
result.m_streaming_status = streamer_ptr->get_finish_streaming_reason();

for (size_t i = 0; i < num_outputs; ++i) {
const auto & sequence = sequences[i];
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob();
Expand Down
6 changes: 5 additions & 1 deletion src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
generated.reserve(res.m_generation_ids.size());
for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) {
generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
if (m_is_chat_conversation && 0 == idx) {
if (m_is_chat_conversation && 0 == idx && res.m_streaming_status != ov::genai::CallbacWorkStatus::CANCEL) {
m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
}
}
Expand All @@ -77,6 +77,10 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
});
}

// if streaming was canceled, prompt/answer of current step shouldn't be presented in history, so let's remove prompt from history
if (m_is_chat_conversation && !encoded.empty() && encoded[0].m_streaming_status == ov::genai::CallbacWorkStatus::CANCEL)
m_history.pop_back();

return decoded;
}
}
77 changes: 52 additions & 25 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "text_callback_streamer.hpp"
#include "utils.hpp"

#include "debug_utils.hpp"

namespace ov::genai {

StatefulLLMPipeline::StatefulLLMPipeline(
Expand Down Expand Up @@ -38,8 +40,8 @@ StatefulLLMPipeline::StatefulLLMPipeline(
const ov::AnyMap& properties,
const ov::genai::GenerationConfig& generation_config)
: LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
utils::apply_slice_before_matmul_transformation(model);
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);
utils::slice_matmul_stateful_model(model);
m_kv_history_manager.kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);

ov::CompiledModel compiled_model;
if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) {
Expand Down Expand Up @@ -86,6 +88,9 @@ DecodedResults StatefulLLMPipeline::generate(

TokenizedInputs encoded_input;

std::string prev_templated_chat_history(m_templated_chat_history);
std::vector<int64_t> prev_tokenized_chat_history(m_tokenized_chat_history);

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts");
encoded_input = m_tokenizer.encode(*input_vector);
Expand All @@ -104,7 +109,7 @@ DecodedResults StatefulLLMPipeline::generate(

m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// Do not add special tokens in chat scenario to be aligned with HF.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
Expand All @@ -116,21 +121,24 @@ DecodedResults StatefulLLMPipeline::generate(
if (!m_tokenized_chat_history.empty()) {
std::set<int64_t> stop_tokens = config.stop_token_ids;
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = trusted_history_length == SIZE_MAX;
}

if (m_tokenized_chat_history.empty()) {
encoded_input = new_chat_tokens;
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) {
// does_kv_cache_need_to_update will be true here if beam search is activated
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_history_cache_need_to_update()) {
// does_history_cache_need_to_update will be true here if beam search is activated
// in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly
// if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager
if (m_kv_history_manager.does_kv_cache_need_to_update()) {
if (m_kv_history_manager.does_history_cache_need_to_update()) {
trusted_history_length = m_kv_history_manager.trusted_history_length;
} else {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length;
size_t num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length;
// if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it
m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;
num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;

// if streaming was used and canceled on prev step, num_tokens_to_remove_from_kv_cache could be already set and it will be bigger as include answer + prompt
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = num_tokens_to_remove_from_kv_cache > m_kv_history_manager.num_tokens_to_remove_from_kv_cache ?
num_tokens_to_remove_from_kv_cache : m_kv_history_manager.num_tokens_to_remove_from_kv_cache;
}

ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
Expand Down Expand Up @@ -169,11 +177,19 @@ DecodedResults StatefulLLMPipeline::generate(
auto decode_stop_time = std::chrono::steady_clock::now();

if (is_chat_conversation) {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
auto answer = decoded_results.texts[0];
m_templated_chat_history.append(answer);
m_history.push_back({{"role", "assistant"}, {"content", answer}});
if (m_chat_generation_finish_status == ov::genai::CallbacWorkStatus::CANCEL) {
// If chat generation process was canceled by user, let's rallback to previous state of history
m_history.pop_back();
m_kv_history_manager.num_tokens_to_remove_from_kv_cache += m_tokenized_chat_history.size() - prev_tokenized_chat_history.size();
m_templated_chat_history = prev_templated_chat_history;
m_tokenized_chat_history = prev_tokenized_chat_history;
} else {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
auto answer = decoded_results.texts[0];
m_templated_chat_history.append(answer);
m_history.push_back({{"role", "assistant"}, {"content", answer}});
}
}

// generate_durations
Expand Down Expand Up @@ -218,6 +234,8 @@ EncodedResults StatefulLLMPipeline::generate(
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));

size_t real_input_ids_size = input_ids.get_shape().at(1);

// Tail of previous output in chat mode is missing in KV cache.
if (m_last_disappeared_token.has_value()) {
attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1);
Expand All @@ -241,6 +259,8 @@ EncodedResults StatefulLLMPipeline::generate(
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
} else if (auto callback = std::get_if<std::function<ov::genai::CallbacWorkStatus(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}

auto batch_size = input_ids.get_shape().at(0);
Expand All @@ -254,7 +274,8 @@ EncodedResults StatefulLLMPipeline::generate(
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");

ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller);
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache,
m_kv_history_manager.kv_cache_seq_length_axis, m_adapter_controller);

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
Expand Down Expand Up @@ -292,8 +313,7 @@ EncodedResults StatefulLLMPipeline::generate(
m_adapter_controller->apply(m_model_runner, config.adapters);
}

if (is_chat_conversation && !m_trust_encoded_history) {
m_trust_encoded_history = true;
if (is_chat_conversation) {
m_kv_history_manager.reset();
}

Expand Down Expand Up @@ -321,26 +341,35 @@ EncodedResults StatefulLLMPipeline::generate(
m_sampler.set_seed(config.rng_seed);
}

ov::genai::EncodedResults result;
std::tie(result, m_last_disappeared_token) = get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
ov::genai::utils::GenerationFinishInfo finish_info = get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
streamer_ptr, m_sampler, requests, position_ids, std::nullopt);

ov::genai::EncodedResults result = finish_info.results;
m_last_disappeared_token = finish_info.probably_disappeared_token;
m_chat_generation_finish_status = finish_info.streaming_finish_status;

if (is_chat_conversation) {
// force remove from kv_cache last answer
if (config.is_beam_search() && m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
m_kv_history_manager.trusted_history_length = m_tokenized_chat_history.size();
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
}

std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
if (m_chat_generation_finish_status == ov::genai::CallbacWorkStatus::CANCEL) {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;

if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
m_tokenized_chat_history.resize(m_tokenized_chat_history.size() - real_input_ids_size);
m_kv_history_manager.num_tokens_to_remove_from_kv_cache += real_input_ids_size;
}
} else {
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
}
} else {
reset_kv_state();
m_last_disappeared_token = std::nullopt;
}

if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));

auto stop_time = std::chrono::steady_clock::now();

// If is called without tokenization then that stat will not be reported.
Expand All @@ -354,7 +383,6 @@ EncodedResults StatefulLLMPipeline::generate(

void StatefulLLMPipeline::start_chat(const std::string& system_message) {
is_chat_conversation = true;
m_trust_encoded_history = true;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
Expand Down Expand Up @@ -387,7 +415,6 @@ void StatefulLLMPipeline::reset_kv_state() {

void StatefulLLMPipeline::finish_chat() {
is_chat_conversation = false;
m_trust_encoded_history = true;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
Expand Down
5 changes: 3 additions & 2 deletions src/cpp/src/llm_pipeline_stateful.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0};
size_t m_kv_cache_seq_length_axis = 2;
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0, 2};
// Finish reason of last generation for chat scenario
ov::genai::CallbacWorkStatus m_chat_generation_finish_status = ov::genai::CallbacWorkStatus::UNDEF;

void reset_kv_state();
public:
Expand Down
19 changes: 10 additions & 9 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "debug_utils.hpp"
#include "lm_encoding.hpp"
#include "openvino/genai/perf_metrics.hpp"
#include "openvino/genai/streamer_base.hpp"


namespace ov {
Expand Down Expand Up @@ -50,7 +51,7 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector<i
}


std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(
ov::InferRequest& m_llm,
const ov::Tensor& input_ids,
const ov::Tensor& attention_mask,
Expand Down Expand Up @@ -92,8 +93,8 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(

// Initialize results and performance metrics.

EncodedResults results;
auto& raw_perf_counters = results.perf_metrics.raw_metrics;
ov::genai::utils::GenerationFinishInfo finish_info;
auto& raw_perf_counters = finish_info.results.perf_metrics.raw_metrics;
raw_perf_counters.m_inference_durations = {{ MicroSeconds(0.0f) }};

// Initialize inputs
Expand Down Expand Up @@ -211,6 +212,7 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(

if (streamer_ptr) { // push streamer's cache
streamer_ptr->end();
finish_info.streaming_finish_status = streamer_ptr->get_finish_streaming_reason();
}

for (auto& sequence_group : sequence_groups) {
Expand All @@ -222,20 +224,19 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
const auto & sequence = sequences[seq_id];
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob();

results.tokens.push_back(sequence->get_generated_ids());
results.scores.push_back(score);
finish_info.results.tokens.push_back(sequence->get_generated_ids());
finish_info.results.scores.push_back(score);
}
}

for (SequenceGroup::Ptr sequence_group : sequence_groups)
sampler.clear_request_info(sequence_group->get_request_id());

// it is not saved in KV cache, we need to add it for some cases
std::optional<int64_t> last_token_of_best_sequence = std::nullopt;
// last generated token is not saved in KV cache, we need to add it for some cases
if (sequence_groups[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || sequence_groups[0]->handle_dropped())
last_token_of_best_sequence = results.tokens[0].back();
finish_info.probably_disappeared_token = finish_info.results.tokens[0].back();

return {results, last_token_of_best_sequence};
return finish_info;
}

} // namespace genai
Expand Down
Loading
Loading