Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wenyi5608 greedy sampling #293

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions text_generation/causal_lm/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_subdirectory(../../../thirdparty/openvino_tokenizers/ "${CMAKE_CURRENT_BINAR

add_executable(greedy_causal_lm greedy_causal_lm.cpp)
target_compile_definitions(greedy_causal_lm PRIVATE OPENVINO_TOKENIZERS_PATH=\"$<TARGET_FILE:openvino_tokenizers>\")
target_include_directories(greedy_causal_lm PRIVATE ./)
find_package(OpenVINO REQUIRED COMPONENTS Runtime)
target_link_libraries(greedy_causal_lm PRIVATE openvino::runtime)
set_target_properties(greedy_causal_lm PROPERTIES CXX_STANDARD 17)
Expand Down
11 changes: 8 additions & 3 deletions text_generation/causal_lm/cpp/greedy_causal_lm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <greedy_sampling.hpp>
#include <openvino/openvino.hpp>

namespace {
Expand Down Expand Up @@ -82,9 +83,13 @@ int main(int argc, char* argv[]) try {
lm.get_tensor("beam_idx").set_shape({BATCH_SIZE});
lm.get_tensor("beam_idx").data<int32_t>()[0] = 0;
lm.infer();
int64_t sequence_len = lm.get_tensor("logits").get_shape().at(1) - 1;
size_t vocab_size = lm.get_tensor("logits").get_shape().back();
float* logits = lm.get_tensor("logits").data<float>() + (input_ids.get_size() - 1) * vocab_size;
int64_t out_token = std::max_element(logits, logits + vocab_size) - logits;
float* logits = lm.get_tensor("logits").data<float>() + (sequence_len) * vocab_size;
const int64_t* prompt_data = input_ids.data<const int64_t>();
SamplingParameters parameters{ std::vector<int64_t>{prompt_data, prompt_data + input_ids.get_size()} };
GreedySampling greedy_sampling{ parameters };
int64_t out_token = greedy_sampling.get_out_token(logits, vocab_size);

lm.get_tensor("input_ids").set_shape({BATCH_SIZE, 1});
position_ids.set_shape({BATCH_SIZE, 1});
Expand All @@ -100,7 +105,7 @@ int main(int argc, char* argv[]) try {
text_streamer.put(out_token);
lm.wait();
logits = lm.get_tensor("logits").data<float>();
out_token = std::max_element(logits, logits + vocab_size) - logits;
out_token = greedy_sampling.get_out_token(logits, vocab_size);
}
text_streamer.end();
// Model is stateful which means that context (kv-cache) which belongs to a particular
Expand Down
165 changes: 165 additions & 0 deletions text_generation/causal_lm/cpp/greedy_sampling.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <regex>
#include <random>
#include <iostream>
#include <cmath>
#include <algorithm>
#include <numeric>
#include <vector>

struct TokenIdScore {
int id;
float score;

TokenIdScore() = default;
TokenIdScore(int id, float score) : id(id), score(score) {}

bool operator<(const TokenIdScore& other) const { return score < other.score; }
bool operator>(const TokenIdScore& other) const { return score > other.score; }

friend std::ostream& operator<<(std::ostream& os, const TokenIdScore& self) {
return os << "TokenIdScore(id=" << self.id << ", score=" << self.score << ")";
}
};

void sampling_softmax_inplace(TokenIdScore* first, TokenIdScore* last) {
float max_score = std::max_element(first, last)->score;
float sum = 0.f;
for (TokenIdScore* p = first; p != last; p++) {
float s = std::exp(p->score - max_score);
p->score = s;
sum += s;
}
float inv_sum = 1.f / sum;
for (TokenIdScore* p = first; p != last; p++) {
p->score *= inv_sum;
}
}

void sampling_top_k(TokenIdScore* first, TokenIdScore* kth, TokenIdScore* last) {
std::nth_element(first, kth, last, std::greater<TokenIdScore>());
}

TokenIdScore* sampling_top_p(TokenIdScore* first, TokenIdScore* last, float top_p) {
// fast top_p in expected O(n) time complexity
sampling_softmax_inplace(first, last);

while (first + 1 < last) {
const float pivot_score = (last - 1)->score; // use mid score?
TokenIdScore* mid =
std::partition(first, last - 1, [pivot_score](const TokenIdScore& x) { return x.score > pivot_score; });
std::swap(*mid, *(last - 1));

const float prefix_sum =
std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore& x) { return sum + x.score; });
if (prefix_sum >= top_p) {
last = mid;
}
else if (prefix_sum + mid->score < top_p) {
first = mid + 1;
top_p -= prefix_sum + mid->score;
}
else {
return mid + 1;
}
}
return last;
}

void sampling_repetition_penalty(float* first, float* last, const std::vector<int64_t>& input_ids,
float penalty) {
if (penalty < 0) {
std::cout << "penalty must be a positive float, but got " << penalty;
return;
}
const float inv_penalty = 1.f / penalty;
const int vocab_size = last - first;
std::vector<bool> occurrence(vocab_size, false);
for (const int64_t id : input_ids) {
if (!occurrence[id]) {
first[id] *= (first[id] > 0) ? inv_penalty : penalty;
}
occurrence[id] = true;
}
}

void sampling_temperature(float* first, float* last, float temp) {
const float inv_temp = 1.f / temp;
for (float* it = first; it != last; it++) {
*it *= inv_temp;
}
}

struct SamplingParameters {
std::vector<int64_t> prompt;
int top_k = 0;
float top_p = 0.7;
float temp = 0.95;
float repeat_penalty = 1.1;
bool do_sample = true;
};

// GreedySampling processes logits prduced by a language model and chooses the token with
// the highest probablity as the next token in the sequence. get_out_token() returns token
// ids selected by the algorithm. The value is used for next inference.
struct GreedySampling {
SamplingParameters parameters;
GreedySampling(SamplingParameters parameters) : parameters{ std::move(parameters) } {
}

int64_t get_out_token(float* logits, size_t vocab_size) {
int64_t out_token;
std::vector<int64_t> prompt{ parameters.prompt };

// logits pre-process
if (parameters.repeat_penalty != 1.f) {
sampling_repetition_penalty(logits, logits + vocab_size, prompt, parameters.repeat_penalty);
}

if (parameters.do_sample)
{
if (parameters.temp > 0) {
sampling_temperature(logits, logits + vocab_size, parameters.temp);
}

std::vector<TokenIdScore> token_scores(vocab_size);
for (int i = 0; i < vocab_size; i++) {
token_scores[i] = TokenIdScore(i, logits[i]);
}

// top_k sampling
if (0 < parameters.top_k && parameters.top_k < (int)token_scores.size()) {
sampling_top_k(token_scores.data(), token_scores.data() + parameters.top_k,
token_scores.data() + token_scores.size());
token_scores.resize(parameters.top_k);
}

// top_p sampling
if (0.f < parameters.top_p && parameters.top_p < 1.f) {
auto pos = sampling_top_p(token_scores.data(), token_scores.data() + token_scores.size(), parameters.top_p);
token_scores.resize(pos - token_scores.data());
}

// sample next token
sampling_softmax_inplace(token_scores.data(), token_scores.data() + token_scores.size());
for (size_t i = 0; i < token_scores.size(); i++) {
logits[i] = token_scores[i].score;
Comment on lines +147 to +149
Copy link
Contributor

@apaniukov apaniukov Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You must take the values from logits for the remaining tokens here, not scores, and then do the softmax on them. The softmax already modified the scores in the top_p call, so the second softmax will skew the distribution towards the most probable token.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refer to chatglm cpp(https://github.com/li-plus/chatglm.cpp/blob/main/chatglm.cpp#L825) to implement this function.
Is this implementation different from HuggingFace.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the softmax effects on the logits and the results for the llama model are that the second softmax makes the distribution nearly uniform and the third one makes the distribution exactly uniform. Here is how the probabilities of the most probable token are evolving:

Softmax 0 times: tensor([[ 0.0000,  8.6805,  9.4036,  8.4245, 10.8393]])
Softmax 1 times: tensor([[3.1250e-05, 2.7906e-01, 2.0102e-01, 7.2127e-02, 2.6164e-01]])
Softmax 2 times: tensor([[3.1250e-05, 4.1308e-05, 3.8207e-05, 3.3586e-05, 4.0595e-05]])
Softmax 3 times: tensor([[3.1250e-05, 3.1250e-05, 3.1250e-05, 3.1250e-05, 3.1250e-05]])
For uniform distribution: 3.125e-05

The second softmax makes model predictions useless.

Huggingface doesn't transform the logits during top_p, just calculates softmax and filters the original logits based on the result. So sampling_softmax_inplace is not needed inside the sampling_top_p function.

Code for the softmax probs:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


model_checkpoint = "JackFram/llama-68m"
hf_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
hf_model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
texts = ["This is a test"]
tokenized = hf_tokenizer(texts, return_tensors="pt")

with torch.no_grad():
    t = hf_model(**tokenized).logits
    print("Softmax 0 times:", torch.max(t, dim=-1).values)
    for i in range(1, 4):
        t = torch.nn.functional.softmax(t, dim=-1,)
        print(f"Softmax {i} times:", torch.max(t, dim=-1).values)

print(f"For uniform distribution: {1 / hf_tokenizer.vocab_size}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apaniukov Thanks a lot!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-implemented the sampling_top_p function, the implementation is similar to the implementation of huggingface.

}

thread_local std::random_device rd;
thread_local std::mt19937 gen(rd());

std::discrete_distribution<> dist(logits, logits + token_scores.size());
out_token = token_scores[dist(gen)].id;
}
else {
out_token = std::max_element(logits, logits + vocab_size) - logits;
}

prompt.push_back(out_token);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems tokens accumulation needed for applying sampling_repetition_penalty. This ids needs to preserved between get_out_token calls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that the logits info remains the same before and after get_out_token calls?


return { out_token };
}
};
Loading