diff --git a/changelog/unreleased/kong/feat-ai-prompt-guard-all-roles.yml b/changelog/unreleased/kong/feat-ai-prompt-guard-all-roles.yml new file mode 100644 index 00000000000..5a1d9ca0cee --- /dev/null +++ b/changelog/unreleased/kong/feat-ai-prompt-guard-all-roles.yml @@ -0,0 +1,3 @@ +message: "**AI-Prompt-Guard**: add `match_all_roles` option to allow match all roles in addition to `user`." +type: feature +scope: Plugin diff --git a/kong/clustering/compat/removed_fields.lua b/kong/clustering/compat/removed_fields.lua index a91b8a6cecd..50ff3fc2080 100644 --- a/kong/clustering/compat/removed_fields.lua +++ b/kong/clustering/compat/removed_fields.lua @@ -174,6 +174,7 @@ return { "max_request_body_size", }, ai_prompt_guard = { + "match_all_roles", "max_request_body_size", }, ai_prompt_template = { diff --git a/kong/plugins/ai-prompt-guard/handler.lua b/kong/plugins/ai-prompt-guard/handler.lua index fcb37f54ed1..b2aab78dbc7 100644 --- a/kong/plugins/ai-prompt-guard/handler.lua +++ b/kong/plugins/ai-prompt-guard/handler.lua @@ -26,6 +26,7 @@ local execute do -- @tparam table request The deserialized JSON body of the request -- @tparam table conf The plugin configuration -- @treturn[1] table The decorated request (same table, content updated) + -- @treturn[2] nil -- @treturn[2] string The error message function execute(request, conf) local collected_prompts @@ -44,7 +45,7 @@ local execute do if type(v.role) ~= "string" then return nil, bad_format_error end - if v.role == "user" then + if v.role == "user" or conf.match_all_roles then if type(v.content) ~= "string" then return nil, bad_format_error end diff --git a/kong/plugins/ai-prompt-guard/schema.lua b/kong/plugins/ai-prompt-guard/schema.lua index 0864696cd29..2629f07154d 100644 --- a/kong/plugins/ai-prompt-guard/schema.lua +++ b/kong/plugins/ai-prompt-guard/schema.lua @@ -36,8 +36,12 @@ return { type = "integer", default = 8 * 1024, gt = 0, - description = "max allowed body size allowed to be introspected",} - }, + description = "max allowed body size allowed to be introspected" } }, + { match_all_roles = { + description = "If true, will match all roles in addition to 'user' role in conversation history.", + type = "boolean", + required = true, + default = false } }, } } } @@ -45,6 +49,10 @@ return { entity_checks = { { at_least_one_of = { "config.allow_patterns", "config.deny_patterns" }, - } + }, + { conditional = { + if_field = "config.match_all_roles", if_match = { eq = true }, + then_field = "config.allow_all_conversation_history", then_match = { eq = false }, + } }, } } diff --git a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua index d5a3c9626c2..b6cd68b9861 100644 --- a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua +++ b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua @@ -649,6 +649,30 @@ describe("CP/DP config compat transformations #" .. strategy, function() -- cleanup admin.plugins:remove({ id = ai_response_transformer.id }) end) + + it("[ai-prompt-guard] sets unsupported match_all_roles to nil or defaults", function() + -- [[ 3.8.x ]] -- + local ai_prompt_guard = admin.plugins:insert { + name = "ai-prompt-guard", + enabled = true, + config = { + allow_patterns = { "a" }, + allow_all_conversation_history = false, + match_all_roles = true, + max_request_body_size = 8192, + }, + } + -- ]] + + local expected = cycle_aware_deep_copy(ai_prompt_guard) + expected.config.match_all_roles = nil + expected.config.max_request_body_size = nil + + do_assert(uuid(), "3.7.0", expected) + + -- cleanup + admin.plugins:remove({ id = ai_prompt_guard.id }) + end) end) describe("www-authenticate header in plugins (realm config)", function() diff --git a/spec/03-plugins/42-ai-prompt-guard/00-config_spec.lua b/spec/03-plugins/42-ai-prompt-guard/00-config_spec.lua index 103ed45840a..7bc8169e157 100644 --- a/spec/03-plugins/42-ai-prompt-guard/00-config_spec.lua +++ b/spec/03-plugins/42-ai-prompt-guard/00-config_spec.lua @@ -84,4 +84,22 @@ describe(PLUGIN_NAME .. ": (schema)", function() assert.same({ config = {allow_patterns = "length must be at most 10" }}, err) end) + it("allow_all_conversation_history needs to be false if match_all_roles is set to true", function() + local config = { + allow_patterns = { "wat" }, + allow_all_conversation_history = true, + match_all_roles = true, + } + + local ok, err = validate(config) + + assert.is_falsy(ok) + assert.not_nil(err) + assert.same({ + ["@entity"] = { + [1] = 'failed conditional validation given value of field \'config.match_all_roles\'' }, + ["config"] = { + ["allow_all_conversation_history"] = 'value must be false' }}, err) + end) + end) diff --git a/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua b/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua index ad72a693ee3..eab961081e6 100644 --- a/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua +++ b/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua @@ -71,6 +71,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() for _, request_type in ipairs({"chat", "completions"}) do + describe(request_type .. " operations", function() it("allows a user request when nothing is set", function() -- deny_pattern in this case should be made to have no effect @@ -82,7 +83,13 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.is_nil(err) end) + -- only chat has history + -- match_all_roles require history for _, has_history in ipairs({false, request_type == "chat" and true or nil}) do + for _, match_all_roles in ipairs({false, has_history and true or nil}) do + + -- we only have user or not user, so testing "assistant" is not necessary + local role = match_all_roles and "system" or "user" describe("conf.allow_patterns is set", function() for _, has_deny_patterns in ipairs({true, false}) do @@ -92,7 +99,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() it("allows a matching user request" .. test_description, function() -- deny_pattern in this case should be made to have no effect - local ctx = create_request(request_type):append_message("user", "pattern") + local ctx = create_request(request_type):append_message(role, "pattern") if has_history then ctx:append_message("user", "no match") @@ -103,6 +110,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() }, deny_patterns = has_deny_patterns and {"deny match"} or nil, allow_all_conversation_history = not has_history, + match_all_roles = match_all_roles, }) assert.is_truthy(ok) @@ -117,7 +125,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() ctx:append_message("user", "no match") else -- if we are ignoring history, actually put a matched message in history to test edge case - ctx:append_message("user", "pattern"):append_message("user", "no match") + ctx:append_message(role, "pattern"):append_message("user", "no match") end local ok, err = access_handler._execute(ctx, { @@ -126,6 +134,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() }, deny_patterns = has_deny_patterns and {"deny match"} or nil, allow_all_conversation_history = not has_history, + match_all_roles = match_all_roles, }) assert.is_falsy(ok) @@ -143,7 +152,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() it("denies a matching user request" .. test_description, function() -- allow_pattern in this case should be made to have no effect - local ctx = create_request(request_type):append_message("user", "pattern") + local ctx = create_request(request_type):append_message(role, "pattern") if has_history then ctx:append_message("user", "no match") @@ -162,13 +171,13 @@ describe(PLUGIN_NAME .. ": (unit)", function() it("allows unmatched user request" .. test_description, function() -- allow_pattern in this case should be made to have no effect - local ctx = create_request(request_type):append_message("user", "allow match") + local ctx = create_request(request_type):append_message(role, "allow match") if has_history then ctx:append_message("user", "no match") else -- if we are ignoring history, actually put a matched message in history to test edge case - ctx:append_message("user", "pattern"):append_message("user", "allow match") + ctx:append_message(role, "pattern"):append_message(role, "allow match") end local ok, err = access_handler._execute(ctx, { @@ -185,6 +194,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() end -- for for _, has_allow_patterns in ipairs({true, false}) do end) + end -- for _, match_all_role in ipairs(false, true)) do end -- for _, has_history in ipairs({true, false}) do end) end -- for _, request_type in ipairs({"chat", "completions"}) do