Skip to content

Commit

Permalink
add redis vector store
Browse files Browse the repository at this point in the history
  • Loading branch information
peng3307165 committed Apr 9, 2024
1 parent 59def83 commit 560c862
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 73 deletions.
21 changes: 21 additions & 0 deletions examples/flask/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
"""
register index for redis
"""
import json
import requests


def run():
url = 'http://127.0.0.1:5000/modelcache'
type = 'register'
scope = {"model": "CODEGPT-1117"}
data = {'type': type, 'scope': scope}
headers = {"Content-Type": "application/json"}
res = requests.post(url, headers=headers, json=json.dumps(data))
res_text = res.text
print('res_text: {}'.format(res_text))


if __name__ == '__main__':
run()
29 changes: 22 additions & 7 deletions flask4modelcache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
import time
from datetime import datetime
from flask import Flask, request
import logging
import configparser
Expand All @@ -15,7 +14,6 @@
from modelcache.utils.model_filter import model_blacklist_filter
from modelcache.embedding import Data2VecAudio


# 创建一个Flask实例
app = Flask(__name__)

Expand All @@ -36,13 +34,19 @@ def response_hitquery(cache_resp):
data2vec = Data2VecAudio()
mysql_config = configparser.ConfigParser()
mysql_config.read('modelcache/config/mysql_config.ini')

milvus_config = configparser.ConfigParser()
milvus_config.read('modelcache/config/milvus_config.ini')
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
# VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))

# redis_config = configparser.ConfigParser()
# redis_config.read('modelcache/config/redis_config.ini')


data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
VectorBase("redis", dimension=data2vec.dimension, milvus_config=milvus_config))
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))

# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config))


cache.init(
Expand Down Expand Up @@ -88,9 +92,9 @@ def user_backend():
model = model.replace('.', '_')
query = param_dict.get("query")
chat_info = param_dict.get("chat_info")
if request_type is None or request_type not in ['query', 'insert', 'detox', 'remove']:
if request_type is None or request_type not in ['query', 'insert', 'remove', 'register']:
result = {"errorCode": 102,
"errorDesc": "type exception, should one of ['query', 'insert', 'detox', 'remove']",
"errorDesc": "type exception, should one of ['query', 'insert', 'remove', 'register']",
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
cache.data_manager.save_query_resp(result, model=model, query='', delta_time=0)
return json.dumps(result)
Expand Down Expand Up @@ -173,6 +177,17 @@ def user_backend():
result = {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
return json.dumps(result)

if request_type == 'register':
# iat_type = param_dict.get("iat_type")
response = adapter.ChatCompletion.create_register(
model=model
)
if response in ['create_success', 'already_exists']:
result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
else:
result = {"errorCode": 502, "errorDesc": "", "response": response, "writeStatus": "exception"}
return json.dumps(result)


if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
11 changes: 11 additions & 0 deletions modelcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from modelcache.adapter.adapter_query import adapt_query
from modelcache.adapter.adapter_insert import adapt_insert
from modelcache.adapter.adapter_remove import adapt_remove
from modelcache.adapter.adapter_register import adapt_register


class ChatCompletion(openai.ChatCompletion):
Expand Down Expand Up @@ -44,6 +45,16 @@ def create_remove(cls, *args, **kwargs):
logging.info('adapt_remove_e: {}'.format(e))
return str(e)

@classmethod
def create_register(cls, *args, **kwargs):
try:
return adapt_register(
*args,
**kwargs
)
except Exception as e:
return str(e)


def construct_resp_from_cache(return_message, return_query):
return {
Expand Down
13 changes: 13 additions & 0 deletions modelcache/adapter/adapter_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from modelcache import cache


def adapt_register(*args, **kwargs):
chat_cache = kwargs.pop("cache_obj", cache)
model = kwargs.pop("model", None)
if model is None or len(model) == 0:
return ValueError('')

register_resp = chat_cache.data_manager.create_index(model)
print('register_resp: {}'.format(register_resp))
return register_resp
3 changes: 3 additions & 0 deletions modelcache/manager/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ def delete(self, id_list, **kwargs):
return {'status': 'success', 'milvus': 'delete_count: '+str(v_delete_count),
'mysql': 'delete_count: '+str(s_delete_count)}

def create_index(self, model, **kwargs):
return self.v.create(model)

def truncate(self, model_name):
# model = kwargs.pop("model", None)
# drop milvus data
Expand Down
22 changes: 22 additions & 0 deletions modelcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,28 @@ def get(name, **kwargs):
local_mode=local_mode,
local_data=local_data
)
elif name == "redis":
from modelcache.manager.vector_data.redis import RedisVectorStore
dimension = kwargs.get("dimension", DIMENSION)
VectorBase.check_dimension(dimension)

redis_config = kwargs.get("redis_config")
host = redis_config.get('redis', 'host')
port = redis_config.get('redis', 'port')
user = redis_config.get('redis', 'user')
password = redis_config.get('redis', 'password')
namespace = kwargs.get("namespace", "")
# collection_name = kwargs.get("collection_name", COLLECTION_NAME)

vector_base = RedisVectorStore(
host=host,
port=port,
username=user,
password=password,
namespace=namespace,
top_k=top_k,
dimension=dimension,
)
elif name == "faiss":
from modelcache.manager.vector_data.faiss import Faiss

Expand Down
92 changes: 26 additions & 66 deletions modelcache/manager/vector_data/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from redis.commands.search.field import TagField, VectorField, NumericField
from redis.client import Redis

from gptcache.manager.vector_data.base import VectorBase, VectorData
from gptcache.utils import import_redis
from gptcache.utils.log import gptcache_log
from gptcache.utils.collection_util import get_collection_name
from gptcache.utils.collection_util import get_collection_prefix
from modelcache.manager.vector_data.base import VectorBase, VectorData
from modelcache.utils import import_redis
from modelcache.utils.log import modelcache_log
from modelcache.utils.index_util import get_index_name
from modelcache.utils.index_util import get_index_prefix
import_redis()


Expand All @@ -21,9 +21,7 @@ def __init__(
port: str = "6379",
username: str = "",
password: str = "",
table_suffix: str = "",
dimension: int = 0,
collection_prefix: str = "gptcache",
top_k: int = 1,
namespace: str = "",
):
Expand All @@ -36,33 +34,28 @@ def __init__(
)
self.top_k = top_k
self.dimension = dimension
self.collection_prefix = collection_prefix
self.table_suffix = table_suffix
self.namespace = namespace
self.doc_prefix = f"{self.namespace}doc:" # Prefix with the specified namespace
# self._create_collection(collection_name)
self.doc_prefix = f"{self.namespace}doc:"

def _check_index_exists(self, index_name: str) -> bool:
"""Check if Redis index exists."""
try:
self._client.ft(index_name).info()
except: # pylint: disable=W0702
gptcache_log.info("Index does not exist")
except:
modelcache_log.info("Index does not exist")
return False
gptcache_log.info("Index already exists")
modelcache_log.info("Index already exists")
return True

def create_collection(self, collection_name, index_prefix):
def create_index(self, index_name, index_prefix):
dimension = self.dimension
print('dimension: {}'.format(dimension))
if self._check_index_exists(collection_name):
gptcache_log.info(
"The %s already exists, and it will be used directly", collection_name
if self._check_index_exists(index_name):
modelcache_log.info(
"The %s already exists, and it will be used directly", index_name
)
return 'already_exists'
else:
# id_field_name = collection_name + '_' + "id"
# embedding_field_name = collection_name + '_' + "vec"
id_field_name = "data_id"
embedding_field_name = "data_vector"

Expand All @@ -76,11 +69,10 @@ def create_collection(self, collection_name, index_prefix):
}
)
fields = [id, embedding]
# definition = IndexDefinition(index_type=IndexType.HASH)
definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH)

# create Index
self._client.ft(collection_name).create_index(
self._client.ft(index_name).create_index(
fields=fields, definition=definition
)
return 'create_success'
Expand All @@ -90,23 +82,14 @@ def mul_add(self, datas: List[VectorData], model=None):
for data in datas:
id: int = data.id
embedding = data.data.astype(np.float32).tobytes()
# id_field_name = collection_name + '_' + "id"
# embedding_field_name = collection_name + '_' + "vec"
id_field_name = "data_id"
embedding_field_name = "data_vector"
obj = {id_field_name: id, embedding_field_name: embedding}
index_prefix = get_collection_prefix(model, self.table_suffix)
index_prefix = get_index_prefix(model)
self._client.hset(f"{index_prefix}{id}", mapping=obj)

# obj = {
# "vector": data.data.astype(np.float32).tobytes(),
# }
# pipe.hset(f"{self.doc_prefix}{key}", mapping=obj)
# pipe.execute()

def search(self, data: np.ndarray, top_k: int = -1, model=None):
collection_name = get_collection_name(model, self.table_suffix)
print('collection_name: {}'.format(collection_name))
index_name = get_index_name(model)
id_field_name = "data_id"
embedding_field_name = "data_vector"

Expand All @@ -119,63 +102,40 @@ def search(self, data: np.ndarray, top_k: int = -1, model=None):
)

query_params = {"vector": data.astype(np.float32).tobytes()}
# print('query_params: {}'.format(query_params))
results = (
self._client.ft(collection_name)
self._client.ft(index_name)
.search(query, query_params=query_params)
.docs
)
print('results: {}'.format(results))
for i, doc in enumerate(results):
print('doc: {}'.format(doc))
print("id_field_name", getattr(doc, id_field_name), ", distance: ", doc.distance)
return [(float(result.distance), int(getattr(result, id_field_name))) for result in results]

def rebuild(self, ids=None) -> bool:
pass

def rebuild_col(self, model):
resp_info = 'failed'
if len(self.table_suffix) == 0:
raise ValueError('table_suffix is none error,please check!')

collection_name_model = get_collection_name(model, self.table_suffix)
print('collection_name_model: {}'.format(collection_name_model))
if self._check_index_exists(collection_name_model):
index_name_model = get_index_name(model)
if self._check_index_exists(index_name_model):
try:
self._client.ft(collection_name_model).dropindex(delete_documents=True)
self._client.ft(index_name_model).dropindex(delete_documents=True)
except Exception as e:
raise ValueError(str(e))
try:
index_prefix = get_collection_prefix(model, self.table_suffix)
self.create_collection(collection_name_model, index_prefix)
index_prefix = get_index_prefix(model)
self.create_index(index_name_model, index_prefix)
except Exception as e:
raise ValueError(str(e))
return 'rebuild success'

# print('remove collection_name_model: {}'.format(collection_name_model))
# try:
# self._client.ft(collection_name_model).dropindex(delete_documents=True)
# resp_info = 'rebuild success'
# except Exception as e:
# print('exception: {}'.format(e))
# resp_info = 'create only'
# try:
# self.create_collection(collection_name_model)
# except Exception as e:
# raise ValueError(str(e))
# return resp_info

def delete(self, ids) -> None:
pipe = self._client.pipeline()
for data_id in ids:
pipe.delete(f"{self.doc_prefix}{data_id}")
pipe.execute()

def create(self, model=None):
collection_name = get_collection_name(model, self.table_suffix)
index_prefix = get_collection_prefix(model, self.table_suffix)
return self.create_collection(collection_name, index_prefix)
index_name = get_index_name(model)
index_prefix = get_index_prefix(model)
return self.create_index(index_name, index_prefix)

def get_collection_by_name(self, collection_name, table_suffix):
def get_index_by_name(self, index_name):
pass
4 changes: 4 additions & 0 deletions modelcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,7 @@ def import_timm():

def import_pillow():
_check_library("PIL", package="pillow")


def import_redis():
_check_library("redis")
9 changes: 9 additions & 0 deletions modelcache/utils/index_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-


def get_index_name(model):
return 'modelcache' + '_' + model


def get_index_prefix(model):
return 'prefix' + '_' + model
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ Requests==2.31.0
torch==2.1.0
transformers==4.34.1
faiss-cpu==1.7.4
redis==5.0.1

0 comments on commit 560c862

Please sign in to comment.