From 082be0c46f2487edee3103f5f1eccd5354816342 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Mon, 10 Jun 2024 02:27:05 +0100 Subject: [PATCH] fix(ai-proxy): invalid precedence on model tuning params --- kong/llm/drivers/anthropic.lua | 9 ++++----- kong/llm/drivers/openai.lua | 4 ++-- kong/llm/drivers/shared.lua | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index c9b46525fb53..f873ce454e0a 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -93,8 +93,8 @@ local transformers_to = { return nil, nil, err end - messages.temperature = request_table.temperature or (model.options and model.options.temperature) or nil - messages.max_tokens = request_table.max_tokens or (model.options and model.options.max_tokens) or nil + messages.temperature = (model.options and model.options.temperature) or request_table.temperature or nil + messages.max_tokens = (model.options and model.options.max_tokens) or request_table.max_tokens or nil messages.model = model.name or request_table.model messages.stream = request_table.stream or false -- explicitly set this if nil @@ -110,9 +110,8 @@ local transformers_to = { return nil, nil, err end - prompt.temperature = request_table.temperature or (model.options and model.options.temperature) or nil - prompt.max_tokens_to_sample = request_table.max_tokens or (model.options and model.options.max_tokens) or nil - prompt.model = model.name + prompt.temperature = (model.options and model.options.temperature) or request_table.temperature or nil + prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or request_table.max_tokens or nil prompt.model = model.name or request_table.model prompt.stream = request_table.stream or false -- explicitly set this if nil diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index b08f29bc3255..1c592e5ef60b 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -18,7 +18,7 @@ end local transformers_to = { ["llm/v1/chat"] = function(request_table, model_info, route_type) - request_table.model = request_table.model or model_info.name + request_table.model = model_info.name or request_table.model request_table.stream = request_table.stream or false -- explicitly set this request_table.top_k = nil -- explicitly remove unsupported default @@ -26,7 +26,7 @@ local transformers_to = { end, ["llm/v1/completions"] = function(request_table, model_info, route_type) - request_table.model = model_info.name + request_table.model = model_info.name or request_table.model request_table.stream = request_table.stream or false -- explicitly set this request_table.top_k = nil -- explicitly remove unsupported default diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 0e60f89ae4f0..b41841ef0fd1 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -265,7 +265,7 @@ function _M.to_ollama(request_table, model) -- common parameters input.stream = request_table.stream or false -- for future capability - input.model = model.name + input.model = model.name or request_table.name if model.options then input.options = {}