From 5b3c1857c631c11de545e50f1ed2f5141f952d27 Mon Sep 17 00:00:00 2001 From: mzegla Date: Thu, 1 Aug 2024 16:51:12 +0200 Subject: [PATCH] beam search --- src/cpp/src/sampler.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 6390fc8725..2ce87531ff 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -193,6 +193,8 @@ class GroupBeamSearcher { // mark current sequence as finished beam.m_sequence->set_status(SequenceStatus::FINISHED); + // Setting length since this function is used when sequence generated tokens number reaches max_new_tokens + beam.m_sequence->set_finish_reason(GenerationFinishReason::LENGTH); // we also need to drop add ongoing / forked sequences from scheduler sampler_output.m_dropped_sequences.push_back(sequence_id); } @@ -432,6 +434,8 @@ void GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutp Sequence::Ptr forked_sequence = m_sequence_group->fork_sequence(candidate.m_sequence); // and finish immidiately forked_sequence->set_status(SequenceStatus::FINISHED); + // Setting length since this function is used when sequence generated eos token + forked_sequence->set_finish_reason(GenerationFinishReason::STOP); // TODO: make it more simplier // currently, we finish sequence and then fork it in current code