Skip to content

Commit

Permalink
refactor(pdk): simplify private RL pdk
Browse files Browse the repository at this point in the history
  • Loading branch information
ADD-SP committed Jun 17, 2024
1 parent 9c0aad4 commit 194734c
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 257 deletions.
268 changes: 17 additions & 251 deletions kong/pdk/private/rate_limiting.lua
Original file line number Diff line number Diff line change
@@ -1,86 +1,14 @@
local table_new = require("table.new")
local buffer = require("string.buffer")
local table_clear = require("table.clear")

local type = type
local pairs = pairs
local assert = assert
local tostring = tostring
local resp_header = ngx.header

local tablex_keys = require("pl.tablex").keys

local RL_LIMIT = "RateLimit-Limit"
local RL_REMAINING = "RateLimit-Remaining"
local RL_RESET = "RateLimit-Reset"
local RETRY_AFTER = "Retry-After"


-- determine the number of pre-allocated fields at runtime
local max_fields_n = 4
local buf = buffer.new(64)

local LIMIT_BY = {
second = {
limit = "X-RateLimit-Limit-Second",
remain = "X-RateLimit-Remaining-Second",
limit_segment_0 = "X-",
limit_segment_1 = "RateLimit-Limit-",
limit_segment_3 = "-Second",
remain_segment_0 = "X-",
remain_segment_1 = "RateLimit-Remaining-",
remain_segment_3 = "-Second",
},
minute = {
limit = "X-RateLimit-Limit-Minute",
remain = "X-RateLimit-Remaining-Minute",
limit_segment_0 = "X-",
limit_segment_1 = "RateLimit-Limit-",
limit_segment_3 = "-Minute",
remain_segment_0 = "X-",
remain_segment_1 = "RateLimit-Remaining-",
remain_segment_3 = "-Minute",
},
hour = {
limit = "X-RateLimit-Limit-Hour",
remain = "X-RateLimit-Remaining-Hour",
limit_segment_0 = "X-",
limit_segment_1 = "RateLimit-Limit-",
limit_segment_3 = "-Hour",
remain_segment_0 = "X-",
remain_segment_1 = "RateLimit-Remaining-",
remain_segment_3 = "-Hour",
},
day = {
limit = "X-RateLimit-Limit-Day",
remain = "X-RateLimit-Remaining-Day",
limit_segment_0 = "X-",
limit_segment_1 = "RateLimit-Limit-",
limit_segment_3 = "-Day",
remain_segment_0 = "X-",
remain_segment_1 = "RateLimit-Remaining-",
remain_segment_3 = "-Day",
},
month = {
limit = "X-RateLimit-Limit-Month",
remain = "X-RateLimit-Remaining-Month",
limit_segment_0 = "X-",
limit_segment_1 = "RateLimit-Limit-",
limit_segment_3 = "-Month",
remain_segment_0 = "X-",
remain_segment_1 = "RateLimit-Remaining-",
remain_segment_3 = "-Month",
},
year = {
limit = "X-RateLimit-Limit-Year",
remain = "X-RateLimit-Remaining-Year",
limit_segment_0 = "X-",
limit_segment_1 = "RateLimit-Limit-",
limit_segment_3 = "-Year",
remain_segment_0 = "X-",
remain_segment_1 = "RateLimit-Remaining-",
remain_segment_3 = "-Year",
},
}

local _M = {}

Expand Down Expand Up @@ -114,201 +42,39 @@ local function _get_or_create_rl_ctx(ngx_ctx)
end


function _M.set_basic_limit(ngx_ctx, limit, remaining, reset)
local rl_ctx = _get_or_create_rl_ctx(ngx_ctx or ngx.ctx)

assert(
type(limit) == "number",
"arg #2 `limit` for `set_basic_limit` must be a number"
)
assert(
type(remaining) == "number",
"arg #3 `remaining` for `set_basic_limit` must be a number"
)
function _M.store_response_header(ngx_ctx, key, value)
assert(
type(reset) == "number",
"arg #4 `reset` for `set_basic_limit` must be a number"
type(key) == "string",
"arg #2 `key` for function `store_response_header` must be a string"
)

rl_ctx[RL_LIMIT] = limit
rl_ctx[RL_REMAINING] = remaining
rl_ctx[RL_RESET] = reset
end

function _M.set_retry_after(ngx_ctx, reset)
local rl_ctx = _get_or_create_rl_ctx(ngx_ctx or ngx.ctx)

assert(
type(reset) == "number",
"arg #2 `reset` for `set_retry_after` must be a number"
)

rl_ctx[RETRY_AFTER] = reset
end

function _M.set_limit_by(ngx_ctx, limit_by, limit, remaining)
local rl_ctx = _get_or_create_rl_ctx(ngx_ctx or ngx.ctx)

assert(
type(limit_by) == "string",
"arg #2 `limit_by` for `set_limit_by` must be a string"
)
assert(
type(limit) == "number",
"arg #3 `limit` for `set_limit_by` must be a number"
)
local value_type = type(value)
assert(
type(remaining) == "number",
"arg #4 `remaining` for `set_limit_by` must be a number"
value_type == "string" or value_type == "number",
"arg #3 `value` for function `store_response_header` must be a string or a number"
)

limit_by = LIMIT_BY[limit_by]
assert(limit_by, "invalid limit_by")

rl_ctx[limit_by.limit] = limit
rl_ctx[limit_by.remain] = remaining
local rl_ctx = _get_or_create_rl_ctx(ngx_ctx)
rl_ctx[key] = value
end

function _M.set_limit_by_with_identifier(ngx_ctx, limit_by, limit, remaining, id_seg_1, id_seg_2)
local rl_ctx = _get_or_create_rl_ctx(ngx_ctx or ngx.ctx)

function _M.get_stored_response_header(ngx_ctx, key)
assert(
type(limit_by) == "string",
"arg #2 `limit_by` for `set_limit_by_with_identifier` must be a string"
)
assert(
type(limit) == "number",
"arg #3 `limit` for `set_limit_by_with_identifier` must be a number"
)
assert(
type(remaining) == "number",
"arg #4 `remaining` for `set_limit_by_with_identifier` must be a number"
)

local id_seg_1_typ = type(id_seg_1)
local id_seg_2_typ = type(id_seg_2)
assert(
id_seg_1_typ == "nil" or id_seg_1_typ == "string",
"arg #5 `id_seg_1` for `set_limit_by_with_identifier` must be a string or nil"
)
assert(
id_seg_2_typ == "nil" or id_seg_2_typ == "string",
"arg #6 `id_seg_2` for `set_limit_by_with_identifier` must be a string or nil"
type(key) == "string",
"arg #2 `key` for function `get_stored_response_header` must be a string"
)

limit_by = LIMIT_BY[limit_by]
if not limit_by then
local valid_limit_bys = tablex_keys(LIMIT_BY)
local msg = string.format(
"arg #2 `limit_by` for `set_limit_by_with_identifier` must be one of: %s",
table.concat(valid_limit_bys, ", ")
)
error(msg)
if not _has_rl_ctx(ngx_ctx) then
return nil
end

id_seg_1 = id_seg_1 or ""
id_seg_2 = id_seg_2 or ""

-- construct the key like X-<id_seg_1>-RateLimit-Limit-<id_seg_2>-<limit_by>
local limit_key = buf:reset():put(
limit_by.limit_segment_0,
id_seg_1,
limit_by.limit_segment_1,
id_seg_2,
limit_by.limit_segment_3
):get()

-- construct the key like X-<id_seg_1>-RateLimit-Remaining-<id_seg_2>-<limit_by>
local remain_key = buf:reset():put(
limit_by.remain_segment_0,
id_seg_1,
limit_by.remain_segment_1,
id_seg_2,
limit_by.remain_segment_3
):get()

rl_ctx[limit_key] = limit
rl_ctx[remain_key] = remaining
end

function _M.get_basic_limit(ngx_ctx)
local rl_ctx = _get_rl_ctx(ngx_ctx or ngx.ctx)
return rl_ctx[RL_LIMIT], rl_ctx[RL_REMAINING], rl_ctx[RL_RESET]
end

function _M.get_retry_after(ngx_ctx)
local rl_ctx = _get_rl_ctx(ngx_ctx or ngx.ctx)
return rl_ctx[RETRY_AFTER]
end

function _M.get_limit_by(ngx_ctx, limit_by)
local rl_ctx = _get_rl_ctx(ngx_ctx or ngx.ctx)

assert(
type(limit_by) == "string",
"arg #2 `limit_by` for `get_limit_by` must be a string"
)

limit_by = LIMIT_BY[limit_by]
assert(limit_by, "invalid limit_by")

return rl_ctx[limit_by.limit], rl_ctx[limit_by.remain]
local rl_ctx = _get_rl_ctx(ngx_ctx)
return rl_ctx[key]
end

function _M.get_limit_by_with_identifier(ngx_ctx, limit_by, id_seg_1, id_seg_2)
local rl_ctx = _get_rl_ctx(ngx_ctx or ngx.ctx)

assert(
type(limit_by) == "string",
"arg #2 `limit_by` for `get_limit_by_with_identifier` must be a string"
)

local id_seg_1_typ = type(id_seg_1)
local id_seg_2_typ = type(id_seg_2)
assert(
id_seg_1_typ == "nil" or id_seg_1_typ == "string",
"arg #3 `id_seg_1` for `get_limit_by_with_identifier` must be a string or nil"
)
assert(
id_seg_2_typ == "nil" or id_seg_2_typ == "string",
"arg #4 `id_seg_2` for `get_limit_by_with_identifier` must be a string or nil"
)

limit_by = LIMIT_BY[limit_by]
if not limit_by then
local valid_limit_bys = tablex_keys(LIMIT_BY)
local msg = string.format(
"arg #2 `limit_by` for `get_limit_by_with_identifier` must be one of: %s",
table.concat(valid_limit_bys, ", ")
)
error(msg)
end

id_seg_1 = id_seg_1 or ""
id_seg_2 = id_seg_2 or ""

-- construct the key like X-<id_seg_1>-RateLimit-Limit-<id_seg_2>-<limit_by>
local limit_key = buf:reset():put(
limit_by.limit_segment_0,
id_seg_1,
limit_by.limit_segment_1,
id_seg_2,
limit_by.limit_segment_3
):get()

-- construct the key like X-<id_seg_1>-RateLimit-Remaining-<id_seg_2>-<limit_by>
local remain_key = buf:reset():put(
limit_by.remain_segment_0,
id_seg_1,
limit_by.remain_segment_1,
id_seg_2,
limit_by.remain_segment_3
):get()

return rl_ctx[limit_key], rl_ctx[remain_key]
end

function _M.set_response_headers(ngx_ctx)
function _M.apply_response_headers(ngx_ctx)
if not _has_rl_ctx(ngx_ctx) then
return
end
Expand Down
10 changes: 9 additions & 1 deletion kong/plugins/response-ratelimiting/access.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
local policies = require "kong.plugins.response-ratelimiting.policies"
local timestamp = require "kong.tools.timestamp"
local pdk_private_rl = require "kong.pdk.private.rate_limiting"


local kong = kong
Expand All @@ -9,6 +10,10 @@ local error = error
local tostring = tostring


local pdk_rl_store_response_header = pdk_private_rl.store_response_header
local pdk_rl_apply_response_headers = pdk_private_rl.apply_response_headers


local EMPTY = {}
local HTTP_TOO_MANY_REQUESTS = 429
local RATELIMIT_REMAINING = "X-RateLimit-Remaining"
Expand Down Expand Up @@ -84,6 +89,7 @@ function _M.execute(conf)
end

-- Append usage headers to the upstream request. Also checks "block_on_first_violation".
local ngx_ctx = ngx.ctx
for k in pairs(conf.limits) do
local remaining
for _, lv in pairs(usage[k]) do
Expand All @@ -97,9 +103,11 @@ function _M.execute(conf)
end
end

kong.service.request.set_header(RATELIMIT_REMAINING .. "-" .. k, remaining)
pdk_rl_store_response_header(ngx_ctx, RATELIMIT_REMAINING .. "-" .. k, remaining)
end

pdk_rl_apply_response_headers(ngx_ctx)

kong.ctx.plugin.usage = usage -- For later use
end

Expand Down
17 changes: 12 additions & 5 deletions kong/plugins/response-ratelimiting/header_filter.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ local math_max = math.max

local strip = kong_string.strip
local split = kong_string.split
local pdk_rl_set_response_headers = pdk_private_rl.set_response_headers
local pdk_rl_set_limit_by_with_identifier = pdk_private_rl.set_limit_by_with_identifier
local pdk_rl_store_response_header = pdk_private_rl.store_response_header
local pdk_rl_apply_response_headers = pdk_private_rl.apply_response_headers


local RATELIMIT_LIMIT = "X-RateLimit-Limit"
local RATELIMIT_REMAINING = "X-RateLimit-Remaining"


local function parse_header(header_value, limits)
Expand Down Expand Up @@ -68,8 +72,12 @@ function _M.execute(conf)
for limit_name in pairs(usage) do
for period_name, lv in pairs(usage[limit_name]) do
if not conf.hide_client_headers then
local limit_hdr = RATELIMIT_LIMIT .. "-" .. limit_name .. "-" .. period_name
local remain_hdr = RATELIMIT_REMAINING .. "-" .. limit_name .. "-" .. period_name
local remain = math_max(0, lv.remaining - (increments[limit_name] and increments[limit_name] or 0))
pdk_rl_set_limit_by_with_identifier(ngx_ctx, period_name, lv.limit, remain, nil, limit_name)

pdk_rl_store_response_header(ngx_ctx, limit_hdr, lv.limit)
pdk_rl_store_response_header(ngx_ctx, remain_hdr, remain)
end

if increments[limit_name] and increments[limit_name] > 0 and lv.remaining <= 0 then
Expand All @@ -78,8 +86,7 @@ function _M.execute(conf)
end
end

-- Set rate-limiting response headers
pdk_rl_set_response_headers(ngx_ctx)
pdk_rl_apply_response_headers(ngx_ctx)

kong.response.clear_header(conf.header_name)

Expand Down
Loading

0 comments on commit 194734c

Please sign in to comment.