Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi prompt support for beam search #349

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,27 @@ jobs:
"
echo "你好! 你好嗎?" passed

timeout 1m ./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16/ "Alan Turing was a" "return 0" "你好! 你好嗎?" > ./pred.txt
python -c "
import transformers
with open('pred.txt', 'r') as file:
predictions = file.read()
tokenizer = transformers.LlamaTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0')
prompts = [
'Alan Turing was a',
'return 0',
'你好! 你好嗎?'
]
for prompt in prompts:
tokenized = tokenizer(prompt, return_tensors='pt')
for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False):
ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) + '\n'
idx = predictions.find(ref)
if -1 == idx:
raise RuntimeError(f'Missing "{ref=}" from predictions')
predictions = predictions[:idx] + predictions[idx + len(ref):]
"
echo "Multi prompt" passed
cpp-beam_search_causal_lm-windows:
runs-on: windows-latest
steps:
Expand Down Expand Up @@ -291,7 +312,6 @@ jobs:
source ./ov/setupvars.sh
convert_tokenizer ./Phi-2/pytorch/dldt/FP16/ --output ./Phi-2/pytorch/dldt/FP16/ --with-detokenizer --trust-remote-code
timeout 50s ./build/beam_search_causal_lm ./Phi-2/pytorch/dldt/FP16/ 69 > ./pred.txt

cpp-beam_search_causal_lm-notus-7b-v1:
runs-on: ubuntu-20.04-16-cores
steps:
Expand Down Expand Up @@ -331,7 +351,7 @@ jobs:
- name: Install OpenVINO
run: |
mkdir ./ov/
curl https://storage.openvinotoolkit.org/repositories/openvino/packages/2023.3/linux/l_openvino_toolkit_ubuntu20_2023.3.0.13775.ceeafaf64f3_x86_64.tgz | tar --directory ./ov/ --strip-components 1 -xz
curl https://storage.openvinotoolkit.org/repositories/openvino/packages/nightly/2024.1.0-14645-e6dc0865128/l_openvino_toolkit_ubuntu20_2024.1.0.dev20240304_x86_64.tgz | tar --directory ./ov/ --strip-components 1 -xz
sudo ./ov/install_dependencies/install_openvino_dependencies.sh
- name: Download, convert and build
run: |
Expand Down
209 changes: 170 additions & 39 deletions text_generation/causal_lm/cpp/beam_search_causal_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
#include <openvino/openvino.hpp>

namespace {
std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest& tokenizer, std::string&& prompt) {
constexpr size_t BATCH_SIZE = 1;
tokenizer.set_input_tensor(ov::Tensor{ov::element::string, {BATCH_SIZE}, &prompt});
tokenizer.infer();
return {tokenizer.get_tensor("input_ids"), tokenizer.get_tensor("attention_mask")};
}

enum SPECIAL_TOKEN { PAD_TOKEN = 2 };

std::string detokenize(ov::InferRequest& detokenizer, const std::vector<int64_t>& tokens) {
constexpr size_t BATCH_SIZE = 1;
Expand All @@ -22,52 +18,187 @@ std::string detokenize(ov::InferRequest& detokenizer, const std::vector<int64_t>
detokenizer.infer();
return detokenizer.get_output_tensor().data<std::string>()[0];
}

std::pair<ov::Tensor, ov::Tensor> pad_left(ov::Tensor&& input_ids, ov::Tensor&& attention_mask) {
const size_t batch_size = input_ids.get_shape().at(0);
const size_t sequence_length = input_ids.get_shape().at(1);
int64_t* inputs_data = input_ids.data<int64_t>();
int64_t* attention_mask_data = attention_mask.data<int64_t>();

for (size_t batch = 0; batch < batch_size; batch++) {
const size_t batch_offset = batch * sequence_length;

// last token in the sequence is not a PAD_TOKEN, skipping
if (inputs_data[batch_offset + sequence_length - 1] != SPECIAL_TOKEN::PAD_TOKEN) {
continue;
}

size_t pad_tokens_number = 0;
for (int i = sequence_length - 1; i >= 0; i--) {
const size_t token_offset = batch_offset + i;

if (inputs_data[token_offset] == SPECIAL_TOKEN::PAD_TOKEN) {
continue;
}

if (pad_tokens_number == 0) {
pad_tokens_number = sequence_length - i - 1;
}

std::swap(inputs_data[token_offset], inputs_data[token_offset + pad_tokens_number]);
std::swap(attention_mask_data[token_offset], attention_mask_data[token_offset + pad_tokens_number]);
}
}

return {input_ids, attention_mask};
}

std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest& tokenizer, std::vector<std::string> prompts) {
tokenizer.set_input_tensor(ov::Tensor{ov::element::string, {prompts.size()}, prompts.data()});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note:
currently, we use batched inference with tokenizers and it creates attention mask for us, which we need to "parse" later. Maybe, alternatively, we could tokenize prompt one by one and it will return us raw (unpadded) data, which we can use more optimally to fill position_ids, etc.

I'm not sure which solution is more optimal, so, let's stick to current one, because it's already implemented.


tokenizer.infer();

pad_left(tokenizer.get_tensor("input_ids"), tokenizer.get_tensor("attention_mask"));

// fix mask filled with '2' instead of '0'
ov::Tensor attention_mask = tokenizer.get_tensor("attention_mask");
int64_t* attention_mask_data = attention_mask.data<int64_t>();
std::replace(attention_mask_data, attention_mask_data + attention_mask.get_size(), 2, 0);

return {tokenizer.get_tensor("input_ids"), tokenizer.get_tensor("attention_mask")};
}

void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask) {
const size_t batch_size = attention_mask.get_shape().at(0);
const size_t sequence_length = attention_mask.get_shape().at(1);

const int64_t* attention_mask_data = attention_mask.data<int64_t>();
int64_t* position_ids_data = position_ids.data<int64_t>();

for (size_t batch = 0; batch < batch_size; batch++) {
const size_t batch_offset = batch * sequence_length;
size_t sum = 0;

for (size_t i = 0; i < sequence_length; i++) {
const size_t element_offset = batch_offset + i;
position_ids_data[element_offset] = sum;
if (attention_mask_data[element_offset] == 1) {
sum += 1;
}
}
}
}

void initialize_inputs(const ov::Tensor& input_ids, const ov::Tensor& attention_mask, ov::InferRequest& request) {
request.set_tensor("input_ids", input_ids);
request.set_tensor("attention_mask", attention_mask);

ov::Shape input_shape = input_ids.get_shape();

ov::Tensor position_ids = request.get_tensor("position_ids");
position_ids.set_shape(input_shape);
initialize_position_ids(position_ids, attention_mask);

ov::Tensor beam_idx = request.get_tensor("beam_idx");
beam_idx.set_shape({input_shape.at(0)});
std::fill_n(beam_idx.data<int32_t>(), input_shape.at(0), 0);
}

void set_attention_mask(ov::Tensor&& attention_mask, std::vector<int32_t> next_beams) {
ov::Tensor original_mask{ov::element::i64, attention_mask.get_shape()};
ov::Shape original_shape = original_mask.get_shape();
attention_mask.copy_to(original_mask);

ov::Shape new_shape{next_beams.size(), original_mask.get_shape().at(1) + 1};
attention_mask.set_shape(new_shape);

for (size_t beam_id = 0; beam_id < next_beams.size(); beam_id++) {
const size_t original_prompt_offset = next_beams.at(beam_id) * original_shape.at(1);
const size_t result_prompt_offset = beam_id * new_shape.at(1);

int64_t* dest = attention_mask.data<int64_t>() + result_prompt_offset;
const int64_t* src = original_mask.data<int64_t>() + original_prompt_offset;

std::memcpy(dest, src, original_shape.at(1) * sizeof(int64_t));
attention_mask.data<int64_t>()[result_prompt_offset + new_shape.at(1) - 1] = 1;
}
}

void set_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) {
const size_t batch_size = attention_mask.get_shape().at(0);
const size_t sequence_length = attention_mask.get_shape().at(1);
position_ids.set_shape({batch_size, 1});

for (size_t batch = 0; batch < batch_size; batch++) {
int64_t* mask_start = attention_mask.data<int64_t>() + batch * sequence_length;
position_ids.data<int64_t>()[batch] = std::accumulate(mask_start, mask_start + sequence_length - 1, 0);
}
}

std::vector<std::string> prompts_arguments_to_vector(int argc, char* argv[]) {
std::vector<std::string> prompts;
prompts.reserve(argc - 2);
for (size_t i = 2; i < argc; i++) {
prompts.push_back(std::string{argv[i]});
}
return prompts;
}

} // namespace

int main(int argc, char* argv[]) try {
if (argc != 3) {
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> '<PROMPT>'");
if (argc < 3) {
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> '<PROMPT>'...");
}

// Compile models
ov::Core core;
core.add_extension(OPENVINO_TOKENIZERS_PATH); // OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt
//Read the tokenizer model information from the file to later get the runtime information
// Read the tokenizer model information from the file to later get the runtime information
auto tokenizer_model = core.read_model(std::string{argv[1]} + "/openvino_tokenizer.xml");
// tokenizer and detokenizer work on CPU only
ov::InferRequest tokenizer = core.compile_model(
tokenizer_model, "CPU").create_infer_request();
auto [input_ids, attention_mask] = tokenize(tokenizer, argv[2]);
ov::InferRequest detokenizer = core.compile_model(
std::string{argv[1]} + "/openvino_detokenizer.xml", "CPU").create_infer_request();
ov::InferRequest tokenizer = core.compile_model(tokenizer_model, "CPU").create_infer_request();
ov::InferRequest detokenizer =
core.compile_model(std::string{argv[1]} + "/openvino_detokenizer.xml", "CPU").create_infer_request();
// The model can be compiled for GPU as well
ov::InferRequest lm = core.compile_model(
std::string{argv[1]} + "/openvino_model.xml", "CPU").create_infer_request();
// Initialize inputs
lm.set_tensor("input_ids", input_ids);
lm.set_tensor("attention_mask", attention_mask);
ov::Tensor position_ids = lm.get_tensor("position_ids");
position_ids.set_shape(input_ids.get_shape());
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), 0);
lm.get_tensor("beam_idx").set_shape({1});
lm.get_tensor("beam_idx").data<int32_t>()[0] = 0;
ov::InferRequest lm =
core.compile_model(std::string{argv[1]} + "/openvino_model.xml", "CPU").create_infer_request();

auto [input_ids, attention_mask] = tokenize(tokenizer, prompts_arguments_to_vector(argc, argv));

// Initialize beam search
const int64_t* prompt_data = input_ids.data<const int64_t>();
std::vector<std::vector<int64_t>> prompts;
prompts.reserve(input_ids.get_shape().at(0));
for (size_t batch = 0; batch < input_ids.get_shape().at(0); batch++) {
size_t sequence_length = input_ids.get_shape().at(1);
size_t batch_offset = batch * sequence_length;
const int64_t* prompt_start = prompt_data + batch_offset;
prompts.push_back(std::vector<int64_t>{prompt_start, prompt_start + sequence_length});
}

// Get the runtime info from the tokenizer model that we read earlier
auto rt_info = tokenizer_model->get_rt_info(); //Get the runtime info for the model
auto rt_info = tokenizer_model->get_rt_info(); // Get the runtime info for the model
int64_t SPECIAL_EOS_TOKEN;

if (rt_info.count("eos_token_id") > 0) { //check if the runtime information has a valid EOS token ID
if (rt_info.count("eos_token_id") > 0) { // check if the runtime information has a valid EOS token ID
SPECIAL_EOS_TOKEN = rt_info["eos_token_id"].as<int64_t>();

} else {
throw std::runtime_error("EOS token ID not found in model's runtime information.");
}
const int64_t* prompt_data = input_ids.data<const int64_t>();
Parameters parameters{std::vector<int64_t>{prompt_data, prompt_data + input_ids.get_size()}, SPECIAL_EOS_TOKEN};

Parameters parameters{std::move(prompts), SPECIAL_EOS_TOKEN};
GroupBeamSearcher group_beam_searcher{parameters};

initialize_inputs(input_ids, attention_mask, lm);

std::vector<int64_t> next_tokens;
std::vector<int32_t> next_beams;

for (size_t length_count = 0; length_count < parameters.max_new_tokens; ++length_count) {
lm.infer();

std::tie(next_tokens, next_beams) = group_beam_searcher.select_next_tokens(lm.get_tensor("logits"));
if (next_tokens.empty()) {
break;
Expand All @@ -77,17 +208,17 @@ int main(int argc, char* argv[]) try {
lm.set_tensor("input_ids", ov::Tensor{ov::element::i64, {batch_size, 1}, next_tokens.data()});
lm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {batch_size}, next_beams.data()});
// Set auxiliary inputs
ov::Tensor attention_mask = lm.get_tensor("attention_mask");
ov::Shape mask_shape{batch_size, attention_mask.get_shape().at(1) + 1};
attention_mask.set_shape(mask_shape);
std::fill_n(attention_mask.data<int64_t>(), ov::shape_size(mask_shape), 1);
lm.get_tensor("position_ids").set_shape({batch_size, 1});
std::fill_n(lm.get_tensor("position_ids").data<int64_t>(), batch_size, mask_shape.at(1) - 1);
set_attention_mask(lm.get_tensor("attention_mask"), next_beams);
set_position_ids(lm.get_tensor("position_ids"), lm.get_tensor("attention_mask"));
}
for (const std::vector<Beam>& group : finalize(std::move(group_beam_searcher))) {
std::cout << "Group:\n";
for (const Beam& beam : group) {
std::cout << beam.score << ": " << detokenize(detokenizer, beam.tokens) << '\n';

for (const std::vector<std::vector<Beam>>& prompt_group : finalize(std::move(group_beam_searcher))) {
std::cout << "Prompt:\n";
for (const std::vector<Beam> group : prompt_group) {
std::cout << "Group:\n";
for (const Beam& beam : group) {
std::cout << beam.score << ": " << detokenize(detokenizer, beam.tokens) << '\n';
}
}
}
// Model is stateful which means that context (kv-cache) which belongs to a particular
Expand Down
Loading
Loading