diff --git a/lualib/skynet-fly/db/mysqlf.lua b/lualib/skynet-fly/db/mysqlf.lua index f9dba89e3..ebe6b8d34 100644 --- a/lualib/skynet-fly/db/mysqlf.lua +++ b/lualib/skynet-fly/db/mysqlf.lua @@ -1,10 +1,15 @@ local skynet = require "skynet" local contriner_client = require "skynet-fly.client.contriner_client" +local mysql = require "skynet.db.mysql" +local log = require "skynet-fly.log" +local timer = require "skynet-fly.timer" local assert = assert local setmetatable = setmetatable +local pcall = pcall +local next = next -contriner_client:register("mysql_m") +contriner_client:register("mysql_m", "share_config_m") local g_instance = nil local g_instance_map = {} @@ -12,6 +17,8 @@ local g_instance_map = {} local M = {} local mt = {__index = M} +---------------------------------mysql_m-------------------------------------------- + function M:new(db_name) local client = contriner_client:new("mysql_m",db_name) local t = { @@ -52,4 +59,66 @@ function M:max_packet_size() end end +---------------------------------mysql_m-------------------------------------------- + +---------------------------------本服直连模式---------------------------------------- +local function keep_alive(week_t) + local t = next(week_t) + if not t then return end + local conn = t.conn + if conn then + local ok,ret = pcall(conn.ping, conn) + if not ok then + log.error("keep_alive err ", ret) + end + else + log.error("keep_alive not conn ", conn) + end +end + +local g_lmt = {__gc = function(t) + if t.keep_time then + t.keep_time:cancel() + end +end} + +local week_mt = {__mode = "kv"} + +function M.l_new_client(db_name) + local cli = contriner_client:new('share_config_m') + local conf_map = cli:mod_call('query','mysql') + assert(conf_map and conf_map[db_name],"not mysql conf") + + local conf = conf_map[db_name] + local database = conf.database + conf.database = nil + local conn = mysql.connect(conf) + conn:query('CREATE DATABASE IF NOT EXISTS ' .. database .. ';') + conn:disconnect() + conf.database = database + conn = mysql.connect(conf) + + local t = { + conf = conf, + conn = conn, + } + + local week_t = setmetatable({}, week_mt) + week_t[t] = true + + t.keep_time = timer:new(timer.second * 10,timer.loop, keep_alive, week_t) + t.keep_time:after_next() + + setmetatable(t, g_lmt) + return t +end + +function M.l_instance(db_name) + if not g_instance_map[db_name] then + g_instance_map[db_name] = M.l_new_client(db_name) + end + + return g_instance_map[db_name] +end + return M \ No newline at end of file diff --git a/lualib/skynet-fly/db/orm/ormtable.lua b/lualib/skynet-fly/db/orm/ormtable.lua index a3ee4f383..288c5dd9c 100644 --- a/lualib/skynet-fly/db/orm/ormtable.lua +++ b/lualib/skynet-fly/db/orm/ormtable.lua @@ -42,6 +42,7 @@ local FIELD_TYPE = { text = 51, blob = 52, + table = 53, } local INVAILD_POINT = {count = 0, total_count = 0} --无效叶点 @@ -68,6 +69,7 @@ local FIELD_LUA_DEFAULT = { [FIELD_TYPE.text] = "", [FIELD_TYPE.blob] = "", + [FIELD_TYPE.table] = {}, } local function create_check_str(len) @@ -98,6 +100,7 @@ local FIELD_TYPE_CHECK_FUNC = { [FIELD_TYPE.text] = function(str) return type(str) == 'string' end, [FIELD_TYPE.blob] = function(str) return type(str) == 'string' end, + [FIELD_TYPE.table] = function(tab) return type(tab) == 'table' end, } local function add_field_name_type(t,field_name,field_type) @@ -355,7 +358,11 @@ local function init_entry_data(t, entry_data, is_old) if entry_data[fn] then new_entry_data[fn] = entry_data[fn] else - new_entry_data[fn] = FIELD_LUA_DEFAULT[ft] + if ft ~= FIELD_TYPE.table then + new_entry_data[fn] = FIELD_LUA_DEFAULT[ft] + else + new_entry_data[fn] = {} + end end end return new_entry_data @@ -752,35 +759,38 @@ local function get_entry_by_in(t, in_values, key_values) else key_values[kv_len + 1] = nil local entry_data_list = t._adapterinterface:get_entry_by_in(in_values, key_values) - if not entry_data_list or not next(entry_data_list) then - --添加无效条目站位,防止缓存穿透 - for i = #in_values, 1, -1 do - local v = in_values[i] - key_values[kv_len + 1] = v - local invaild_entry = create_invaild_entry(t, key_values) - add_key_select(t, invaild_entry) - set_total_count(t, key_values, 0) - end - return res_entry_list, false - else - local in_v_count_map = {} - for i = 1,#entry_data_list do - local entry_data = init_entry_data(t, entry_data_list[i], true) - local entry = ormentry:new(t, entry_data) - tinsert(res_entry_list, add_key_select(t, entry)) - local inv = entry_data[in_field_name] - if not in_v_count_map[inv] then - in_v_count_map[inv] = 0 - end - in_v_count_map[inv] = in_v_count_map[inv] + 1 + local in_v_count_map = {} + local in_v_cnt = 0 + for i = 1,#entry_data_list do + local entry_data = init_entry_data(t, entry_data_list[i], true) + local entry = ormentry:new(t, entry_data) + tinsert(res_entry_list, add_key_select(t, entry)) + local inv = entry_data[in_field_name] + if not in_v_count_map[inv] then + in_v_count_map[inv] = 0 + in_v_cnt = in_v_cnt + 1 end + in_v_count_map[inv] = in_v_count_map[inv] + 1 + end - for inv, count in pairs(in_v_count_map) do - key_values[kv_len + 1] = inv - set_total_count(t, key_values, count) + for inv, count in pairs(in_v_count_map) do + key_values[kv_len + 1] = inv + set_total_count(t, key_values, count) + end + + --添加无效条目站位,防止缓存穿透 + if in_v_cnt ~= #in_values then + for i = 1, #in_values do + local v = in_values[i] + if not in_v_count_map[v] then + key_values[kv_len + 1] = v + local invaild_entry = create_invaild_entry(t, key_values) + add_key_select(t, invaild_entry) + set_total_count(t, key_values, 0) + end end - return res_entry_list, false end + return res_entry_list, false end end diff --git a/lualib/skynet-fly/db/ormadapter/ormadapter_mysql.lua b/lualib/skynet-fly/db/ormadapter/ormadapter_mysql.lua index 29b0e0c27..2eeb8c100 100644 --- a/lualib/skynet-fly/db/ormadapter/ormadapter_mysql.lua +++ b/lualib/skynet-fly/db/ormadapter/ormadapter_mysql.lua @@ -3,12 +3,13 @@ local table_util = require "skynet-fly.utils.table_util" local string_util = require "skynet-fly.utils.string_util" local mysqlf = require "skynet-fly.db.mysqlf" local log = require "skynet-fly.log" +local json = require "cjson" local FIELD_TYPE = require "skynet-fly.db.orm.ormtable".FIELD_TYPE local FIELD_LUA_DEFAULT = require "skynet-fly.db.orm.ormtable".FIELD_LUA_DEFAULT local setmetatable = setmetatable -local sfild = string.find +local sfind = string.find local sformat = string.format local assert = assert local tconcat = table.concat @@ -20,6 +21,8 @@ local type = type local tonumber = tonumber local tinsert = table.insert local tremove = table.remove +local tsort = table.sort +local pcall = pcall local FIELD_TYPE_SQL_TYPE = { [FIELD_TYPE.int8] = "tinyint", @@ -42,6 +45,7 @@ local FIELD_TYPE_SQL_TYPE = { [FIELD_TYPE.text] = "text", [FIELD_TYPE.blob] = "blob", + [FIELD_TYPE.table] = "blob", } local IS_NUMBER_TYPE = { @@ -62,17 +66,53 @@ do end end +local function new_prepare_obj(prepare_str) + local t = { + prepare_str = prepare_str, + } + + return t +end + +local function get_stmt(db, prepare_obj) + if not prepare_obj.stmt then + local ret = db.conn:prepare(prepare_obj.prepare_str) + if ret.err then + error(ret.err) + end + prepare_obj.stmt = ret + end + return prepare_obj.stmt +end + +local function prepare_execute(db, prepare_obj, ...) + local stmt = get_stmt(db, prepare_obj) + local ret = db.conn:execute(stmt, ...) + if ret.err and sfind(ret.err, "Unknown prepared", nil, true) then + prepare_obj.stmt = nil + stmt = get_stmt(db, prepare_obj) + else + return ret + end + + return db.conn:execute(stmt, ...) +end + local M = {} local mata = {__index = M} -- 新建适配对象 function M:new(db_name) local t = { - _db = mysqlf:new(db_name), + _db = mysqlf.l_new_client(db_name), _tab_name = nil, _field_list = nil, _field_map = nil, _key_list = nil, + _tab_encode = json.encode, + _tab_decode = json.decode, + batch_insert_num = 10, + batch_update_num = 10, } setmetatable(t, mata) @@ -80,6 +120,27 @@ function M:new(db_name) return t end +--设置单次整合批量插入的数量 +function M:set_batch_insert_num(num) + assert(num > 0) + self.batch_insert_num = num + return self +end + +--设置单次整合批量更新的数量 +function M:set_batch_update_num(num) + assert(num > 0) + self.batch_update_num = num + return self +end + +-- 设置table类型的打包解包函数 +function M:set_table_pack(encode, decode) + self._tab_encode = encode + self._tab_decode = decode + return self +end + local function create_table(t) local field_list = t._field_list local field_map = t._field_map @@ -89,7 +150,7 @@ local function create_table(t) local field_name = field_list[i] local field_type = field_map[field_name] local convert_type = assert(FIELD_TYPE_SQL_TYPE[field_type],"unknown type : " .. field_type) - if field_type == FIELD_TYPE.text or field_type == FIELD_TYPE.blob then --text 和 blob类型不支持指定默认值 + if field_type == FIELD_TYPE.text or field_type == FIELD_TYPE.blob or field_type == FIELD_TYPE.table then --text 和 blob类型不支持指定默认值 sql_str = sql_str .. sformat("\t`%s` %s,\n", field_name, convert_type) else sql_str = sql_str .. sformat("\t`%s` %s NOT NULL DEFAULT '%s',\n", field_name, convert_type, FIELD_LUA_DEFAULT[field_type]) @@ -99,7 +160,7 @@ local function create_table(t) sql_str = sql_str .. sformat("\tprimary key(%s)\n", tconcat(key_list,',')) sql_str = sql_str .. ');' - local ret = t._db:query(sql_str) + local ret = t._db.conn:query(sql_str) if not ret then error("create table err", sql_str) elseif ret.err then @@ -154,7 +215,11 @@ local function alter_table(t, describe, index_info) if def_info._flag == "add" then tinsert(new_field_list,field_name) elseif def_info._flag == "valuedef" then - error("can`t change type " .. field_name .. ' new:' .. def_info._new .. ' old:' .. def_info._old) --不能修改类型 + if (def_info._new == FIELD_TYPE.blob and def_info._old == FIELD_TYPE.table) or + (def_info._new == FIELD_TYPE.table and def_info._old == FIELD_TYPE.blob) then + else + error("can`t change type " .. field_name .. ' new:' .. def_info._new .. ' old:' .. def_info._old) --不能修改类型 + end end end @@ -165,13 +230,13 @@ local function alter_table(t, describe, index_info) local field_type = field_map[field_name] local convert_type = assert(FIELD_TYPE_SQL_TYPE[field_type],"unknown type : " .. field_type) if not is_end then - if field_type == FIELD_TYPE.text or field_type == FIELD_TYPE.blob then + if field_type == FIELD_TYPE.text or field_type == FIELD_TYPE.blob or field_type == FIELD_TYPE.table then sql_str = sql_str .. sformat("add `%s` %s,\n", field_name, convert_type) else sql_str = sql_str .. sformat("add `%s` %s NOT NULL DEFAULT '%s',\n", field_name, convert_type, FIELD_LUA_DEFAULT[field_type]) end else - if field_type == FIELD_TYPE.text or field_type == FIELD_TYPE.blob then + if field_type == FIELD_TYPE.text or field_type == FIELD_TYPE.blob or field_type == FIELD_TYPE.table then sql_str = sql_str .. sformat("add `%s` %s;\n", field_name, convert_type) else sql_str = sql_str .. sformat("add `%s` %s NOT NULL DEFAULT '%s';\n", field_name, convert_type, FIELD_LUA_DEFAULT[field_type]) @@ -179,7 +244,7 @@ local function alter_table(t, describe, index_info) end end - local ret = t._db:query(sql_str) + local ret = t._db.conn:query(sql_str) if not ret then log.error("alter_table err ",sql_str) error("alter_table table err") @@ -196,28 +261,22 @@ function M:builder(tab_name, field_list, field_map, key_list) self._key_list = key_list self._field_list = field_list + local tab_encode = self._tab_encode + local tab_decode = self._tab_decode + local table_field_list = {} -- 查询表的字段信息 - local describe = self._db:query("DESCRIBE " .. tab_name) + local describe = self._db.conn:query("DESCRIBE " .. tab_name) assert(describe, "not describe ret " .. tab_name) if describe.err then - assert(sfild(describe.err, "doesn't exist", nil, true), "unknown") + assert(sfind(describe.err, "doesn't exist", nil, true), "unknown") --不存在 创建 create_table(self) else --存在 检查变更 - local index_info = self._db:query("show index from " .. tab_name) + local index_info = self._db.conn:query("show index from " .. tab_name) assert(index_info, "can`t get index_info ") alter_table(self, describe, index_info) end - - local packet_max = self._db:query("SHOW VARIABLES LIKE 'max_allowed_packet'") - assert(packet_max and #packet_max >= 1, "can`t get packet max") - - local max_packet_size = tonumber(packet_max[1].Value) --包最大长度 - local local_max_packet_size = self._db:max_packet_size() - if local_max_packet_size < max_packet_size then - max_packet_size = local_max_packet_size - end local field_index_map = {} @@ -231,28 +290,21 @@ function M:builder(tab_name, field_list, field_map, key_list) local update_format_head = sformat("update %s set ",tab_name) local update_format_head_list = {} local update_format_end = " where " + local updates_format_end = " where(" + local updates_format_key = "(" local delete_format_head = sformat("delete from %s",tab_name) - local delete_format_center = " where " local len = #field_list for i = 1,len do local field_name = field_list[i] local field_type = field_map[field_name] - if IS_NUMBER_TYPE[field_type] then - if i == len then - insert_format_end = insert_format_end .. "%d" - else - insert_format_end = insert_format_end .. "%d," - end - update_format_head_list[i] = '`' .. field_name .. "` = %d" + update_format_head_list[i] = '`' .. field_name .. "`=?," + if i == len then + insert_format_end = insert_format_end .. "?" else - if i == len then - insert_format_end = insert_format_end .. "'%s'" - else - insert_format_end = insert_format_end .. "'%s'," - end - update_format_head_list[i] = '`' .. field_name .. "` = '%s'" + insert_format_end = insert_format_end .. "?," end + if i == len then insert_format_head = insert_format_head .. '`' .. field_name .. '`' select_format_head = select_format_head .. '`' .. field_name .. '`' @@ -262,32 +314,29 @@ function M:builder(tab_name, field_list, field_map, key_list) end field_index_map[field_name] = i + + if field_type == FIELD_TYPE.table then + tinsert(table_field_list, field_name) + end end + local table_field_len = #table_field_list + len = #key_list for i = 1,len do local field_name = key_list[i] - local field_type = field_map[field_name] - if IS_NUMBER_TYPE[field_type] then - if i == len then - select_format_end = select_format_end .. '`' .. field_name .. '`=' .. "%d" - update_format_end = update_format_end .. '`' .. field_name .. '`=' .. "%d" - select_format_key_head = select_format_key_head .. '`' .. field_name .. '`' - else - select_format_end = select_format_end .. '`' .. field_name .. '`=' .. "%d" - update_format_end = update_format_end .. '`' .. field_name .. '`=' .. "%d and " - select_format_key_head = select_format_key_head .. '`' .. field_name .. '`,' - end + if i == len then + select_format_end = select_format_end .. '`' .. field_name .. '`=' .. "?" + update_format_end = update_format_end .. '`' .. field_name .. '`=' .. "?;" + select_format_key_head = select_format_key_head .. '`' .. field_name .. '`' + updates_format_end = updates_format_end .. '`' .. field_name .. '`)' + updates_format_key = updates_format_key .. '?)' else - if i == len then - select_format_end = select_format_end .. '`' .. field_name .. '`=' .. "'%s'" - update_format_end = update_format_end .. '`' .. field_name .. '`=' .. "'%s'" - select_format_key_head = select_format_key_head .. '`' .. field_name .. '`' - else - select_format_end = select_format_end .. '`' .. field_name .. '`=' .. "'%s'" - update_format_end = update_format_end .. '`' .. field_name .. '`=' .. "'%s' and " - select_format_key_head = select_format_key_head .. '`' .. field_name .. '`,' - end + select_format_end = select_format_end .. '`' .. field_name .. '`=' .. "?" + update_format_end = update_format_end .. '`' .. field_name .. '`=' .. "? and " + select_format_key_head = select_format_key_head .. '`' .. field_name .. '`,' + updates_format_end = updates_format_end .. '`' .. field_name .. '`,' + updates_format_key = updates_format_key .. '?,' end select_format_end_list[i] = select_format_end @@ -299,20 +348,180 @@ function M:builder(tab_name, field_list, field_map, key_list) select_format_key_head = select_format_key_head .. ' from ' .. tab_name select_format_head = select_format_head .. ' from ' .. tab_name + --insert prepare 处理 + local insert_prepare_list = {} + for i = 1, self.batch_insert_num do + local end_str = "" + for j = 1, i do + if j ~= i then + end_str = end_str .. insert_format_end .. ',' + else + end_str = end_str .. insert_format_end + end + end + + insert_prepare_list[i] = new_prepare_obj(insert_format_head .. end_str) + end + + insert_format_head = nil + insert_format_end = nil + + --select prepare 处理 + local select_prepare_list = {} + select_prepare_list[0] = new_prepare_obj(select_format_head) + for i = 1, len do + select_prepare_list[i] = new_prepare_obj(select_format_head .. select_format_center .. select_format_end_list[i]) + end + + -- select * from player where key1 in (?); + -- select * from player where key1=?,key2 in (?); + -- select * from player where key1=?,key2=?,key3 in (?); + --select in prepare 处理 + local select_in_prepare_list = {} + for i = 1, len do + local end_field_name = key_list[i] + if i == 1 then + select_in_prepare_list[i] = sformat("%s%s`%s` in ", select_format_head, select_format_center, end_field_name) + else + select_in_prepare_list[i] = sformat("%s%s%s and `%s` in ", select_format_head, select_format_center, select_format_end_list[i - 1], end_field_name) + end + end + + --select limit prepare 处理 + --select * from player where key1 > ? order by ? desc limit ? + --select * from player where key1 < ? order by ? limit ? + local count_sql = nil + if len == 1 then + count_sql = sformat("select count(*) from %s;", self._tab_name) + else + count_sql = sformat("select count(*) from %s where %s;", self._tab_name, select_format_end_list[len - 1]) + end + local select_limit_desc_pre_pare + local select_limit_pre_pare + local select_limit_k_desc_pre_pare + local select_limit_k_pre_pare + local select_f_limit_desc_pre_pare + local select_f_limit_pre_pare + local select_f_limit_k_desc_pre_pare + local select_f_limit_k_pre_pare + + local select_count_pre_pare + + local end_field_name = key_list[len] + if len == 1 then + select_limit_desc_pre_pare = new_prepare_obj(sformat("%s%s`%s` < ? order by `%s` desc limit ?", select_format_head, select_format_center, end_field_name, end_field_name)) + select_limit_pre_pare = new_prepare_obj(sformat("%s%s`%s` > ? order by `%s` limit ?", select_format_head, select_format_center, end_field_name, end_field_name)) + select_limit_k_desc_pre_pare = new_prepare_obj(sformat("%s%s`%s` < ? order by `%s` desc limit ?", select_format_key_head, select_format_center, end_field_name, end_field_name)) + select_limit_k_pre_pare = new_prepare_obj(sformat("%s%s`%s` > ? order by `%s` limit ?", select_format_key_head, select_format_center, end_field_name, end_field_name)) + + select_f_limit_desc_pre_pare = new_prepare_obj(sformat("%s order by `%s` desc limit ?", select_format_head, end_field_name)) + select_f_limit_pre_pare = new_prepare_obj(sformat("%s order by `%s` limit ?", select_format_head, end_field_name)) + select_f_limit_k_desc_pre_pare = new_prepare_obj(sformat("%s order by `%s` desc limit ?", select_format_key_head, end_field_name)) + select_f_limit_k_pre_pare = new_prepare_obj(sformat("%s order by `%s` limit ?", select_format_key_head, end_field_name)) + else + select_limit_desc_pre_pare = new_prepare_obj(sformat("%s%s%s and `%s` < ? order by `%s` desc limit ?", select_format_head, select_format_center, select_format_end_list[len - 1], end_field_name, end_field_name)) + select_limit_pre_pare = new_prepare_obj(sformat("%s%s%s and `%s` > ? order by `%s` limit ?", select_format_head, select_format_center, select_format_end_list[len - 1], end_field_name, end_field_name)) + select_limit_k_desc_pre_pare = new_prepare_obj(sformat("%s%s%s and `%s` < ? order by `%s` desc limit ?", select_format_key_head, select_format_center, select_format_end_list[len - 1], end_field_name, end_field_name)) + select_limit_k_pre_pare = new_prepare_obj(sformat("%s%s%s and `%s` > ? order by `%s` limit ?", select_format_key_head, select_format_center, select_format_end_list[len - 1], end_field_name, end_field_name)) + + select_f_limit_desc_pre_pare = new_prepare_obj(sformat("%s%s%s order by `%s` desc limit ?", select_format_head, select_format_center, select_format_end_list[len - 1], end_field_name)) + select_f_limit_pre_pare = new_prepare_obj(sformat("%s%s%s order by `%s` limit ?", select_format_head, select_format_center, select_format_end_list[len - 1], end_field_name)) + select_f_limit_k_desc_pre_pare = new_prepare_obj(sformat("%s%s%s order by `%s` desc limit ?", select_format_key_head, select_format_center, select_format_end_list[len - 1], end_field_name)) + select_f_limit_k_pre_pare = new_prepare_obj(sformat("%s%s%s order by `%s` limit ?", select_format_key_head, select_format_center, select_format_end_list[len - 1], end_field_name)) + end + + select_count_pre_pare = new_prepare_obj(count_sql) + + select_format_head = nil + select_format_key_head = nil + count_sql = nil + select_format_end = nil + + --update prepare + local update_pre_pare_map = {} + local function get_update_pre_pare(change_map) + local index_list = {} + for field_name in pairs(change_map) do + local index = field_index_map[field_name] + tinsert(index_list, index) + end + tsort(index_list) + local indexs_str = tconcat(index_list, '') + if not update_pre_pare_map[indexs_str] then + local sql_str = update_format_head + for i = 1, #index_list do + local index = index_list[i] + sql_str = sql_str .. update_format_head_list[index] + end + sql_str = sql_str:sub(1,sql_str:len() - 1) .. update_format_end + update_pre_pare_map[indexs_str] = new_prepare_obj(sql_str) + end + + return update_pre_pare_map[indexs_str], index_list + end + + --delete prepare + local delete_prepare_list = {} + delete_prepare_list[0] = new_prepare_obj(delete_format_head) + for i = 1, len do + delete_prepare_list[i] = new_prepare_obj(delete_format_head .. select_format_center .. select_format_end_list[i]) + end + + --delete byrange prepare + --delete from player where key1=? and key2>=? and key2<=?; + --delete from player where key1=? and key2>=?; + --delete from player where key1=? and key2<=?; + local delete_range_prepare_list_b = {} --大于等于 + local delete_range_prepare_list_s = {} --小于等于 + local delete_range_prepare_list_c = {} --大于等于 and 小于等于 + for i = 1, len do + local end_field_name = key_list[i] + if i == 1 then + delete_range_prepare_list_b[i] = new_prepare_obj(sformat("%s%s`%s`>=?", delete_format_head, select_format_center, end_field_name)) + delete_range_prepare_list_s[i] = new_prepare_obj(sformat("%s%s`%s`<=?", delete_format_head, select_format_center, end_field_name)) + delete_range_prepare_list_c[i] = new_prepare_obj(sformat("%s%s`%s`>=? and `%s` <=?", delete_format_head, select_format_center, end_field_name, end_field_name)) + else + delete_range_prepare_list_b[i] = new_prepare_obj(sformat("%s%s%s and `%s`>=?", delete_format_head, select_format_center, select_format_end_list[i-1], end_field_name)) + delete_range_prepare_list_s[i] = new_prepare_obj(sformat("%s%s%s and `%s`<=?", delete_format_head, select_format_center, select_format_end_list[i-1], end_field_name)) + delete_range_prepare_list_c[i] = new_prepare_obj(sformat("%s%s%s and `%s`>=? and `%s` <=?", delete_format_head, select_format_center, select_format_end_list[i-1], end_field_name, end_field_name)) + end + end + + delete_format_head = nil + select_format_center = nil + local insert_list = {} - local function entry_data_to_list(entry_data) + local function entry_data_to_list(entry_data, add_list) for i = 1,#field_list do local fn = field_list[i] + local field_type = field_map[fn] local fv = entry_data[fn] - - if type(fv) == 'string' then - fv = string_util.quote_sql_str(fv) + if field_type == FIELD_TYPE.table then + fv = tab_encode(fv) + end + if not add_list then + insert_list[i] = fv + else + add_list[#add_list + 1] = fv end - insert_list[i] = fv end return insert_list end + --解包查询结果的table + local function decode_tables(sql_ret) + if table_field_len <= 0 then return end + for i = 1,#sql_ret do + local one_ret = sql_ret[i] + for j = 1, table_field_len do + local fn = table_field_list[j] + local v = one_ret[fn] + if v then + one_ret[fn] = tab_decode(v) + end + end + end + end --防止sql注入 local function quete_key_values(key_values) for i = 1, #key_values do @@ -323,287 +532,336 @@ function M:builder(tab_name, field_list, field_map, key_list) end end - local function handle_sql_ret(ret_list,s_start,s_end,sql_ret,sql_str) - if not sql_ret or sql_ret.err then - log.error("sql ret err ",sql_ret,sql_str) - for i = s_start, s_end do - ret_list[i] = false - end - else - for i = s_start, s_end do - ret_list[i] = true - end - end - end - --insert 批量插入 self._insert = function(entry_data_list) - local sql_str = insert_format_head - local add_str = nil - local add_count = 0 - local index = 1 - local len = #entry_data_list - local sql_ret = nil - local ret_list = {} - local s_index = index - - while index <= len do - local entry_data = entry_data_list[index] - if not entry_data then break end - - add_str = add_str or sformat(insert_format_end, tunpack(entry_data_to_list(entry_data))) - if sql_str:len() + add_str:len() > max_packet_size then - --一条都超过 - assert(add_count ~= 0, "can`t insert max_packet_size:" .. max_packet_size .. ' packlen:' .. sql_str:len() + add_str:len()) - --超过最大长度了 先插入一波 - sql_str = sql_str:sub(1,sql_str:len() - 1) - sql_ret = self._db:query(sql_str) - handle_sql_ret(ret_list, s_index, index - 1, sql_ret, sql_str) - sql_str = insert_format_head - add_count = 0 - s_index = index - elseif index == len then --到结尾了 - sql_str = sql_str .. add_str - sql_ret = self._db:query(sql_str) - handle_sql_ret(ret_list, s_index, index, sql_ret, sql_str) - add_count = 0 - index = index + 1 - s_index = index - else - sql_ret = nil - sql_str = sql_str .. add_str .. ',' - index = index + 1 - add_str = nil - add_count = add_count + 1 + --批量插入 + local res_list = {} + local ref_list = {} + local cur = 1 + local ret_index = 1 + local len = #entry_data_list + while true do + if cur > len then break end + local add_list = {} + local cnt = 0 + for j = 1, self.batch_insert_num do + local entry_data = entry_data_list[cur] + if entry_data then + entry_data_to_list(entry_data, add_list) + ref_list[j] = entry_data + cur = cur + 1 + cnt = cnt + 1 + else + break + end end - end - return ret_list + + if cnt <= 0 then break end + local prepare_obj = insert_prepare_list[cnt] + local isok, ret = pcall(prepare_execute, self._db, prepare_obj, tunpack(add_list)) + if isok and ret and not ret.err then + for i = 1, cnt do + res_list[ret_index] = true + ret_index = ret_index + 1 + end + else + log.error("_insert err ", self._tab_name, ret, ref_list) + for i = 1, cnt do + res_list[ret_index] = false + ret_index = ret_index + 1 + end + end + end + + return res_list end --insert_one插入单条 self._insert_one = function(entry_data) - local sql_str = insert_format_head .. sformat(insert_format_end, tunpack(entry_data_to_list(entry_data))) - assert(sql_str:len() <= max_packet_size, "can`t insert max_packet_size:" .. max_packet_size .. ' packlen:' .. sql_str:len()) - local sql_ret = self._db:query(sql_str) - if not sql_ret or sql_ret.err then - log.error("_insert_one err ",sql_ret,sql_str) - error("_insert_one err " .. sql_str) + local prepare_obj = insert_prepare_list[1] + local add_list = entry_data_to_list(entry_data) + local isok, ret = pcall(prepare_execute, self._db, prepare_obj, tunpack(add_list)) + if not isok or not ret or ret.err then + log.error("_insert_one err ", ret, entry_data) + error("_insert_one err ") end - return true end --select 查询 self._select = function(key_values) local len = #key_values - assert(len >= 0 and len <= #select_format_end_list, "err key_values len " .. len) - local sql_str = nil - if len == 0 then - sql_str = select_format_head - else - quete_key_values(key_values) - sql_str = select_format_head .. select_format_center .. sformat(select_format_end_list[len], tunpack(key_values)) - end - local sql_ret = self._db:query(sql_str) - if not sql_ret or sql_ret.err then - log.error("select err ",sql_str,sql_ret) - error("select err ".. sql_str) + assert(len >= 0 and len <= #key_list, "err key_values len " .. len) + local prepare_obj = select_prepare_list[len] + local isok, ret = pcall(prepare_execute, self._db, prepare_obj, tunpack(key_values)) + if not isok or not ret or ret.err then + log.error("_select err ", ret, key_values) + error("_select err ") end - return sql_ret + + decode_tables(ret) + return ret end --查询一条数据 local keys_max_len = #key_list self._select_one = function(key_values) - quete_key_values(key_values) - local sql_str = select_format_head .. select_format_center .. sformat(select_format_end_list[keys_max_len], tunpack(key_values)) - local sql_ret = self._db:query(sql_str) - if not sql_ret or sql_ret.err then - log.error("_select_one err ",sql_str,sql_ret) - error("_select_one err " .. sql_str) + local prepare_obj = select_prepare_list[keys_max_len] + local isok, ret = pcall(prepare_execute, self._db, prepare_obj, tunpack(key_values)) + if not isok or not ret or ret.err then + log.error("_select_one err ", ret, key_values) + error("_select_one err ") end - return sql_ret[1] + + decode_tables(ret) + return ret[1] end --IN 查询 self._select_in = function(in_values, key_values) local len = #key_values - if type(in_values[1]) == 'string' then - for i = 1,#in_values do - in_values[i] = "'" .. in_values[i] .. "'" + local prepare_str = select_in_prepare_list[len + 1] + prepare_str = prepare_str .. '(' + local in_len = #in_values + + local args = {} + for i = 1, len do + args[#args + 1] = key_values[i] + end + for i = 1, in_len do + if i == in_len then + prepare_str = prepare_str .. '?' + else + prepare_str = prepare_str .. '?,' end + args[#args + 1] = in_values[i] end - local end_field_name = key_list[len + 1] - local endstr = "" - quete_key_values(in_values) - if len > 0 then - quete_key_values(key_values) - endstr = sformat(select_format_end_list[len], tunpack(key_values)) - endstr = endstr .. sformat(" and `%s` in(%s)", end_field_name, tconcat(in_values, ',')) - else - endstr = endstr .. sformat(" `%s` in(%s)", end_field_name, tconcat(in_values, ',')) + prepare_str = prepare_str .. ')' + local prepare_obj = new_prepare_obj(prepare_str) + local isok, ret = pcall(prepare_execute, self._db, prepare_obj, tunpack(args)) + + local stmt = prepare_obj.stmt + if stmt then + pcall(self._db.conn.stmt_close, self._db.conn, stmt) end - local sql_str = select_format_head .. select_format_center .. endstr - local sql_ret = self._db:query(sql_str) - if not sql_ret or sql_ret.err then - log.error("_select_in err ",sql_str,sql_ret) - error("_select_in err " .. sql_str) + if not isok or not ret or ret.err then + log.error("_select_in err ", ret, key_values) + error("_select_in err ") end - return sql_ret + + decode_tables(ret) + return ret end --分页 查询 self._select_limit = function(cursor, limit, sort, key_values, is_only_key) assert(type(limit) == 'number') assert(type(sort) == 'number') - if type(cursor) == 'string' then - cursor = string_util.quote_sql_str(cursor) - end - quete_key_values(key_values) local len = #key_values - local sql_str = "" local end_field_name = key_list[len + 1] - local head = nil - if is_only_key then --是否仅查询主键 - head = select_format_key_head - else - head = select_format_head - end - - local flag = nil - local end_str = nil - if sort == 1 then --升序 - flag = '>' - end_str = ' order by ' .. end_field_name - else --降序 - flag = '<' - end_str = ' order by ' .. end_field_name .. ' desc' - end + local prepare_obj = nil - if not cursor then --开头把总数查出来 - local keys_str = "" - if len > 0 then - keys_str = sformat(select_format_end_list[len], tunpack(key_values)) - sql_str = "select count(*) from " .. tab_name .. select_format_center .. keys_str .. ';' - sql_str = sql_str .. head .. select_format_center .. keys_str .. end_str .. ' limit ' .. limit + local args = {} + if not cursor then + if is_only_key then + if sort == 1 then --升序 + prepare_obj = select_f_limit_k_pre_pare + else + prepare_obj = select_f_limit_k_desc_pre_pare + end else - sql_str = "select count(*) from " .. tab_name .. ';' - sql_str = sql_str .. head .. end_str .. ' limit ' .. limit + if sort == 1 then --升序 + prepare_obj = select_f_limit_pre_pare + else + prepare_obj = select_f_limit_desc_pre_pare + end end else - if len > 0 then - sql_str = head .. select_format_center .. sformat(select_format_end_list[len] .. ' ', tunpack(key_values)) - .. ' and ' .. end_field_name .. flag .. cursor .. end_str .. ' limit ' .. limit + if is_only_key then + if sort == 1 then --升序 + prepare_obj = select_limit_k_pre_pare + else + prepare_obj = select_limit_k_desc_pre_pare + end else - sql_str = head .. select_format_center .. end_field_name .. flag .. cursor .. end_str .. ' limit ' .. limit + if sort == 1 then --升序 + prepare_obj = select_limit_pre_pare + else + prepare_obj = select_limit_desc_pre_pare + end + end + end + + local count = nil + --拿一下count + if not cursor then + local isok, ret = pcall(prepare_execute, self._db, select_count_pre_pare, tunpack(key_values)) + if not isok or not ret or ret.err then + log.error("_select_limit err ", ret, key_values) + error("_select_limit err ") end + count = ret[1]["count(*)"] end - local sql_ret = self._db:query(sql_str) - if not sql_ret or sql_ret.err then - log.error("_select_limit err ",sql_str, sql_ret) - error("_select_limit err " .. sql_str) + --where参数 + for i = 1, len do + args[#args + 1] = key_values[i] end + if cursor then + args[#args + 1] = cursor + end + args[#args + 1] = limit - local cursor = nil - local count = nil - local ret_list = nil - if sql_ret.multiresultset then - count = sql_ret[1][1]["count(*)"] - ret_list = sql_ret[2] - else - ret_list = sql_ret + local isok, ret = pcall(prepare_execute, self._db, prepare_obj, tunpack(args)) + + if not isok or not ret or ret.err then + log.error("_select_limit err ", ret, key_values) + error("_select_limit err ") end - if #ret_list > 0 then - local end_ret = ret_list[#ret_list] + + local cursor = nil + if #ret > 0 then + local end_ret = ret[#ret] cursor = end_ret[end_field_name] end - - return cursor, ret_list, count + decode_tables(ret) + return cursor, ret, count end --update 更新 self._update = function(entry_data_list,change_map_list) - local sql_str = "" - local add_str = nil - local add_count = 0 - local index = 1 + local res_list = {} + local cur = 1 + local ret_index = 1 local len = #entry_data_list - local sql_ret = nil - local ret_list = {} - local s_index = index - - while index <= len do - local entry_data = entry_data_list[index] - if not entry_data then break end - - if not add_str then - local change_map = change_map_list[index] - local center_str = "" - for field_name in pairs(change_map) do - local index = field_index_map[field_name] - local field_value = entry_data[field_name] - if type(field_value) == 'string' then - field_value = string_util.quote_sql_str(field_value) + local min_len = self.batch_update_num + if len < min_len then + min_len = len + end + + while true do + if cur > len then break end + local updates = {} + local changes = {} + local field_name_map = {} + local fcnt = 0 + for i = 1, self.batch_update_num do + local entry_data = entry_data_list[cur] + local change_map = change_map_list[cur] + cur = cur + 1 + if entry_data then + for field_name in pairs(change_map) do + if not field_name_map[field_name] then + field_name_map[field_name] = {} + fcnt = fcnt + 1 + end + local field_n_list = field_name_map[field_name] + field_n_list[#field_n_list + 1] = entry_data + end + + updates[i] = entry_data + changes[i] = change_map + else + break + end + end + + local prepare_str = update_format_head + local args = {} + local key_len = #key_list + local center_str = select_format_end_list[key_len] + local ucnt = 0 + for field_name, list in pairs(field_name_map) do + ucnt = ucnt + 1 + prepare_str = prepare_str .. '`' .. field_name .. '`=case\n' + for i = 1, #list do + prepare_str = prepare_str .. "when " .. center_str .. " then ?\n" + local entry_data = list[i] + for j = 1, key_len do + local kn = key_list[j] + local kv = entry_data[kn] + args[#args + 1] = kv + end + + local fv = entry_data[field_name] + local ft = field_map[field_name] + if ft == FIELD_TYPE.table then + fv = tab_encode(fv) end - center_str = center_str .. sformat(update_format_head_list[index], field_value) .. ',' + args[#args + 1] = fv + end + + if ucnt < fcnt then + prepare_str = prepare_str .. 'else `' .. field_name .. '`\n end,\n' + else + prepare_str = prepare_str .. 'else `' .. field_name .. '`\n end\n' + end + end + + prepare_str = prepare_str .. updates_format_end .. ' in (' + local uplen = #updates + for i = 1, uplen do + local entry_data = updates[i] + if i == uplen then + prepare_str = prepare_str .. updates_format_key + else + prepare_str = prepare_str .. updates_format_key .. ',' end - center_str = center_str:sub(1,center_str:len() - 1) - local key_values = {} - for i = 1,#key_list do - key_values[i] = entry_data[key_list[i]] + for j = 1, key_len do + local kn = key_list[j] + local kv = entry_data[kn] + args[#args + 1] = kv end - add_str = update_format_head .. center_str .. sformat(update_format_end,tunpack(key_values)) end - if sql_str:len() + add_str:len() > max_packet_size then - --一条都超过 - assert(add_count ~= 0, "can`t update max_packet_size:" .. max_packet_size .. ' packlen:' .. sql_str:len() + add_str:len()) - --超过最大长度了 先插入一波 - sql_str = sql_str:sub(1,sql_str:len() - 1) - sql_ret = self._db:query(sql_str) - handle_sql_ret(ret_list, s_index, index - 1, sql_ret, sql_str) - sql_str = "" - add_count = 0 - s_index = index - elseif index == len then --到结尾了 - sql_str = sql_str .. add_str - sql_ret = self._db:query(sql_str) - handle_sql_ret(ret_list, s_index, index, sql_ret, sql_str) - add_count = 0 - index = index + 1 - s_index = index + prepare_str = prepare_str .. ');' + + if uplen <= 0 then break end + local prepare_obj = new_prepare_obj(prepare_str) + local isok, ret = pcall(prepare_execute, self._db, prepare_obj, tunpack(args)) + if not isok or not ret or ret.err then + log.error("_update_one err ", ret, updates, changes) + for i = 1, uplen do + res_list[ret_index] = false + ret_index = ret_index + 1 + end else - sql_ret = nil - sql_str = sql_str .. add_str .. ';' - index = index + 1 - add_str = nil - add_count = add_count + 1 + for i = 1, uplen do + res_list[ret_index] = true + ret_index = ret_index + 1 + end + end + local stmt = prepare_obj.stmt + if stmt then + pcall(self._db.conn.stmt_close, self._db.conn, stmt) end end - return ret_list + + return res_list end --更新一条数据 self._update_one = function(entry_data, change_map) - local center_str = "" - for field_name in pairs(change_map) do - local index = field_index_map[field_name] + local prepare_obj, index_list = get_update_pre_pare(change_map) + local args = {} + for i = 1, #index_list do + local index = index_list[i] + local field_name = field_list[index] + local field_type = field_map[field_name] local field_value = entry_data[field_name] - if type(field_value) == 'string' then - field_value = string_util.quote_sql_str(field_value) + if field_type == FIELD_TYPE.table then + field_value = tab_encode(field_value) end - center_str = center_str .. sformat(update_format_head_list[index], field_value) .. ',' + args[#args + 1] = field_value end - center_str = center_str:sub(1,center_str:len() - 1) - local key_values = {} + for i = 1,#key_list do - key_values[i] = entry_data[key_list[i]] + args[#args + 1] = entry_data[key_list[i]] end - local sql_str = update_format_head .. center_str .. sformat(update_format_end,tunpack(key_values)) - local sql_ret = self._db:query(sql_str) - if not sql_ret or sql_ret.err then - log.error("_update_one err ",sql_str,sql_ret) - error("_update_one err " .. sql_str) + + local isok, ret = pcall(prepare_execute, self._db, prepare_obj, tunpack(args)) + if not isok or not ret or ret.err then + log.error("_update_one err ", ret, entry_data, change_map) + error("_update_one err ") end return true @@ -612,64 +870,39 @@ function M:builder(tab_name, field_list, field_map, key_list) self._delete = function(key_values) local len = #key_values assert(len >= 0 and len <= #select_format_end_list, "err key_values len " .. len) - local sql_str = nil - if len == 0 then - sql_str = delete_format_head - else - quete_key_values(key_values) - sql_str = delete_format_head .. delete_format_center .. sformat(select_format_end_list[len], tunpack(key_values)) + local prepare_obj = delete_prepare_list[len] + local isok, ret = pcall(prepare_execute, self._db, prepare_obj, tunpack(key_values)) + if not isok or not ret or ret.err then + log.error("_delete err ", ret, key_values) + error("_delete err ") end - local sql_ret = self._db:query(sql_str) - if not sql_ret or sql_ret.err then - log.error("_delete err ",sql_str,sql_ret) - error("_delete err " .. sql_str) - end return true end self._delete_by_range = function(left, right, key_values) local len = #key_values - local end_field_name = key_list[len + 1] - local sql_str = nil - local end_str = nil - local field_type = field_map[end_field_name] - if type(left) == 'string' then - left = string_util.quote_sql_str(left) - end - if type(right) == 'string' then - right = string_util.quote_sql_str(right) + local args = {} + for i = 1, len do + args[#args + 1] = key_values[i] end + local prepare_obj = nil if left and right then - if IS_NUMBER_TYPE[field_type] then - end_str = sformat("`%s` >= %d and `%s` <= %d", end_field_name, left, end_field_name, right) - else - end_str = sformat("`%s` >= '%s' and `%s` <= '%s'", end_field_name, left, end_field_name, right) - end + prepare_obj = delete_range_prepare_list_c[len + 1] + args[#args + 1] = left + args[#args + 1] = right elseif left then - if IS_NUMBER_TYPE[field_type] then - end_str = sformat("`%s` >= %d", end_field_name, left) - else - end_str = sformat("`%s` >= '%s'", end_field_name, left) - end - else - if IS_NUMBER_TYPE[field_type] then - end_str = sformat("`%s` <= %d", end_field_name, right) - else - end_str = sformat("`%s` <= '%s'", end_field_name, right) - end - end - if len > 0 then - quete_key_values(key_values) - sql_str = delete_format_head .. delete_format_center .. sformat(select_format_end_list[len], tunpack(key_values)) .. ' and ' .. end_str + prepare_obj = delete_range_prepare_list_b[len + 1] + args[#args + 1] = left else - sql_str = delete_format_head .. delete_format_center .. end_str + prepare_obj = delete_range_prepare_list_s[len + 1] + args[#args + 1] = right end - local sql_ret = self._db:query(sql_str) - if not sql_ret or sql_ret.err then - log.error("_delete_by_range err ",sql_str,sql_ret) - error("_delete_by_range err " .. sql_str) + local isok, ret = pcall(prepare_execute, self._db, prepare_obj, tunpack(args)) + if not isok or not ret or ret.err then + log.error("_delete_by_range err ", ret, key_values) + error("_delete_by_range err ") end return true diff --git a/test/load_mods_mysql.lua b/test/load_mods_mysql.lua index 8f858502a..c56d0ebfd 100644 --- a/test/load_mods_mysql.lua +++ b/test/load_mods_mysql.lua @@ -4,8 +4,26 @@ local user = "root" local password = "123456" return { + share_config_m = { + launch_seq = 1, + launch_num = 1, + default_arg = { + --直连的配置 + mysql = { + game = { + host = host, + port = port, + max_packet_size = 1048576, + user = user, + password = password, + database = 'gamedb', + } + } + } + }, + mysql_m = { - launch_seq = 4, + launch_seq = 2, launch_num = 4, mod_args = { { @@ -60,7 +78,7 @@ return { }, mysql_test_m = { - launch_seq = 2, + launch_seq = 3, launch_num = 1, } } \ No newline at end of file diff --git a/test/load_mods_ormmysql.lua b/test/load_mods_ormmysql.lua index ef057a361..fbd1f6694 100644 --- a/test/load_mods_ormmysql.lua +++ b/test/load_mods_ormmysql.lua @@ -1,6 +1,23 @@ return { + share_config_m = { + launch_seq = 2, + launch_num = 1, + default_arg = { + mysql = { + admin = { + host = '127.0.0.1', + port = '3306', + max_packet_size = 1048576, + user = 'root', + password = '123456', + database = 'admin', + } + } + } + }, + mysql_m = { - launch_seq = 1, + launch_seq = 2, launch_num = 1, default_arg = { instance_name = "admin", @@ -17,7 +34,7 @@ return { }, ormmysql_test_m = { - launch_seq = 2, + launch_seq = 3, launch_num = 1, } } \ No newline at end of file diff --git a/test/module/mysql_test_m.lua b/test/module/mysql_test_m.lua index 324379d65..c3c6ccb8d 100644 --- a/test/module/mysql_test_m.lua +++ b/test/module/mysql_test_m.lua @@ -43,8 +43,39 @@ local function test() end +local function test_l_new_client() + local db = mysqlf.l_new_client("game") + db.conn:query("drop table if exists user") + + local create_sql = [[ + CREATE TABLE IF NOT EXISTS `user` ( + `id` int(11) NOT NULL, + `name` varchar(255) DEFAULT NULL, + PRIMARY KEY (`id`) + ) ENGINE=InnoDB DEFAULT CHARSET=latin1; + ]] + db.conn:query(create_sql) + + local stmt_insert = db.conn:prepare("INSERT user (id,name) VALUES (?,?),(?,?)") + log.info("stmt_insert >>> ", stmt_insert) + local r = db.conn:execute(stmt_insert,10001, "'", 10002, "'ddd") + log.info("insert ret:", r) + -- os.execute("pkill mysql") + -- log.info("杀掉数据库》》》》》》》》》》》》》") + + skynet.sleep(500) + + -- os.execute("systemctl start mysql") + -- log.info("启动数据库》》》》》》》》》》》》》") + local r = db.conn:execute(stmt_insert,10003, "'", 10004, "'ddd") + log.info("insert ret:", r) + + log.info("select:", db.conn:query("select * from user;")) +end + function CMD.start() - skynet.fork(test) + --skynet.fork(test) + skynet.fork(test_l_new_client) return true end diff --git a/test/module/ormmongo_test_m.lua b/test/module/ormmongo_test_m.lua index 88a52293a..b850aff91 100644 --- a/test/module/ormmongo_test_m.lua +++ b/test/module/ormmongo_test_m.lua @@ -1455,7 +1455,6 @@ end --压测 --stress testing ---用skynet.queue qps = 823 local function stress_testing() delete_table() local adapter = ormadapter_mongo:new("admin") @@ -1697,6 +1696,114 @@ local function test_get_entry_limit() end end delete_table() + + --测试1个key + delete_table() + --测试有缓存的 + local adapter = ormadapter_mongo:new("admin") + local orm_obj = ormtable:new("t_player") + :int64("player_id") + :int64("role_id") + :int8("sex") + :set_keys("player_id") + :set_cache(500,500) --5秒保存一次 + :builder(adapter) + + for i = 1, 100 do + orm_obj:create_one_entry({player_id = i}) + end + + --测试升序 + local entry_list = nil + local curson = nil + local count = nil + for i = 1, 10 do + curson, entry_list, count = orm_obj:get_entry_by_limit(curson, 10, 1) + assert(curson == i * 10) + + if i == 1 then + assert(count == 100) + else + assert(not count) + end + + for k,v in ipairs(entry_list) do + assert(v:get('player_id') == (i - 1) * 10 + k) + end + end + + --测试降序 + local entry_list = nil + local curson = nil + local count = nil + for i = 1, 10 do + curson, entry_list, count = orm_obj:get_entry_by_limit(curson, 10, -1) + assert(curson == i * 10) + + if i == 1 then + assert(count == 100) + else + assert(not count) + end + + for k,v in ipairs(entry_list) do + assert(v:get('player_id') == (10 - i + 1) * 10 - (k - 1)) + end + end + + delete_table() + + --测试没有缓存的 + local adapter = ormadapter_mongo:new("admin") + local orm_obj = ormtable:new("t_player") + :int64("player_id") + :int64("role_id") + :int8("sex") + :set_keys("player_id") + :builder(adapter) + + for i = 1, 100 do + orm_obj:create_one_entry({player_id = i}) + end + + --测试升序 + local entry_list = nil + local curson = nil + local count = nil + for i = 1, 10 do + curson, entry_list, count = orm_obj:get_entry_by_limit(curson, 10, 1) + assert(curson == i * 10) + + if i == 1 then + assert(count == 100) + else + assert(not count) + end + + for k,v in ipairs(entry_list) do + assert(v:get('player_id') == (i - 1) * 10 + k) + end + end + + --测试降序 + local entry_list = nil + local curson = nil + local count = nil + for i = 1, 10 do + curson, entry_list, count = orm_obj:get_entry_by_limit(curson, 10, -1) + assert(curson == i * 10) + + if i == 1 then + assert(count == 100) + else + assert(not count) + end + + for k,v in ipairs(entry_list) do + assert(v:get('player_id') == (10 - i + 1) * 10 - (k - 1)) + end + end + delete_table() end local function test_delete_by_range() @@ -1773,6 +1880,35 @@ local function test_delete_by_range() delete_table() end +--测试table +local function test_table_type() + delete_table() + + local adapter = ormadapter_mongo:new("admin") + local orm_obj = ormtable:new("t_player") + :int64("player_id") + :table("info") + :set_keys("player_id") + :set_cache(500,500) --5秒保存一次 + :builder(adapter) + + local entry = orm_obj:create_one_entry({player_id = 10001, info = {a = 1, b = 2, c = "'"}}) + assert(entry) + local info = entry:get('info') + info.d = 100 + info.c = nil + entry:set('info', info) + orm_obj:save_one_entry(entry) + + local entry = orm_obj:get_one_entry(10001) + assert(entry) + local info = entry:get('info') + assert(info.c == nil) + assert(info.d == 100) + + delete_table() +end + function CMD.start() skynet.fork(function() delete_table() @@ -1825,7 +1961,8 @@ function CMD.start() test_get_entry_limit() log.info("test_delete_by_range") test_delete_by_range() - + log.info("test_table_type") + test_table_type() delete_table() log.info("test over >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") end) diff --git a/test/module/ormmysql_test_m.lua b/test/module/ormmysql_test_m.lua index d92bb15a4..0a737d6ac 100644 --- a/test/module/ormmysql_test_m.lua +++ b/test/module/ormmysql_test_m.lua @@ -39,6 +39,7 @@ local function test_create_table(is_del) :string8192("sex13") :text("sex14") :blob("sex15") + :table("info") :set_keys("player_id","role_id","sex") :builder(adapter) @@ -84,6 +85,7 @@ local function test_alter_table() :int64("role_id") :int8("sex") :int8("nickname1") + :table("info2") :set_keys("player_id","role_id","sex") :builder(adapter) @@ -113,7 +115,9 @@ local function test_alter_table() assert(sqlret[18].Field == 'sex13' and sqlret[18].Type == 'varchar(8192)') assert(sqlret[19].Field == 'sex14' and sqlret[19].Type == 'text') assert(sqlret[20].Field == 'sex15' and sqlret[20].Type == 'blob') - assert(sqlret[21].Field == 'nickname1' and sqlret[21].Type == 'tinyint') + assert(sqlret[21].Field == 'info' and sqlret[21].Type == 'blob') + assert(sqlret[22].Field == 'nickname1' and sqlret[22].Type == 'tinyint') + assert(sqlret[23].Field == 'info2' and sqlret[23].Type == 'blob') delete_table() end @@ -1515,7 +1519,6 @@ end --压测 --stress testing ---用skynet.queue qps = 823 local function stress_testing() delete_table() local adapter = ormadapter_mysql:new("admin") @@ -1757,6 +1760,114 @@ local function test_get_entry_limit() end end delete_table() + + --测试1个key + delete_table() + --测试有缓存的 + local adapter = ormadapter_mysql:new("admin") + local orm_obj = ormtable:new("t_player") + :int64("player_id") + :int64("role_id") + :int8("sex") + :set_keys("player_id") + :set_cache(500,500) --5秒保存一次 + :builder(adapter) + + for i = 1, 100 do + orm_obj:create_one_entry({player_id = i}) + end + + --测试升序 + local entry_list = nil + local curson = nil + local count = nil + for i = 1, 10 do + curson, entry_list, count = orm_obj:get_entry_by_limit(curson, 10, 1) + assert(curson == i * 10) + + if i == 1 then + assert(count == 100) + else + assert(not count) + end + + for k,v in ipairs(entry_list) do + assert(v:get('player_id') == (i - 1) * 10 + k) + end + end + + --测试降序 + local entry_list = nil + local curson = nil + local count = nil + for i = 1, 10 do + curson, entry_list, count = orm_obj:get_entry_by_limit(curson, 10, -1) + assert(curson == (10 - i) * 10 + 1) + + if i == 1 then + assert(count == 100) + else + assert(not count) + end + + for k,v in ipairs(entry_list) do + assert(v:get('player_id') == (10 - i + 1) * 10 - (k - 1)) + end + end + + delete_table() + + --测试没有缓存的 + local adapter = ormadapter_mysql:new("admin") + local orm_obj = ormtable:new("t_player") + :int64("player_id") + :int64("role_id") + :int8("sex") + :set_keys("player_id") + :builder(adapter) + + for i = 1, 100 do + orm_obj:create_one_entry({player_id = i}) + end + + --测试升序 + local entry_list = nil + local curson = nil + local count = nil + for i = 1, 10 do + curson, entry_list, count = orm_obj:get_entry_by_limit(curson, 10, 1) + assert(curson == i * 10) + + if i == 1 then + assert(count == 100) + else + assert(not count) + end + + for k,v in ipairs(entry_list) do + assert(v:get('player_id') == (i - 1) * 10 + k) + end + end + + --测试降序 + local entry_list = nil + local curson = nil + local count = nil + for i = 1, 10 do + curson, entry_list, count = orm_obj:get_entry_by_limit(curson, 10, -1) + assert(curson == (10 - i) * 10 + 1) + + if i == 1 then + assert(count == 100) + else + assert(not count) + end + + for k,v in ipairs(entry_list) do + assert(v:get('player_id') == (10 - i + 1) * 10 - (k - 1)) + end + end + delete_table() end local function test_delete_by_range() @@ -1853,6 +1964,35 @@ local function test_quete_key_values() mysqlf:instance("admin"):query("drop table if exists t_user") end +--测试table +local function test_table_type() + delete_table() + + local adapter = ormadapter_mysql:new("admin") + local orm_obj = ormtable:new("t_player") + :int64("player_id") + :table("info") + :set_keys("player_id") + :set_cache(500,500) --5秒保存一次 + :builder(adapter) + + local entry = orm_obj:create_one_entry({player_id = 10001, info = {a = 1, b = 2, c = "'"}}) + assert(entry) + local info = entry:get('info') + info.d = 100 + info.c = nil + entry:set('info', info) + orm_obj:save_one_entry(entry) + + local entry = orm_obj:get_one_entry(10001) + assert(entry) + local info = entry:get('info') + assert(info.c == nil) + assert(info.d == 100) + + delete_table() +end + function CMD.start() skynet.fork(function() delete_table() @@ -1907,6 +2047,8 @@ function CMD.start() test_delete_by_range() log.info("test_quete_key_values") test_quete_key_values() + log.info("test_table_type") + test_table_type() delete_table() log.info("test over >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") end)