Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Min P support #4

Open
user-0a opened this issue Oct 25, 2024 · 3 comments
Open

Min P support #4

user-0a opened this issue Oct 25, 2024 · 3 comments

Comments

@user-0a
Copy link

user-0a commented Oct 25, 2024

Hello,

Thank you for this incredible project!

I was wondering if it would be possible for you to add min p sampling parameter support, it is a very useful parameter that unfortunately is not supported in the base version of TRT LLM.

@mmoskal
Copy link
Collaborator

mmoskal commented Oct 28, 2024

I assume you actually mean min_p and not top_p ? I initially misread it.

I guess it would be possible by using a custom logits processor. Where do you find it useful? and what are values that you use?

I noticed that OpenAI doesn't support it.

@user-0a
Copy link
Author

user-0a commented Oct 29, 2024

Hi,

Yes, min_p, not top_p.

It is not currently supported by OpenAI, but it is offered by other inference engines, such as vLLM, Aphrodite, and Llama.cpp. Inference providers such as Together.ai also support it. There is also a Github issue requesting support for it in TRT LLM: NVIDIA/TensorRT-LLM#1154.

As for where it is used - it is a common parameter that is used when running models locally. It is particularly useful for long-context chatting on lower parameter count models. If you search "min p" on subreddits such as r/LocalLlama, you will find a lot of discussion around it.

See this reddit thread: https://www.reddit.com/r/LocalLLaMA/comments/17vonjo/your_settings_are_probably_hurting_your_model_why/

Edit: The diagram (which is also included in the TRT LLM issue above) describes how it works
image

@mmoskal
Copy link
Collaborator

mmoskal commented Oct 29, 2024

Just putting a note in case someone wants to take it, the key is the following:

diff --git a/trtllm-c/mask_logits.cu b/trtllm-c/mask_logits.cu
index 6d4339f..9c6a585 100644
--- a/trtllm-c/mask_logits.cu
+++ b/trtllm-c/mask_logits.cu
@@ -62,7 +62,7 @@ __inline__ __device__ void blockReduceMax2(T& val, int& idx, T flt_max)
 
 template <typename T>
 __global__ void mask_logits_kernel(T** logit_ptrs, int64_t* mask_offsets, size_t batch_size, size_t n_vocab,
-    size_t mask_stride, float* temperatures, T flt_max, float* mask_fractions)
+    size_t mask_stride, float* temperatures, float* ln_min_p, T flt_max, float* mask_fractions)
 {
     auto const batch_idx = blockIdx.x;
     auto logits_ptr = logit_ptrs[batch_idx];
@@ -135,6 +135,10 @@ __global__ void mask_logits_kernel(T** logit_ptrs, int64_t* mask_offsets, size_t
             else
             {
                 logit_adjusted = (logit - s_max_val_allowed) * beta;
+                if (logit_adjusted < ln_min_p[batch_idx])
+                {
+                    logit_adjusted = -flt_max;
+                }
             }
         }
 

where ln_min_p should be passed with log_e(min_p) values for each request and -Infinity for ones without min_p. It would also require adding a dummy allow-everything constraint to each request with min_p (if it has no constraint) so the logit processor is run for it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants