From 560c86275c41e5c26ac71aedf037ffe815575323 Mon Sep 17 00:00:00 2001 From: fuhui Date: Tue, 9 Apr 2024 14:31:54 +0800 Subject: [PATCH] add redis vector store --- examples/flask/register.py | 21 ++++++ flask4modelcache.py | 29 +++++-- modelcache/adapter/adapter.py | 11 +++ modelcache/adapter/adapter_register.py | 13 ++++ modelcache/manager/data_manager.py | 3 + modelcache/manager/vector_data/manager.py | 22 ++++++ modelcache/manager/vector_data/redis.py | 92 +++++++---------------- modelcache/utils/__init__.py | 4 + modelcache/utils/index_util.py | 9 +++ requirements.txt | 2 + 10 files changed, 133 insertions(+), 73 deletions(-) create mode 100644 examples/flask/register.py create mode 100644 modelcache/adapter/adapter_register.py create mode 100644 modelcache/utils/index_util.py diff --git a/examples/flask/register.py b/examples/flask/register.py new file mode 100644 index 0000000..737b495 --- /dev/null +++ b/examples/flask/register.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +""" +register index for redis +""" +import json +import requests + + +def run(): + url = 'http://127.0.0.1:5000/modelcache' + type = 'register' + scope = {"model": "CODEGPT-1117"} + data = {'type': type, 'scope': scope} + headers = {"Content-Type": "application/json"} + res = requests.post(url, headers=headers, json=json.dumps(data)) + res_text = res.text + print('res_text: {}'.format(res_text)) + + +if __name__ == '__main__': + run() \ No newline at end of file diff --git a/flask4modelcache.py b/flask4modelcache.py index cde579d..8a3efa2 100644 --- a/flask4modelcache.py +++ b/flask4modelcache.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import time -from datetime import datetime from flask import Flask, request import logging import configparser @@ -15,7 +14,6 @@ from modelcache.utils.model_filter import model_blacklist_filter from modelcache.embedding import Data2VecAudio - # 创建一个Flask实例 app = Flask(__name__) @@ -36,13 +34,19 @@ def response_hitquery(cache_resp): data2vec = Data2VecAudio() mysql_config = configparser.ConfigParser() mysql_config.read('modelcache/config/mysql_config.ini') + milvus_config = configparser.ConfigParser() milvus_config.read('modelcache/config/milvus_config.ini') -# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config), -# VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)) + +# redis_config = configparser.ConfigParser() +# redis_config.read('modelcache/config/redis_config.ini') + data_manager = get_data_manager(CacheBase("mysql", config=mysql_config), - VectorBase("redis", dimension=data2vec.dimension, milvus_config=milvus_config)) + VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)) + +# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config), +# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config)) cache.init( @@ -88,9 +92,9 @@ def user_backend(): model = model.replace('.', '_') query = param_dict.get("query") chat_info = param_dict.get("chat_info") - if request_type is None or request_type not in ['query', 'insert', 'detox', 'remove']: + if request_type is None or request_type not in ['query', 'insert', 'remove', 'register']: result = {"errorCode": 102, - "errorDesc": "type exception, should one of ['query', 'insert', 'detox', 'remove']", + "errorDesc": "type exception, should one of ['query', 'insert', 'remove', 'register']", "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} cache.data_manager.save_query_resp(result, model=model, query='', delta_time=0) return json.dumps(result) @@ -173,6 +177,17 @@ def user_backend(): result = {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"} return json.dumps(result) + if request_type == 'register': + # iat_type = param_dict.get("iat_type") + response = adapter.ChatCompletion.create_register( + model=model + ) + if response in ['create_success', 'already_exists']: + result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"} + else: + result = {"errorCode": 502, "errorDesc": "", "response": response, "writeStatus": "exception"} + return json.dumps(result) + if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, debug=True) diff --git a/modelcache/adapter/adapter.py b/modelcache/adapter/adapter.py index f5e38eb..1428da2 100644 --- a/modelcache/adapter/adapter.py +++ b/modelcache/adapter/adapter.py @@ -5,6 +5,7 @@ from modelcache.adapter.adapter_query import adapt_query from modelcache.adapter.adapter_insert import adapt_insert from modelcache.adapter.adapter_remove import adapt_remove +from modelcache.adapter.adapter_register import adapt_register class ChatCompletion(openai.ChatCompletion): @@ -44,6 +45,16 @@ def create_remove(cls, *args, **kwargs): logging.info('adapt_remove_e: {}'.format(e)) return str(e) + @classmethod + def create_register(cls, *args, **kwargs): + try: + return adapt_register( + *args, + **kwargs + ) + except Exception as e: + return str(e) + def construct_resp_from_cache(return_message, return_query): return { diff --git a/modelcache/adapter/adapter_register.py b/modelcache/adapter/adapter_register.py new file mode 100644 index 0000000..53df128 --- /dev/null +++ b/modelcache/adapter/adapter_register.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +from modelcache import cache + + +def adapt_register(*args, **kwargs): + chat_cache = kwargs.pop("cache_obj", cache) + model = kwargs.pop("model", None) + if model is None or len(model) == 0: + return ValueError('') + + register_resp = chat_cache.data_manager.create_index(model) + print('register_resp: {}'.format(register_resp)) + return register_resp diff --git a/modelcache/manager/data_manager.py b/modelcache/manager/data_manager.py index db8b776..a83e638 100644 --- a/modelcache/manager/data_manager.py +++ b/modelcache/manager/data_manager.py @@ -256,6 +256,9 @@ def delete(self, id_list, **kwargs): return {'status': 'success', 'milvus': 'delete_count: '+str(v_delete_count), 'mysql': 'delete_count: '+str(s_delete_count)} + def create_index(self, model, **kwargs): + return self.v.create(model) + def truncate(self, model_name): # model = kwargs.pop("model", None) # drop milvus data diff --git a/modelcache/manager/vector_data/manager.py b/modelcache/manager/vector_data/manager.py index 54f7c55..70448b2 100644 --- a/modelcache/manager/vector_data/manager.py +++ b/modelcache/manager/vector_data/manager.py @@ -68,6 +68,28 @@ def get(name, **kwargs): local_mode=local_mode, local_data=local_data ) + elif name == "redis": + from modelcache.manager.vector_data.redis import RedisVectorStore + dimension = kwargs.get("dimension", DIMENSION) + VectorBase.check_dimension(dimension) + + redis_config = kwargs.get("redis_config") + host = redis_config.get('redis', 'host') + port = redis_config.get('redis', 'port') + user = redis_config.get('redis', 'user') + password = redis_config.get('redis', 'password') + namespace = kwargs.get("namespace", "") + # collection_name = kwargs.get("collection_name", COLLECTION_NAME) + + vector_base = RedisVectorStore( + host=host, + port=port, + username=user, + password=password, + namespace=namespace, + top_k=top_k, + dimension=dimension, + ) elif name == "faiss": from modelcache.manager.vector_data.faiss import Faiss diff --git a/modelcache/manager/vector_data/redis.py b/modelcache/manager/vector_data/redis.py index fdd2ae8..e8fd3f4 100644 --- a/modelcache/manager/vector_data/redis.py +++ b/modelcache/manager/vector_data/redis.py @@ -6,11 +6,11 @@ from redis.commands.search.field import TagField, VectorField, NumericField from redis.client import Redis -from gptcache.manager.vector_data.base import VectorBase, VectorData -from gptcache.utils import import_redis -from gptcache.utils.log import gptcache_log -from gptcache.utils.collection_util import get_collection_name -from gptcache.utils.collection_util import get_collection_prefix +from modelcache.manager.vector_data.base import VectorBase, VectorData +from modelcache.utils import import_redis +from modelcache.utils.log import modelcache_log +from modelcache.utils.index_util import get_index_name +from modelcache.utils.index_util import get_index_prefix import_redis() @@ -21,9 +21,7 @@ def __init__( port: str = "6379", username: str = "", password: str = "", - table_suffix: str = "", dimension: int = 0, - collection_prefix: str = "gptcache", top_k: int = 1, namespace: str = "", ): @@ -36,33 +34,28 @@ def __init__( ) self.top_k = top_k self.dimension = dimension - self.collection_prefix = collection_prefix - self.table_suffix = table_suffix self.namespace = namespace - self.doc_prefix = f"{self.namespace}doc:" # Prefix with the specified namespace - # self._create_collection(collection_name) + self.doc_prefix = f"{self.namespace}doc:" def _check_index_exists(self, index_name: str) -> bool: """Check if Redis index exists.""" try: self._client.ft(index_name).info() - except: # pylint: disable=W0702 - gptcache_log.info("Index does not exist") + except: + modelcache_log.info("Index does not exist") return False - gptcache_log.info("Index already exists") + modelcache_log.info("Index already exists") return True - def create_collection(self, collection_name, index_prefix): + def create_index(self, index_name, index_prefix): dimension = self.dimension print('dimension: {}'.format(dimension)) - if self._check_index_exists(collection_name): - gptcache_log.info( - "The %s already exists, and it will be used directly", collection_name + if self._check_index_exists(index_name): + modelcache_log.info( + "The %s already exists, and it will be used directly", index_name ) return 'already_exists' else: - # id_field_name = collection_name + '_' + "id" - # embedding_field_name = collection_name + '_' + "vec" id_field_name = "data_id" embedding_field_name = "data_vector" @@ -76,11 +69,10 @@ def create_collection(self, collection_name, index_prefix): } ) fields = [id, embedding] - # definition = IndexDefinition(index_type=IndexType.HASH) definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH) # create Index - self._client.ft(collection_name).create_index( + self._client.ft(index_name).create_index( fields=fields, definition=definition ) return 'create_success' @@ -90,23 +82,14 @@ def mul_add(self, datas: List[VectorData], model=None): for data in datas: id: int = data.id embedding = data.data.astype(np.float32).tobytes() - # id_field_name = collection_name + '_' + "id" - # embedding_field_name = collection_name + '_' + "vec" id_field_name = "data_id" embedding_field_name = "data_vector" obj = {id_field_name: id, embedding_field_name: embedding} - index_prefix = get_collection_prefix(model, self.table_suffix) + index_prefix = get_index_prefix(model) self._client.hset(f"{index_prefix}{id}", mapping=obj) - # obj = { - # "vector": data.data.astype(np.float32).tobytes(), - # } - # pipe.hset(f"{self.doc_prefix}{key}", mapping=obj) - # pipe.execute() - def search(self, data: np.ndarray, top_k: int = -1, model=None): - collection_name = get_collection_name(model, self.table_suffix) - print('collection_name: {}'.format(collection_name)) + index_name = get_index_name(model) id_field_name = "data_id" embedding_field_name = "data_vector" @@ -119,53 +102,30 @@ def search(self, data: np.ndarray, top_k: int = -1, model=None): ) query_params = {"vector": data.astype(np.float32).tobytes()} - # print('query_params: {}'.format(query_params)) results = ( - self._client.ft(collection_name) + self._client.ft(index_name) .search(query, query_params=query_params) .docs ) - print('results: {}'.format(results)) - for i, doc in enumerate(results): - print('doc: {}'.format(doc)) - print("id_field_name", getattr(doc, id_field_name), ", distance: ", doc.distance) return [(float(result.distance), int(getattr(result, id_field_name))) for result in results] def rebuild(self, ids=None) -> bool: pass def rebuild_col(self, model): - resp_info = 'failed' - if len(self.table_suffix) == 0: - raise ValueError('table_suffix is none error,please check!') - - collection_name_model = get_collection_name(model, self.table_suffix) - print('collection_name_model: {}'.format(collection_name_model)) - if self._check_index_exists(collection_name_model): + index_name_model = get_index_name(model) + if self._check_index_exists(index_name_model): try: - self._client.ft(collection_name_model).dropindex(delete_documents=True) + self._client.ft(index_name_model).dropindex(delete_documents=True) except Exception as e: raise ValueError(str(e)) try: - index_prefix = get_collection_prefix(model, self.table_suffix) - self.create_collection(collection_name_model, index_prefix) + index_prefix = get_index_prefix(model) + self.create_index(index_name_model, index_prefix) except Exception as e: raise ValueError(str(e)) return 'rebuild success' - # print('remove collection_name_model: {}'.format(collection_name_model)) - # try: - # self._client.ft(collection_name_model).dropindex(delete_documents=True) - # resp_info = 'rebuild success' - # except Exception as e: - # print('exception: {}'.format(e)) - # resp_info = 'create only' - # try: - # self.create_collection(collection_name_model) - # except Exception as e: - # raise ValueError(str(e)) - # return resp_info - def delete(self, ids) -> None: pipe = self._client.pipeline() for data_id in ids: @@ -173,9 +133,9 @@ def delete(self, ids) -> None: pipe.execute() def create(self, model=None): - collection_name = get_collection_name(model, self.table_suffix) - index_prefix = get_collection_prefix(model, self.table_suffix) - return self.create_collection(collection_name, index_prefix) + index_name = get_index_name(model) + index_prefix = get_index_prefix(model) + return self.create_index(index_name, index_prefix) - def get_collection_by_name(self, collection_name, table_suffix): + def get_index_by_name(self, index_name): pass diff --git a/modelcache/utils/__init__.py b/modelcache/utils/__init__.py index 425b926..147a56e 100644 --- a/modelcache/utils/__init__.py +++ b/modelcache/utils/__init__.py @@ -69,3 +69,7 @@ def import_timm(): def import_pillow(): _check_library("PIL", package="pillow") + + +def import_redis(): + _check_library("redis") diff --git a/modelcache/utils/index_util.py b/modelcache/utils/index_util.py new file mode 100644 index 0000000..be6e856 --- /dev/null +++ b/modelcache/utils/index_util.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + + +def get_index_name(model): + return 'modelcache' + '_' + model + + +def get_index_prefix(model): + return 'prefix' + '_' + model diff --git a/requirements.txt b/requirements.txt index e622636..3bf85e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,5 @@ Requests==2.31.0 torch==2.1.0 transformers==4.34.1 faiss-cpu==1.7.4 +redis==5.0.1 +