diff --git a/src/cpp/src/generation_handle.cpp b/src/cpp/src/generation_handle.cpp index 26cc12604f..f8e88bfecb 100644 --- a/src/cpp/src/generation_handle.cpp +++ b/src/cpp/src/generation_handle.cpp @@ -36,6 +36,7 @@ void add_partial_result(std::unordered_map& partial_ } else { partial_result_iter->second.generated_token_ids.push_back(iteration_result.second.generated_token_ids[0]); partial_result_iter->second.score = iteration_result.second.score; + partial_result_iter->second.finish_reason = iteration_result.second.finish_reason; } } } diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index db227a3436..a8fb528554 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -111,6 +111,7 @@ class Sequence { OPENVINO_ASSERT(m_generated_ids.size()); output.score = get_cumulative_log_probs(); output.generated_token_ids = std::vector {m_generated_ids.back()}; + output.finish_reason = get_finish_reason(); return output; } @@ -215,10 +216,11 @@ class SequenceGroup { // stop sequence by max_new_tokens or EOS token running_sequence->set_status(SequenceStatus::FINISHED); - if (running_sequence->get_generated_ids().back() == m_sampling_params.eos_token_id && !m_sampling_params.ignore_eos) + if (running_sequence->get_generated_ids().back() == m_sampling_params.eos_token_id && !m_sampling_params.ignore_eos) { running_sequence->set_finish_reason(GenerationFinishReason::STOP); - else if (m_sampling_params.max_new_tokens == generated_len) + } else if (m_sampling_params.max_new_tokens == generated_len) { running_sequence->set_finish_reason(GenerationFinishReason::LENGTH); + } dropped_seq_ids.push_back(running_sequence->get_id()); }