From 4f45191c171c799f836731a2d7ea1465348d5cb9 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Tue, 18 Jun 2024 11:07:10 +0100 Subject: [PATCH] fix(ai-proxy): remove nil checks on model and tuning parameters (cherry picked from commit 192f56ffbe3283e1906f1b3583677bf31792683c) --- kong/llm/drivers/anthropic.lua | 25 ++++++++++------------- kong/llm/drivers/cohere.lua | 33 ++++++++++++------------------- kong/plugins/ai-proxy/handler.lua | 4 ++-- 3 files changed, 25 insertions(+), 37 deletions(-) diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index f873ce454e0..fcc6419d33b 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -93,8 +93,8 @@ local transformers_to = { return nil, nil, err end - messages.temperature = (model.options and model.options.temperature) or request_table.temperature or nil - messages.max_tokens = (model.options and model.options.max_tokens) or request_table.max_tokens or nil + messages.temperature = (model.options and model.options.temperature) or request_table.temperature + messages.max_tokens = (model.options and model.options.max_tokens) or request_table.max_tokens messages.model = model.name or request_table.model messages.stream = request_table.stream or false -- explicitly set this if nil @@ -110,8 +110,8 @@ local transformers_to = { return nil, nil, err end - prompt.temperature = (model.options and model.options.temperature) or request_table.temperature or nil - prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or request_table.max_tokens or nil + prompt.temperature = (model.options and model.options.temperature) or request_table.temperature + prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or request_table.max_tokens prompt.model = model.name or request_table.model prompt.stream = request_table.stream or false -- explicitly set this if nil @@ -151,11 +151,9 @@ local function start_to_event(event_data, model_info) local metadata = { prompt_tokens = meta.usage - and meta.usage.input_tokens - or nil, + and meta.usage.input_tokens, completion_tokens = meta.usage - and meta.usage.output_tokens - or nil, + and meta.usage.output_tokens, model = meta.model, stop_reason = meta.stop_reason, stop_sequence = meta.stop_sequence, @@ -208,14 +206,11 @@ local function handle_stream_event(event_t, model_info, route_type) and event_data.usage then return nil, nil, { prompt_tokens = nil, - completion_tokens = event_data.usage.output_tokens - or nil, + completion_tokens = event_data.usage.output_tokens, stop_reason = event_data.delta - and event_data.delta.stop_reason - or nil, + and event_data.delta.stop_reason, stop_sequence = event_data.delta - and event_data.delta.stop_sequence - or nil, + and event_data.delta.stop_sequence, } else return nil, "message_delta is missing the metadata block", nil @@ -266,7 +261,7 @@ local transformers_from = { prompt_tokens = usage.input_tokens, completion_tokens = usage.output_tokens, total_tokens = usage.input_tokens and usage.output_tokens and - usage.input_tokens + usage.output_tokens or nil, + usage.input_tokens + usage.output_tokens, } else diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index b59f14630d4..b96cbbbc2d4 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -219,18 +219,15 @@ local transformers_from = { local stats = { completion_tokens = response_table.meta and response_table.meta.billed_units - and response_table.meta.billed_units.output_tokens - or nil, + and response_table.meta.billed_units.output_tokens, prompt_tokens = response_table.meta and response_table.meta.billed_units - and response_table.meta.billed_units.input_tokens - or nil, + and response_table.meta.billed_units.input_tokens, total_tokens = response_table.meta and response_table.meta.billed_units - and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens) - or nil, + and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } messages.usage = stats @@ -252,18 +249,15 @@ local transformers_from = { local stats = { completion_tokens = response_table.meta and response_table.meta.billed_units - and response_table.meta.billed_units.output_tokens - or nil, + and response_table.meta.billed_units.output_tokens, prompt_tokens = response_table.meta and response_table.meta.billed_units - and response_table.meta.billed_units.input_tokens - or nil, + and response_table.meta.billed_units.input_tokens, total_tokens = response_table.meta and response_table.meta.billed_units - and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens) - or nil, + and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } messages.usage = stats @@ -271,7 +265,7 @@ local transformers_from = { return nil, "'text' or 'generations' missing from cohere response body" end - + return cjson.encode(messages) end, @@ -299,11 +293,10 @@ local transformers_from = { prompt.id = response_table.id local stats = { - completion_tokens = response_table.meta and response_table.meta.billed_units.output_tokens or nil, - prompt_tokens = response_table.meta and response_table.meta.billed_units.input_tokens or nil, + completion_tokens = response_table.meta and response_table.meta.billed_units.output_tokens, + prompt_tokens = response_table.meta and response_table.meta.billed_units.input_tokens, total_tokens = response_table.meta - and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens) - or nil, + and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } prompt.usage = stats @@ -323,9 +316,9 @@ local transformers_from = { prompt.id = response_table.generation_id local stats = { - completion_tokens = response_table.token_count and response_table.token_count.response_tokens or nil, - prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens or nil, - total_tokens = response_table.token_count and response_table.token_count.total_tokens or nil, + completion_tokens = response_table.token_count and response_table.token_count.response_tokens, + prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens, + total_tokens = response_table.token_count and response_table.token_count.total_tokens, } prompt.usage = stats diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 6ec7c2ed529..35e13fbe8d9 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -335,7 +335,7 @@ function _M:access(conf) -- copy from the user request if present if (not multipart) and (not conf_m.model.name) and (request_table.model) then - if request_table.model ~= cjson.null then + if type(request_table.model) == "string" then conf_m.model.name = request_table.model end elseif multipart then @@ -343,7 +343,7 @@ function _M:access(conf) end -- check that the user isn't trying to override the plugin conf model in the request body - if request_table and request_table.model and type(request_table.model) == "string" then + if request_table and request_table.model and type(request_table.model) == "string" and request_table.model ~= "" then if request_table.model ~= conf_m.model.name then return bad_request("cannot use own model - must be: " .. conf_m.model.name) end