diff --git a/changelog/unreleased/kong/fix-ai-semantic-cache-model.yml b/changelog/unreleased/kong/fix-ai-semantic-cache-model.yml new file mode 100644 index 000000000000..4b2eb99a5d8f --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-semantic-cache-model.yml @@ -0,0 +1,4 @@ +message: "Fixed an bug that AI semantic cache can't use request provided models" +type: bugfix +scope: Plugin + diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index cc19a1f9c7e7..f408b671b633 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -709,7 +709,7 @@ function _M.post_request(conf, response_object) -- Set the model, response, and provider names in the current try context request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name - request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = kong.ctx.plugin.llm_model_requested or conf.model.name + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = llm_state.get_request_model() request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name -- Set the llm latency meta, and time per token usage diff --git a/kong/llm/proxy/handler.lua b/kong/llm/proxy/handler.lua index 40ef85634b78..1ae9e1885ec6 100644 --- a/kong/llm/proxy/handler.lua +++ b/kong/llm/proxy/handler.lua @@ -269,7 +269,7 @@ function _M:header_filter(conf) end if ngx.var.http_kong_debug or conf.model_name_header then - local name = conf.model.provider .. "/" .. (kong.ctx.plugin.llm_model_requested or conf.model.name) + local name = conf.model.provider .. "/" .. (llm_state.get_request_model()) kong.response.set_header("X-Kong-LLM-Model", name) end @@ -386,7 +386,7 @@ function _M:access(conf) return bail(400, "model parameter not found in request, nor in gateway configuration") end - kong_ctx_plugin.llm_model_requested = conf_m.model.name + llm_state.set_request_model(conf_m.model.name) -- check the incoming format is the same as the configured LLM format local compatible, err = llm.is_compatible(request_table, route_type) diff --git a/kong/llm/state.lua b/kong/llm/state.lua index 1ba0eb52e748..35ab807c7402 100644 --- a/kong/llm/state.lua +++ b/kong/llm/state.lua @@ -104,4 +104,12 @@ function _M.get_metrics(key) return (kong.ctx.shared.llm_metrics or {})[key] end +function _M.set_request_model(model) + kong.ctx.shared.llm_model_requested = model +end + +function _M.get_request_model() + return kong.ctx.shared.llm_model_requested or "NOT_SPECIFIED" +end + return _M diff --git a/kong/plugins/ai-request-transformer/handler.lua b/kong/plugins/ai-request-transformer/handler.lua index dd4325183d45..6a22a6d8297e 100644 --- a/kong/plugins/ai-request-transformer/handler.lua +++ b/kong/plugins/ai-request-transformer/handler.lua @@ -46,6 +46,7 @@ local function create_http_opts(conf) end function _M:access(conf) + llm_state.set_request_model(conf.llm.model and conf.llm.model.name) local kong_ctx_shared = kong.ctx.shared kong.service.request.enable_buffering() diff --git a/kong/plugins/ai-response-transformer/handler.lua b/kong/plugins/ai-response-transformer/handler.lua index 872b8ea924f4..d119f98610c5 100644 --- a/kong/plugins/ai-response-transformer/handler.lua +++ b/kong/plugins/ai-response-transformer/handler.lua @@ -105,6 +105,7 @@ end function _M:access(conf) + llm_state.set_request_model(conf.llm.model and conf.llm.model.name) local kong_ctx_shared = kong.ctx.shared kong.service.request.enable_buffering() diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index 3e2e98829d2d..d0017dd96c2e 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -67,7 +67,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local client lazy_setup(function() - local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME }) + local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME, "ctx-checker-last", "ctx-checker" }) -- set up openai mock fixtures local fixtures = { @@ -274,6 +274,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then path = FILE_LOG_PATH_STATS_ONLY, }, } + bp.plugins:insert { + name = "ctx-checker-last", + route = { id = chat_good.id }, + config = { + ctx_kind = "kong.ctx.shared", + ctx_check_field = "llm_model_requested", + ctx_check_value = "gpt-3.5-turbo", + } + } -- 200 chat good with one option local chat_good_no_allow_override = assert(bp.routes:insert { @@ -544,8 +553,8 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } -- - -- 200 chat good but no model set - local chat_good = assert(bp.routes:insert { + -- 200 chat good but no model set in plugin config + local chat_good_no_model = assert(bp.routes:insert { service = empty_service, protocols = { "http" }, strip_path = true, @@ -553,7 +562,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }) bp.plugins:insert { name = PLUGIN_NAME, - route = { id = chat_good.id }, + route = { id = chat_good_no_model.id }, config = { route_type = "llm/v1/chat", auth = { @@ -572,11 +581,20 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } bp.plugins:insert { name = "file-log", - route = { id = chat_good.id }, + route = { id = chat_good_no_model.id }, config = { path = "/dev/stdout", }, } + bp.plugins:insert { + name = "ctx-checker-last", + route = { id = chat_good_no_model.id }, + config = { + ctx_kind = "kong.ctx.shared", + ctx_check_field = "llm_model_requested", + ctx_check_value = "try-to-override-the-model", + } + } -- -- 200 completions good using post body key @@ -755,7 +773,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } - -- + -- start kong assert(helpers.start_kong({ @@ -764,7 +782,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then -- use the custom test template to create a local mock server nginx_conf = "spec/fixtures/custom_nginx.template", -- make sure our plugin gets loaded - plugins = "bundled," .. PLUGIN_NAME, + plugins = "bundled,ctx-checker-last,ctx-checker," .. PLUGIN_NAME, -- write & load declarative config, only if 'strategy=off' declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) @@ -835,6 +853,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.same(first_expected, first_got) assert.is_true(actual_llm_latency >= 0) assert.same(actual_time_per_token, time_per_token) + assert.same(first_got.meta.request_model, "gpt-3.5-turbo") end) it("does not log statistics", function() @@ -1030,6 +1049,9 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then content = "The sum of 1 + 1 is 2.", role = "assistant", }, json.choices[1].message) + + -- from ctx-checker-last plugin + assert.equals(r.headers["ctx-checker-last-llm-model-requested"], "gpt-3.5-turbo") end) it("good request, parses model of cjson.null", function() @@ -1110,6 +1132,38 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_truthy(json.error) assert.equals(json.error.message, "request format not recognised") end) + + -- check that kong.ctx.shared.llm_model_requested is set + it("good request setting model from client body", function() + local r = client:get("/openai/llm/v1/chat/good-no-model-param", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "gpt-3.5-turbo-0613") + assert.equals(json.object, "chat.completion") + assert.equals(r.headers["X-Kong-LLM-Model"], "openai/try-to-override-the-model") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + + -- from ctx-checker-last plugin + assert.equals(r.headers["ctx-checker-last-llm-model-requested"], "try-to-override-the-model") + end) + end) describe("openai llm/v1/completions", function()