Skip to content

Commit

Permalink
Read EOS token from model runtime information for speculative_decodin…
Browse files Browse the repository at this point in the history
…g_lm (#353)

Extension to issue #277, Added the functionality to read EOS token from
model runtime information in the speculative_decoding_lm.
  • Loading branch information
anzr299 authored Apr 11, 2024
1 parent 28286d4 commit e84defc
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions text_generation/causal_lm/cpp/speculative_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ constexpr size_t BATCH_SIZE = 1;
// threfore usually SEQ_LEN_AXIS = 2
constexpr size_t SEQ_LEN_AXIS = 2;

// There's no way to extract special token values from the detokenizer for now
constexpr int64_t SPECIAL_EOS_TOKEN = 2;
int64_t SPECIAL_EOS_TOKEN;

namespace {
std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest& tokenizer, std::string&& prompt) {
Expand Down Expand Up @@ -117,9 +116,10 @@ int main(int argc, char* argv[]) try {
// tokenizer model
ov::Core core;
core.add_extension(OPENVINO_TOKENIZERS_PATH); // OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt
auto tokenizer_model = core.read_model(std::string{argv[1]} + "/openvino_tokenizer.xml");
// tokenizer and detokenizer work on CPU only
ov::InferRequest tokenizer = core.compile_model(
std::string{argv[1]} + "/openvino_tokenizer.xml", "CPU").create_infer_request();
tokenizer_model, "CPU").create_infer_request();
auto [draft_input_ids, draft_attention_mask] = tokenize(tokenizer, argv[3]);
ov::InferRequest detokenizer = core.compile_model(
std::string{argv[1]} + "/openvino_detokenizer.xml", "CPU").create_infer_request();
Expand Down Expand Up @@ -183,6 +183,14 @@ int main(int argc, char* argv[]) try {
draft_input_ids.set_shape({BATCH_SIZE, 1});
draft_position_ids.set_shape({BATCH_SIZE, 1});

auto rt_info = tokenizer_model->get_rt_info(); //Get the runtime info for the model

if (rt_info.count("eos_token_id") > 0) { //check if the runtime information has a valid EOS token ID
SPECIAL_EOS_TOKEN = rt_info["eos_token_id"].as<int64_t>();
} else {
throw std::runtime_error("EOS token ID not found in model's runtime information.");
}

/* Speculative decoding works the following way. The draft model predicts the next K
tokens one by one in an autoregressive manner, while the main model validates these
predictions and corrects them if necessary. We go through each predicted token, and
Expand Down

0 comments on commit e84defc

Please sign in to comment.