diff --git a/kong/router/atc.lua b/kong/router/atc.lua index 492f361de9b0..8d2776386131 100644 --- a/kong/router/atc.lua +++ b/kong/router/atc.lua @@ -35,8 +35,6 @@ local ngx_ERR = ngx.ERR local check_select_params = utils.check_select_params local get_service_info = utils.get_service_info local route_match_stat = utils.route_match_stat -local get_cache_key = fields.get_cache_key -local fill_atc_context = fields.fill_atc_context local DEFAULT_MATCH_LRUCACHE_SIZE = utils.DEFAULT_MATCH_LRUCACHE_SIZE @@ -57,35 +55,6 @@ local CACHED_SCHEMA local HTTP_SCHEMA local STREAM_SCHEMA do - local HTTP_FIELDS = { - - ["String"] = {"net.protocol", "tls.sni", - "http.method", "http.host", - "http.path", - "http.path.segments.*", - "http.headers.*", - "http.queries.*", - }, - - ["Int"] = {"net.src.port", "net.dst.port", - }, - - ["IpAddr"] = {"net.src.ip", "net.dst.ip", - }, - } - - local STREAM_FIELDS = { - - ["String"] = {"net.protocol", "tls.sni", - }, - - ["Int"] = {"net.src.port", "net.dst.port", - }, - - ["IpAddr"] = {"net.src.ip", "net.dst.ip", - }, - } - local function generate_schema(fields) local s = schema.new() @@ -99,8 +68,8 @@ do end -- used by validation - HTTP_SCHEMA = generate_schema(HTTP_FIELDS) - STREAM_SCHEMA = generate_schema(STREAM_FIELDS) + HTTP_SCHEMA = generate_schema(fields.HTTP_FIELDS) + STREAM_SCHEMA = generate_schema(fields.STREAM_FIELDS) -- used by running router CACHED_SCHEMA = is_http and HTTP_SCHEMA or STREAM_SCHEMA @@ -227,14 +196,12 @@ local function new_from_scratch(routes, get_exp_and_priority) yield(true, phase) end - local fields = inst:get_fields() - return setmetatable({ context = context.new(CACHED_SCHEMA), + fields = fields.new(inst:get_fields()), router = inst, routes = routes_t, services = services_t, - fields = fields, updated_at = new_updated_at, rebuilding = false, }, _MT) @@ -318,9 +285,7 @@ local function new_from_previous(routes, get_exp_and_priority, old_router) yield(true, phase) end - local fields = inst:get_fields() - - old_router.fields = fields + old_router.fields = fields.new(inst:get_fields()) old_router.updated_at = new_updated_at old_router.rebuilding = false @@ -436,7 +401,7 @@ function _M:matching(params) self.context:reset() - local c, err = fill_atc_context(self.context, self.fields, params) + local c, err = self.fields:fill_atc_context(self.context, params) if not c then return nil, err @@ -503,6 +468,8 @@ end function _M:exec(ctx) + local fields = self.fields + local req_uri = ctx and ctx.request_uri or var.request_uri local req_host = var.http_host @@ -520,7 +487,7 @@ function _M:exec(ctx) CACHE_PARAMS.uri = req_uri CACHE_PARAMS.host = req_host - local cache_key = get_cache_key(self.fields, CACHE_PARAMS) + local cache_key = fields:get_cache_key(CACHE_PARAMS) -- cache lookup @@ -580,7 +547,7 @@ function _M:matching(params) self.context:reset() - local c, err = fill_atc_context(self.context, self.fields, params) + local c, err = self.fields:fill_atc_context(self.context, params) if not c then return nil, err end @@ -633,6 +600,8 @@ end function _M:exec(ctx) + local fields = self.fields + -- cache key calculation if not CACHE_PARAMS then @@ -642,7 +611,7 @@ function _M:exec(ctx) CACHE_PARAMS:clear() - local cache_key = get_cache_key(self.fields, CACHE_PARAMS, ctx) + local cache_key = fields:get_cache_key(CACHE_PARAMS, ctx) -- cache lookup @@ -681,7 +650,7 @@ function _M:exec(ctx) -- preserve_host logic, modify cache result if match_t.route.preserve_host then - match_t.upstream_host = fields.get_value("tls.sni", CACHE_PARAMS) + match_t.upstream_host = fields:get_value("tls.sni", CACHE_PARAMS) end end diff --git a/kong/router/fields.lua b/kong/router/fields.lua index 21dfc244f14a..3608459f556f 100644 --- a/kong/router/fields.lua +++ b/kong/router/fields.lua @@ -5,6 +5,7 @@ local type = type local ipairs = ipairs local assert = assert local tonumber = tonumber +local setmetatable = setmetatable local tb_sort = table.sort local tb_concat = table.concat local replace_dashes_lower = require("kong.tools.string").replace_dashes_lower @@ -22,6 +23,37 @@ local HTTP_HEADERS_PREFIX = "http.headers." local HTTP_QUERIES_PREFIX = "http.queries." +local HTTP_FIELDS = { + + ["String"] = {"net.protocol", "tls.sni", + "http.method", "http.host", + "http.path", + "http.path.segments.*", + "http.headers.*", + "http.queries.*", + }, + + ["Int"] = {"net.src.port", "net.dst.port", + }, + + ["IpAddr"] = {"net.src.ip", "net.dst.ip", + }, +} + + +local STREAM_FIELDS = { + + ["String"] = {"net.protocol", "tls.sni", + }, + + ["Int"] = {"net.src.port", "net.dst.port", + }, + + ["IpAddr"] = {"net.src.ip", "net.dst.ip", + }, +} + + local FIELDS_FUNCS = { -- http.* @@ -164,6 +196,10 @@ else -- stream end -- is_http +-- stream subsystem need not to generate func +local get_field_accessor = function(funcs, field) end + + if is_http then local fmt = string.format @@ -197,30 +233,54 @@ if is_http then end - setmetatable(FIELDS_FUNCS, { - __index = function(_, field) + get_field_accessor = function(funcs, field) + local f = funcs[field] + if f then + return f + end + local prefix = field:sub(1, PREFIX_LEN) + -- generate for http.headers.* + if prefix == HTTP_HEADERS_PREFIX then - return function(params) + local name = field:sub(PREFIX_LEN + 1) + + f = function(params) if not params.headers then params.headers = get_http_params(get_headers, "headers", "lua_max_req_headers") end - return params.headers[field:sub(PREFIX_LEN + 1)] - end + return params.headers[name] + end -- f - elseif prefix == HTTP_QUERIES_PREFIX then - return function(params) + funcs[field] = f + return f + end -- if prefix == HTTP_HEADERS_PREFIX + + -- generate for http.queries.* + + if prefix == HTTP_QUERIES_PREFIX then + local name = field:sub(PREFIX_LEN + 1) + + f = function(params) if not params.queries then params.queries = get_http_params(get_uri_args, "queries", "lua_max_uri_args") end - return params.queries[field:sub(PREFIX_LEN + 1)] - end + return params.queries[name] + end -- f + + funcs[field] = f + return f + end -- if prefix == HTTP_QUERIES_PREFIX - elseif field:sub(1, HTTP_SEGMENTS_PREFIX_LEN) == HTTP_SEGMENTS_PREFIX then - return function(params) + -- generate for http.path.segments.* + + if field:sub(1, HTTP_SEGMENTS_PREFIX_LEN) == HTTP_SEGMENTS_PREFIX then + local range = field:sub(HTTP_SEGMENTS_PREFIX_LEN + 1) + + f = function(params) if not params.segments then HTTP_SEGMENTS_REG_CTX.pos = 2 -- reset ctx, skip first '/' params.segments = re_split(params.uri, "/", "jo", HTTP_SEGMENTS_REG_CTX) @@ -228,7 +288,6 @@ if is_http then local segments = params.segments - local range = field:sub(HTTP_SEGMENTS_PREFIX_LEN + 1) local value = segments[range] if value then @@ -276,31 +335,47 @@ if is_http then segments[range] = value return value - end + end -- f - end -- if prefix + funcs[field] = f + return f + end -- if field:sub(1, HTTP_SEGMENTS_PREFIX_LEN) -- others return nil end - }) end -- is_http -local function get_value(field, params, ctx) - local func = FIELDS_FUNCS[field] +local _M = {} +local _MT = { __index = _M, } + - if not func then -- unknown field - error("unknown router matching schema field: " .. field) - end -- if func +_M.HTTP_FIELDS = HTTP_FIELDS +_M.STREAM_FIELDS = STREAM_FIELDS + + +function _M.new(fields) + return setmetatable({ + fields = fields, + funcs = {}, + }, _MT) +end + + +function _M:get_value(field, params, ctx) + local func = FIELDS_FUNCS[field] or + get_field_accessor(self.funcs, field) + + assert(func, "unknown router matching schema field: " .. field) return func(params, ctx) end -local function fields_visitor(fields, params, ctx, cb) - for _, field in ipairs(fields) do - local value = get_value(field, params, ctx) +function _M:fields_visitor(params, ctx, cb) + for _, field in ipairs(self.fields) do + local value = self:get_value(field, params, ctx) local res, err = cb(field, value) if not res then @@ -316,11 +391,11 @@ end local str_buf = buffer.new(64) -local function get_cache_key(fields, params, ctx) +function _M:get_cache_key(params, ctx) str_buf:reset() local res = - fields_visitor(fields, params, ctx, function(field, value) + self:fields_visitor(params, ctx, function(field, value) -- these fields were not in cache key if field == "net.protocol" then @@ -361,11 +436,11 @@ local function get_cache_key(fields, params, ctx) end -local function fill_atc_context(context, fields, params) +function _M:fill_atc_context(context, params) local c = context local res, err = - fields_visitor(fields, params, nil, function(field, value) + self:fields_visitor(params, nil, function(field, value) local prefix = field:sub(1, PREFIX_LEN) @@ -404,7 +479,7 @@ local function fill_atc_context(context, fields, params) end -local function _set_ngx(mock_ngx) +function _M._set_ngx(mock_ngx) if mock_ngx.var then var = mock_ngx.var end @@ -425,11 +500,4 @@ local function _set_ngx(mock_ngx) end -return { - get_value = get_value, - - get_cache_key = get_cache_key, - fill_atc_context = fill_atc_context, - - _set_ngx = _set_ngx, -} +return _M