Skip to content

Commit

Permalink
Automatically apply chat template in non-chat scenarios
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Jan 27, 2025
1 parent 834b677 commit 5924b23
Show file tree
Hide file tree
Showing 32 changed files with 193 additions and 32 deletions.
20 changes: 14 additions & 6 deletions .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ jobs:
"
echo "你好! 你好嗎?" passed
timeout 1m ${{ matrix.executable }} ./TinyLlama-1.1B-Chat-v1.0/ "Alan Turing was a" "return 0" "你好! 你好嗎?" > ./pred.txt
timeout 1m ${{ matrix.executable }} ./TinyLlama-1.1B-Chat-v1.0/ "Why is the Sun yellow?" "return 0" "你好! 你好嗎?" > ./pred.txt
python -c "
import transformers
with open('pred.txt', 'r', errors='ignore') as file:
Expand All @@ -219,15 +219,14 @@ jobs:
print('\n\n')
tokenizer = transformers.AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0')
prompts = [
'Alan Turing was a',
'Why is the Sun yellow?',
'return 0',
'你好! 你好嗎?'
]
for prompt in prompts:
if tokenizer.chat_template:
prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True)
tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
print(tokenized)
for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False):
ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True)
print(ref)
Expand Down Expand Up @@ -277,7 +276,10 @@ jobs:
echo import transformers > ref.py
echo predictions = open('cpp.txt', 'r').read() >> ref.py
echo tokenizer = transformers.AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0', trust_remote_code=True) >> ref.py
echo tokenized = tokenizer('69', return_tensors='pt') >> ref.py
echo prompt = '69' >> ref.py
echo if tokenizer.chat_template: >> ref.py
echo prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) >> ref.py
echo tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False) >> ref.py
echo for beam in transformers.AutoModelForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0', trust_remote_code=True).generate(**tokenized, max_new_tokens=100, do_sample=False): >> ref.py
echo ref = tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) >> ref.py
echo idx = predictions.find(ref) >> ref.py
Expand Down Expand Up @@ -584,7 +586,10 @@ jobs:
with open('pred_greedy.txt', 'r') as file:
predictions = file.read()
tokenizer = transformers.AutoTokenizer.from_pretrained('microsoft/phi-1_5')
tokenized = tokenizer('Alan Turing was a', return_tensors='pt')
prompt = 'Alan Turing was a'
if tokenizer.chat_template:
prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True)
tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
for output in transformers.AutoModelForCausalLM.from_pretrained('microsoft/phi-1_5').generate(**tokenized, max_length=100, do_sample=False):
ref = tokenizer.decode(output[tokenized['input_ids'].numel():], skip_special_tokens=True)
idx = predictions.find(ref)
Expand Down Expand Up @@ -639,7 +644,10 @@ jobs:
with open('pred_greedy.txt', 'r') as file:
predictions = file.read()
tokenizer = transformers.AutoTokenizer.from_pretrained('ikala/redpajama-3b-chat')
tokenized = tokenizer('Alan Turing was a', return_tensors='pt')
prompt = 'Alan Turing was a'
if tokenizer.chat_template:
prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True)
tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
for output in transformers.AutoModelForCausalLM.from_pretrained('ikala/redpajama-3b-chat').generate(**tokenized, max_length=100, do_sample=False):
ref = tokenizer.decode(output[tokenized['input_ids'].numel():], skip_special_tokens=True)
idx = predictions.find(ref)
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ from PIL import Image

# Choose GPU instead of CPU in the line below to run the model on Intel integrated or discrete GPU
pipe = openvino_genai.VLMPipeline("./InternVL2-1B", "CPU")
pipe.start_chat()

image = Image.open("dog.jpg")
image_data = np.array(image.getdata()).reshape(1, image.size[1], image.size[0], 3).astype(np.uint8)
Expand Down
2 changes: 1 addition & 1 deletion samples/cpp/text_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Recommended models: meta-llama/Llama-2-7b-chat-hf, TinyLlama/TinyLlama-1.1B-Chat
./chat_sample <MODEL_DIR>
```
#### Missing chat template
If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model.
If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model or update it using call `pipe.get_tokenizer().set_chat_template(new_chat_template)`.
The following template can be used as a default, but it may not work properly with every model:
```
"chat_template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n<|im_start|>assistant\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>\n'}}{% endif %}{% endfor %}",
Expand Down
2 changes: 1 addition & 1 deletion samples/python/text_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Recommended models: meta-llama/Llama-2-7b-chat-hf, TinyLlama/TinyLlama-1.1B-Chat
python chat_sample.py model_dir
```
#### Missing chat template
If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model.
If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model or update it using call `pipe.get_tokenizer().set_chat_template(new_chat_template)`.
The following template can be used as a default, but it may not work properly with every model:
```
"chat_template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n<|im_start|>assistant\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>\n'}}{% endif %}{% endfor %}",
Expand Down
2 changes: 2 additions & 0 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ output:
'it is made up of carbon atoms. The carbon atoms are arranged in a linear pattern, which gives the yellow color. The arrangement of carbon atoms in'
```
>**Note**: The chat_template from tokenizer_config.json or from tokenizer/detokenizer model will be automatically applied to the prompt at the generation stage. If you want to disable it, you can do it by calling pipe.get_tokenizer().set_chat_template("").
A simple chat in Python:
```python
import openvino_genai as ov_genai
Expand Down
4 changes: 4 additions & 0 deletions src/cpp/include/openvino/genai/generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig {

std::optional<AdapterConfig> adapters;

bool apply_chat_template = true;

/** @brief sets eos_token_id to tokenizer_eos_token_id if eos_token_id is less than 0.
* Otherwise verifies eos_token_id == tokenizer_eos_token_id.
*/
Expand Down Expand Up @@ -189,6 +191,8 @@ extern OPENVINO_GENAI_EXPORTS ov::Property<size_t> rng_seed;
static constexpr ov::Property<float> assistant_confidence_threshold{"assistant_confidence_threshold"};
static constexpr ov::Property<size_t> num_assistant_tokens{"num_assistant_tokens"};

static constexpr ov::Property<bool> apply_chat_template{"apply_chat_template"};

// Predefined Configs

OPENVINO_DEPRECATED("Please, use individual parameters instead of predefined configs. This method will be removed in 2026.0.0 release")
Expand Down
4 changes: 4 additions & 0 deletions src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline {
* @param generation_config optional GenerationConfig
* @param streamer optional streamer
* @return DecodedResults decoded resulting text
* chat_template will be applied to the prompt, run pipe.get_tokenizer().set_chat_template(custom_chat_template) to update it.
* To disable it for non-chat mode, please, use custom_chat_template eq "" or set generation_config.apply_chat_template to false.
*/
DecodedResults generate(
StringInputs inputs,
Expand All @@ -191,6 +193,8 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline {
* @param inputs input prompt or a vector of prompts
* @param properties properties
* @return DecodedResults decoded resulting text
* chat_template will be applied to the prompt, run pipe.get_tokenizer().set_chat_template(custom_chat_template) to update it.
* To disable it for non-chat mode, please, use custom_chat_template eq "" or set generation_config.apply_chat_template to false.
*/
template <typename... Properties>
util::EnableIfAllStringAny<DecodedResults, Properties...> generate(
Expand Down
3 changes: 3 additions & 0 deletions src/cpp/include/openvino/genai/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
/// @param chat_template The new template to override with.
void set_chat_template(const std::string& chat_template);

// get information about a chat template to check its status, for example whether it is empty
std::string get_chat_template() const;

// information about <bos>, <eos> tokens should be public,
// they are used at least in StreamerBase descendants
int64_t get_bos_token_id() const;
Expand Down
8 changes: 8 additions & 0 deletions src/cpp/include/openvino/genai/visual_language/pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class OPENVINO_GENAI_EXPORTS VLMPipeline {
/// @param generation_config A config to follow for text generation.
/// @param streamer A streamer to acquire intermediate result.
/// @return A string generated by a model.
/// chat_template will be applied to the prompt, run pipe.set_chat_template(custom_chat_template) to update it.
/// To disable it for non-chat mode, please, use custom_chat_template eq "" or set generation_config.apply_chat_template to false.
VLMDecodedResults generate(
const std::string& prompt,
const std::vector<ov::Tensor>& rgbs,
Expand All @@ -111,6 +113,8 @@ class OPENVINO_GENAI_EXPORTS VLMPipeline {
/// @param generation_config A config to follow for text generation.
/// @param streamer A streamer to acquire intermediate result.
/// @return A string generated by a model.
/// chat_template will be applied to the prompt, run pipe.set_chat_template(custom_chat_template) to update it.
/// To disable it for non-chat mode, please, use custom_chat_template eq "" or set generation_config.apply_chat_template to false.
VLMDecodedResults generate(
const std::string& prompt,
const ov::Tensor& rgb,
Expand All @@ -124,6 +128,8 @@ class OPENVINO_GENAI_EXPORTS VLMPipeline {
/// for its members, StreamerVariant a single image or multiple
/// images.
/// @return A string generated by a model.
/// chat_template will be applied to the prompt, run pipe.set_chat_template(custom_chat_template) to update it.
/// To disable it for non-chat mode, please, use custom_chat_template eq "" or set generation_config.apply_chat_template to false.
VLMDecodedResults generate(
const std::string& prompt,
const ov::AnyMap& config_map
Expand All @@ -137,6 +143,8 @@ class OPENVINO_GENAI_EXPORTS VLMPipeline {
/// @param ...properties ov::Property instances to be combined into
/// ov::AnyMap.
/// @return A string generated by a model.
/// chat_template will be applied to the prompt, run pipe.set_chat_template(custom_chat_template) to update it.
/// To disable it for non-chat mode, please, use custom_chat_template eq "" or set generation_config.apply_chat_template to false.
template <typename... Properties>
util::EnableIfAllStringAny<VLMDecodedResults, Properties...> generate(
const std::string& prompt,
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/include/openvino/genai/whisper_generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig : public GenerationConfig {
// A list containing the non-speech tokens that will be suppressed during generation.
std::vector<int64_t> suppress_tokens;

bool apply_chat_template = false;

void update_generation_config(const ov::AnyMap& config_map = {});

template <typename... Properties>
Expand Down
1 change: 1 addition & 0 deletions src/cpp/src/generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ void GenerationConfig::update_generation_config(const ov::AnyMap& properties) {
read_anymap_param(properties, "logprobs", logprobs);
read_anymap_param(properties, "num_return_sequences", num_return_sequences);
read_anymap_param(properties, "adapters", adapters);
read_anymap_param(properties, "apply_chat_template", apply_chat_template);

// penalties
read_anymap_param(properties, "frequency_penalty", frequency_penalty);
Expand Down
15 changes: 13 additions & 2 deletions src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,20 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
} else {
input_ids.reserve(prompts.size());
timer.start();
for (const std::string& prompt : prompts) {
for (size_t i = 0; i < prompts.size(); i++) {
const std::string& prompt = prompts.at(i);
const auto encode_start = std::chrono::steady_clock::now();
input_ids.push_back(m_tokenizer.encode(prompt).input_ids);
ov::Tensor encoded_inputs;
if (sampling_params.at(i).apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
ChatHistory history({{{"role", "user"}, {"content", prompt}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
encoded_inputs = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)).input_ids;
} else {
// in case when chat_template was not found in tokenizer_config.json or set
encoded_inputs = m_tokenizer.encode(prompt).input_ids;
}
input_ids.push_back(encoded_inputs);
tokenization_durations.emplace_back(PerfMetrics::get_microsec(std::chrono::steady_clock::now() - encode_start));
}
timer.end();
Expand Down
34 changes: 25 additions & 9 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "text_callback_streamer.hpp"
#include "utils.hpp"

#include "debug_utils.hpp"

namespace ov::genai {

StatefulLLMPipeline::StatefulLLMPipeline(
Expand Down Expand Up @@ -87,14 +89,19 @@ DecodedResults StatefulLLMPipeline::generate(
TokenizedInputs encoded_input;

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
std::vector<std::string> templated_input_vector;
for (auto& input : *input_vector) {
ChatHistory history({{{"role", "user"}, {"content", input}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
templated_input_vector.push_back(templated_prompt);
OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts");
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
std::vector<std::string> templated_input_vector;
for (auto& input : *input_vector) {
ChatHistory history({{{"role", "user"}, {"content", input}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
templated_input_vector.push_back(templated_prompt);
}
encoded_input = m_tokenizer.encode(templated_input_vector, ov::genai::add_special_tokens(false));
} else {
encoded_input = m_tokenizer.encode(*input_vector);
}
encoded_input = m_tokenizer.encode(templated_input_vector, ov::genai::add_special_tokens(false));
} else if (auto input_prompt = std::get_if<std::string>(&inputs)) {
std::string& prompt = *input_prompt;

Expand All @@ -110,7 +117,7 @@ DecodedResults StatefulLLMPipeline::generate(

m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// Do not add special tokens in chat scenario to be aligned with HF.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
Expand Down Expand Up @@ -163,7 +170,16 @@ DecodedResults StatefulLLMPipeline::generate(

// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
} else {
encoded_input = m_tokenizer.encode(prompt);
std::string& prompt = *input_prompt;
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
ChatHistory history({{{"role", "user"}, {"content", prompt}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
encoded_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false));
} else {
// in case when chat_template was not found in tokenizer_config.json or set
encoded_input = m_tokenizer.encode(prompt);
}
}
}

Expand Down
20 changes: 18 additions & 2 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,15 @@ DecodedResults StatefulLLMPipeline::generate(
// for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false));
} else {
tokenized_input = m_tokenizer.encode(prompt);
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
ChatHistory history({{{"role", "user"}, {"content", prompt}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false));
} else {
// in case when chat_template was not found in tokenizer_config.json or set
tokenized_input = m_tokenizer.encode(prompt);
}
}

auto encode_stop_time = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -1294,7 +1302,15 @@ DecodedResults StatelessLLMPipeline::generate(
// for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false));
} else {
tokenized_input = m_tokenizer.encode(prompt);
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
ChatHistory history({{{"role", "user"}, {"content", prompt}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false));
} else {
// in case when chat_template was not found in tokenizer_config.json or set
tokenized_input = m_tokenizer.encode(prompt);
}
}

auto encode_stop_time = std::chrono::steady_clock::now();
Expand Down
Loading

0 comments on commit 5924b23

Please sign in to comment.