Skip to content

Commit

Permalink
perf(router/atc): lazy generate and cache field visit functions (#12378)
Browse files Browse the repository at this point in the history
  • Loading branch information
chronolaw committed Jan 24, 2024
1 parent ce4d6af commit f848e65
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 80 deletions.
57 changes: 13 additions & 44 deletions kong/router/atc.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ local ngx_ERR = ngx.ERR
local check_select_params = utils.check_select_params
local get_service_info = utils.get_service_info
local route_match_stat = utils.route_match_stat
local get_cache_key = fields.get_cache_key
local fill_atc_context = fields.fill_atc_context


local DEFAULT_MATCH_LRUCACHE_SIZE = utils.DEFAULT_MATCH_LRUCACHE_SIZE
Expand All @@ -57,35 +55,6 @@ local CACHED_SCHEMA
local HTTP_SCHEMA
local STREAM_SCHEMA
do
local HTTP_FIELDS = {

["String"] = {"net.protocol", "tls.sni",
"http.method", "http.host",
"http.path",
"http.path.segments.*",
"http.headers.*",
"http.queries.*",
},

["Int"] = {"net.src.port", "net.dst.port",
},

["IpAddr"] = {"net.src.ip", "net.dst.ip",
},
}

local STREAM_FIELDS = {

["String"] = {"net.protocol", "tls.sni",
},

["Int"] = {"net.src.port", "net.dst.port",
},

["IpAddr"] = {"net.src.ip", "net.dst.ip",
},
}

local function generate_schema(fields)
local s = schema.new()

Expand All @@ -99,8 +68,8 @@ do
end

-- used by validation
HTTP_SCHEMA = generate_schema(HTTP_FIELDS)
STREAM_SCHEMA = generate_schema(STREAM_FIELDS)
HTTP_SCHEMA = generate_schema(fields.HTTP_FIELDS)
STREAM_SCHEMA = generate_schema(fields.STREAM_FIELDS)

-- used by running router
CACHED_SCHEMA = is_http and HTTP_SCHEMA or STREAM_SCHEMA
Expand Down Expand Up @@ -227,14 +196,12 @@ local function new_from_scratch(routes, get_exp_and_priority)
yield(true, phase)
end

local fields = inst:get_fields()

return setmetatable({
context = context.new(CACHED_SCHEMA),
fields = fields.new(inst:get_fields()),
router = inst,
routes = routes_t,
services = services_t,
fields = fields,
updated_at = new_updated_at,
rebuilding = false,
}, _MT)
Expand Down Expand Up @@ -318,9 +285,7 @@ local function new_from_previous(routes, get_exp_and_priority, old_router)
yield(true, phase)
end

local fields = inst:get_fields()

old_router.fields = fields
old_router.fields = fields.new(inst:get_fields())
old_router.updated_at = new_updated_at
old_router.rebuilding = false

Expand Down Expand Up @@ -436,7 +401,7 @@ function _M:matching(params)

self.context:reset()

local c, err = fill_atc_context(self.context, self.fields, params)
local c, err = self.fields:fill_atc_context(self.context, params)

if not c then
return nil, err
Expand Down Expand Up @@ -503,6 +468,8 @@ end


function _M:exec(ctx)
local fields = self.fields

local req_uri = ctx and ctx.request_uri or var.request_uri
local req_host = var.http_host

Expand All @@ -520,7 +487,7 @@ function _M:exec(ctx)
CACHE_PARAMS.uri = req_uri
CACHE_PARAMS.host = req_host

local cache_key = get_cache_key(self.fields, CACHE_PARAMS)
local cache_key = fields:get_cache_key(CACHE_PARAMS)

-- cache lookup

Expand Down Expand Up @@ -580,7 +547,7 @@ function _M:matching(params)

self.context:reset()

local c, err = fill_atc_context(self.context, self.fields, params)
local c, err = self.fields:fill_atc_context(self.context, params)
if not c then
return nil, err
end
Expand Down Expand Up @@ -633,6 +600,8 @@ end


function _M:exec(ctx)
local fields = self.fields

-- cache key calculation

if not CACHE_PARAMS then
Expand All @@ -642,7 +611,7 @@ function _M:exec(ctx)

CACHE_PARAMS:clear()

local cache_key = get_cache_key(self.fields, CACHE_PARAMS, ctx)
local cache_key = fields:get_cache_key(CACHE_PARAMS, ctx)

-- cache lookup

Expand Down Expand Up @@ -681,7 +650,7 @@ function _M:exec(ctx)

-- preserve_host logic, modify cache result
if match_t.route.preserve_host then
match_t.upstream_host = fields.get_value("tls.sni", CACHE_PARAMS)
match_t.upstream_host = fields:get_value("tls.sni", CACHE_PARAMS)
end
end

Expand Down
140 changes: 104 additions & 36 deletions kong/router/fields.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ local type = type
local ipairs = ipairs
local assert = assert
local tonumber = tonumber
local setmetatable = setmetatable
local tb_sort = table.sort
local tb_concat = table.concat
local replace_dashes_lower = require("kong.tools.string").replace_dashes_lower
Expand All @@ -22,6 +23,37 @@ local HTTP_HEADERS_PREFIX = "http.headers."
local HTTP_QUERIES_PREFIX = "http.queries."


local HTTP_FIELDS = {

["String"] = {"net.protocol", "tls.sni",
"http.method", "http.host",
"http.path",
"http.path.segments.*",
"http.headers.*",
"http.queries.*",
},

["Int"] = {"net.src.port", "net.dst.port",
},

["IpAddr"] = {"net.src.ip", "net.dst.ip",
},
}


local STREAM_FIELDS = {

["String"] = {"net.protocol", "tls.sni",
},

["Int"] = {"net.src.port", "net.dst.port",
},

["IpAddr"] = {"net.src.ip", "net.dst.ip",
},
}


local FIELDS_FUNCS = {
-- http.*

Expand Down Expand Up @@ -164,6 +196,10 @@ else -- stream
end -- is_http


-- stream subsystem need not to generate func
local get_field_accessor = function(funcs, field) end


if is_http then

local fmt = string.format
Expand Down Expand Up @@ -197,38 +233,61 @@ if is_http then
end


setmetatable(FIELDS_FUNCS, {
__index = function(_, field)
get_field_accessor = function(funcs, field)
local f = funcs[field]
if f then
return f
end

local prefix = field:sub(1, PREFIX_LEN)

-- generate for http.headers.*

if prefix == HTTP_HEADERS_PREFIX then
return function(params)
local name = field:sub(PREFIX_LEN + 1)

f = function(params)
if not params.headers then
params.headers = get_http_params(get_headers, "headers", "lua_max_req_headers")
end

return params.headers[field:sub(PREFIX_LEN + 1)]
end
return params.headers[name]
end -- f

elseif prefix == HTTP_QUERIES_PREFIX then
return function(params)
funcs[field] = f
return f
end -- if prefix == HTTP_HEADERS_PREFIX

-- generate for http.queries.*

if prefix == HTTP_QUERIES_PREFIX then
local name = field:sub(PREFIX_LEN + 1)

f = function(params)
if not params.queries then
params.queries = get_http_params(get_uri_args, "queries", "lua_max_uri_args")
end

return params.queries[field:sub(PREFIX_LEN + 1)]
end
return params.queries[name]
end -- f

funcs[field] = f
return f
end -- if prefix == HTTP_QUERIES_PREFIX

elseif field:sub(1, HTTP_SEGMENTS_PREFIX_LEN) == HTTP_SEGMENTS_PREFIX then
return function(params)
-- generate for http.path.segments.*

if field:sub(1, HTTP_SEGMENTS_PREFIX_LEN) == HTTP_SEGMENTS_PREFIX then
local range = field:sub(HTTP_SEGMENTS_PREFIX_LEN + 1)

f = function(params)
if not params.segments then
HTTP_SEGMENTS_REG_CTX.pos = 2 -- reset ctx, skip first '/'
params.segments = re_split(params.uri, "/", "jo", HTTP_SEGMENTS_REG_CTX)
end

local segments = params.segments

local range = field:sub(HTTP_SEGMENTS_PREFIX_LEN + 1)
local value = segments[range]

if value then
Expand Down Expand Up @@ -276,31 +335,47 @@ if is_http then
segments[range] = value

return value
end
end -- f

end -- if prefix
funcs[field] = f
return f
end -- if field:sub(1, HTTP_SEGMENTS_PREFIX_LEN)

-- others return nil
end
})

end -- is_http


local function get_value(field, params, ctx)
local func = FIELDS_FUNCS[field]
local _M = {}
local _MT = { __index = _M, }


if not func then -- unknown field
error("unknown router matching schema field: " .. field)
end -- if func
_M.HTTP_FIELDS = HTTP_FIELDS
_M.STREAM_FIELDS = STREAM_FIELDS


function _M.new(fields)
return setmetatable({
fields = fields,
funcs = {},
}, _MT)
end


function _M:get_value(field, params, ctx)
local func = FIELDS_FUNCS[field] or
get_field_accessor(self.funcs, field)

assert(func, "unknown router matching schema field: " .. field)

return func(params, ctx)
end


local function fields_visitor(fields, params, ctx, cb)
for _, field in ipairs(fields) do
local value = get_value(field, params, ctx)
function _M:fields_visitor(params, ctx, cb)
for _, field in ipairs(self.fields) do
local value = self:get_value(field, params, ctx)

local res, err = cb(field, value)
if not res then
Expand All @@ -316,11 +391,11 @@ end
local str_buf = buffer.new(64)


local function get_cache_key(fields, params, ctx)
function _M:get_cache_key(params, ctx)
str_buf:reset()

local res =
fields_visitor(fields, params, ctx, function(field, value)
self:fields_visitor(params, ctx, function(field, value)

-- these fields were not in cache key
if field == "net.protocol" then
Expand Down Expand Up @@ -361,11 +436,11 @@ local function get_cache_key(fields, params, ctx)
end


local function fill_atc_context(context, fields, params)
function _M:fill_atc_context(context, params)
local c = context

local res, err =
fields_visitor(fields, params, nil, function(field, value)
self:fields_visitor(params, nil, function(field, value)

local prefix = field:sub(1, PREFIX_LEN)

Expand Down Expand Up @@ -404,7 +479,7 @@ local function fill_atc_context(context, fields, params)
end


local function _set_ngx(mock_ngx)
function _M._set_ngx(mock_ngx)
if mock_ngx.var then
var = mock_ngx.var
end
Expand All @@ -425,11 +500,4 @@ local function _set_ngx(mock_ngx)
end


return {
get_value = get_value,

get_cache_key = get_cache_key,
fill_atc_context = fill_atc_context,

_set_ngx = _set_ngx,
}
return _M

0 comments on commit f848e65

Please sign in to comment.