From 6ff305529b38c0196eeb739c073c00d3e78e891a Mon Sep 17 00:00:00 2001 From: sbalandi Date: Wed, 16 Oct 2024 12:41:43 +0100 Subject: [PATCH] apply comments --- .../visual_language_chat.cpp | 19 ++++++++++--------- src/cpp/src/sampler.cpp | 5 +++-- src/cpp/src/sampler.hpp | 2 +- src/cpp/src/visual_language/pipeline.cpp | 8 ++------ 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/samples/cpp/visual_language_chat/visual_language_chat.cpp b/samples/cpp/visual_language_chat/visual_language_chat.cpp index 3d6e412ee5..c5e24247c2 100644 --- a/samples/cpp/visual_language_chat/visual_language_chat.cpp +++ b/samples/cpp/visual_language_chat/visual_language_chat.cpp @@ -6,8 +6,6 @@ #include #include -namespace fs = std::filesystem; - bool print_subword(std::string&& subword) { return !(std::cout << subword << std::flush); } @@ -19,6 +17,9 @@ int main(int argc, char* argv[]) try { std::vector images = utils::load_images(argv[2]); + ov::genai::GenerationConfig generation_config; + generation_config.max_new_tokens = 30; + std::string device = "CPU"; // GPU can be used as well ov::AnyMap enable_compile_cache; if ("GPU" == device) { @@ -33,16 +34,16 @@ int main(int argc, char* argv[]) try { std::cout << "question:\n"; std::getline(std::cin, prompt); - auto resuls = pipe.generate(prompt, - ov::genai::images(images), - ov::genai::generation_config(ov::genai::greedy()), - ov::genai::streamer(print_subword)); + pipe.generate(prompt, + ov::genai::images(images), + ov::genai::generation_config(generation_config), + ov::genai::streamer(print_subword)); std::cout << "\n----------\n" "question:\n"; while (std::getline(std::cin, prompt)) { - resuls = pipe.generate(prompt, - ov::genai::generation_config(ov::genai::greedy()), - ov::genai::streamer(print_subword)); + pipe.generate(prompt, + ov::genai::generation_config(generation_config), + ov::genai::streamer(print_subword)); std::cout << "\n----------\n" "question:\n"; } diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 2e631f6201..4885a21b8f 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -597,10 +597,11 @@ void register_new_token(const Token& sampled_token_id, } }; -std::vector Sampler::get_beam_idxs(uint64_t request_id) { +std::vector Sampler::get_beam_idxs(SequenceGroup::CPtr sequence_group) { + size_t request_id = sequence_group->get_request_id(); auto beam_searcher = m_beam_search_info.find(request_id); if (m_beam_search_info.find(request_id) == m_beam_search_info.end()) { - return { 0 }; + return std::vector(sequence_group->num_running_seqs(), 0); } return beam_searcher->second.get_beam_idxs(); } diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index ca73cbb92d..13933e0b75 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -65,7 +65,7 @@ class Sampler { SamplerOutput sample(std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false); void set_seed(size_t seed) { rng_engine.seed(seed); } void clear_beam_search_info(uint64_t request_id); - std::vector get_beam_idxs(uint64_t request_id); + std::vector get_beam_idxs(SequenceGroup::CPtr sequence_group); }; class Sampler::GroupBeamSearcher { diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index 42bc8eb465..773536270d 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -326,13 +326,9 @@ EncodedGenerationResult get_lm_encoded_results( language.set_tensor("position_ids", position_ids); - std::vector beam_idxs = sampler.get_beam_idxs(request->get_request_id()); + std::vector beam_idxs = sampler.get_beam_idxs(request); int32_t *beam_idx_data = beam_idx.data(); - if (total_num_tokens > beam_idxs.size()) { - std::fill_n(beam_idx_data, total_num_tokens, 0); - } else { - copy(beam_idxs.begin(), beam_idxs.end(), beam_idx_data); - } + std::fill_n(beam_idx_data, total_num_tokens, 0); language.set_tensor("beam_idx", beam_idx); language.infer();