From 8e3a66511533768137edbcc613aa45a020028380 Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Fri, 7 Jun 2024 14:58:39 +0800 Subject: [PATCH] fix(ai-prompt-guard): fix an issue when `allow_all_conversation_history` is set to false, the first user request is selected instead of the last one --- .../kong/fix-ai-prompt-guard-order.yml | 3 + kong/plugins/ai-prompt-guard/handler.lua | 43 +- .../42-ai-prompt-guard/01-unit_spec.lua | 386 ++++++++---------- 3 files changed, 189 insertions(+), 243 deletions(-) create mode 100644 changelog/unreleased/kong/fix-ai-prompt-guard-order.yml diff --git a/changelog/unreleased/kong/fix-ai-prompt-guard-order.yml b/changelog/unreleased/kong/fix-ai-prompt-guard-order.yml new file mode 100644 index 00000000000..a6bfdfab9ae --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-prompt-guard-order.yml @@ -0,0 +1,3 @@ +message: "**AI-Prompt-Guard**: Fixed an issue when `allow_all_conversation_history` is set to false, the first user request is selected instead of the last one." +type: bugfix +scope: Plugin diff --git a/kong/plugins/ai-prompt-guard/handler.lua b/kong/plugins/ai-prompt-guard/handler.lua index 304b9f55e45..fcb37f54ed1 100644 --- a/kong/plugins/ai-prompt-guard/handler.lua +++ b/kong/plugins/ai-prompt-guard/handler.lua @@ -26,16 +26,21 @@ local execute do -- @tparam table request The deserialized JSON body of the request -- @tparam table conf The plugin configuration -- @treturn[1] table The decorated request (same table, content updated) - -- @treturn[2] nil -- @treturn[2] string The error message function execute(request, conf) - local user_prompt + local collected_prompts + local messages = request.messages - -- concat all 'user' prompts into one string, if conversation history must be checked - if type(request.messages) == "table" and not conf.allow_all_conversation_history then + -- concat all prompts into one string, if conversation history must be checked + if type(messages) == "table" then local buf = buffer.new() + -- Note allow_all_conversation_history means ignores history + local just_pick_latest = conf.allow_all_conversation_history - for _, v in ipairs(request.messages) do + -- iterate in reverse so we get the latest user prompt first + -- instead of the oldest one in history + for i=#messages, 1, -1 do + local v = messages[i] if type(v.role) ~= "string" then return nil, bad_format_error end @@ -44,33 +49,25 @@ local execute do return nil, bad_format_error end buf:put(v.content) - end - end - - user_prompt = buf:get() - elseif type(request.messages) == "table" then - -- just take the trailing 'user' prompt - for _, v in ipairs(request.messages) do - if type(v.role) ~= "string" then - return nil, bad_format_error - end - if v.role == "user" then - if type(v.content) ~= "string" then - return nil, bad_format_error + if just_pick_latest then + break end - user_prompt = v.content + + buf:put(" ") -- put a seperator to avoid adhension of words end end + collected_prompts = buf:get() + elseif type(request.prompt) == "string" then - user_prompt = request.prompt + collected_prompts = request.prompt else return nil, bad_format_error end - if not user_prompt then + if not collected_prompts then return nil, "no 'prompt' or 'messages' received" end @@ -78,7 +75,7 @@ local execute do -- check the prompt for explcit ban patterns for _, v in ipairs(conf.deny_patterns or EMPTY) do -- check each denylist; if prompt matches it, deny immediately - local m, _, err = ngx_re_find(user_prompt, v, "jo") + local m, _, err = ngx_re_find(collected_prompts, v, "jo") if err then -- regex failed, that's an error by the administrator kong.log.err("bad regex pattern '", v ,"', failed to execute: ", err) @@ -98,7 +95,7 @@ local execute do -- if any allow_patterns specified, make sure the prompt matches one of them for _, v in ipairs(conf.allow_patterns or EMPTY) do -- check each denylist; if prompt matches it, deny immediately - local m, _, err = ngx_re_find(user_prompt, v, "jo") + local m, _, err = ngx_re_find(collected_prompts, v, "jo") if err then -- regex failed, that's an error by the administrator diff --git a/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua b/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua index 9007376fcf0..ad72a693ee3 100644 --- a/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua +++ b/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua @@ -1,119 +1,57 @@ local PLUGIN_NAME = "ai-prompt-guard" - - -local general_chat_request = { - messages = { - [1] = { - role = "system", - content = "You are a mathematician." - }, - [2] = { - role = "user", - content = "What is 1 + 1?" - }, - }, -} - -local general_chat_request_with_history = { - messages = { - [1] = { - role = "system", - content = "You are a mathematician." - }, - [2] = { - role = "user", - content = "What is 12 + 1?" - }, - [3] = { - role = "assistant", - content = "The answer is 13.", - }, - [4] = { - role = "user", - content = "Now double the previous answer.", - }, - }, +local message_fixtures = { + user = "this is a user request", + system = "this is a system message", + assistant = "this is an assistant reply", } -local denied_chat_request = { - messages = { - [1] = { +local _M = {} +local function create_request(typ) + local messages = { + { role = "system", - content = "You are a mathematician." - }, - [2] = { - role = "user", - content = "What is 22 + 1?" - }, - }, -} - -local neither_allowed_nor_denied_chat_request = { - messages = { - [1] = { - role = "system", - content = "You are a mathematician." - }, - [2] = { - role = "user", - content = "What is 55 + 55?" - }, - }, -} - - -local general_completions_request = { - prompt = "You are a mathematician. What is 1 + 1?" -} - - -local denied_completions_request = { - prompt = "You are a mathematician. What is 22 + 1?" -} - -local neither_allowed_nor_denied_completions_request = { - prompt = "You are a mathematician. What is 55 + 55?" -} - -local allow_patterns_no_history = { - allow_patterns = { - [1] = ".*1 \\+ 1.*" - }, - allow_all_conversation_history = true, -} - -local allow_patterns_with_history = { - allow_patterns = { - [1] = ".*1 \\+ 1.*" - }, - allow_all_conversation_history = false, -} - -local deny_patterns_with_history = { - deny_patterns = { - [1] = ".*12 \\+ 1.*" - }, - allow_all_conversation_history = false, -} - -local deny_patterns_no_history = { - deny_patterns = { - [1] = ".*22 \\+ 1.*" - }, - allow_all_conversation_history = true, -} - -local both_patterns_no_history = { - allow_patterns = { - [1] = ".*1 \\+ 1.*" - }, - deny_patterns = { - [1] = ".*99 \\+ 99.*" - }, - allow_all_conversation_history = true, -} - + content = message_fixtures.system, + } + } + + if typ ~= "chat" and typ ~= "completions" then + error("type must be one of 'chat' or 'completions'", 2) + end + + return setmetatable({ + messages = messages, + type = typ, + }, { + __index = _M, + }) +end + +function _M:append_message(role, custom) + if not message_fixtures[role] then + assert("role must be one of: user, system or assistant") + end + + if self.type == "completion" then + self.prompt = "this is a completions request" + if custom then + self.prompt = self.prompt .. " with custom content " .. custom + end + return + end + + local message = message_fixtures[role] + if custom then + message = message .. " with custom content " .. custom + end + + self.messages[#self.messages+1] = { + role = "user", + content = message + } + + return self +end describe(PLUGIN_NAME .. ": (unit)", function() @@ -132,115 +70,123 @@ describe(PLUGIN_NAME .. ": (unit)", function() - describe("chat operations", function() - - it("allows request when only conf.allow_patterns is set", function() - local ok, err = access_handler._execute(general_chat_request, allow_patterns_no_history) - - assert.is_truthy(ok) - assert.is_nil(err) - end) - - - it("allows request when only conf.deny_patterns is set, and pattern should not match", function() - local ok, err = access_handler._execute(general_chat_request, deny_patterns_no_history) - - assert.is_truthy(ok) - assert.is_nil(err) - end) - - - it("denies request when only conf.allow_patterns is set, and pattern should not match", function() - local ok, err = access_handler._execute(denied_chat_request, allow_patterns_no_history) - - assert.is_falsy(ok) - assert.equal(err, "prompt doesn't match any allowed pattern") - end) - - - it("denies request when only conf.deny_patterns is set, and pattern should match", function() - local ok, err = access_handler._execute(denied_chat_request, deny_patterns_no_history) - - assert.is_falsy(ok) - assert.equal(err, "prompt pattern is blocked") - end) - - - it("allows request when both conf.allow_patterns and conf.deny_patterns are set, and pattern matches allow", function() - local ok, err = access_handler._execute(general_chat_request, both_patterns_no_history) - - assert.is_truthy(ok) - assert.is_nil(err) - end) - - - it("denies request when both conf.allow_patterns and conf.deny_patterns are set, and pattern matches neither", function() - local ok, err = access_handler._execute(neither_allowed_nor_denied_chat_request, both_patterns_no_history) - - assert.is_falsy(ok) - assert.equal(err, "prompt doesn't match any allowed pattern") - end) - - - it("denies request when only conf.allow_patterns is set and previous chat history should not match", function() - local ok, err = access_handler._execute(general_chat_request_with_history, allow_patterns_with_history) - - assert.is_falsy(ok) - assert.equal(err, "prompt doesn't match any allowed pattern") - end) - - - it("denies request when only conf.deny_patterns is set and previous chat history should match", function() - local ok, err = access_handler._execute(general_chat_request_with_history, deny_patterns_with_history) - - assert.is_falsy(ok) - assert.equal(err, "prompt pattern is blocked") - end) - - end) - - - describe("completions operations", function() - - it("allows request when only conf.allow_patterns is set", function() - local ok, err = access_handler._execute(general_completions_request, allow_patterns_no_history) - - assert.is_truthy(ok) - assert.is_nil(err) - end) - - - it("allows request when only conf.deny_patterns is set, and pattern should not match", function() - local ok, err = access_handler._execute(general_completions_request, deny_patterns_no_history) - - assert.is_truthy(ok) - assert.is_nil(err) - end) - - - it("denies request when only conf.allow_patterns is set, and pattern should not match", function() - local ok, err = access_handler._execute(denied_completions_request, allow_patterns_no_history) - - assert.is_falsy(ok) - assert.equal(err, "prompt doesn't match any allowed pattern") + for _, request_type in ipairs({"chat", "completions"}) do + describe(request_type .. " operations", function() + it("allows a user request when nothing is set", function() + -- deny_pattern in this case should be made to have no effect + local ctx = create_request(request_type):append_message("user", "pattern") + local ok, err = access_handler._execute(ctx, { + }) + + assert.is_truthy(ok) + assert.is_nil(err) + end) + + for _, has_history in ipairs({false, request_type == "chat" and true or nil}) do + + describe("conf.allow_patterns is set", function() + for _, has_deny_patterns in ipairs({true, false}) do + + local test_description = has_history and " in history" or " only the last" + test_description = test_description .. (has_deny_patterns and ", conf.deny_patterns is also set" or "") + + it("allows a matching user request" .. test_description, function() + -- deny_pattern in this case should be made to have no effect + local ctx = create_request(request_type):append_message("user", "pattern") + + if has_history then + ctx:append_message("user", "no match") + end + local ok, err = access_handler._execute(ctx, { + allow_patterns = { + "pa..ern" + }, + deny_patterns = has_deny_patterns and {"deny match"} or nil, + allow_all_conversation_history = not has_history, + }) + + assert.is_truthy(ok) + assert.is_nil(err) + end) + + it("denies an unmatched user request" .. test_description, function() + -- deny_pattern in this case should be made to have no effect + local ctx = create_request(request_type):append_message("user", "no match") + + if has_history then + ctx:append_message("user", "no match") + else + -- if we are ignoring history, actually put a matched message in history to test edge case + ctx:append_message("user", "pattern"):append_message("user", "no match") + end + + local ok, err = access_handler._execute(ctx, { + allow_patterns = { + "pa..ern" + }, + deny_patterns = has_deny_patterns and {"deny match"} or nil, + allow_all_conversation_history = not has_history, + }) + + assert.is_falsy(ok) + assert.equal("prompt doesn't match any allowed pattern", err) + end) + + end -- for _, has_deny_patterns in ipairs({true, false}) do + end) + + describe("conf.deny_patterns is set", function() + for _, has_allow_patterns in ipairs({true, false}) do + + local test_description = has_history and " in history" or " only the last" + test_description = test_description .. (has_allow_patterns and ", conf.allow_patterns is also set" or "") + + it("denies a matching user request" .. test_description, function() + -- allow_pattern in this case should be made to have no effect + local ctx = create_request(request_type):append_message("user", "pattern") + + if has_history then + ctx:append_message("user", "no match") + end + local ok, err = access_handler._execute(ctx, { + deny_patterns = { + "pa..ern" + }, + allow_patterns = has_allow_patterns and {"allow match"} or nil, + allow_all_conversation_history = not has_history, + }) + + assert.is_falsy(ok) + assert.equal("prompt pattern is blocked", err) + end) + + it("allows unmatched user request" .. test_description, function() + -- allow_pattern in this case should be made to have no effect + local ctx = create_request(request_type):append_message("user", "allow match") + + if has_history then + ctx:append_message("user", "no match") + else + -- if we are ignoring history, actually put a matched message in history to test edge case + ctx:append_message("user", "pattern"):append_message("user", "allow match") + end + + local ok, err = access_handler._execute(ctx, { + deny_patterns = { + "pa..ern" + }, + allow_patterns = has_allow_patterns and {"allow match"} or nil, + allow_all_conversation_history = not has_history, + }) + + assert.is_truthy(ok) + assert.is_nil(err) + end) + end -- for for _, has_allow_patterns in ipairs({true, false}) do + end) + + end -- for _, has_history in ipairs({true, false}) do end) - - - it("denies request when only conf.deny_patterns is set, and pattern should match", function() - local ok, err = access_handler._execute(denied_completions_request, deny_patterns_no_history) - - assert.is_falsy(ok) - assert.equal("prompt pattern is blocked", err) - end) - - - it("denies request when both conf.allow_patterns and conf.deny_patterns are set, and pattern matches neither", function() - local ok, err = access_handler._execute(neither_allowed_nor_denied_completions_request, both_patterns_no_history) - - assert.is_falsy(ok) - assert.equal(err, "prompt doesn't match any allowed pattern") - end) - - end) + end -- for _, request_type in ipairs({"chat", "completions"}) do end)