From 90b1ab1f5795bca9df761bcf5fda5949b66c9a5f Mon Sep 17 00:00:00 2001 From: Sofya Balandina Date: Thu, 17 Oct 2024 14:44:11 +0100 Subject: [PATCH] Add sampling to vlm pipeline by Sampler (#950) CVS-152890 --- .github/workflows/causal_lm_cpp.yml | 10 +- .../cpp/visual_language_chat/load_image.cpp | 22 ++ .../cpp/visual_language_chat/load_image.hpp | 1 + .../visual_language_chat.cpp | 24 +- src/cpp/src/sampler.cpp | 25 ++ src/cpp/src/sampler.hpp | 2 + src/cpp/src/visual_language/pipeline.cpp | 260 ++++++++++-------- src/cpp/src/vlm_sampling.hpp | 96 ------- 8 files changed, 221 insertions(+), 219 deletions(-) delete mode 100644 src/cpp/src/vlm_sampling.hpp diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index 2537a8eb36..4b7913bafd 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -710,7 +710,8 @@ jobs: python -m pip install -U "optimum<1.23" --no-dependencies source ./ov/setupvars.sh optimum-cli export openvino -m openbmb/MiniCPM-V-2_6 MiniCPM-V-2_6 --trust-remote-code - wget https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11 --output-document cat.jpg + mkdir cat_img + wget https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11 --output-document cat_img/cat.jpg - name: Generate reference shell: python run: | @@ -741,6 +742,11 @@ jobs: && timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./MiniCPM-V-2_6/ lines.png <<< $'What is unusual on this image?' | tee cpp.txt - run: diff cpp.txt ref.txt + - name: Run visual_language_chat C++ sample with dir - MiniCPM-V-2_6 + run: > + source ./ov/setupvars.sh + && timeout 120s ./build/samples/cpp/visual_language_chat/visual_language_chat ./MiniCPM-V-2_6/ cat_img + <<< $'What is unusual on this image?' - name: Download and convert LLaVa 1.5 model and an image run: | source ./ov/setupvars.sh @@ -768,7 +774,7 @@ jobs: source ./ov/setupvars.sh export PYTHONPATH=./build/:$PYTHONPATH printf 'What is on the image?\nWhat is special on the image?\n' > ./input.txt - timeout 120s python ./samples/python/visual_language_chat/visual_language_chat.py ./MiniCPM-V-2_6/ cat.jpg < input.txt > ./pred.txt + timeout 120s python ./samples/python/visual_language_chat/visual_language_chat.py ./MiniCPM-V-2_6/ cat_img/cat.jpg < input.txt > ./pred.txt cpp-continuous-batching-ubuntu: runs-on: ubuntu-20.04-8-cores diff --git a/samples/cpp/visual_language_chat/load_image.cpp b/samples/cpp/visual_language_chat/load_image.cpp index 855f7567bf..7956f8c128 100644 --- a/samples/cpp/visual_language_chat/load_image.cpp +++ b/samples/cpp/visual_language_chat/load_image.cpp @@ -6,6 +6,28 @@ #include "stb_image.h" #include "load_image.hpp" +namespace fs = std::filesystem; + +std::vector utils::load_images(const std::filesystem::path& input_path) { + std::vector images; + if (!input_path.empty() && fs::exists(input_path)) { + if (fs::is_directory(input_path)) { + for (const auto& dir_entry : fs::directory_iterator(input_path)) { + ov::Tensor image = utils::load_image(dir_entry.path()); + images.push_back(std::move(image)); + } + } else if (fs::is_regular_file(input_path)) { + ov::Tensor image = utils::load_image(input_path); + images.push_back(std::move(image)); + } + } + + if (images.empty()) + throw std::runtime_error(std::string{"No images were found in path "} + input_path.string()); + + return images; +} + ov::Tensor utils::load_image(const std::filesystem::path& image_path) { int x = 0, y = 0, channels_in_file = 0; constexpr int desired_channels = 3; diff --git a/samples/cpp/visual_language_chat/load_image.hpp b/samples/cpp/visual_language_chat/load_image.hpp index f66dd2caf2..d0dcc271cd 100644 --- a/samples/cpp/visual_language_chat/load_image.hpp +++ b/samples/cpp/visual_language_chat/load_image.hpp @@ -9,4 +9,5 @@ namespace utils { ov::Tensor load_image(const std::filesystem::path& image_path); +std::vector load_images(const std::filesystem::path& image_path); } diff --git a/samples/cpp/visual_language_chat/visual_language_chat.cpp b/samples/cpp/visual_language_chat/visual_language_chat.cpp index 95342402cb..d38b98a9a2 100644 --- a/samples/cpp/visual_language_chat/visual_language_chat.cpp +++ b/samples/cpp/visual_language_chat/visual_language_chat.cpp @@ -3,6 +3,7 @@ #include "load_image.hpp" #include +#include #include bool print_subword(std::string&& subword) { @@ -11,9 +12,14 @@ bool print_subword(std::string&& subword) { int main(int argc, char* argv[]) try { if (3 != argc) { - throw std::runtime_error(std::string{"Usage "} + argv[0] + " "); + throw std::runtime_error(std::string{"Usage "} + argv[0] + " "); } - ov::Tensor image = utils::load_image(argv[2]); + + std::vector images = utils::load_images(argv[2]); + + ov::genai::GenerationConfig generation_config; + generation_config.max_new_tokens = 200; + std::string device = "CPU"; // GPU can be used as well ov::AnyMap enable_compile_cache; if ("GPU" == device) { @@ -26,16 +32,18 @@ int main(int argc, char* argv[]) try { pipe.start_chat(); std::cout << "question:\n"; + std::getline(std::cin, prompt); - pipe.generate( - prompt, - ov::genai::image(image), - ov::genai::streamer(print_subword) - ); + pipe.generate(prompt, + ov::genai::images(images), + ov::genai::generation_config(generation_config), + ov::genai::streamer(print_subword)); std::cout << "\n----------\n" "question:\n"; while (std::getline(std::cin, prompt)) { - pipe.generate(prompt, ov::genai::streamer(print_subword)); + pipe.generate(prompt, + ov::genai::generation_config(generation_config), + ov::genai::streamer(print_subword)); std::cout << "\n----------\n" "question:\n"; } diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 5ae604c725..4885a21b8f 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -230,6 +230,22 @@ Sampler::GroupBeamSearcher::GroupBeamSearcher(SequenceGroup::Ptr sequence_group, } } + +std::vector Sampler::GroupBeamSearcher::get_beam_idxs() { + std::vector next_beams; + + for (Group& group : m_groups) { + if (!group.done) { + for (Beam& beam : group.ongoing) { + next_beams.push_back(beam.m_global_beam_idx); + } + } + } + + return next_beams; +} + + void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output) { assert(m_parameters.num_beams % m_parameters.num_beam_groups == 0 && "number of beams should be divisible by number of groups"); @@ -581,6 +597,15 @@ void register_new_token(const Token& sampled_token_id, } }; +std::vector Sampler::get_beam_idxs(SequenceGroup::CPtr sequence_group) { + size_t request_id = sequence_group->get_request_id(); + auto beam_searcher = m_beam_search_info.find(request_id); + if (m_beam_search_info.find(request_id) == m_beam_search_info.end()) { + return std::vector(sequence_group->num_running_seqs(), 0); + } + return beam_searcher->second.get_beam_idxs(); +} + std::list create_n_forked_sequences(SequenceGroup::Ptr sequence_group, LogitProcessor& logit_processor, diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 8188b35573..13933e0b75 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -65,6 +65,7 @@ class Sampler { SamplerOutput sample(std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false); void set_seed(size_t seed) { rng_engine.seed(seed); } void clear_beam_search_info(uint64_t request_id); + std::vector get_beam_idxs(SequenceGroup::CPtr sequence_group); }; class Sampler::GroupBeamSearcher { @@ -109,5 +110,6 @@ class Sampler::GroupBeamSearcher { void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output); void finalize(SamplerOutput& sampler_output); + std::vector get_beam_idxs(); }; } diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index 8c3882e4bf..7413d0cccf 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -3,7 +3,7 @@ #include "openvino/genai/visual_language/pipeline.hpp" #include "openvino/genai/tokenizer.hpp" -#include "vlm_sampling.hpp" +#include "sampler.hpp" #include "clip.hpp" #include "text_callback_streamer.hpp" #include "utils.hpp" @@ -21,64 +21,6 @@ template overloaded(Ts...) -> overloaded; constexpr size_t BATCH_SIZE = 1; -struct Args { - bool do_sample = false; - int top_k = 0; - float top_p = 0.7f; - float temp = 0.95f; - float repeat_penalty = 1.0f; -}; - -int64_t get_out_token_id(const std::vector& input_ids, float* logits, size_t vocab_size, Args args) { - int64_t out_token; - - // logits pre-process - if (args.repeat_penalty != 1.f) { - sampling_repetition_penalty(logits, logits + vocab_size, input_ids, args.repeat_penalty); - } - - if (args.do_sample) - { - if (args.temp > 0) { - sampling_temperature(logits, logits + vocab_size, args.temp); - } - - std::vector token_scores(vocab_size); - for (int i = 0; i < vocab_size; i++) { - token_scores[i] = TokenIdScore(i, logits[i]); - } - - // top_k sampling - if (0 < args.top_k && args.top_k < (int)token_scores.size()) { - sampling_top_k(token_scores.data(), token_scores.data() + args.top_k, - token_scores.data() + token_scores.size()); - token_scores.resize(args.top_k); - } - - // top_p sampling - if (0.f < args.top_p && args.top_p < 1.f) { - auto pos = sampling_top_p(token_scores.data(), token_scores.data() + token_scores.size(), args.top_p); - token_scores.resize(pos - token_scores.data()); - } - - // sample next token - sampling_softmax_inplace(token_scores.data(), token_scores.data() + token_scores.size()); - for (size_t i = 0; i < token_scores.size(); i++) { - logits[i] = token_scores[i].score; - } - - thread_local std::random_device rd; - thread_local std::mt19937 gen(rd()); - - std::discrete_distribution<> dist(logits, logits + token_scores.size()); - out_token = token_scores[dist(gen)].id; - } - else { - out_token = std::max_element(logits, logits + vocab_size) - logits; - } - - return out_token; -} ov::Tensor process_prompt(ov::InferRequest& embedding, const ov::Tensor& prompt, float scale_emb) { embedding.set_input_tensor(prompt); @@ -564,6 +506,126 @@ ov::Core singleton_core() { static ov::Core core; return core; } + +EncodedGenerationResult get_lm_encoded_results( + ov::InferRequest& language, + ov::InferRequest& embedding, + const ov::Tensor& inputs_embeds, + const VLMConfig& m_vlm_config, + const std::shared_ptr& streamer_ptr, + Sampler& sampler, + std::vector requests +) { + SequenceGroup::Ptr request = requests.back(); + GenerationHandle generation = std::make_shared(request->get_generation_stream(), request->get_sampling_parameters()); + + language.set_tensor("inputs_embeds", inputs_embeds); + + size_t history_len = language.get_tensor("attention_mask").get_shape().at(1); + language.get_tensor("attention_mask").set_shape({1, history_len + inputs_embeds.get_shape()[1]}); + std::fill_n(language.get_tensor("attention_mask").data(), language.get_tensor("attention_mask").get_size(), 1); + + language.get_tensor("position_ids").set_shape({1, inputs_embeds.get_shape().at(1)}); + std::iota(language.get_tensor("position_ids").data(), language.get_tensor("position_ids").data() + language.get_tensor("position_ids").get_size(), history_len); + + language.get_tensor("beam_idx").set_shape({ BATCH_SIZE }); + language.get_tensor("beam_idx").data()[0] = 0; + + language.infer(); + + int64_t sequence_len = language.get_tensor("logits").get_shape().at(1); + request->schedule_tokens(sequence_len); + + SamplerOutput sampler_output = sampler.sample(requests, language.get_tensor("logits")); + + language.get_tensor("inputs_embeds").set_shape({BATCH_SIZE, 1, m_vlm_config.hidden_size}); + language.get_tensor("position_ids").set_shape({ BATCH_SIZE, 1 }); + + while (!request->has_finished()) { + request->schedule_tokens(1); + size_t num_sequences = request->num_running_seqs(); + size_t total_num_tokens = request->get_num_scheduled_tokens() * num_sequences; + + ov::Tensor + input_ids(ov::element::i64, {total_num_tokens, 1}), + position_ids(ov::element::i64, {total_num_tokens, 1}), + beam_idx(ov::element::i32, { total_num_tokens }); + + int64_t + * input_ids_data = input_ids.data(), + * position_ids_data = position_ids.data(); + + size_t num_scheduled_tokens = request->get_num_scheduled_tokens(); + size_t group_position_id = request->get_num_processed_tokens(); + for (Sequence::Ptr& sequence : request->get_running_sequences()) { + for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) { + // compute token for current sequence + input_ids_data[token_id] = position_id < request->get_prompt_len() ? + request->get_prompt_ids()[position_id] : + sequence->get_generated_ids()[position_id - request->get_prompt_len()]; + + position_ids_data[token_id] = position_id; + } + // apply strides to shift to a next sequence + input_ids_data += num_scheduled_tokens; + position_ids_data += num_scheduled_tokens; + } + + embedding.set_input_tensor(input_ids); + + embedding.infer(); + const ov::Tensor& embed_prompt_tensor = embedding.get_output_tensor(); + float* embed_data = embed_prompt_tensor.data(); + for (auto idx = 0; idx < embed_prompt_tensor.get_size(); idx++) { + embed_data[idx] = embed_data[idx] * m_vlm_config.scale_emb; + } + + language.set_tensor("inputs_embeds", embed_prompt_tensor); + + language.get_tensor("attention_mask").set_shape({ total_num_tokens, language.get_tensor("attention_mask").get_shape()[1] + 1 }); + std::fill_n(language.get_tensor("attention_mask").data(), language.get_tensor("attention_mask").get_size(), 1); + + language.set_tensor("position_ids", position_ids); + + std::vector beam_idxs = sampler.get_beam_idxs(request); + int32_t *beam_idx_data = beam_idx.data(); + copy(beam_idxs.begin(), beam_idxs.end(), beam_idx_data); + language.set_tensor("beam_idx", beam_idx); + + language.infer(); + + if (streamer_ptr) { + // first sequence + int64_t out_token = request.get()->operator[](0)->get_generated_ids().back(); + if (streamer_ptr->put(out_token)) { + break; + } + } + + sampler_output = sampler.sample(requests, language.get_tensor("logits")); + } + + if (streamer_ptr) { + streamer_ptr->end(); + } + + EncodedGenerationResult result; + result.m_request_id = 1; + std::vector generation_outputs = generation->read_all(); + std::sort(generation_outputs.begin(), generation_outputs.end(), [] (const GenerationOutput& r1, const GenerationOutput& r2) { + return r1.score > r2.score; + }); + + auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size()); + for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { + const auto& generation_output = generation_outputs[generation_output_idx]; + result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); + result.m_scores.push_back(generation_output.score); + } + result.m_status = generation->get_status(); + + return result; +} } class ov::genai::VLMPipeline::VLMPipelineImpl { @@ -648,33 +710,21 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { inputs_embeds = get_inputs_embeds_llava_next(prompt, rgbs); } - m_language.set_tensor("inputs_embeds", inputs_embeds); - size_t history_len = m_language.get_tensor("attention_mask").get_shape().at(1); - m_language.get_tensor("attention_mask").set_shape({1, history_len + inputs_embeds.get_shape()[1]}); - std::fill_n(m_language.get_tensor("attention_mask").data(), m_language.get_tensor("attention_mask").get_size(), 1); - - m_language.get_tensor("position_ids").set_shape({1, inputs_embeds.get_shape().at(1)}); - std::iota(m_language.get_tensor("position_ids").data(), m_language.get_tensor("position_ids").data() + m_language.get_tensor("position_ids").get_size(), history_len); - - m_language.get_tensor("beam_idx").set_shape({ BATCH_SIZE }); - m_language.get_tensor("beam_idx").data()[0] = 0; - - m_language.infer(); + Sampler sampler = Sampler(m_tokenizer); - ov::Shape logits_shape = m_language.get_tensor("logits").get_shape(); - auto attention_size = m_language.get_tensor("attention_mask").get_size(); + std::vector requests; + size_t request_id = 0; + size_t block_size = 1; + bool enable_prefix_caching = false; + size_t history_size = m_language.get_tensor("attention_mask").get_shape().at(1); + size_t inputs_embeds_size = inputs_embeds.get_shape().at(1); + ov::Tensor prompt_ids(ov::element::i64, { history_size + inputs_embeds_size }); - int64_t sequence_len = m_language.get_tensor("logits").get_shape().at(1) - 1; - size_t vocab_size = m_language.get_tensor("logits").get_shape().back(); - float* logits = m_language.get_tensor("logits").data() + sequence_len * vocab_size; - int64_t out_token = std::max_element(logits, logits + vocab_size) - logits; + SequenceGroup::Ptr sequence_group = std::make_shared(request_id, prompt_ids, generation_config, block_size, enable_prefix_caching); + sequence_group->update_processed_tokens_num(history_size); + sequence_group->set_sequence_group_ptr(sequence_group); + requests.push_back(sequence_group); - m_language.get_tensor("inputs_embeds").set_shape({BATCH_SIZE, 1, m_vlm_config.hidden_size}); - m_language.get_tensor("position_ids").set_shape({ BATCH_SIZE, 1 }); - - m_embedding.get_input_tensor().set_shape({ 1, 1 }); - - int64_t eos_token_id = m_tokenizer.get_eos_token_id(); std::shared_ptr streamer_ptr = std::visit(overloaded{ [&m_tokenizer = m_tokenizer]( const std::function& callback @@ -688,40 +738,19 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { return std::shared_ptr{nullptr}; }, }, streamer); - std::vector generated; - while (true) { //(out_token != eos_token_id) - m_embedding.get_input_tensor().data()[0] = out_token; - m_embedding.infer(); - const ov::Tensor& embed_prompt_tensor = m_embedding.get_output_tensor(); - float* embed_data = embed_prompt_tensor.data(); - for (auto idx = 0; idx < embed_prompt_tensor.get_size(); idx++) { - embed_data[idx] = embed_data[idx] * m_vlm_config.scale_emb; - } - - m_language.set_tensor("inputs_embeds", embed_prompt_tensor); - m_language.get_tensor("attention_mask").set_shape({ BATCH_SIZE, m_language.get_tensor("attention_mask").get_shape()[1] + 1 }); - std::fill_n(m_language.get_tensor("attention_mask").data(), m_language.get_tensor("attention_mask").get_size(), 1); - m_language.get_tensor("position_ids").data()[0] = int64_t(m_language.get_tensor("attention_mask").get_size() - 1); - m_language.infer(); + OPENVINO_ASSERT((generation_config.is_greedy_decoding() || generation_config.is_multinomial() || !streamer_ptr), + "Currently streaming is possible only for greedy or multinomial decoding"); - generated.push_back(out_token); - if (streamer_ptr && streamer_ptr->put(out_token)) { - break; - } - logits = m_language.get_tensor("logits").data(); + EncodedGenerationResult encoded_result = get_lm_encoded_results(m_language, m_embedding, inputs_embeds, m_vlm_config, streamer_ptr, sampler, requests); - out_token = std::max_element(logits, logits + vocab_size) - logits; - if (out_token == eos_token_id) { - break; - } + DecodedResults decoded; + for (size_t idx = 0; idx < encoded_result.m_generation_ids.size(); ++idx) { + decoded.texts.push_back(m_tokenizer.decode(encoded_result.m_generation_ids.at(idx))); + decoded.scores.push_back(encoded_result.m_scores.at(idx)); } - if (streamer_ptr) { - streamer_ptr->end(); - } - - std::string decoded_results = m_tokenizer.decode(generated); + std::string decoded_results = decoded.texts.at(0); if (m_is_chat_conversation) { // Tail of chat template is missing in KV cache. // Find the tail to concatenate it with the next input prompt. @@ -733,7 +762,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { } m_language.get_tensor("attention_mask").set_shape({1, 0}); } - return {{std::move(decoded_results)}}; + return decoded; } DecodedResults generate( @@ -755,6 +784,11 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::genai::OptionalGenerationConfig config_arg = utils::get_config_from_map(config_map); GenerationConfig config = (config_arg.has_value()) ? *config_arg : get_generation_config(); config.update_generation_config(config_map); + + // If eos_token_id was not provided, take value + if (config.eos_token_id == -1) + config.set_eos_token_id(m_tokenizer.get_eos_token_id()); + return generate( prompt, rgbs, diff --git a/src/cpp/src/vlm_sampling.hpp b/src/cpp/src/vlm_sampling.hpp deleted file mode 100644 index b0a7d2341f..0000000000 --- a/src/cpp/src/vlm_sampling.hpp +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include -#include -#include -#include - -struct TokenIdScore { - int id; - float score; - - TokenIdScore() = default; - TokenIdScore(int id, float score) : id(id), score(score) {} - - bool operator<(const TokenIdScore& other) const { return score < other.score; } - bool operator>(const TokenIdScore& other) const { return score > other.score; } - - friend std::ostream& operator<<(std::ostream& os, const TokenIdScore& self) { - return os << "TokenIdScore(id=" << self.id << ", score=" << self.score << ")"; - } -}; - -void sampling_softmax_inplace(TokenIdScore* first, TokenIdScore* last) { - float max_score = std::max_element(first, last)->score; - float sum = 0.f; - for (TokenIdScore* p = first; p != last; p++) { - float s = std::exp(p->score - max_score); - p->score = s; - sum += s; - } - float inv_sum = 1.f / sum; - for (TokenIdScore* p = first; p != last; p++) { - p->score *= inv_sum; - } -} - -void sampling_top_k(TokenIdScore* first, TokenIdScore* kth, TokenIdScore* last) { - std::nth_element(first, kth, last, std::greater()); -} - -TokenIdScore* sampling_top_p(TokenIdScore* first, TokenIdScore* last, float top_p) { - // fast top_p in expected O(n) time complexity - sampling_softmax_inplace(first, last); - - while (first + 1 < last) { - const float pivot_score = (last - 1)->score; // use mid score? - TokenIdScore* mid = - std::partition(first, last - 1, [pivot_score](const TokenIdScore& x) { return x.score > pivot_score; }); - std::swap(*mid, *(last - 1)); - - const float prefix_sum = - std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore& x) { return sum + x.score; }); - if (prefix_sum >= top_p) { - last = mid; - } - else if (prefix_sum + mid->score < top_p) { - first = mid + 1; - top_p -= prefix_sum + mid->score; - } - else { - return mid + 1; - } - } - return last; -} - -void sampling_repetition_penalty(float* first, float* last, const std::vector& input_ids, - float penalty) { - if (penalty < 0) { - std::cout << "penalty must be a positive float, but got " << penalty; - return; - } - const float inv_penalty = 1.f / penalty; - const ptrdiff_t vocab_size = last - first; - std::vector occurrence(vocab_size, false); - for (const int id : input_ids) { - if (!occurrence[id]) { - first[id] *= (first[id] > 0) ? inv_penalty : penalty; - } - occurrence[id] = true; - } -} - -void sampling_temperature(float* first, float* last, float temp) { - const float inv_temp = 1.f / temp; - for (float* it = first; it != last; it++) { - *it *= inv_temp; - } -} - - -