From a427d0a52d7e85a31d2fb2b2d5811f81ea13372c Mon Sep 17 00:00:00 2001 From: wenyi5608 <93560477+wenyi5608@users.noreply.github.com> Date: Fri, 8 Mar 2024 15:11:17 +0800 Subject: [PATCH 1/5] greedy-sampling --- text_generation/causal_lm/cpp/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/text_generation/causal_lm/cpp/CMakeLists.txt b/text_generation/causal_lm/cpp/CMakeLists.txt index 63c6490b9f..55ecc18378 100644 --- a/text_generation/causal_lm/cpp/CMakeLists.txt +++ b/text_generation/causal_lm/cpp/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(../../../thirdparty/openvino_tokenizers/ "${CMAKE_CURRENT_BINAR add_executable(greedy_causal_lm greedy_causal_lm.cpp) target_compile_definitions(greedy_causal_lm PRIVATE OPENVINO_TOKENIZERS_PATH=\"$\") +target_include_directories(greedy_causal_lm PRIVATE ./) find_package(OpenVINO REQUIRED COMPONENTS Runtime) target_link_libraries(greedy_causal_lm PRIVATE openvino::runtime) set_target_properties(greedy_causal_lm PROPERTIES CXX_STANDARD 17) From d9320dfef4fe131a5aaaa98907d9e8354b49d3f1 Mon Sep 17 00:00:00 2001 From: wenyi5608 <93560477+wenyi5608@users.noreply.github.com> Date: Fri, 8 Mar 2024 15:18:11 +0800 Subject: [PATCH 2/5] greedy sampling --- text_generation/causal_lm/cpp/greedy_causal_lm.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/text_generation/causal_lm/cpp/greedy_causal_lm.cpp b/text_generation/causal_lm/cpp/greedy_causal_lm.cpp index 7f287378b8..de0bec26b4 100644 --- a/text_generation/causal_lm/cpp/greedy_causal_lm.cpp +++ b/text_generation/causal_lm/cpp/greedy_causal_lm.cpp @@ -1,6 +1,7 @@ // Copyright (C) 2023-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 +#include #include namespace { @@ -82,9 +83,13 @@ int main(int argc, char* argv[]) try { lm.get_tensor("beam_idx").set_shape({BATCH_SIZE}); lm.get_tensor("beam_idx").data()[0] = 0; lm.infer(); + int64_t sequence_len = lm.get_tensor("logits").get_shape().at(1) - 1; size_t vocab_size = lm.get_tensor("logits").get_shape().back(); - float* logits = lm.get_tensor("logits").data() + (input_ids.get_size() - 1) * vocab_size; - int64_t out_token = std::max_element(logits, logits + vocab_size) - logits; + float* logits = lm.get_tensor("logits").data() + (sequence_len) * vocab_size; + const int64_t* prompt_data = input_ids.data(); + SamplingParameters parameters{ std::vector{prompt_data, prompt_data + input_ids.get_size()} }; + GreedySampling greedy_sampling{ parameters }; + int64_t out_token = greedy_sampling.get_out_token(logits, vocab_size); lm.get_tensor("input_ids").set_shape({BATCH_SIZE, 1}); position_ids.set_shape({BATCH_SIZE, 1}); @@ -100,7 +105,7 @@ int main(int argc, char* argv[]) try { text_streamer.put(out_token); lm.wait(); logits = lm.get_tensor("logits").data(); - out_token = std::max_element(logits, logits + vocab_size) - logits; + out_token = greedy_sampling.get_out_token(logits, vocab_size); } text_streamer.end(); // Model is stateful which means that context (kv-cache) which belongs to a particular From a5aa97fbfba968ceeaa436f815539341d6ab461e Mon Sep 17 00:00:00 2001 From: wenyi5608 <93560477+wenyi5608@users.noreply.github.com> Date: Fri, 8 Mar 2024 15:18:56 +0800 Subject: [PATCH 3/5] greedy sampling --- .../causal_lm/cpp/greedy_sampling.hpp | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 text_generation/causal_lm/cpp/greedy_sampling.hpp diff --git a/text_generation/causal_lm/cpp/greedy_sampling.hpp b/text_generation/causal_lm/cpp/greedy_sampling.hpp new file mode 100644 index 0000000000..0726475b39 --- /dev/null +++ b/text_generation/causal_lm/cpp/greedy_sampling.hpp @@ -0,0 +1,165 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include +#include +#include + +struct TokenIdScore { + int id; + float score; + + TokenIdScore() = default; + TokenIdScore(int id, float score) : id(id), score(score) {} + + bool operator<(const TokenIdScore& other) const { return score < other.score; } + bool operator>(const TokenIdScore& other) const { return score > other.score; } + + friend std::ostream& operator<<(std::ostream& os, const TokenIdScore& self) { + return os << "TokenIdScore(id=" << self.id << ", score=" << self.score << ")"; + } +}; + +void sampling_softmax_inplace(TokenIdScore* first, TokenIdScore* last) { + float max_score = std::max_element(first, last)->score; + float sum = 0.f; + for (TokenIdScore* p = first; p != last; p++) { + float s = std::exp(p->score - max_score); + p->score = s; + sum += s; + } + float inv_sum = 1.f / sum; + for (TokenIdScore* p = first; p != last; p++) { + p->score *= inv_sum; + } +} + +void sampling_top_k(TokenIdScore* first, TokenIdScore* kth, TokenIdScore* last) { + std::nth_element(first, kth, last, std::greater()); +} + +TokenIdScore* sampling_top_p(TokenIdScore* first, TokenIdScore* last, float top_p) { + // fast top_p in expected O(n) time complexity + sampling_softmax_inplace(first, last); + + while (first + 1 < last) { + const float pivot_score = (last - 1)->score; // use mid score? + TokenIdScore* mid = + std::partition(first, last - 1, [pivot_score](const TokenIdScore& x) { return x.score > pivot_score; }); + std::swap(*mid, *(last - 1)); + + const float prefix_sum = + std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore& x) { return sum + x.score; }); + if (prefix_sum >= top_p) { + last = mid; + } + else if (prefix_sum + mid->score < top_p) { + first = mid + 1; + top_p -= prefix_sum + mid->score; + } + else { + return mid + 1; + } + } + return last; +} + +void sampling_repetition_penalty(float* first, float* last, const std::vector& input_ids, + float penalty) { + if (penalty < 0) { + std::cout << "penalty must be a positive float, but got " << penalty; + return; + } + const float inv_penalty = 1.f / penalty; + const int vocab_size = last - first; + std::vector occurrence(vocab_size, false); + for (const int64_t id : input_ids) { + if (!occurrence[id]) { + first[id] *= (first[id] > 0) ? inv_penalty : penalty; + } + occurrence[id] = true; + } +} + +void sampling_temperature(float* first, float* last, float temp) { + const float inv_temp = 1.f / temp; + for (float* it = first; it != last; it++) { + *it *= inv_temp; + } +} + +struct SamplingParameters { + std::vector prompt; + int top_k = 0; + float top_p = 0.7; + float temp = 0.95; + float repeat_penalty = 1.1; + bool do_sample = true; +}; + +// GreedySampling processes logits prduced by a language model and chooses the token with +// the highest probablity as the next token in the sequence. get_out_token() returns token +// ids selected by the algorithm. The value is used for next inference. +struct GreedySampling { + SamplingParameters parameters; + GreedySampling(SamplingParameters parameters) : parameters{ std::move(parameters) } { + } + + int64_t get_out_token(float* logits, size_t vocab_size) { + int64_t out_token; + std::vector prompt{ parameters.prompt }; + + // logits pre-process + if (parameters.repeat_penalty != 1.f) { + sampling_repetition_penalty(logits, logits + vocab_size, prompt, parameters.repeat_penalty); + } + + if (parameters.do_sample) + { + if (parameters.temp > 0) { + sampling_temperature(logits, logits + vocab_size, parameters.temp); + } + + std::vector token_scores(vocab_size); + for (int i = 0; i < vocab_size; i++) { + token_scores[i] = TokenIdScore(i, logits[i]); + } + + // top_k sampling + if (0 < parameters.top_k && parameters.top_k < (int)token_scores.size()) { + sampling_top_k(token_scores.data(), token_scores.data() + parameters.top_k, + token_scores.data() + token_scores.size()); + token_scores.resize(parameters.top_k); + } + + // top_p sampling + if (0.f < parameters.top_p && parameters.top_p < 1.f) { + auto pos = sampling_top_p(token_scores.data(), token_scores.data() + token_scores.size(), parameters.top_p); + token_scores.resize(pos - token_scores.data()); + } + + // sample next token + sampling_softmax_inplace(token_scores.data(), token_scores.data() + token_scores.size()); + for (size_t i = 0; i < token_scores.size(); i++) { + logits[i] = token_scores[i].score; + } + + thread_local std::random_device rd; + thread_local std::mt19937 gen(rd()); + + std::discrete_distribution<> dist(logits, logits + token_scores.size()); + out_token = token_scores[dist(gen)].id; + } + else { + out_token = std::max_element(logits, logits + vocab_size) - logits; + } + + prompt.push_back(out_token); + + return { out_token }; + } +}; \ No newline at end of file From 3a598ddc3795d1ede77e50b140351bc44d703827 Mon Sep 17 00:00:00 2001 From: wenyi5608 <93560477+wenyi5608@users.noreply.github.com> Date: Wed, 17 Apr 2024 15:15:12 +0800 Subject: [PATCH 4/5] Update greedy_sampling.hpp don't transform the logits during top_p --- .../causal_lm/cpp/greedy_sampling.hpp | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/text_generation/causal_lm/cpp/greedy_sampling.hpp b/text_generation/causal_lm/cpp/greedy_sampling.hpp index 0726475b39..c854a38ff2 100644 --- a/text_generation/causal_lm/cpp/greedy_sampling.hpp +++ b/text_generation/causal_lm/cpp/greedy_sampling.hpp @@ -43,28 +43,29 @@ void sampling_top_k(TokenIdScore* first, TokenIdScore* kth, TokenIdScore* last) } TokenIdScore* sampling_top_p(TokenIdScore* first, TokenIdScore* last, float top_p) { - // fast top_p in expected O(n) time complexity - sampling_softmax_inplace(first, last); - - while (first + 1 < last) { - const float pivot_score = (last - 1)->score; // use mid score? - TokenIdScore* mid = - std::partition(first, last - 1, [pivot_score](const TokenIdScore& x) { return x.score > pivot_score; }); - std::swap(*mid, *(last - 1)); - - const float prefix_sum = - std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore& x) { return sum + x.score; }); - if (prefix_sum >= top_p) { - last = mid; - } - else if (prefix_sum + mid->score < top_p) { - first = mid + 1; - top_p -= prefix_sum + mid->score; - } - else { - return mid + 1; + //sort score + std::sort(first, last, std::greater()); + + int vocab_size = last - first; + std::vector token_scores(vocab_size); + for (int i = 0; i < vocab_size; i++) { + token_scores[i] = first[i]; + } + + //calculate softmax + sampling_softmax_inplace(token_scores.data(), token_scores.data() + token_scores.size()); + + float prefix_sum = 0.0f; + + //top_p + for (int i = 0; i < vocab_size; i++) { + prefix_sum += token_scores[i].score; + if (prefix_sum >= top_p){ + return first + (i + 1); + break; } } + return last; } @@ -162,4 +163,4 @@ struct GreedySampling { return { out_token }; } -}; \ No newline at end of file +}; From 5017097fafc0cec54368116ebb2fa17207976f2d Mon Sep 17 00:00:00 2001 From: wenyi5608 <93560477+wenyi5608@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:40:36 +0800 Subject: [PATCH 5/5] Update greedy_causal_lm.cpp --- .../causal_lm/cpp/greedy_causal_lm.cpp | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/text_generation/causal_lm/cpp/greedy_causal_lm.cpp b/text_generation/causal_lm/cpp/greedy_causal_lm.cpp index de0bec26b4..edcc6283e4 100644 --- a/text_generation/causal_lm/cpp/greedy_causal_lm.cpp +++ b/text_generation/causal_lm/cpp/greedy_causal_lm.cpp @@ -59,33 +59,40 @@ int main(int argc, char* argv[]) try { if (argc != 3) { throw std::runtime_error(std::string{"Usage: "} + argv[0] + " ''"); } + // Compile models ov::Core core; core.add_extension(OPENVINO_TOKENIZERS_PATH); // OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt + //Read the tokenizer model information from the file to later get the runtime information + auto tokenizer_model = core.read_model(std::string{ argv[1] } + "/openvino_tokenizer.xml"); // tokenizer and detokenizer work on CPU only ov::InferRequest tokenizer = core.compile_model( - std::string{argv[1]} + "/openvino_tokenizer.xml", "CPU").create_infer_request(); + tokenizer_model, "CPU").create_infer_request(); auto [input_ids, attention_mask] = tokenize(tokenizer, argv[2]); ov::InferRequest detokenizer = core.compile_model( std::string{argv[1]} + "/openvino_detokenizer.xml", "CPU").create_infer_request(); // The model can be compiled for GPU as well ov::InferRequest lm = core.compile_model( std::string{argv[1]} + "/openvino_model.xml", "CPU").create_infer_request(); + auto seq_len = input_ids.get_size(); + // Initialize inputs lm.set_tensor("input_ids", input_ids); lm.set_tensor("attention_mask", attention_mask); ov::Tensor position_ids = lm.get_tensor("position_ids"); position_ids.set_shape(input_ids.get_shape()); - std::iota(position_ids.data(), position_ids.data() + position_ids.get_size(), 0); + std::iota(position_ids.data(), position_ids.data() + seq_len, 0); constexpr size_t BATCH_SIZE = 1; // Input values are persistent between inference calls. // That allows to set values, which aren't going to change, only once lm.get_tensor("beam_idx").set_shape({BATCH_SIZE}); lm.get_tensor("beam_idx").data()[0] = 0; lm.infer(); - int64_t sequence_len = lm.get_tensor("logits").get_shape().at(1) - 1; + + int64_t sequence_offset = lm.get_tensor("logits").get_shape().at(1) - 1; size_t vocab_size = lm.get_tensor("logits").get_shape().back(); - float* logits = lm.get_tensor("logits").data() + (sequence_len) * vocab_size; + float* logits = lm.get_tensor("logits").data() + (sequence_offset) * vocab_size; + const int64_t* prompt_data = input_ids.data(); SamplingParameters parameters{ std::vector{prompt_data, prompt_data + input_ids.get_size()} }; GreedySampling greedy_sampling{ parameters }; @@ -94,9 +101,20 @@ int main(int argc, char* argv[]) try { lm.get_tensor("input_ids").set_shape({BATCH_SIZE, 1}); position_ids.set_shape({BATCH_SIZE, 1}); TextStreamer text_streamer{std::move(detokenizer)}; - // There's no way to extract special token values from the detokenizer for now - constexpr int64_t SPECIAL_EOS_TOKEN = 2; - while (out_token != SPECIAL_EOS_TOKEN) { + // Get the runtime info from the tokenizer model that we read earlier + auto rt_info = tokenizer_model->get_rt_info(); //Get the runtime info for the model + int64_t SPECIAL_EOS_TOKEN; + + if (rt_info.count("eos_token_id") > 0) { //check if the runtime information has a valid EOS token ID + SPECIAL_EOS_TOKEN = rt_info["eos_token_id"].as(); + } + else { + throw std::runtime_error("EOS token ID not found in model's runtime information."); + } + + int max_sequence_length = 100; + while (out_token != SPECIAL_EOS_TOKEN && seq_len < max_sequence_length) { + ++seq_len; lm.get_tensor("input_ids").data()[0] = out_token; lm.get_tensor("attention_mask").set_shape({BATCH_SIZE, lm.get_tensor("attention_mask").get_shape().at(1) + 1}); std::fill_n(lm.get_tensor("attention_mask").data(), lm.get_tensor("attention_mask").get_size(), 1);