diff --git a/text_generation/causal_lm/cpp/speculative_decoding_lm.cpp b/text_generation/causal_lm/cpp/speculative_decoding_lm.cpp index f5e79ae8f0..92523f82a5 100644 --- a/text_generation/causal_lm/cpp/speculative_decoding_lm.cpp +++ b/text_generation/causal_lm/cpp/speculative_decoding_lm.cpp @@ -11,8 +11,7 @@ constexpr size_t BATCH_SIZE = 1; // threfore usually SEQ_LEN_AXIS = 2 constexpr size_t SEQ_LEN_AXIS = 2; -// There's no way to extract special token values from the detokenizer for now -constexpr int64_t SPECIAL_EOS_TOKEN = 2; +int64_t SPECIAL_EOS_TOKEN; namespace { std::pair tokenize(ov::InferRequest& tokenizer, std::string&& prompt) { @@ -117,9 +116,10 @@ int main(int argc, char* argv[]) try { // tokenizer model ov::Core core; core.add_extension(OPENVINO_TOKENIZERS_PATH); // OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt + auto tokenizer_model = core.read_model(std::string{argv[1]} + "/openvino_tokenizer.xml"); // tokenizer and detokenizer work on CPU only ov::InferRequest tokenizer = core.compile_model( - std::string{argv[1]} + "/openvino_tokenizer.xml", "CPU").create_infer_request(); + tokenizer_model, "CPU").create_infer_request(); auto [draft_input_ids, draft_attention_mask] = tokenize(tokenizer, argv[3]); ov::InferRequest detokenizer = core.compile_model( std::string{argv[1]} + "/openvino_detokenizer.xml", "CPU").create_infer_request(); @@ -183,6 +183,14 @@ int main(int argc, char* argv[]) try { draft_input_ids.set_shape({BATCH_SIZE, 1}); draft_position_ids.set_shape({BATCH_SIZE, 1}); + auto rt_info = tokenizer_model->get_rt_info(); //Get the runtime info for the model + + if (rt_info.count("eos_token_id") > 0) { //check if the runtime information has a valid EOS token ID + SPECIAL_EOS_TOKEN = rt_info["eos_token_id"].as(); + } else { + throw std::runtime_error("EOS token ID not found in model's runtime information."); + } + /* Speculative decoding works the following way. The draft model predicts the next K tokens one by one in an autoregressive manner, while the main model validates these predictions and corrects them if necessary. We go through each predicted token, and