diff --git a/conf/app.json b/conf/app.example.json similarity index 61% rename from conf/app.json rename to conf/app.example.json index e234b94..0c04924 100644 --- a/conf/app.json +++ b/conf/app.example.json @@ -1,17 +1,27 @@ { "appName": "codo-gateway", - "env": "prod", + "env": "test", "etcd": { "http_host": "http://127.0.0.1:2379", - "data_prefix": "/codo/gw/" + "data_prefix": "/my/gw/" }, "jwt_auth": { "key": "auth_key", - "token_secret": "yxLxQ0jGtmAhRT4RMMxRrXf" + "token_secret": "xxxxxx" }, "codo_rbac": { "key": "auth_key", - "token_secret": "yxLxQ0jGtmAhRT4RMMxRrXf" + "token_secret": "xxxxxx" + }, + "sso2internal": { + "sso_token_secret": "xxxxxx", + "sso_jwt_key": "sso_token", + "internal_token_secret": "xxxxxx", + "internal_jwt_key": "auth_key" + }, + "mfa": { + "mfa_secret": "xxxxxx", + "mfa_key": "mfa_key" }, "plugins": [ "default_plugin", @@ -25,35 +35,38 @@ "limit-req", "referer-restriction", "ip-restriction", - "cors_plugin" + "cors_plugin", + "sso2internal" ], "plugins_config": { "redis-logger": { "host": "127.0.0.1", "port": 6379, - "auth_pwd": "PuiJUnyKCSLhhRT4RMMxR", - "db": 11, + "auth_pwd": "1234567", + "db": 1, "alive_time": 604800, "channel": "gw", "full_log": "no" } }, "admin": { - "jwt_secret": "yxLxQ0jGtmAhRT4RMMxRrXf", + "jwt_secret": "xxxxxx", "account": { "admin": { "password": "tainiubile", "info": { - "roles": ["admin"], + "roles": [ + "admin" + ], "introduction": "I am a super administrator", - "avatar": "https://wpimg.wallstcn.com/f778738c-e4f8-4870-b634-56703b4acafe.gif", + "avatar": "https://xxx.com/1.gif", "name": "管理员" } } } }, "tokens": { - "e09d6153f1c15395144794GtmAhRT4": { + "xxx": { "desc": "系统默认 api token" } } diff --git a/conf/common/proxy.conf b/conf/common/proxy.conf index c1fc777..714b87c 100644 --- a/conf/common/proxy.conf +++ b/conf/common/proxy.conf @@ -1,3 +1,4 @@ +resolver 8.8.8.8; proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "upgrade"; diff --git a/conf/nginx.conf b/conf/nginx.conf index 0b18b4d..97d163c 100644 --- a/conf/nginx.conf +++ b/conf/nginx.conf @@ -2,12 +2,12 @@ pid logs/nginx.pid; # 开发时可以设置成 debug -error_log logs/error.log error; +error_log logs/error.log info; master_process on; # 生产环境设置成 auto -worker_processes 4; +worker_processes 2; # 最大文件描述符数量 worker_rlimit_nofile 20240; diff --git a/conf/servers/gateway.conf b/conf/servers/gateway.conf index 5e04e44..10b1bc5 100644 --- a/conf/servers/gateway.conf +++ b/conf/servers/gateway.conf @@ -34,14 +34,22 @@ server { location / { include common/proxy.conf; set $target_service_name ''; - set $origin_uri $uri; + set $origin_uri $request_uri; + # set $origin_uri $uri; + set $proxy_path ''; + set $biz_schema 'http'; + set $biz_domain 'backend_server'; lua_code_cache on; rewrite_by_lua_block { local app = require("app") app.http_rewrite() } include common/phase.conf; - proxy_pass http://backend_server; + + proxy_ssl_name $biz_domain; + proxy_ssl_server_name on; + proxy_ssl_session_reuse off; + proxy_pass $biz_schema://backend_server$proxy_path; } location @grpc_pass { diff --git a/deps/share/lua/5.1/resty/etcd/utils.lua b/deps/share/lua/5.1/resty/etcd/utils.lua index c9766e3..e94c7a1 100644 --- a/deps/share/lua/5.1/resty/etcd/utils.lua +++ b/deps/share/lua/5.1/resty/etcd/utils.lua @@ -84,12 +84,18 @@ end local ngx_log = ngx.log local ngx_ERR = ngx.ERR local ngx_INFO = ngx.INFO +local ngx_DEBUG = ngx.DEBUG local function log_error(...) return ngx_log(ngx_ERR, ...) end _M.log_error = log_error +local function log_debug( ... ) + return ngx_log(ngx_DEBUG, ...) +end +_M.log_debug = log_debug + local function log_info( ... ) return ngx_log(ngx_INFO, ...) end diff --git a/deps/share/lua/5.1/resty/etcd/v3.lua b/deps/share/lua/5.1/resty/etcd/v3.lua index 5d4c296..0c9fc5a 100644 --- a/deps/share/lua/5.1/resty/etcd/v3.lua +++ b/deps/share/lua/5.1/resty/etcd/v3.lua @@ -31,7 +31,7 @@ local mt = { __index = _M } local refresh_jwt_token local function _request_uri(self, method, uri, opts, timeout, ignore_auth) - utils.log_info("v3 request uri: ", uri, ", timeout: ", timeout) + utils.log_debug("v3 request uri: ", uri, ", timeout: ", timeout) local body if opts and opts.body and tab_nkeys(opts.body) > 0 then @@ -230,12 +230,12 @@ function refresh_jwt_token(self, timeout) end if self.last_refresh_jwt_err then - utils.log_info("v3 refresh jwt last err: ", self.last_refresh_jwt_err) + utils.log_debug("v3 refresh jwt last err: ", self.last_refresh_jwt_err) return nil, self.last_refresh_jwt_err end -- something unexpected happened, try again - utils.log_info("v3 try auth after waiting, timeout: ", timeout) + utils.log_debug("v3 try auth after waiting, timeout: ", timeout) end self.last_refresh_jwt_err = nil @@ -331,7 +331,7 @@ local function set(self, key, val, attr) -- get if res.status < 300 then -- TODO(optimize): delay json encode - utils.log_info("v3 set body: ", encode_json(res.body)) + utils.log_debug("v3 set body: ", encode_json(res.body)) end return res @@ -560,7 +560,7 @@ local function request_chunk(self, method, scheme, host, port, path, opts, timeo query = query, headers = headers, }) - utils.log_info("http request method: ", method, " path: ", path, + utils.log_debug("http request method: ", method, " path: ", path, " body: ", body, " query: ", query) if not res then diff --git a/deps/share/lua/5.1/resty/radixtree.lua b/deps/share/lua/5.1/resty/radixtree.lua index d89d199..da1633a 100644 --- a/deps/share/lua/5.1/resty/radixtree.lua +++ b/deps/share/lua/5.1/resty/radixtree.lua @@ -164,7 +164,17 @@ end local ngx_log = ngx.log local ngx_INFO = ngx.INFO + local ngx_DEBUG = ngx.DEBUG local ngx_ERR = ngx.ERR + +local function log_debug(...) + if cur_level and ngx_DEBUG > cur_level then + return + end + + return ngx_log(ngx_DEBUG, ...) +end + local function log_info(...) if cur_level and ngx_INFO > cur_level then return @@ -229,7 +239,7 @@ local function insert_route(self, opts) self.match_data[self.match_data_index] = {opts} radix.radix_tree_insert(self.tree, path, #path, self.match_data_index) - log_info("insert route path: ", path, " dataprt: ", self.match_data_index) + log_debug("insert route path: ", path, " dataprt: ", self.match_data_index) return true end @@ -340,7 +350,7 @@ function pre_insert_route(self, path, route) route_opts.path = path end - log_info("path: ", route_opts.path, " operator: ", route_opts.path_op) + log_debug("path: ", route_opts.path, " operator: ", route_opts.path_op) route_opts.metadata = route.metadata route_opts.handler = route.handler @@ -472,7 +482,7 @@ local function compare_param(req_path, route, opts) end local pat, names = fetch_pat(route.path_org) - log_info("pcre pat: ", pat) + log_debug("pcre pat: ", pat) if #names == 0 then return true end @@ -517,7 +527,7 @@ local function match_route_opts(route, opts, args) if matcher_ins then local ok, err = matcher_ins:match(opts.remote_addr) if err then - log_info("failed to match ip: ", err) + log_debug("failed to match ip: ", err) return false end if not ok then @@ -525,7 +535,7 @@ local function match_route_opts(route, opts, args) end end - -- log_info("route.hosts: ", type(route.hosts)) + -- log_debug("route.hosts: ", type(route.hosts)) if route.hosts then local matched = false @@ -548,7 +558,7 @@ local function match_route_opts(route, opts, args) end end - log_info("hosts match: ", matched) + log_debug("hosts match: ", matched) if not matched then return false end @@ -597,8 +607,8 @@ local function _match_from_routes(routes, path, opts, args) local opts_matched_exists = (opts.matched ~= nil) for _, route in ipairs(routes) do if match_route_opts(route, opts, args) then - -- log_info("matched route: ", require("cjson").encode(route)) - -- log_info("matched path: ", path) + -- log_debug("matched route: ", require("cjson").encode(route)) + -- log_debug("matched path: ", path) if compare_param(path, route, opts) then if opts_matched_exists then opts.matched._path = route.path_org diff --git a/docker-compose.yml b/docker-compose.yml index 4f12c6a..e605e08 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,15 +1,8 @@ -version: '3.8' +version: '3.3' services: - gateway: - image: tianmen2_image:latest - restart: unless-stopped - volumes: - - /data/gw/logs/:/usr/local/openresty/nginx/logs/ - - /sys/fs/cgroup:/sys/fs/cgroup - ports: - - "8886:8888" - - "11006:11000" - networks: - - mynetwork -networks: - mynetwork: + gateway: + restart: unless-stopped + build: . + ports: + - "8888:8888" + - "11000:11000" diff --git a/gateway/admin/init.lua b/gateway/admin/init.lua index 006cc41..09528fa 100644 --- a/gateway/admin/init.lua +++ b/gateway/admin/init.lua @@ -72,13 +72,13 @@ local function check_api_token() end local tokens = config_get("tokens") if not tokens then - log.info("no api token settings") + log.debug("no api token settings") return false end if not tokens[token] then return false end - log.info("api token auth: ", token) + log.debug("api token auth: ", token) return true end @@ -159,8 +159,7 @@ function _M.init_worker() mapping(res.apis) end router = radixtree.new(req_mapping) - log.info("admin init") - + log.debug("admin init") end function _M.http_admin() diff --git a/gateway/app/config.lua b/gateway/app/config.lua index e61fe4d..92b0054 100644 --- a/gateway/app/config.lua +++ b/gateway/app/config.lua @@ -50,13 +50,18 @@ function _M.get_jwt_auth() return get("jwt_auth") end +function _M.get_sso2internal() + return get("sso2internal") +end + +function _M.get_mfa() + return get("mfa") +end + function _M.get_codo_rbac() return get("codo_rbac") end -function _M.get_auth_rbac() - return get("auth_rbac") -end function _M.get_plugins_config(plugin_name) local _plugins_config = get("plugins_config") diff --git a/gateway/app/core/balancer.lua b/gateway/app/core/balancer.lua index b7d6cf0..b27af1f 100644 --- a/gateway/app/core/balancer.lua +++ b/gateway/app/core/balancer.lua @@ -26,9 +26,11 @@ local upstream_type_cache = ngx.shared.upstream_type_cache local _M = {} local balancer_cache +local upstream_node_cache do - balancer_cache = lrucache.new({count = 1024}) + balancer_cache = lrucache.new({ count = 1024 }) + upstream_node_cache = lrucache.new({ count = 1024 }) end -- end do function _M.set_upstream_type(service_name, type) @@ -53,7 +55,7 @@ local balancer_types = { -- 刷新服务节点缓存 local function refresh(service_name, nodes) local type = get_upstream_type(service_name) - log.info("refresh balancer: ", json.delay_encode({service_name, type, nodes})) + log.debug("refresh balancer: ", json.delay_encode({ service_name, type, nodes })) local balancer_up = balancer_types[type](nodes) return balancer_cache:set(service_name, balancer_up) end @@ -62,14 +64,16 @@ _M.refresh = refresh -- 通过服务名获取 balancer 缓存 local function get(service_name) - log.error("get service balancer: ", json.delay_encode(balancer_cache:get(service_name))) + log.debug("get service balancer: ", json.delay_encode(balancer_cache:get(service_name))) return balancer_cache:get(service_name) end -- 更新服务节点 -function _M.set(service_name, upstream, weight) - weight = weight or 1 +function _M.set(service_name, upstream, node) + local weight = node.weight or 1 local balancer_up = get(service_name) + log.debug("set service balancer: ", service_name, ", ", upstream, ", ", weight) + upstream_node_cache:set(upstream, node) if not balancer_up then local nodes = { [upstream] = weight @@ -77,7 +81,6 @@ function _M.set(service_name, upstream, weight) refresh(service_name, nodes) return end - log.error("set service balancer: ", service_name, ", ", upstream, ", ", weight) balancer_up:set(upstream, weight or 1) end @@ -88,18 +91,25 @@ function _M.find(service_name) log.error("can not found service balancer: ", service_name) return nil end - return balancer_up:find() + local upstream = balancer_up:find() + if not upstream then + log.error("can not found any service node: ", service_name) + return nil + end + local node = upstream_node_cache:get(upstream) + return node end -- 删除服务节点 function _M.delete(service_name, upstream) local balancer_up = get(service_name) if balancer_up then - log.error("remove service balancer: ", service_name, " - ", upstream," - ", tab_nkeys(balancer_up.nodes)) - if tab_nkeys(balancer_up.nodes) == 1 then - balancer_cache:delete(service_name) + log.error("remove service balancer: ", service_name, " - ", upstream, " - ", tab_nkeys(balancer_up.nodes)) + balancer_up:delete(upstream) + upstream_node_cache:delete(upstream) + if tab_nkeys(balancer_up.nodes) == 0 then + balancer_cache:delete(service_name) else - balancer_up:delete(upstream) end end end diff --git a/gateway/app/core/ctx.lua b/gateway/app/core/ctx.lua index 93eff1e..a8ffd06 100644 --- a/gateway/app/core/ctx.lua +++ b/gateway/app/core/ctx.lua @@ -16,14 +16,17 @@ -- local require = require local log = require("app.core.log") +local json = require("app.core.json") local dispatcher = require("app.core.dispatcher") local call_utils = require("app.utils.call_utils") local pcall = pcall local pairs = pairs +local ipairs = ipairs local ngx = ngx local config = require("app.config") local tab_nkeys = require("table.nkeys") local router = require("app.core.router") +local core_table = require("app.core.table") local _M = {} @@ -41,18 +44,18 @@ local function install_plugins() if func then func() end - log.info("install plugin:", plugin.name) + log.debug("install plugin:", plugin.name) end end end function _M.init() - log.info("ctx init") + log.debug("ctx init") install_plugins() end function _M.init_worker() - log.info("ctx init worker") + log.debug("ctx init worker") call_utils.call(plugins, "do_in_init_worker") end @@ -62,6 +65,7 @@ function _M.get_dispatcher() local ctx_dispatcher = ngx_ctx.dispatcher if not ctx_dispatcher then + --log.info("ngx.var.uri === ", ngx.var.uri) local route = router.match(ngx.var.uri) if not route then route = { @@ -76,13 +80,25 @@ function _M.get_dispatcher() end if not route or tab_nkeys(route.plugins) == 0 then - route.plugins = {"default"} + route.plugins = { "default" } end + -- 构造数组 local dispatcher_plugins = {} - for _, plugin_name in pairs(route.plugins) do - dispatcher_plugins[plugin_name] = plugins[plugin_name] + for _, plugin_name in ipairs(route.plugins) do + local plugin = plugins[plugin_name] + log.debug("ngx.var.uri === ", ngx.var.uri, " dispatcher_plugins === ", plugin_name) + core_table.insert(dispatcher_plugins, plugin) end + + -- 升序排序 + core_table.sort(dispatcher_plugins, function(a, b) + local ap = a.priority or 9999 + local bp = b.priority or 9999 + return ap < bp + end) + + -- 创建dispatcher ctx_dispatcher = dispatcher:new(dispatcher_plugins, route) ngx_ctx.dispatcher = ctx_dispatcher end diff --git a/gateway/app/core/dispatcher.lua b/gateway/app/core/dispatcher.lua index 528d4d0..d940473 100644 --- a/gateway/app/core/dispatcher.lua +++ b/gateway/app/core/dispatcher.lua @@ -24,37 +24,37 @@ local _M = {} local mt = { __index = _M } function _M.do_in_rewrite(self) - call_utils.call(self.plugins, "do_in_rewrite", self.route) + call_utils.call_alphabeta(self.plugins, "do_in_rewrite", self.route) end function _M.do_in_access(self) - call_utils.call(self.plugins, "do_in_access", self.route) + call_utils.call_alphabeta(self.plugins, "do_in_access", self.route) end function _M.do_in_content(self) - call_utils.call(self.plugins, "do_in_content", self.route) + call_utils.call_alphabeta(self.plugins, "do_in_content", self.route) end function _M.do_in_balancer(self) - call_utils.call(self.plugins, "do_in_balancer", self.route) + call_utils.call_alphabeta(self.plugins, "do_in_balancer", self.route) end function _M.do_in_header_filter(self) - call_utils.call(self.plugins, "do_in_header_filter", self.route) + call_utils.call_alphabeta(self.plugins, "do_in_header_filter", self.route) end function _M.do_in_body_filter(self) - call_utils.call(self.plugins, "do_in_body_filter", self.route) + call_utils.call_alphabeta(self.plugins, "do_in_body_filter", self.route) end function _M.do_in_log(self) - call_utils.call(self.plugins, "do_in_log", self.route) + call_utils.call_alphabeta(self.plugins, "do_in_log", self.route) end function _M.new(self, plugins, route) self.plugins = plugins self.route = route - log.info("new dispatcher ==> ", cjson.encode(route)) + log.debug("new dispatcher ==> ", cjson.encode(route)) return setmetatable(self, mt) end diff --git a/gateway/app/core/json.lua b/gateway/app/core/json.lua index fcd4941..f52e9c7 100644 --- a/gateway/app/core/json.lua +++ b/gateway/app/core/json.lua @@ -89,7 +89,7 @@ local delay_tab = ) -- this is a non-thread safe implementation --- it works well with log, eg: log.info(..., json.delay_encode({...})) +-- it works well with log, eg: log.debug(..., json.delay_encode({...})) function _M.delay_encode(data, force) delay_tab.data = data delay_tab.force = force @@ -97,10 +97,10 @@ function _M.delay_encode(data, force) end function _M.decode_json_file(path) - log.info("decode json path: ", path) + log.debug("decode json path: ", path) local jsonFile = io.open(path, "r") local jsonStr = jsonFile:read("*a") - log.info("decode json data: ", jsonStr) + log.debug("decode json data: ", jsonStr) jsonFile:close() return json_decode(jsonStr) end diff --git a/gateway/app/core/lrucache.lua b/gateway/app/core/lrucache.lua index b5bd862..1577945 100644 --- a/gateway/app/core/lrucache.lua +++ b/gateway/app/core/lrucache.lua @@ -94,7 +94,7 @@ function _M.fetch_cache(self, key, invalid_stale, create_val_fun, ...) elapsed, err = lock:lock(key_s) if not elapsed then - log.info("failed to acquire the lock: ", err) + log.debug("failed to acquire the lock: ", err) -- 没有获得锁,则从过期缓存返回数据 if not invalid_stale and stale_obj then return get_val(stale_obj) @@ -122,7 +122,7 @@ function _M.capacity(self) end function _M.get_keys(self, max_count) - log.info("get keys: ", type(self.lru.get_keys)) + log.debug("get keys: ", type(self.lru.get_keys)) return self.lru:get_keys(max_count) end diff --git a/gateway/app/core/lrucache2.lua b/gateway/app/core/lrucache2.lua index 54c4e65..d969bf9 100644 --- a/gateway/app/core/lrucache2.lua +++ b/gateway/app/core/lrucache2.lua @@ -109,7 +109,7 @@ local function new_lru_fun(opts) end local key_s = tostring(key) - log.info("try to lock with key ", key_s) + log.debug("try to lock with key ", key_s) local elapsed, err = lock:lock(key_s) if not elapsed then @@ -120,7 +120,7 @@ local function new_lru_fun(opts) nil, key, version) if cache_obj then lock:unlock() - log.info("unlock with key ", key_s) + log.debug("unlock with key ", key_s) return cache_obj.val end @@ -129,7 +129,7 @@ local function new_lru_fun(opts) lru_obj:set(key, {val = obj, ver = version}, item_ttl) end lock:unlock() - log.info("unlock with key ", key_s) + log.debug("unlock with key ", key_s) return obj, err end diff --git a/gateway/app/core/request.lua b/gateway/app/core/request.lua index dc2dee5..694e5b6 100644 --- a/gateway/app/core/request.lua +++ b/gateway/app/core/request.lua @@ -179,7 +179,7 @@ function _M.get_body(max_size, ctx) return nil end - log.info("attempt to read body from file: ", file_name) + log.debug("attempt to read body from file: ", file_name) if max_size then local size, err = lfs.attributes (file_name, "size") diff --git a/gateway/app/core/router.lua b/gateway/app/core/router.lua index 60ed9b1..ad12234 100644 --- a/gateway/app/core/router.lua +++ b/gateway/app/core/router.lua @@ -32,7 +32,7 @@ do end -- end do local function create_rx(routes) - log.info("routes: ", json.delay_encode(routes)) + log.debug("routes: ", json.delay_encode(routes)) local mapping = {} for _, route in ipairs(routes) do core_table.insert( @@ -43,7 +43,7 @@ local function create_rx(routes) } ) end - log.info("mapping: ", json.delay_encode(mapping)) + log.debug("mapping: ", json.delay_encode(mapping)) return radixtree.new(mapping) end @@ -51,7 +51,7 @@ end function _M.match(url) local rx = radix_cache:get(rx_key, false) local route = rx:match(url) - log.info("match route: ", json.delay_encode({url, route})) + log.debug("match route: ", json.delay_encode({url, route})) return route end diff --git a/gateway/app/core/table.lua b/gateway/app/core/table.lua index 9088d09..1c5e982 100644 --- a/gateway/app/core/table.lua +++ b/gateway/app/core/table.lua @@ -33,6 +33,7 @@ local _M = { nkeys = nkeys, insert = table.insert, concat = table.concat, + sort = table.sort, clone = require("table.clone"), } diff --git a/gateway/app/core/timer.lua b/gateway/app/core/timer.lua index 3e030b2..18b2aa9 100644 --- a/gateway/app/core/timer.lua +++ b/gateway/app/core/timer.lua @@ -75,14 +75,14 @@ local function callback_fun(self) if lock then local elapsed, err = lock:lock(name) if not elapsed then - log.info("timer[", name, "] failed to acquire the lock: ", err) + log.debug("timer[", name, "] failed to acquire the lock: ", err) if self.fail_sleep_time > 0 then sleep(self.fail_sleep_time) end return end end - log.info("timer[", name, "] start") + log.debug("timer[", name, "] start") local start_time = time.now() local ok, err = pcall(callback, self.ctx) if not ok then @@ -94,7 +94,7 @@ local function callback_fun(self) end local ms = time.now() - start_time - log.info("timer[", name, "] run finish, take ", string.format("%.2f", ms), "s") + log.debug("timer[", name, "] run finish, take ", string.format("%.2f", ms), "s") end end diff --git a/gateway/app/init.lua b/gateway/app/init.lua index e980204..8a07fb3 100644 --- a/gateway/app/init.lua +++ b/gateway/app/init.lua @@ -33,7 +33,7 @@ end function _M.http_init_worker() local ctx = require("app.core.ctx") local worker_id = ngx.worker.id() - log.info("init worker: ", worker_id) + log.debug("init worker: ", worker_id) ctx.init_worker() end diff --git a/gateway/app/my_core/sso_users.lua b/gateway/app/my_core/sso_users.lua new file mode 100644 index 0000000..12250ff --- /dev/null +++ b/gateway/app/my_core/sso_users.lua @@ -0,0 +1,45 @@ +local ipairs = ipairs +local log = require("app.core.log") +local json = require("app.core.json") +local core_table = require("app.core.table") +local lrucache = require("app.core.lrucache") +local radixtree = require("resty.radixtree") +local _M = {} + +local radix_cache +local rx_key = "sso_users.rx" + +do + radix_cache = lrucache.new({ count = 2048 }) +end -- end do + +local function create_rx(sso_users_list) + log.debug("[create_rx] sso_users_list ", json.delay_encode(sso_users_list, false)) + local mapping = {} + for _, data in ipairs(sso_users_list) do + core_table.insert(mapping, + { + paths = { data.email }, + metadata = data, + }) + end + -- log.error("mapping: ", json.delay_encode(mapping)) + return radixtree.new(mapping) +end + + +-- 匹配权限 +function _M.get_user(email) + local rx = radix_cache:get(rx_key, false) + local user_info = rx:match(email) + return user_info +end + + +-- 注册URI权限 +function _M.refresh(sso_users_list) + local rx = create_rx(sso_users_list) + radix_cache:set(rx_key, rx) +end + +return _M diff --git a/gateway/app/plugins/codo_rbac_plugin.lua b/gateway/app/plugins/codo_rbac_plugin.lua index 75ac223..f1b68d9 100644 --- a/gateway/app/plugins/codo_rbac_plugin.lua +++ b/gateway/app/plugins/codo_rbac_plugin.lua @@ -14,6 +14,7 @@ local config = require("app.config") local codo_rbac_dict = config.get_codo_rbac() local rbac_secret = codo_rbac_dict.token_secret local rbac_key = codo_rbac_dict.key +local log = require("app.core.log") local _M = { name = "CRBAC", @@ -28,13 +29,10 @@ local function decode_auth_token_verify(auth_token) return load_token end - function _M.do_in_init_worker() rbac_store.init() end - - function _M.do_in_access() local cookie, err = ck:new() @@ -96,6 +94,8 @@ function _M.do_in_access() end local uri = ngx.var.origin_uri + uri = string.gsub(uri, "%?.*", "") + log.debug("rbac uri =====", uri) local state = rbac_verify.match(app_code, uri, user_id, method) if state == true then return diff --git a/gateway/app/plugins/default_plugin.lua b/gateway/app/plugins/default_plugin.lua index a573935..877900c 100644 --- a/gateway/app/plugins/default_plugin.lua +++ b/gateway/app/plugins/default_plugin.lua @@ -24,7 +24,8 @@ local _M = { name = "default", desc = "默认插件", optional = false, - version = "v1.0" + version = "v1.0", + priority = 10000 } function _M.do_in_init() diff --git a/gateway/app/plugins/discovery_plugin.lua b/gateway/app/plugins/discovery_plugin.lua index 10dc394..0ca54b6 100644 --- a/gateway/app/plugins/discovery_plugin.lua +++ b/gateway/app/plugins/discovery_plugin.lua @@ -15,16 +15,19 @@ -- limitations under the License. -- local ngx = ngx +local core = require("app.core") local log = require("app.core.log") +local json = require("app.core.json") local route_store = require("app.store.route_store") local discovery_store = require("app.store.discovery_store") local resp = require("app.core.response") -local ngx_balancer = require ("ngx.balancer") +local ngx_balancer = require("ngx.balancer") local balancer = require("app.core.balancer") local _M = { name = "discovery", desc = "服务发现插件", + priority = 100, optional = true, version = "v1.0" } @@ -40,17 +43,36 @@ function _M.do_in_rewrite(route) local service_name = route.service_name var.target_service_name = service_name - local upstream = balancer.find(service_name) + local svc_node = balancer.find(service_name) + local upstream = svc_node.upstream if not upstream then log.error("can not find any service node") return resp.exit(ngx.HTTP_NOT_FOUND) end + + local dst_domain = svc_node.dst_domain + local dst_is_https = svc_node.dst_is_https + + if dst_domain ~= "" then + ngx.var.biz_domain = dst_domain + log.info(route.prefix, "set Host header: ", dst_domain) + core.request.set_header(ngx_ctx, "Host", dst_domain) + end + + if dst_is_https == "yes" then + log.info(route.prefix, "set https proxy") + ngx.var.biz_schema = "https" + end + + log.info("upstream==", upstream, " svc_node==", json.delay_encode(svc_node)) + ngx_ctx.upstream_server = upstream + ngx_ctx.upstream_node = svc_node end -function _M.do_in_balancer() +function _M.do_in_balancer(route) local ngx_ctx = ngx.ctx local server = ngx_ctx.upstream_server log.info("upstream server: ", server) diff --git a/gateway/app/plugins/kafka-logger.lua b/gateway/app/plugins/kafka-logger.lua index e60c595..cf2e992 100644 --- a/gateway/app/plugins/kafka-logger.lua +++ b/gateway/app/plugins/kafka-logger.lua @@ -19,25 +19,25 @@ local _M = { name = "kafka-logger", desc = "kafka日志", optional = true, + priority = 300, version = "v0.1" } - function _M.do_in_log(route) -- 记录日志操作 --- local method = ngx.req.get_method() --- local uri = ngx.var.uri --- local postargs = ngx.req.get_body_data() --str --- -- local postargs = ngx.req.get_post_args() --table --- local data = { --- username = user_info.username, --- nickname = user_info.nickname, --- login_ip = ngx.var.proxy_add_x_forwarded_for, --- method = method, --- uri = ngx.var.request_uri, --- data = postargs, --- time = os.date('%Y-%m-%d %H:%M:%S') --- } + -- local method = ngx.req.get_method() + -- local uri = ngx.var.uri + -- local postargs = ngx.req.get_body_data() --str + -- -- local postargs = ngx.req.get_post_args() --table + -- local data = { + -- username = user_info.username, + -- nickname = user_info.nickname, + -- login_ip = ngx.var.proxy_add_x_forwarded_for, + -- method = method, + -- uri = ngx.var.request_uri, + -- data = postargs, + -- time = os.date('%Y-%m-%d %H:%M:%S') + -- } local var = ngx.var local client_ip = str_utils.split(ngx.var.proxy_add_x_forwarded_for, ',')[1] @@ -64,15 +64,15 @@ function _M.do_in_log(route) start_time = ngx.req.start_time() * 1000, latency = (ngx.now() - ngx.req.start_time()) * 1000 } --- if conf.include_req_body then --- local body = ngx.req.get_body_data() --- if body then --- log.request.body = body --- end --- end + -- if conf.include_req_body then + -- local body = ngx.req.get_body_data() + -- if body then + -- log.request.body = body + -- end + -- end --- log = json.encode(log) --- ngx.log(ngx.ERR, 'log log--->>>>>>>>>>>>>', log) + -- log = json.encode(log) + -- ngx.log(ngx.ERR, 'log log--->>>>>>>>>>>>>', log) -- ngx.log(ngx.ERR, 'log data--->>>>>>>>>>>>>', log) end diff --git a/gateway/app/plugins/mfa.lua b/gateway/app/plugins/mfa.lua new file mode 100644 index 0000000..0ef17fe --- /dev/null +++ b/gateway/app/plugins/mfa.lua @@ -0,0 +1,101 @@ +local jwt = require "app.my_core.resty.jwt" +local sso_users = require "app.my_core.sso_users" +local sso_users_store = require "app.store.sso_users_store" +local user_info = ngx.shared.user_info +local resp = require("app.core.response") +local ck = require("resty.cookie") +local core = require("app.core") +local log = require("app.core.log") +local json = require("app.core.json") +local string = string +local ngx = ngx +local get_headers = ngx.req.get_headers +local pairs = pairs +local config = require("app.config") +local mfa_config = config.get_mfa() + +local _M = { + name = "mfa", + priority = 400, + desc = "MFA认证中间件", + optional = true, + version = "v1.0" +} + +local mfa_secret = mfa_config.mfa_secret +local mfa_key = mfa_config.mfa_key + +local function encode_jwt_token(token_secret, payload) + return jwt:sign( + token_secret, + { + header = { typ = "JWT", alg = "HS256" }, + payload = payload + } + ) +end + +local function decode_jwt_token(token_secret, jwt_token) + local load_token = jwt:verify(token_secret, jwt_token) + return load_token +end + +function _M.do_in_init_worker() + sso_users_store.init() +end + +function _M.do_in_access() + local ngx_ctx = ngx.ctx + local cookie = ck:new() + + -- 绕过GET + if ngx.req.get_method() == "GET" then + return + end + + + -- 绕过 PATCH + if ngx.req.get_method() == "PATCH" then + return + end + + -- 超管跳过认证 + if user_info.is_superuser then + return + end + + + -- 获取内部的 jwt + local auth_key = cookie:get(mfa_key) + + if auth_key == nil then + auth_key = get_headers()[mfa_key] + end + + if auth_key == nil then + local arg = ngx.req.get_uri_args() + if arg ~= nil then + for k, v in pairs(arg) do + if k == mfa_key then + auth_key = v + end + end + end + end + + if auth_key == nil then + return resp.exit(423, "MFA认证失败") + end + + local load_token = decode_jwt_token(mfa_secret, auth_key) + + -- 鉴定token是否正常 + if load_token.verified == false then + return resp.exit(423, "MFA认证失败") + end + + core.request.set_header(ngx_ctx, mfa_key, auth_key) + return +end + +return _M diff --git a/gateway/app/plugins/redis-logger.lua b/gateway/app/plugins/redis-logger.lua index 1815c65..20561db 100644 --- a/gateway/app/plugins/redis-logger.lua +++ b/gateway/app/plugins/redis-logger.lua @@ -11,10 +11,10 @@ local ngx = ngx local redis = require "resty.redis" redis.add_commands("xadd") -local core = require("app.core") +local core = require("app.core") local user_info = ngx.shared.user_info local batch_processor = require("app.utils.batch-processor") -local log_util = require("app.utils.log-util") +local log_util = require("app.utils.log-util") -- plugin config local plugin_name = "redis-logger" @@ -26,16 +26,14 @@ local stale_timer_running = false local timer_at = ngx.timer.at local buffers = {} - - local _M = { name = plugin_name, desc = "Redis日志", optional = true, + priority = 500, version = "v0.1" } - local function connect(conf) local red = redis:new() @@ -79,7 +77,7 @@ local function remove_stale_objects(premature) for key, batch in ipairs(buffers) do if #batch.entry_buffer.entries == 0 and #batch.batch_to_process == 0 then core.log.warn("removing batch processor stale object, conf: ", - core.json.delay_encode(key)) + core.json.delay_encode(key)) buffers[key] = nil end end @@ -99,7 +97,7 @@ local buffers_config = { } local function default_plugin_conf(conf) - local _conf = { + local _conf = { host = conf.host or '127.0.0.1', port = conf.port or 6379, auth_pwd = conf.auth_pwd or '123456', @@ -119,7 +117,7 @@ redis_config = default_plugin_conf(read_config) function _M.do_in_log(route) -- local entry = log_util.get_full_log(ngx, {include_req_body=true}) - if redis_config.full_log == 'no' and ngx.req.get_method() == "GET" then + if redis_config.full_log == 'no' and ngx.req.get_method() == "GET" then return end @@ -163,7 +161,7 @@ function _M.do_in_log(route) end -- red:xadd - for i,k in ipairs(entries[1]) do + for i, k in ipairs(entries[1]) do core.log.error(i, k) end local ok, _err = red:xadd( diff --git a/gateway/app/plugins/rewrite_plugin.lua b/gateway/app/plugins/rewrite_plugin.lua index 2e51cfd..e0efe7e 100644 --- a/gateway/app/plugins/rewrite_plugin.lua +++ b/gateway/app/plugins/rewrite_plugin.lua @@ -16,26 +16,33 @@ function _M.do_in_rewrite(route) local rewrite_url_regex = route.props.rewrite_url_regex if not rewrite_url_regex then - log.info("rewrite props[rewrite_url_regex] not set") + log.debug("rewrite props[rewrite_url_regex] not set") return end local rewrite_replace = route.props.rewrite_replace if not rewrite_url_regex then - log.info("rewrite props[rewrite_replace] not set") + log.debug("rewrite props[rewrite_replace] not set") return end - local uri = var.uri - - local target_uri, _, err = re.gsub(uri, rewrite_url_regex, rewrite_replace, "jo") + -- local uri = var.origin_uri + local uri = re.sub(var.origin_uri, "\\?.*", "") + local proxy_path, _, err = re.gsub(var.origin_uri, rewrite_url_regex, rewrite_replace, "jo") if err then - log.error("rewrite url error: ", err) + log.error("rewrite proxy_path error: ", err) + return + end + + local target_uri, _, err2 = re.gsub(uri, rewrite_url_regex, rewrite_replace, "jo") + if err2 then + log.error("rewrite url error: ", err2) return end - log.info("rewrite url ==> origin_uri: ", var.origin_uri, ", target_uri: ", target_uri) + log.error("rewrite url ==> origin_uri: ", var.origin_uri, ", target_uri: ", target_uri, ", proxy_path: ", proxy_path) req.set_uri(target_uri, false) + ngx.var.proxy_path = proxy_path end return _M diff --git a/gateway/app/plugins/sso2internal.lua b/gateway/app/plugins/sso2internal.lua new file mode 100644 index 0000000..9aec8e2 --- /dev/null +++ b/gateway/app/plugins/sso2internal.lua @@ -0,0 +1,125 @@ +local jwt = require "app.my_core.resty.jwt" +local sso_users = require "app.my_core.sso_users" +local sso_users_store = require "app.store.sso_users_store" +local user_info = ngx.shared.user_info +local resp = require("app.core.response") +local ck = require("resty.cookie") +local core = require("app.core") +local log = require("app.core.log") +local json = require("app.core.json") +local string = string +local ngx = ngx +local get_headers = ngx.req.get_headers +local pairs = pairs +local config = require("app.config") +local sso2internal_config = config.get_sso2internal() + +local _M = { + name = "sso2internal", + priority = 100, + desc = "将外部的 SSO TOKEN 转换成 内部的 AUTH KEY,用于接入外部系统", + optional = true, + version = "v1.0" +} + +local sso_token_secret = sso2internal_config.sso_token_secret +local sso_jwt_key = sso2internal_config.sso_jwt_key +local internal_token_secret = sso2internal_config.internal_token_secret +local internal_jwt_key = sso2internal_config.internal_jwt_key + +local function encode_jwt_token(token_secret, payload) + return jwt:sign( + token_secret, + { + header = { typ = "JWT", alg = "HS256" }, + payload = payload + } + ) +end + +local function decode_jwt_token(token_secret, jwt_token) + local load_token = jwt:verify(token_secret, jwt_token) + return load_token +end + +function _M.do_in_init_worker() + sso_users_store.init() +end + +function _M.do_in_access() + local ngx_ctx = ngx.ctx + local cookie = ck:new() + local current_time = ngx.time() + + -- 绕过WS + if ngx.req.get_method() == "GET" and get_headers()["upgrade"] == "websocket" then + return + end + + -- 获取内部的 jwt + local auth_key = cookie:get(internal_jwt_key) + + if auth_key == nil then + local arg = ngx.req.get_uri_args() + if arg ~= nil then + for k, v in pairs(arg) do + if k == internal_jwt_key then + auth_key = v + end + end + end + end + + -- 尝试获取 SSO TOKEN + if auth_key == nil then + local sso_jwt_token = cookie:get(sso_jwt_key) + -- 没有 SSO TOKEN 直接返回 + if sso_jwt_token == nil then + return + end + + -- 解密 sso jwt token + local sso_token_data = decode_jwt_token(sso_token_secret, sso_jwt_token) + + --log.info("sso_token_data ,", json.delay_encode(sso_token_data, false), "sso_token_secret ", sso_token_secret) + + local email = sso_token_data.payload.email + local sso_user = sso_users.get_user(email) + local jwt_token = encode_jwt_token(internal_token_secret, { + sub = "my token", + exp = current_time + 86400, + iat = current_time, + nbf = current_time, + data = { + user_id = sso_user.codo_user_id, + username = sso_user.name, + nickname = sso_user.name, + email = email, + is_superuser = sso_user.codo_is_superuser, + } + }) + auth_key = jwt_token + + -- set cookie + cookie:set({ + key = internal_jwt_key, + value = jwt_token, + path = "/", + }) + cookie:set({ + key = "is_login", + value = "yes", + path = "/", + }) + end + + log.debug("auth_key ==== ", auth_key) + + if auth_key then + -- 设置 header 并且 "_" 切换为 "-" + core.request.set_header(ngx_ctx, string.gsub(internal_jwt_key,"_","-"), auth_key) + end + return +end + +return _M diff --git a/gateway/app/plugins/tracing_plugin.lua b/gateway/app/plugins/tracing_plugin.lua index 7fcc3c6..50fd75a 100644 --- a/gateway/app/plugins/tracing_plugin.lua +++ b/gateway/app/plugins/tracing_plugin.lua @@ -23,13 +23,14 @@ local _M = { name = "tracing", desc = "链路跟踪插件", optional = true, + priority = 90, version = "v1.0" } function _M.do_in_init() -- automatic seeding with os.time(), LuaSocket, or ngx.time() uuid.seed() - log.info("jit-uuid init") + log.debug("jit-uuid init") end function _M.do_in_access() diff --git a/gateway/app/store/discovery_store.lua b/gateway/app/store/discovery_store.lua index 0b55c89..d4427e6 100644 --- a/gateway/app/store/discovery_store.lua +++ b/gateway/app/store/discovery_store.lua @@ -59,6 +59,8 @@ local function parse_node(service) key = create_key(service.service_name, service.upstream), service_name = service.service_name, upstream = service.upstream, + dst_domain = service.dst_domain, + dst_is_https = service.dst_is_https, weight = service.weight, status = service.status, time = service.time or time.now() * 1000 @@ -78,7 +80,7 @@ local function find_by_prefix(prefix) prefix = prefix and service_node_etcd_key(prefix) or etcd_prefix local resp, err = etcd.readdir(prefix) if err then - log.info("failed to query service list, key: ", prefix, ", err: ", err) + log.debug("failed to query service list, key: ", prefix, ", err: ", err) return nil, err end @@ -104,47 +106,22 @@ end _M.query_service_node_list = query_service_node_list --- 从 etcd 读取所有服务节点 {service_name = node_list} -local function query_services_nodes() - local list = query_service_node_list() - if not list and tab_nkeys(list) < 1 then - return nil - end - - local nodes = {} - for _, node in ipairs(list) do - if node.status ~= 1 then - goto CONTINUE - end - - local items = nodes[node.service_name] or {} - items[node.upstream] = node.weight - nodes[node.service_name] = items - - ::CONTINUE:: - end - return nodes -end - -_M.query_services_nodes = query_services_nodes - -- 更新服务节点信息 local function apply_balancer(node) local service_name = node.service_name local service_upstream = node.upstream - local weight = node.weight -- 下线状态,删除缓存服务节点 if node.status == 0 then balancer.delete(service_name, service_upstream) return end - balancer.set(service_name, service_upstream, weight) + balancer.set(service_name, service_upstream, node) end -- 从 etcd 删除服务节点 local function delete_etcd_node(key) - local xxx =service_node_etcd_key(key) + local xxx = service_node_etcd_key(key) local _, err = etcd.delete(service_node_etcd_key(key)) if not err then return false, err @@ -179,7 +156,7 @@ _M.set_service_node = set_service_node -- 监听服务节点数据变更 local function watch_services(ctx) - log.info("watch services start_revision: ", ctx.start_revision) + log.debug("watch services start_revision: ", ctx.start_revision) local opts = { timeout = etcd_watch_opts.timeout, prev_kv = etcd_watch_opts.prev_kv, @@ -200,7 +177,7 @@ local function watch_services(ctx) end break end - log.info("services watch result: ", json.delay_encode(chunk.result)) + log.debug("services watch result: ", json.delay_encode(chunk.result)) ctx.start_revision = chunk.result.header.revision + 1 if chunk.result.events then for _, event in ipairs(chunk.result.events) do @@ -218,13 +195,13 @@ end -- 加载服务节点注册信息 local function load_services() - local nodes = query_services_nodes() + local nodes = query_service_node_list() if nodes and tab_nkeys(nodes) > 0 then - for name, items in pairs(nodes) do - balancer.refresh(name, items) + for _, node in pairs(nodes) do + apply_balancer(node) end else - log.info("service nodes empty") + log.debug("service nodes empty") end end @@ -236,9 +213,9 @@ local function _init() end function _M.init() - discovery_timer = timer.new("discovery.timer", _init, {delay = 0}) - discovery_refresh_timer = timer.new("discovery.refresh.timer", load_services, {delay = 3}) - discovery_watcher_timer = timer.new("discovery.watcher.timer", watch_services, {delay = 0}) + discovery_timer = timer.new("discovery.timer", _init, { delay = 0 }) + discovery_refresh_timer = timer.new("discovery.refresh.timer", load_services, { delay = 3 }) + discovery_watcher_timer = timer.new("discovery.watcher.timer", watch_services, { delay = 0 }) local ok, err = discovery_timer:once() if not ok then error("failed to init discovery: " .. err) diff --git a/gateway/app/store/rbac_store.lua b/gateway/app/store/rbac_store.lua index e0bd87f..ac12f2c 100644 --- a/gateway/app/store/rbac_store.lua +++ b/gateway/app/store/rbac_store.lua @@ -58,8 +58,7 @@ local function query_list() end local crbac = {} - - --ngx.log(ngx.ERR, 'rbac query list--->',etcd_prefix, json.encode( resp.body.kvs)) + -- ngx.log(ngx.ERR, 'rbac query list--->',etcd_prefix, json.encode( resp.body.kvs)) if resp.body.kvs and tab_nkeys(resp.body.kvs) > 0 then for _, kv in ipairs(resp.body.kvs) do core_table.insert(crbac, kv.value) @@ -96,7 +95,7 @@ local function refresh_crbac() end local function watch_crbac(ctx) - log.info("watch crbacs start_revision: ", ctx.start_revision) + log.debug("watch crbacs start_revision: ", ctx.start_revision) local opts = { timeout = etcd_watch_opts.timeout, prev_kv = etcd_watch_opts.prev_kv, @@ -117,11 +116,11 @@ local function watch_crbac(ctx) end break end - log.info("rabc watch result: ", json.delay_encode(chunk.result)) + log.debug("rabc watch result: ", json.delay_encode(chunk.result)) ctx.start_revision = chunk.result.header.revision + 1 if chunk.result.events then for _, event in ipairs(chunk.result.events) do - --log.error("rabc event: ", event.type, " - ", json.delay_encode(event.kv)) + log.error("rabc event: ", event.type, " - ", json.delay_encode(event.kv)) refresh_crbac() end end diff --git a/gateway/app/store/route_store.lua b/gateway/app/store/route_store.lua index a5ec6d6..851442e 100644 --- a/gateway/app/store/route_store.lua +++ b/gateway/app/store/route_store.lua @@ -99,7 +99,7 @@ local function refresh_router() end local function watch_routes(ctx) - log.info("watch routes start_revision: ", ctx.start_revision) + log.debug("watch routes start_revision: ", ctx.start_revision) local opts = { timeout = etcd_watch_opts.timeout, prev_kv = etcd_watch_opts.prev_kv, @@ -120,7 +120,7 @@ local function watch_routes(ctx) end break end - log.info("routes watch result: ", json.delay_encode(chunk.result)) + log.debug("routes watch result: ", json.delay_encode(chunk.result)) ctx.start_revision = chunk.result.header.revision + 1 if chunk.result.events then for _, event in ipairs(chunk.result.events) do diff --git a/gateway/app/store/sso_users_store.lua b/gateway/app/store/sso_users_store.lua new file mode 100644 index 0000000..9b5d9dc --- /dev/null +++ b/gateway/app/store/sso_users_store.lua @@ -0,0 +1,118 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +local error = error +local ipairs = ipairs +local etcd = require("app.core.etcd") +local log = require("app.core.log") +local tab_nkeys = require("table.nkeys") +local core_table = require("app.core.table") +local sso_users = require("app.my_core.sso_users") +local timer = require("app.core.timer") +local json = require("app.core.json") + +local _M = {} + +local sso_users_timer +local sso_users_watch_timer +-- 防止网络异常导致路由数据监听处理失败,未及时更新路由信息,定时轮训路由配置 +local sso_users_refresh_timer + +local etcd_prefix = "uc/userinfo/" + +local etcd_watch_opts = { + timeout = 60, + prev_kv = true +} + +-- 查询所有用户信息,返回 list +local function query_list() + local resp, err = etcd.readdir(etcd_prefix) + if err ~= nil then + log.error("failed to load sso_users_list", err) + return nil, err + end + + local sso_users_list = {} + if resp.body.kvs and tab_nkeys(resp.body.kvs) > 0 then + for _, kv in ipairs(resp.body.kvs) do + core_table.insert(sso_users_list, kv.value) + end + end + return sso_users_list, nil +end + +_M.query_list = query_list + +local function refresh_sso_users() + local sso_users_list, err = query_list() + if not sso_users_list and tab_nkeys(sso_users_list) < 1 then + return nil, err + end + sso_users.refresh(sso_users_list) +end + +local function watch_sso_users_list(ctx) + log.debug("watch sso_users_list start_revision: ", ctx.start_revision) + local opts = { + timeout = etcd_watch_opts.timeout, + prev_kv = etcd_watch_opts.prev_kv, + start_revision = ctx.start_revision + } + local chunk_fun, err = etcd.watchdir(etcd_prefix, opts) + + if not chunk_fun then + log.error("sso_users_list chunk err: ", err) + return + end + while true do + local chunk + chunk, err = chunk_fun() + if not chunk then + if err ~= "timeout" then + log.error("sso_users_list chunk err: ", err) + end + break + end + log.debug("sso_users_list watch result: ", json.delay_encode(chunk.result)) + ctx.start_revision = chunk.result.header.revision + 1 + if chunk.result.events then + for _, event in ipairs(chunk.result.events) do + log.error("sso_users_list event: ", event.type, " - ", json.delay_encode(event.kv)) + refresh_sso_users() + end + end + end +end + +local function _init() + refresh_sso_users() + sso_users_watch_timer:recursion() + sso_users_refresh_timer:every() +end + +-- 初始化 +function _M.init() + sso_users_timer = timer.new("sso_users.timer", _init, {delay = 0}) + sso_users_refresh_timer = timer.new("sso_users.refresh.timer", refresh_sso_users, { delay = 3}) + sso_users_watch_timer = timer.new("sso_users.watch.timer", watch_sso_users_list, {delay = 0}) + local ok, err = sso_users_timer:once() + if not ok then + error("failed to load sso_users_list: " .. err) + end +end + +return _M diff --git a/gateway/app/utils/call_utils.lua b/gateway/app/utils/call_utils.lua index a87044e..5389970 100644 --- a/gateway/app/utils/call_utils.lua +++ b/gateway/app/utils/call_utils.lua @@ -17,6 +17,7 @@ local log = require("app.core.log") local pcall = pcall local pairs = pairs +local ipairs = ipairs local _M = {} function _M.call(modules, method_name, ...) @@ -36,4 +37,24 @@ function _M.call(modules, method_name, ...) end end +-- call_alphabeta 顺序调用模块方法 +function _M.call_alphabeta(modules, method_name, ...) + for _, m in ipairs(modules) do + local func = m[method_name] + local plugin_name = m.name + if not func then + log.debug("can not found module method, ", plugin_name, ".", method_name) + goto CONTINUE + end + log.debug("call, ", plugin_name, ".", method_name) + + local ok, err = pcall(func, ...) + if not ok then + log.error("call error:", method_name, " - ", err) + end + + ::CONTINUE:: + end +end + return _M