Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ai-prompr-decorator): fix unable to modify request #13966

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions kong/llm/plugin/ctx.lua
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ local EMPTY_REQUEST_T = _M.immutable_table({})

function _M.get_request_body_table_inuse()
local request_body_table

if _M.has_namespace("decorate-prompt") then -- has ai-prompt-decorator and others in future
request_body_table = _M.get_namespaced_ctx("decorate-prompt", "request_body_table")
end

if _M.has_namespace("normalize-request") then -- has ai-proxy/ai-proxy-advanced
request_body_table = _M.get_namespaced_ctx("normalize-request", "request_body_table")
end
Expand Down
11 changes: 9 additions & 2 deletions kong/llm/plugin/shared-filters/normalize-request.lua
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,14 @@ local function validate_and_transform(conf)
local model_t = conf_m.model
local model_provider = conf.model.provider -- use the one from conf, not the merged one to avoid potential security risk

local request_table = ai_plugin_ctx.get_namespaced_ctx("parse-request", "request_body_table")
local request_table
if ai_plugin_ctx.has_namespace("decorate-prompt") and
ai_plugin_ctx.get_namespaced_ctx("decorate-prompt", "decorated") then
request_table = ai_plugin_ctx.get_namespaced_ctx("decorate-prompt", "request_body_table")
else
request_table = ai_plugin_ctx.get_namespaced_ctx("parse-request", "request_body_table")
end

if not request_table then
return bail(400, "content-type header does not match request body, or bad JSON formatting")
end
Expand Down Expand Up @@ -219,4 +226,4 @@ function _M:run(conf)
return true
end

return _M
return _M
5 changes: 1 addition & 4 deletions kong/llm/plugin/shared-filters/serialize-analytics.lua
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ function _M:run(conf)
total_tokens = ai_plugin_o11y.metrics_get("llm_total_tokens_count"),
cost = ai_plugin_o11y.metrics_get("llm_usage_cost"),
}

kong.log.inspect(usage)

kong.log.set_serialize_value(string.format("ai.%s.usage", ai_plugin_o11y.NAMESPACE), usage)


Expand All @@ -82,4 +79,4 @@ function _M:run(conf)
return true
end

return _M
return _M
15 changes: 11 additions & 4 deletions kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@

local new_tab = require("table.new")
local ai_plugin_ctx = require("kong.llm.plugin.ctx")
local cycle_aware_deep_copy = require("kong.tools.table").cycle_aware_deep_copy

local _M = {
NAME = "decorate-prompt",
STAGE = "REQ_TRANSFORMATION",
}
}

local FILTER_OUTPUT_SCHEMA = {
decorated = "boolean",
request_body_table = "table",
}

local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_SCHEMA)
Expand All @@ -23,7 +25,7 @@ local EMPTY = {}


local function bad_request(msg)
kong.log.debug(msg)
kong.log.info(msg)
return kong.response.exit(400, { error = { message = msg } })
end

Expand Down Expand Up @@ -77,11 +79,16 @@ function _M:run(conf)
return bad_request("this LLM route only supports llm/chat type requests")
end

kong.service.request.set_body(execute(request_body_table, conf), "application/json")
-- Deep copy to avoid modifying the immutable table.
-- Re-assign it to trigger GC of the old one and save memory.
request_body_table = execute(cycle_aware_deep_copy(request_body_table), conf)

kong.service.request.set_body(request_body_table, "application/json") -- legacy

set_ctx("decorated", true)
set_ctx("request_body_table", request_body_table)

return true
end

return _M
return _M
241 changes: 187 additions & 54 deletions spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua
Original file line number Diff line number Diff line change
@@ -1,23 +1,45 @@
local helpers = require "spec.helpers"
local helpers = require("spec.helpers")
local cjson = require("cjson")


local PLUGIN_NAME = "ai-prompt-decorator"


for _, strategy in helpers.all_strategies() do
local openai_flat_chat = {
messages = {
{
role = "user",
content = "I think that cheddar is the best cheese.",
},
{
role = "assistant",
content = "No, brie is the best cheese.",
},
{
role = "user",
content = "Why brie?",
},
},
}


for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function()
local client

lazy_setup(function()

local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME })
local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME, "ctx-checker-last", "ctx-checker" })


local route1 = bp.routes:insert({
hosts = { "test1.com" },
-- echo route, we don't need a mock AI here
local prepend = bp.routes:insert({
hosts = { "prepend.decorate.local" },
})

bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = route1.id },
route = { id = prepend.id },
config = {
prompts = {
prepend = {
Expand All @@ -30,6 +52,28 @@ for _, strategy in helpers.all_strategies() do
content = "Prepend text 2 here.",
},
},
},
},
}

bp.plugins:insert {
name = "ctx-checker-last",
route = { id = prepend.id },
config = {
ctx_check_field = "ai_namespaced_ctx",
}
}


local append = bp.routes:insert({
hosts = { "append.decorate.local" },
})

bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = append.id },
config = {
prompts = {
append = {
[1] = {
role = "assistant",
Expand All @@ -44,72 +88,161 @@ for _, strategy in helpers.all_strategies() do
},
}

bp.plugins:insert {
name = "ctx-checker-last",
route = { id = append.id },
config = {
ctx_check_field = "ai_namespaced_ctx",
}
}

local both = bp.routes:insert({
hosts = { "both.decorate.local" },
})


bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = both.id },
config = {
prompts = {
prepend = {
[1] = {
role = "system",
content = "Prepend text 1 here.",
},
[2] = {
role = "assistant",
content = "Prepend text 2 here.",
},
},
append = {
[1] = {
role = "assistant",
content = "Append text 3 here.",
},
[2] = {
role = "user",
content = "Append text 4 here.",
},
},
},
},
}

bp.plugins:insert {
name = "ctx-checker-last",
route = { id = both.id },
config = {
ctx_check_field = "ai_namespaced_ctx",
}
}


assert(helpers.start_kong({
database = strategy,
nginx_conf = "spec/fixtures/custom_nginx.template",
plugins = "bundled," .. PLUGIN_NAME,
plugins = "bundled,ctx-checker-last,ctx-checker," .. PLUGIN_NAME,
declarative_config = strategy == "off" and helpers.make_yaml_file() or nil,
}))
end)


lazy_teardown(function()
helpers.stop_kong()
helpers.stop_kong(nil, true)
end)


before_each(function()
client = helpers.proxy_client()
end)


after_each(function()
if client then client:close() end
end)



it("blocks a non-chat message", function()
local r = client:get("/request", {
headers = {
host = "test1.com",
["Content-Type"] = "application/json",
},
body = [[
{
"anything": [
{
"random": "data"
}
]
}]],
method = "POST",
})

assert.response(r).has.status(400)
local json = assert.response(r).has.jsonbody()
assert.same({ error = { message = "this LLM route only supports llm/chat type requests" }}, json)
end)


it("blocks an empty messages array", function()
local r = client:get("/request", {
headers = {
host = "test1.com",
["Content-Type"] = "application/json",
},
body = [[
{
"messages": []
}]],
method = "POST",
})

assert.response(r).has.status(400)
local json = assert.response(r).has.jsonbody()
assert.same({ error = { message = "this LLM route only supports llm/chat type requests" }}, json)
describe("request", function()
it("modifies the LLM chat request - prepend", function()
local r = client:get("/", {
headers = {
host = "prepend.decorate.local",
["Content-Type"] = "application/json"
},
body = cjson.encode(openai_flat_chat),
})

-- get the REQUEST body, that left Kong for the upstream, using the echo system
assert.response(r).has.status(200)
local request = assert.response(r).has.jsonbody()
request = cjson.decode(request.post_data.text)

assert.same({ content = "Prepend text 1 here.", role = "system" }, request.messages[1])
assert.same({ content = "Prepend text 2 here.", role = "system" }, request.messages[2])

-- check ngx.ctx was set properly for later AI chain filters
local ctx = assert.response(r).has.header("ctx-checker-last-ai-namespaced-ctx")
ctx = ngx.unescape_uri(ctx)
assert.match_re(ctx, [[.*decorate-prompt.*]])
assert.match_re(ctx, [[.*decorated = true.*]])
assert.match_re(ctx, [[.*Prepend text 1 here.*]])
assert.match_re(ctx, [[.*Prepend text 2 here.*]])
end)

it("modifies the LLM chat request - append", function()
local r = client:get("/", {
headers = {
host = "append.decorate.local",
["Content-Type"] = "application/json"
},
body = cjson.encode(openai_flat_chat),
})

-- get the REQUEST body, that left Kong for the upstream, using the echo system
assert.response(r).has.status(200)
local request = assert.response(r).has.jsonbody()
request = cjson.decode(request.post_data.text)

assert.same({ content = "Append text 1 here.", role = "assistant" }, request.messages[#request.messages-1])
assert.same({ content = "Append text 2 here.", role = "user" }, request.messages[#request.messages])

-- check ngx.ctx was set properly for later AI chain filters
local ctx = assert.response(r).has.header("ctx-checker-last-ai-namespaced-ctx")
ctx = ngx.unescape_uri(ctx)
assert.match_re(ctx, [[.*decorate-prompt.*]])
assert.match_re(ctx, [[.*decorated = true.*]])
assert.match_re(ctx, [[.*Append text 1 here.*]])
assert.match_re(ctx, [[.*Append text 2 here.*]])
end)


it("modifies the LLM chat request - both", function()
local r = client:get("/", {
headers = {
host = "both.decorate.local",
["Content-Type"] = "application/json"
},
body = cjson.encode(openai_flat_chat),
})

-- get the REQUEST body, that left Kong for the upstream, using the echo system
assert.response(r).has.status(200)
local request = assert.response(r).has.jsonbody()
request = cjson.decode(request.post_data.text)

assert.same({ content = "Prepend text 1 here.", role = "system" }, request.messages[1])
assert.same({ content = "Prepend text 2 here.", role = "assistant" }, request.messages[2])
assert.same({ content = "Append text 3 here.", role = "assistant" }, request.messages[#request.messages-1])
assert.same({ content = "Append text 4 here.", role = "user" }, request.messages[#request.messages])

-- check ngx.ctx was set properly for later AI chain filters
local ctx = assert.response(r).has.header("ctx-checker-last-ai-namespaced-ctx")
ctx = ngx.unescape_uri(ctx)
assert.match_re(ctx, [[.*decorate-prompt.*]])
assert.match_re(ctx, [[.*decorated = true.*]])
assert.match_re(ctx, [[.*Prepend text 1 here.*]])
assert.match_re(ctx, [[.*Prepend text 2 here.*]])
assert.match_re(ctx, [[.*Append text 3 here.*]])
assert.match_re(ctx, [[.*Append text 4 here.*]])
end)
end)

end)

end
end end
Loading