From f8b1069a1cb930804183d86249167d47dd473546 Mon Sep 17 00:00:00 2001 From: chrisbarrera <34655880+chrisbarrera@users.noreply.github.com> Date: Sat, 24 Feb 2024 16:51:34 -0600 Subject: [PATCH] add min_p sampling parameter (#2014) Signed-off-by: Christopher Barrera Co-authored-by: Jared Van Bortel --- gpt4all-backend/llamamodel.cpp | 4 +- gpt4all-backend/llmodel.h | 1 + gpt4all-backend/llmodel_c.cpp | 2 + gpt4all-backend/llmodel_c.h | 1 + gpt4all-bindings/cli/app.py | 2 + .../Gpt4All/Bindings/LLPromptContext.cs | 9 ++++ .../csharp/Gpt4All/Bindings/NativeMethods.cs | 2 + .../Extensions/LLPromptContextExtensions.cs | 1 + .../PredictRequestOptionsExtensions.cs | 1 + .../Prediction/PredictRequestOptions.cs | 2 + gpt4all-bindings/golang/binding.cpp | 4 +- gpt4all-bindings/golang/binding.h | 4 +- gpt4all-bindings/golang/gpt4all.go | 4 +- gpt4all-bindings/golang/options.go | 10 +++- .../java/com/hexadevlabs/gpt4all/LLModel.java | 6 +++ .../hexadevlabs/gpt4all/LLModelLibrary.java | 1 + gpt4all-bindings/python/gpt4all/_pyllmodel.py | 6 +++ gpt4all-bindings/python/gpt4all/gpt4all.py | 3 ++ gpt4all-bindings/typescript/index.cc | 3 ++ gpt4all-chat/chatllm.cpp | 10 +++- gpt4all-chat/chatllm.h | 2 +- gpt4all-chat/main.qml | 1 + gpt4all-chat/modellist.cpp | 22 ++++++++ gpt4all-chat/modellist.h | 6 +++ gpt4all-chat/mysettings.cpp | 23 ++++++++ gpt4all-chat/mysettings.h | 3 ++ gpt4all-chat/qml/ModelSettings.qml | 52 +++++++++++++++++-- gpt4all-chat/server.cpp | 5 ++ 28 files changed, 176 insertions(+), 14 deletions(-) diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 7a2d894a6c23..e03892290ff7 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -64,6 +64,7 @@ static int llama_sample_top_p_top_k( int last_n_tokens_size, int top_k, float top_p, + float min_p, float temp, float repeat_penalty, int32_t pos) { @@ -83,6 +84,7 @@ static int llama_sample_top_p_top_k( llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1); llama_sample_typical(ctx, &candidates_p, 1.0f, 1); llama_sample_top_p(ctx, &candidates_p, top_p, 1); + llama_sample_min_p(ctx, &candidates_p, min_p, 1); llama_sample_temp(ctx, &candidates_p, temp); return llama_sample_token(ctx, &candidates_p); } @@ -392,7 +394,7 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); return llama_sample_top_p_top_k(d_ptr->ctx, promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, - n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.min_p, promptCtx.temp, promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1); } diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 5ccbea08a119..5139110d00ca 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -66,6 +66,7 @@ class LLModel { int32_t n_predict = 200; int32_t top_k = 40; float top_p = 0.9f; + float min_p = 0.0f; float temp = 0.9f; int32_t n_batch = 9; float repeat_penalty = 1.10f; diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index b6306a77d894..da2bc2b4c49b 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -134,6 +134,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt, wrapper->promptContext.n_predict = ctx->n_predict; wrapper->promptContext.top_k = ctx->top_k; wrapper->promptContext.top_p = ctx->top_p; + wrapper->promptContext.min_p = ctx->min_p; wrapper->promptContext.temp = ctx->temp; wrapper->promptContext.n_batch = ctx->n_batch; wrapper->promptContext.repeat_penalty = ctx->repeat_penalty; @@ -156,6 +157,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt, ctx->n_predict = wrapper->promptContext.n_predict; ctx->top_k = wrapper->promptContext.top_k; ctx->top_p = wrapper->promptContext.top_p; + ctx->min_p = wrapper->promptContext.min_p; ctx->temp = wrapper->promptContext.temp; ctx->n_batch = wrapper->promptContext.n_batch; ctx->repeat_penalty = wrapper->promptContext.repeat_penalty; diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index eac4ae9b9666..a19bd0837453 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -39,6 +39,7 @@ struct llmodel_prompt_context { int32_t n_predict; // number of tokens to predict int32_t top_k; // top k logits to sample from float top_p; // nucleus sampling probability threshold + float min_p; // Min P sampling float temp; // temperature to adjust model's output distribution int32_t n_batch; // number of predictions to generate in parallel float repeat_penalty; // penalty factor for repeated tokens diff --git a/gpt4all-bindings/cli/app.py b/gpt4all-bindings/cli/app.py index 083fa1735bbd..e584a318038e 100755 --- a/gpt4all-bindings/cli/app.py +++ b/gpt4all-bindings/cli/app.py @@ -120,6 +120,7 @@ def _old_loop(gpt4all_instance): n_predict=200, top_k=40, top_p=0.9, + min_p=0.0, temp=0.9, n_batch=9, repeat_penalty=1.1, @@ -156,6 +157,7 @@ def _new_loop(gpt4all_instance): temp=0.9, top_k=40, top_p=0.9, + min_p=0.0, repeat_penalty=1.1, repeat_last_n=64, n_batch=9, diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs index cec6948e58c5..002972b22378 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs @@ -64,6 +64,15 @@ public float TopP set => _ctx.top_p = value; } + /// + /// min p sampling probability threshold + /// + public float MinP + { + get => _ctx.min_p; + set => _ctx.min_p = value; + } + /// /// temperature to adjust model's output distribution /// diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs index 7ac955c5166b..2e61d9335a23 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs @@ -29,6 +29,8 @@ public unsafe partial struct llmodel_prompt_context public float top_p; + public float min_p; + public float temp; [NativeTypeName("int32_t")] diff --git a/gpt4all-bindings/csharp/Gpt4All/Extensions/LLPromptContextExtensions.cs b/gpt4all-bindings/csharp/Gpt4All/Extensions/LLPromptContextExtensions.cs index 4426ef495217..5581e4584d4f 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Extensions/LLPromptContextExtensions.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Extensions/LLPromptContextExtensions.cs @@ -16,6 +16,7 @@ public static string Dump(this LLModelPromptContext context) n_predict = {ctx.n_predict} top_k = {ctx.top_k} top_p = {ctx.top_p} + min_p = {ctx.min_p} temp = {ctx.temp} n_batch = {ctx.n_batch} repeat_penalty = {ctx.repeat_penalty} diff --git a/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs b/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs index 48ebd1f1c785..07d1e104e333 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs @@ -12,6 +12,7 @@ public static LLModelPromptContext ToPromptContext(this PredictRequestOptions op TokensSize = opts.TokensSize, TopK = opts.TopK, TopP = opts.TopP, + MinP = opts.MinP, PastNum = opts.PastConversationTokensNum, RepeatPenalty = opts.RepeatPenalty, Temperature = opts.Temperature, diff --git a/gpt4all-bindings/csharp/Gpt4All/Prediction/PredictRequestOptions.cs b/gpt4all-bindings/csharp/Gpt4All/Prediction/PredictRequestOptions.cs index 2f3e57aff592..c151a5b63199 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Prediction/PredictRequestOptions.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Prediction/PredictRequestOptions.cs @@ -16,6 +16,8 @@ public record PredictRequestOptions public float TopP { get; init; } = 0.9f; + public float MinP { get; init; } = 0.0f; + public float Temperature { get; init; } = 0.1f; public int Batches { get; init; } = 8; diff --git a/gpt4all-bindings/golang/binding.cpp b/gpt4all-bindings/golang/binding.cpp index 0026d8658572..de73026247b9 100644 --- a/gpt4all-bindings/golang/binding.cpp +++ b/gpt4all-bindings/golang/binding.cpp @@ -36,7 +36,7 @@ std::string res = ""; void * mm; void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, - float top_p, float temp, int n_batch,float ctx_erase) + float top_p, float min_p, float temp, int n_batch,float ctx_erase) { llmodel_model* model = (llmodel_model*) m; @@ -69,6 +69,7 @@ void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, .n_predict = 50, .top_k = 10, .top_p = 0.9, + .min_p = 0.0, .temp = 1.0, .n_batch = 1, .repeat_penalty = 1.2, @@ -83,6 +84,7 @@ void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, prompt_context->top_k = top_k; prompt_context->context_erase = ctx_erase; prompt_context->top_p = top_p; + prompt_context->min_p = min_p; prompt_context->temp = temp; prompt_context->n_batch = n_batch; diff --git a/gpt4all-bindings/golang/binding.h b/gpt4all-bindings/golang/binding.h index bc203a304440..3a4d3656b6c9 100644 --- a/gpt4all-bindings/golang/binding.h +++ b/gpt4all-bindings/golang/binding.h @@ -7,7 +7,7 @@ extern "C" { void* load_model(const char *fname, int n_threads); void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, - float top_p, float temp, int n_batch,float ctx_erase); + float top_p, float min_p, float temp, int n_batch,float ctx_erase); void free_model(void *state_ptr); @@ -15,4 +15,4 @@ extern unsigned char getTokenCallback(void *, char *); #ifdef __cplusplus } -#endif \ No newline at end of file +#endif diff --git a/gpt4all-bindings/golang/gpt4all.go b/gpt4all-bindings/golang/gpt4all.go index 5a66fa1e2d5a..f97eebf62a88 100644 --- a/gpt4all-bindings/golang/gpt4all.go +++ b/gpt4all-bindings/golang/gpt4all.go @@ -7,7 +7,7 @@ package gpt4all // #cgo LDFLAGS: -lgpt4all -lm -lstdc++ -ldl // void* load_model(const char *fname, int n_threads); // void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, -// float top_p, float temp, int n_batch,float ctx_erase); +// float top_p, float min_p, float temp, int n_batch,float ctx_erase); // void free_model(void *state_ptr); // extern unsigned char getTokenCallback(void *, char *); // void llmodel_set_implementation_search_path(const char *path); @@ -58,7 +58,7 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) { out := make([]byte, po.Tokens) C.model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize), - C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase)) + C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.MinP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase)) res := C.GoString((*C.char)(unsafe.Pointer(&out[0]))) res = strings.TrimPrefix(res, " ") diff --git a/gpt4all-bindings/golang/options.go b/gpt4all-bindings/golang/options.go index d79b1723929e..e2650ca07e7a 100644 --- a/gpt4all-bindings/golang/options.go +++ b/gpt4all-bindings/golang/options.go @@ -2,7 +2,7 @@ package gpt4all type PredictOptions struct { ContextSize, RepeatLastN, Tokens, TopK, Batch int - TopP, Temperature, ContextErase, RepeatPenalty float64 + TopP, MinP, Temperature, ContextErase, RepeatPenalty float64 } type PredictOption func(p *PredictOptions) @@ -11,6 +11,7 @@ var DefaultOptions PredictOptions = PredictOptions{ Tokens: 200, TopK: 10, TopP: 0.90, + MinP: 0.0, Temperature: 0.96, Batch: 1, ContextErase: 0.55, @@ -50,6 +51,13 @@ func SetTopP(topp float64) PredictOption { } } +// SetMinP sets the value for min p sampling +func SetMinP(minp float64) PredictOption { + return func(p *PredictOptions) { + p.MinP = minp + } +} + // SetRepeatPenalty sets the repeat penalty. func SetRepeatPenalty(ce float64) PredictOption { return func(p *PredictOptions) { diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java index 769de02a42b4..6114cfad62fe 100644 --- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java @@ -32,6 +32,7 @@ private GenerationConfig() { n_predict.set(128); top_k.set(40); top_p.set(0.95); + min_p.set(0.0); temp.set(0.28); n_batch.set(8); repeat_penalty.set(1.1); @@ -71,6 +72,11 @@ public Builder withTopP(float top_p) { return this; } + public Builder withMinP(float min_p) { + configToBuild.min_p.set(min_p); + return this; + } + public Builder withTemp(float temp) { configToBuild.temp.set(temp); return this; diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java index 356b6149b82f..d538a080d13d 100644 --- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java @@ -48,6 +48,7 @@ class LLModelPromptContext extends Struct { public final int32_t n_predict = new int32_t(); public final int32_t top_k = new int32_t(); public final Float top_p = new Float(); + public final Float min_p = new Float(); public final Float temp = new Float(); public final int32_t n_batch = new int32_t(); public final Float repeat_penalty = new Float(); diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 9aaa94c10208..fcf7b7335d59 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -49,6 +49,7 @@ class LLModelPromptContext(ctypes.Structure): ("n_predict", ctypes.c_int32), ("top_k", ctypes.c_int32), ("top_p", ctypes.c_float), + ("min_p", ctypes.c_float), ("temp", ctypes.c_float), ("n_batch", ctypes.c_int32), ("repeat_penalty", ctypes.c_float), @@ -241,6 +242,7 @@ def _set_context( n_predict: int = 4096, top_k: int = 40, top_p: float = 0.9, + min_p: float = 0.0, temp: float = 0.1, n_batch: int = 8, repeat_penalty: float = 1.2, @@ -257,6 +259,7 @@ def _set_context( n_predict=n_predict, top_k=top_k, top_p=top_p, + min_p=min_p, temp=temp, n_batch=n_batch, repeat_penalty=repeat_penalty, @@ -272,6 +275,7 @@ def _set_context( self.context.n_predict = n_predict self.context.top_k = top_k self.context.top_p = top_p + self.context.min_p = min_p self.context.temp = temp self.context.n_batch = n_batch self.context.repeat_penalty = repeat_penalty @@ -297,6 +301,7 @@ def prompt_model( n_predict: int = 4096, top_k: int = 40, top_p: float = 0.9, + min_p: float = 0.0, temp: float = 0.1, n_batch: int = 8, repeat_penalty: float = 1.2, @@ -334,6 +339,7 @@ def prompt_model( n_predict=n_predict, top_k=top_k, top_p=top_p, + min_p=min_p, temp=temp, n_batch=n_batch, repeat_penalty=repeat_penalty, diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 82342b28babf..4510ec324c89 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -289,6 +289,7 @@ def generate( temp: float = 0.7, top_k: int = 40, top_p: float = 0.4, + min_p: float = 0.0, repeat_penalty: float = 1.18, repeat_last_n: int = 64, n_batch: int = 8, @@ -305,6 +306,7 @@ def generate( temp: The model temperature. Larger values increase creativity but decrease factuality. top_k: Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding. top_p: Randomly sample at each generation step from the top most likely tokens whose probabilities add up to top_p. + min_p: Randomly sample at each generation step from the top most likely tokens whose probabilities are at least min_p. repeat_penalty: Penalize the model for repetition. Higher values result in less repetition. repeat_last_n: How far in the models generation history to apply the repeat penalty. n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements. @@ -325,6 +327,7 @@ def generate( temp=temp, top_k=top_k, top_p=top_p, + min_p=min_p, repeat_penalty=repeat_penalty, repeat_last_n=repeat_last_n, n_batch=n_batch, diff --git a/gpt4all-bindings/typescript/index.cc b/gpt4all-bindings/typescript/index.cc index d957b453eee5..2d4968c42ff3 100644 --- a/gpt4all-bindings/typescript/index.cc +++ b/gpt4all-bindings/typescript/index.cc @@ -248,6 +248,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info) .n_predict = 128, .top_k = 40, .top_p = 0.9f, + .min_p = 0.0f, .temp = 0.72f, .n_batch = 8, .repeat_penalty = 1.0f, @@ -277,6 +278,8 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info) promptContext.top_k = inputObject.Get("top_k").As().Int32Value(); if(inputObject.Has("top_p")) promptContext.top_p = inputObject.Get("top_p").As().FloatValue(); + if(inputObject.Has("min_p")) + promptContext.min_p = inputObject.Get("min_p").As().FloatValue(); if(inputObject.Has("temp")) promptContext.temp = inputObject.Get("temp").As().FloatValue(); if(inputObject.Has("n_batch")) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index d0c9d33b1f6a..36b6febb15ea 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -568,16 +568,17 @@ bool ChatLLM::prompt(const QList &collectionList, const QString &prompt const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); + const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo); const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo); - return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, temp, n_batch, + return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); } bool ChatLLM::promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, - int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) { if (!isModelLoaded()) @@ -608,6 +609,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_ctx.n_predict = n_predict; m_ctx.top_k = top_k; m_ctx.top_p = top_p; + m_ctx.min_p = min_p; m_ctx.temp = temp; m_ctx.n_batch = n_batch; m_ctx.repeat_penalty = repeat_penalty; @@ -1020,6 +1022,7 @@ void ChatLLM::processSystemPrompt() const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); + const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo); const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); @@ -1028,6 +1031,7 @@ void ChatLLM::processSystemPrompt() m_ctx.n_predict = n_predict; m_ctx.top_k = top_k; m_ctx.top_p = top_p; + m_ctx.min_p = min_p; m_ctx.temp = temp; m_ctx.n_batch = n_batch; m_ctx.repeat_penalty = repeat_penalty; @@ -1067,6 +1071,7 @@ void ChatLLM::processRestoreStateFromText() const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); + const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo); const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); @@ -1075,6 +1080,7 @@ void ChatLLM::processRestoreStateFromText() m_ctx.n_predict = n_predict; m_ctx.top_k = top_k; m_ctx.top_p = top_p; + m_ctx.min_p = min_p; m_ctx.temp = temp; m_ctx.n_batch = n_batch; m_ctx.repeat_penalty = repeat_penalty; diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 278e79cc0b82..da8483389c48 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -139,7 +139,7 @@ public Q_SLOTS: protected: bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, - int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index 70fe6dae9170..54375a28b4d4 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -1380,6 +1380,7 @@ Window { MySettings.maxLength, MySettings.topK, MySettings.topP, + MySettings.minP, MySettings.temperature, MySettings.promptBatchSize, MySettings.repeatPenalty, diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index 58d1b554315b..642b01889cbe 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -60,12 +60,23 @@ double ModelInfo::topP() const return MySettings::globalInstance()->modelTopP(*this); } +double ModelInfo::minP() const +{ + return MySettings::globalInstance()->modelMinP(*this); +} + void ModelInfo::setTopP(double p) { if (isClone) MySettings::globalInstance()->setModelTopP(*this, p, isClone /*force*/); m_topP = p; } +void ModelInfo::setMinP(double p) +{ + if (isClone) MySettings::globalInstance()->setModelMinP(*this, p, isClone /*force*/); + m_minP = p; +} + int ModelInfo::topK() const { return MySettings::globalInstance()->modelTopK(*this); @@ -321,6 +332,7 @@ ModelList::ModelList() connect(MySettings::globalInstance(), &MySettings::nameChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::topPChanged, this, &ModelList::updateDataForSettings); + connect(MySettings::globalInstance(), &MySettings::minPChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings); @@ -571,6 +583,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const return info->temperature(); case TopPRole: return info->topP(); + case MinPRole: + return info->minP(); case TopKRole: return info->topK(); case MaxLengthRole: @@ -700,6 +714,8 @@ void ModelList::updateData(const QString &id, int role, const QVariant &value) info->setTemperature(value.toDouble()); break; case TopPRole: info->setTopP(value.toDouble()); break; + case MinPRole: + info->setMinP(value.toDouble()); break; case TopKRole: info->setTopK(value.toInt()); break; case MaxLengthRole: @@ -797,6 +813,7 @@ QString ModelList::clone(const ModelInfo &model) updateData(id, ModelList::OnlineRole, model.isOnline); updateData(id, ModelList::TemperatureRole, model.temperature()); updateData(id, ModelList::TopPRole, model.topP()); + updateData(id, ModelList::MinPRole, model.minP()); updateData(id, ModelList::TopKRole, model.topK()); updateData(id, ModelList::MaxLengthRole, model.maxLength()); updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize()); @@ -1163,6 +1180,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) updateData(id, ModelList::TemperatureRole, obj["temperature"].toDouble()); if (obj.contains("topP")) updateData(id, ModelList::TopPRole, obj["topP"].toDouble()); + if (obj.contains("minP")) + updateData(id, ModelList::MinPRole, obj["minP"].toDouble()); if (obj.contains("topK")) updateData(id, ModelList::TopKRole, obj["topK"].toInt()); if (obj.contains("maxLength")) @@ -1287,6 +1306,8 @@ void ModelList::updateModelsFromSettings() const double temperature = settings.value(g + "/temperature").toDouble(); Q_ASSERT(settings.contains(g + "/topP")); const double topP = settings.value(g + "/topP").toDouble(); + Q_ASSERT(settings.contains(g + "/minP")); + const double minP = settings.value(g + "/minP").toDouble(); Q_ASSERT(settings.contains(g + "/topK")); const int topK = settings.value(g + "/topK").toInt(); Q_ASSERT(settings.contains(g + "/maxLength")); @@ -1312,6 +1333,7 @@ void ModelList::updateModelsFromSettings() updateData(id, ModelList::FilenameRole, filename); updateData(id, ModelList::TemperatureRole, temperature); updateData(id, ModelList::TopPRole, topP); + updateData(id, ModelList::MinPRole, minP); updateData(id, ModelList::TopKRole, topK); updateData(id, ModelList::MaxLengthRole, maxLength); updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize); diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index 8ffd81639547..003dfe4488b4 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -36,6 +36,7 @@ struct ModelInfo { Q_PROPERTY(bool isClone MEMBER isClone) Q_PROPERTY(double temperature READ temperature WRITE setTemperature) Q_PROPERTY(double topP READ topP WRITE setTopP) + Q_PROPERTY(double minP READ minP WRITE setMinP) Q_PROPERTY(int topK READ topK WRITE setTopK) Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength) Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize) @@ -92,6 +93,8 @@ struct ModelInfo { void setTemperature(double t); double topP() const; void setTopP(double p); + double minP() const; + void setMinP(double p); int topK() const; void setTopK(int k); int maxLength() const; @@ -119,6 +122,7 @@ struct ModelInfo { QString m_filename; double m_temperature = 0.7; double m_topP = 0.4; + double m_minP = 0.0; int m_topK = 40; int m_maxLength = 4096; int m_promptBatchSize = 128; @@ -247,6 +251,7 @@ class ModelList : public QAbstractListModel RepeatPenaltyTokensRole, PromptTemplateRole, SystemPromptRole, + MinPRole, }; QHash roleNames() const override @@ -282,6 +287,7 @@ class ModelList : public QAbstractListModel roles[IsCloneRole] = "isClone"; roles[TemperatureRole] = "temperature"; roles[TopPRole] = "topP"; + roles[MinPRole] = "minP"; roles[TopKRole] = "topK"; roles[MaxLengthRole] = "maxLength"; roles[PromptBatchSizeRole] = "promptBatchSize"; diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index 9e5cdad0ce06..23beda2e93d1 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -87,6 +87,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &model) { setModelTemperature(model, model.m_temperature); setModelTopP(model, model.m_topP); + setModelMinP(model, model.m_minP); setModelTopK(model, model.m_topK);; setModelMaxLength(model, model.m_maxLength); setModelPromptBatchSize(model, model.m_promptBatchSize); @@ -201,6 +202,13 @@ double MySettings::modelTopP(const ModelInfo &m) const return setting.value(QString("model-%1").arg(m.id()) + "/topP", m.m_topP).toDouble(); } +double MySettings::modelMinP(const ModelInfo &m) const +{ + QSettings setting; + setting.sync(); + return setting.value(QString("model-%1").arg(m.id()) + "/minP", m.m_minP).toDouble(); +} + void MySettings::setModelTopP(const ModelInfo &m, double p, bool force) { if (modelTopP(m) == p && !force) @@ -216,6 +224,21 @@ void MySettings::setModelTopP(const ModelInfo &m, double p, bool force) emit topPChanged(m); } +void MySettings::setModelMinP(const ModelInfo &m, double p, bool force) +{ + if (modelMinP(m) == p && !force) + return; + + QSettings setting; + if (m.m_minP == p && !m.isClone) + setting.remove(QString("model-%1").arg(m.id()) + "/minP"); + else + setting.setValue(QString("model-%1").arg(m.id()) + "/minP", p); + setting.sync(); + if (!force) + emit minPChanged(m); +} + int MySettings::modelTopK(const ModelInfo &m) const { QSettings setting; diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index c5019b91c8a4..94f86c32e754 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -47,6 +47,8 @@ class MySettings : public QObject Q_INVOKABLE void setModelTemperature(const ModelInfo &m, double t, bool force = false); double modelTopP(const ModelInfo &m) const; Q_INVOKABLE void setModelTopP(const ModelInfo &m, double p, bool force = false); + double modelMinP(const ModelInfo &m) const; + Q_INVOKABLE void setModelMinP(const ModelInfo &m, double p, bool force = false); int modelTopK(const ModelInfo &m) const; Q_INVOKABLE void setModelTopK(const ModelInfo &m, int k, bool force = false); int modelMaxLength(const ModelInfo &m) const; @@ -119,6 +121,7 @@ class MySettings : public QObject void filenameChanged(const ModelInfo &model); void temperatureChanged(const ModelInfo &model); void topPChanged(const ModelInfo &model); + void minPChanged(const ModelInfo &model); void topKChanged(const ModelInfo &model); void maxLengthChanged(const ModelInfo &model); void promptBatchSizeChanged(const ModelInfo &model); diff --git a/gpt4all-chat/qml/ModelSettings.qml b/gpt4all-chat/qml/ModelSettings.qml index 3fdf49c9f348..9bfdece6c797 100644 --- a/gpt4all-chat/qml/ModelSettings.qml +++ b/gpt4all-chat/qml/ModelSettings.qml @@ -452,6 +452,50 @@ MySettingsTab { Accessible.name: topPLabel.text Accessible.description: ToolTip.text } + MySettingsLabel { + id: minPLabel + text: qsTr("Min P") + Layout.row: 3 + Layout.column: 0 + } + MyTextField { + id: minPField + text: root.currentModelInfo.minP + color: theme.textColor + font.pixelSize: theme.fontSizeLarge + ToolTip.text: qsTr("Sets the minimum relative probability for a token to be considered.") + ToolTip.visible: hovered + Layout.row: 3 + Layout.column: 1 + validator: DoubleValidator { + locale: "C" + } + Connections { + target: MySettings + function onMinPChanged() { + minPField.text = root.currentModelInfo.minP; + } + } + Connections { + target: root + function onCurrentModelInfoChanged() { + minPField.text = root.currentModelInfo.minP; + } + } + onEditingFinished: { + var val = parseFloat(text) + if (!isNaN(val)) { + MySettings.setModelMinP(root.currentModelInfo, val) + focus = false + } else { + text = root.currentModelInfo.minP + } + } + Accessible.role: Accessible.EditableText + Accessible.name: minPLabel.text + Accessible.description: ToolTip.text + } + MySettingsLabel { id: topKLabel visible: !root.currentModelInfo.isOnline @@ -592,8 +636,8 @@ MySettingsTab { id: repeatPenaltyLabel visible: !root.currentModelInfo.isOnline text: qsTr("Repeat Penalty") - Layout.row: 3 - Layout.column: 0 + Layout.row: 4 + Layout.column: 2 } MyTextField { id: repeatPenaltyField @@ -603,8 +647,8 @@ MySettingsTab { font.pixelSize: theme.fontSizeLarge ToolTip.text: qsTr("Amount to penalize repetitiveness of the output") ToolTip.visible: hovered - Layout.row: 3 - Layout.column: 1 + Layout.row: 4 + Layout.column: 3 validator: DoubleValidator { locale: "C" } diff --git a/gpt4all-chat/server.cpp b/gpt4all-chat/server.cpp index 7537cb20525c..b7587966f380 100644 --- a/gpt4all-chat/server.cpp +++ b/gpt4all-chat/server.cpp @@ -205,6 +205,10 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re if (body.contains("top_p")) top_p = body["top_p"].toDouble(); + float min_p = 0.f; + if (body.contains("min_p")) + min_p = body["min_p"].toDouble(); + int n = 1; if (body.contains("n")) n = body["n"].toInt(); @@ -312,6 +316,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re max_tokens /*n_predict*/, top_k, top_p, + min_p, temperature, n_batch, repeat_penalty,