Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport -> release/3.7.x] fix(ai-proxy): Fix Cohere breaks with model parameter in body; Fix OpenAI token counting for function requests; Fix user sending own-model parameter #13230

Merged
merged 12 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog/unreleased/kong/ai-proxy-azure-streaming.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fixed a bug where certain Azure models would return partial tokens/words
when in response-streaming mode.
scope: Plugin
type: bugfix
5 changes: 5 additions & 0 deletions changelog/unreleased/kong/ai-proxy-fix-model-parameter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fixed a bug where Cohere and Anthropic providers don't read the `model` parameter properly
from the caller's request body.
scope: Plugin
type: bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fixed a bug where using "OpenAI Function" inference requests would log a
request error, and then hang until timeout.
scope: Plugin
type: bugfix
5 changes: 5 additions & 0 deletions changelog/unreleased/kong/ai-proxy-fix-sending-own-model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fixed a bug where AI Proxy would still allow callers to specify their own model,
ignoring the plugin-configured model name.
scope: Plugin
type: bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fixed a bug where AI Proxy would not take precedence of the
plugin's configured model tuning options, over those in the user's LLM request.
scope: Plugin
type: bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fixed a bug where setting OpenAI SDK model parameter "null" caused analytics
to not be written to the logging plugin(s).
scope: Plugin
type: bugfix
33 changes: 11 additions & 22 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ local transformers_to = {
return nil, nil, err
end

messages.temperature = request_table.temperature or (model.options and model.options.temperature) or nil
messages.max_tokens = request_table.max_tokens or (model.options and model.options.max_tokens) or nil
messages.temperature = (model.options and model.options.temperature) or request_table.temperature
messages.max_tokens = (model.options and model.options.max_tokens) or request_table.max_tokens
messages.model = model.name or request_table.model
messages.stream = request_table.stream or false -- explicitly set this if nil

Expand All @@ -110,9 +110,8 @@ local transformers_to = {
return nil, nil, err
end

prompt.temperature = request_table.temperature or (model.options and model.options.temperature) or nil
prompt.max_tokens_to_sample = request_table.max_tokens or (model.options and model.options.max_tokens) or nil
prompt.model = model.name
prompt.temperature = (model.options and model.options.temperature) or request_table.temperature
prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or request_table.max_tokens
prompt.model = model.name or request_table.model
prompt.stream = request_table.stream or false -- explicitly set this if nil

Expand Down Expand Up @@ -152,11 +151,9 @@ local function start_to_event(event_data, model_info)

local metadata = {
prompt_tokens = meta.usage
and meta.usage.input_tokens
or nil,
and meta.usage.input_tokens,
completion_tokens = meta.usage
and meta.usage.output_tokens
or nil,
and meta.usage.output_tokens,
model = meta.model,
stop_reason = meta.stop_reason,
stop_sequence = meta.stop_sequence,
Expand Down Expand Up @@ -209,14 +206,11 @@ local function handle_stream_event(event_t, model_info, route_type)
and event_data.usage then
return nil, nil, {
prompt_tokens = nil,
completion_tokens = event_data.usage.output_tokens
or nil,
completion_tokens = event_data.usage.output_tokens,
stop_reason = event_data.delta
and event_data.delta.stop_reason
or nil,
and event_data.delta.stop_reason,
stop_sequence = event_data.delta
and event_data.delta.stop_sequence
or nil,
and event_data.delta.stop_sequence,
}
else
return nil, "message_delta is missing the metadata block", nil
Expand Down Expand Up @@ -267,7 +261,7 @@ local transformers_from = {
prompt_tokens = usage.input_tokens,
completion_tokens = usage.output_tokens,
total_tokens = usage.input_tokens and usage.output_tokens and
usage.input_tokens + usage.output_tokens or nil,
usage.input_tokens + usage.output_tokens,
}

else
Expand Down Expand Up @@ -442,12 +436,7 @@ function _M.post_request(conf)
end

function _M.pre_request(conf, body)
-- check for user trying to bring own model
if body and body.model then
return nil, "cannot use own model for this instance"
end

return true, nil
return true
end

-- returns err or nil
Expand Down
46 changes: 15 additions & 31 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,15 @@ local transformers_from = {
local stats = {
completion_tokens = response_table.meta
and response_table.meta.billed_units
and response_table.meta.billed_units.output_tokens
or nil,
and response_table.meta.billed_units.output_tokens,

prompt_tokens = response_table.meta
and response_table.meta.billed_units
and response_table.meta.billed_units.input_tokens
or nil,
and response_table.meta.billed_units.input_tokens,

total_tokens = response_table.meta
and response_table.meta.billed_units
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens)
or nil,
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens),
}
messages.usage = stats

Expand All @@ -252,26 +249,23 @@ local transformers_from = {
local stats = {
completion_tokens = response_table.meta
and response_table.meta.billed_units
and response_table.meta.billed_units.output_tokens
or nil,
and response_table.meta.billed_units.output_tokens,

prompt_tokens = response_table.meta
and response_table.meta.billed_units
and response_table.meta.billed_units.input_tokens
or nil,
and response_table.meta.billed_units.input_tokens,

total_tokens = response_table.meta
and response_table.meta.billed_units
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens)
or nil,
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens),
}
messages.usage = stats

else -- probably a fault
return nil, "'text' or 'generations' missing from cohere response body"

end

return cjson.encode(messages)
end,

Expand Down Expand Up @@ -299,11 +293,10 @@ local transformers_from = {
prompt.id = response_table.id

local stats = {
completion_tokens = response_table.meta and response_table.meta.billed_units.output_tokens or nil,
prompt_tokens = response_table.meta and response_table.meta.billed_units.input_tokens or nil,
completion_tokens = response_table.meta and response_table.meta.billed_units.output_tokens,
prompt_tokens = response_table.meta and response_table.meta.billed_units.input_tokens,
total_tokens = response_table.meta
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens)
or nil,
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens),
}
prompt.usage = stats

Expand All @@ -323,9 +316,9 @@ local transformers_from = {
prompt.id = response_table.generation_id

local stats = {
completion_tokens = response_table.token_count and response_table.token_count.response_tokens or nil,
prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens or nil,
total_tokens = response_table.token_count and response_table.token_count.total_tokens or nil,
completion_tokens = response_table.token_count and response_table.token_count.response_tokens,
prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens,
total_tokens = response_table.token_count and response_table.token_count.total_tokens,
}
prompt.usage = stats

Expand Down Expand Up @@ -400,12 +393,7 @@ function _M.post_request(conf)
end

function _M.pre_request(conf, body)
-- check for user trying to bring own model
if body and body.model then
return false, "cannot use own model for this instance"
end

return true, nil
return true
end

function _M.subrequest(body, conf, http_opts, return_res_table)
Expand Down Expand Up @@ -467,7 +455,7 @@ end
function _M.configure_request(conf)
local parsed_url

if conf.model.options.upstream_url then
if conf.model.options and conf.model.options.upstream_url then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
Expand All @@ -476,10 +464,6 @@ function _M.configure_request(conf)
or ai_shared.operation_map[DRIVER_NAME][conf.route_type]
and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
or "/"

if not parsed_url.path then
return false, fmt("operation %s is not supported for cohere provider", conf.route_type)
end
end

-- if the path is read from a URL capture, ensure that it is valid
Expand Down
4 changes: 2 additions & 2 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model_info, route_type)
request_table.model = request_table.model or model_info.name
request_table.model = model_info.name or request_table.model
request_table.stream = request_table.stream or false -- explicitly set this
request_table.top_k = nil -- explicitly remove unsupported default

return request_table, "application/json", nil
end,

["llm/v1/completions"] = function(request_table, model_info, route_type)
request_table.model = model_info.name
request_table.model = model_info.name or request_table.model
request_table.stream = request_table.stream or false -- explicitly set this
request_table.top_k = nil -- explicitly remove unsupported default

Expand Down
40 changes: 29 additions & 11 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ _M.clear_response_headers = {
-- @return {string} error if any is thrown - request should definitely be terminated if this is not nil
function _M.merge_config_defaults(request, options, request_format)
if options then
request.temperature = request.temperature or options.temperature
request.max_tokens = request.max_tokens or options.max_tokens
request.top_p = request.top_p or options.top_p
request.top_k = request.top_k or options.top_k
request.temperature = options.temperature or request.temperature
request.max_tokens = options.max_tokens or request.max_tokens
request.top_p = options.top_p or request.top_p
request.top_k = options.top_k or request.top_k
end

return request, nil
Expand Down Expand Up @@ -197,28 +197,44 @@ end
function _M.frame_to_events(frame)
local events = {}

-- todo check if it's raw json and
-- Cohere / Other flat-JSON format parser
-- just return the split up data frame
if string.sub(str_ltrim(frame), 1, 1) == "{" then
if (not kong or not kong.ctx.plugin.truncated_frame) and string.sub(str_ltrim(frame), 1, 1) == "{" then
for event in frame:gmatch("[^\r\n]+") do
events[#events + 1] = {
data = event,
}
end
else
-- standard SSE parser
local event_lines = split(frame, "\n")
local struct = { event = nil, id = nil, data = nil }

for _, dat in ipairs(event_lines) do
for i, dat in ipairs(event_lines) do
if #dat < 1 then
events[#events + 1] = struct
struct = { event = nil, id = nil, data = nil }
end

-- test for truncated chunk on the last line (no trailing \r\n\r\n)
if #dat > 0 and #event_lines == i then
ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head")
kong.ctx.plugin.truncated_frame = dat
break -- stop parsing immediately, server has done something wrong
end

-- test for abnormal start-of-frame (truncation tail)
if kong and kong.ctx.plugin.truncated_frame then
-- this is the tail of a previous incomplete chunk
ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail")
dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat)
kong.ctx.plugin.truncated_frame = nil
end

local s1, _ = str_find(dat, ":") -- find where the cut point is

if s1 and s1 ~= 1 then
local field = str_sub(dat, 1, s1-1) -- returns "data " from data: hello world
local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world
local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world

-- for now not checking if the value is already been set
Expand Down Expand Up @@ -249,7 +265,7 @@ function _M.to_ollama(request_table, model)

-- common parameters
input.stream = request_table.stream or false -- for future capability
input.model = model.name
input.model = model.name or request_table.name

if model.options then
input.options = {}
Expand Down Expand Up @@ -603,8 +619,10 @@ end
-- Function to count the number of words in a string
local function count_words(str)
local count = 0
for word in str:gmatch("%S+") do
count = count + 1
if type(str) == "string" then
for word in str:gmatch("%S+") do
count = count + 1
end
end
return count
end
Expand Down
15 changes: 12 additions & 3 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ local function get_token_text(event_t)
-- - event_t.choices[1].delta.content
-- - event_t.choices[1].text
-- - ""
return (first_choice.delta or EMPTY).content or first_choice.text or ""
local token_text = (first_choice.delta or EMPTY).content or first_choice.text or ""
return (type(token_text) == "string" and token_text) or ""
end


Expand Down Expand Up @@ -334,17 +335,25 @@ function _M:access(conf)

-- copy from the user request if present
if (not multipart) and (not conf_m.model.name) and (request_table.model) then
conf_m.model.name = request_table.model
if type(request_table.model) == "string" then
conf_m.model.name = request_table.model
end
elseif multipart then
conf_m.model.name = "NOT_SPECIFIED"
end

-- check that the user isn't trying to override the plugin conf model in the request body
if request_table and request_table.model and type(request_table.model) == "string" and request_table.model ~= "" then
if request_table.model ~= conf_m.model.name then
return bad_request("cannot use own model - must be: " .. conf_m.model.name)
end
end

-- model is stashed in the copied plugin conf, for consistency in transformation functions
if not conf_m.model.name then
return bad_request("model parameter not found in request, nor in gateway configuration")
end

-- stash for analytics later
kong_ctx_plugin.llm_model_requested = conf_m.model.name

-- check the incoming format is the same as the configured LLM format
Expand Down
6 changes: 3 additions & 3 deletions spec/03-plugins/38-ai-proxy/01-unit_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()
SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS,
{
max_tokens = 1024,
top_p = 1.0,
top_p = 0.5,
},
"llm/v1/chat"
)
Expand All @@ -638,9 +638,9 @@ describe(PLUGIN_NAME .. ": (unit)", function()

assert.is_nil(err)
assert.same({
max_tokens = 256,
max_tokens = 1024,
temperature = 0.1,
top_p = 0.2,
top_p = 0.5,
some_extra_param = "string_val",
another_extra_param = 0.5,
}, formatted)
Expand Down
Loading
Loading