diff --git a/.gitignore b/.gitignore index 0637d70..b0d11aa 100644 --- a/.gitignore +++ b/.gitignore @@ -149,4 +149,5 @@ resources/ tests data/ tests/mathvista -running_logs/ \ No newline at end of file +running_logs/ +.db \ No newline at end of file diff --git a/README.md b/README.md index eae99f0..4481d3d 100644 --- a/README.md +++ b/README.md @@ -77,13 +77,15 @@ For more details, check out our paper **[OmAgent: A Multi-modal Agent Framework ### Video Understanding Task #### Environment Preparation -- Deploy [milvus vector database](https://milvus.io/docs/install_standalone-docker.md) using docker. The vector database is used to store video feature vectors and retrieve relevant vectors based on queries to reduce MLLM computation. Not installed docker? Refer to [docker installation guide](https://docs.docker.com/get-docker/). +- **```Optional```** OmAgent uses Milvus Lite as a vector database to store vector data by default. If you wish to use the full Milvus service, you can deploy it [milvus vector database](https://milvus.io/docs/install_standalone-docker.md) using docker. The vector database is used to store video feature vectors and retrieve relevant vectors based on queries to reduce MLLM computation. Not installed docker? Refer to [docker installation guide](https://docs.docker.com/get-docker/). ```shell # Download milvus startup script curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed.sh # Start milvus in standalone mode bash standalone_embed.sh start ``` + Fill in the relevant configuration information after the deployment ```workflows/video_understanding/config.yml``` + - **```Optional```** Configure the face recognition algorithm. The face recognition algorithm can be called as a tool by the agent, but it is optional. You can disable this feature by modifying the ```workflows/video_understanding/tools/video_tools.json``` configuration file and removing the FaceRecognition section. The default face recognition database is stored in the ```data/face_db``` directory, with different folders corresponding to different individuals. - **```Optional```** Open Vocabulary Detection (ovd) service, used to enhance OmAgent's ability to recognize various objects. The ovd tools depend on this service, but it is optional. You can disable ovd tools by following these steps. Remove the following from ```workflows/video_understanding/tools/video_tools.json``` ```json diff --git a/README_ZH.md b/README_ZH.md index eca169a..62aa0fd 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -76,13 +76,15 @@ OmAgent包括三个核心组成部分: ### 视频理解任务 #### 相关环境准备 -- 使用docker部署[milvus向量数据库](https://milvus.io/docs/install_standalone-docker.md)。向量数据库用于储存视频特征向量,以根据问题检索相关向量,减少MLLM的计算量。未安装docker?请参考[docker安装指南](https://docs.docker.com/get-docker/)。 +- **```可选```** OmAgent默认使用Milvus Lite作为向量数据库存储向量数据。如果你希望使用完整的Milvus服务,可以使用docker部署[milvus向量数据库](https://milvus.io/docs/install_standalone-docker.md)。向量数据库用于储存视频特征向量,以根据问题检索相关向量,减少MLLM的计算量。未安装docker?请参考[docker安装指南](https://docs.docker.com/get-docker/)。 ```shell # 下载milvus启动脚本 curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed.sh # 以standalone模式启动milvus bash standalone_embed.sh start ``` + 部署完成后填写相关配置信息```workflows/video_understanding/config.yml``` + - **```可选```** 配置人脸识别算法。人脸识别算法可以作为智能体的工具进行调用,当然这是可选的。你可以通过修改```workflows/video_understanding/tools/video_tools.json```配置文件,删除其中关于FaceRecognition的部分对该功能进行禁用。默认的人脸识别底库存储在```data/face_db```目录下,不同文件夹对应不同人物。 - **```可选```** Open Vocabulary Detection(ovd)服务,开放词表检测,用于增强OmAgent对于各种目标物体的识别能力,ovd tools依赖于此,当然这是可选的。你可以按如下步骤对ovd tools进行禁用。 删除```workflows/video_understanding/tools/video_tools.json```中的 ```json diff --git a/engine/video_process/scene.py b/engine/video_process/scene.py index 329ee61..0e3cbe7 100755 --- a/engine/video_process/scene.py +++ b/engine/video_process/scene.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional, Tuple, Union from PIL import Image +import cv2 from pydantic import BaseModel from pydub import AudioSegment from pydub.effects import normalize @@ -118,6 +119,8 @@ def get_video_frames( for index in range(scene_len): if index % interval == 0: f = self.stream.read() + if f is False: continue + f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(f)) time_stamps.append(self.stream.position.get_seconds()) else: diff --git a/omagent-core/pyproject.toml b/omagent-core/pyproject.toml index 5c13379..75b545a 100644 --- a/omagent-core/pyproject.toml +++ b/omagent-core/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "tenacity>=8.2.3", "pyyaml>=6.0.1", "requests", - "pymilvus>=2.3.3", + "pymilvus>=2.4.6", "networkx>=3.2.1", "duckduckgo_search>=3.9.9", "bs4>=0.0.1", diff --git a/omagent-core/src/omagent_core/handlers/data_handler/milvus_handler.py b/omagent-core/src/omagent_core/handlers/data_handler/milvus_handler.py index c59c78e..74bbf75 100644 --- a/omagent-core/src/omagent_core/handlers/data_handler/milvus_handler.py +++ b/omagent-core/src/omagent_core/handlers/data_handler/milvus_handler.py @@ -1,27 +1,38 @@ -from uuid import uuid4 from typing import Any +from uuid import uuid4 + import numpy as np -from pymilvus import Collection, DataType, connections, utility -from pymilvus.client import types from pydantic import BaseModel +from pymilvus import Collection, DataType, MilvusClient, connections, utility +from pymilvus.client import types -from ...utils.env import EnvVar from ...utils.registry import registry from ..error_handler.error import VQLError @registry.register_handler() class MilvusHandler(BaseModel): - host_url: str - port: int - alias: str + host_url: str = "./memory.db" + user: str = "" + password: str = "" + db_name: str = "default" primary_field: Any = None vector_field: Any = None - index_id: str + + class Config: + """Configuration for this pydantic object.""" + + extra = "allow" + arbitrary_types_allowed = True def __init__(self, **data: Any): super().__init__(**data) - connections.connect(host=self.host_url, port=self.port, alias=self.alias) + self.milvus_client = MilvusClient( + uri=self.host_url, + user=self.user, + password=self.password, + db_name=self.db_name, + ) def is_collection_in(self, collection_name): """ @@ -33,7 +44,7 @@ def is_collection_in(self, collection_name): Returns: bool: True if the collection exists, False otherwise. """ - return utility.has_collection(collection_name) + return self.milvus_client.has_collection(collection_name) def make_collection(self, collection_name, schema): """ @@ -51,23 +62,28 @@ def make_collection(self, collection_name, schema): Raises: VQLError: If the schema does not have exactly one primary key. """ - self.vector_field = [ - each.name - for each in schema.fields - if each.dtype == DataType.FLOAT_VECTOR - or each.dtype == DataType.BINARY_VECTOR - ] - primary_candidate = [each.name for each in schema.fields if each.is_primary] - if len(primary_candidate) > 0: - self.primary_field = primary_candidate[0] - else: - raise VQLError(500, detail="The number of primary key is not one!") + + index_params = self.milvus_client.prepare_index_params() + for field in schema.fields: + if ( + field.dtype == DataType.FLOAT_VECTOR + or field.dtype == DataType.BINARY_VECTOR + ): + index_params.add_index( + field_name=field.name, + index_name=field.name, + index_type="FLAT", + metric_type="COSINE", + params={"nlist": 128}, + ) + print(f"{field.name} of {collection_name} index created") if self.is_collection_in(collection_name): print(f"{collection_name} collection already exists") else: - Collection(name=collection_name, schema=schema, using=self.alias) - self.create_index(collection_name, self.vector_field) + self.milvus_client.create_collection( + collection_name, schema=schema, index_params=index_params + ) print(f"Create collection {collection_name} successfully") def drop_collection(self, collection_name): @@ -82,8 +98,7 @@ def drop_collection(self, collection_name): collection_name (str): The name of the collection to drop. """ if self.is_collection_in(collection_name): - collection = Collection(name=collection_name, using=self.alias) - collection.drop() + self.milvus_client.drop_collection(collection_name) print(f"Drop collection {collection_name} successfully") else: print(f"{collection_name} collection does not exist") @@ -107,38 +122,10 @@ def do_add(self, collection_name, vectors): VQLError: If the collection does not exist. """ if self.is_collection_in(collection_name): - loaded_collection = Collection(collection_name) - ids = loaded_collection.insert(vectors) - loaded_collection.flush() - return ids - else: - raise VQLError(500, detail=f"{collection_name} collection does not exist") - - def create_index(self, collection_name, vector_fields): - """ - Create an index for the specified vector fields in a collection in Milvus. - - This method will first check if a collection with the given name exists. - If it does, it will create an index for each of the specified vector fields in the collection. - If it doesn't, it will raise a VQLError. - - Args: - collection_name (str): The name of the collection to create an index in. - vector_fields (list): The list of vector fields to create an index for. - - Raises: - VQLError: If the collection does not exist. - """ - if self.is_collection_in(collection_name): - loaded_collection = Collection(collection_name) - for vector_field in vector_fields: - index = { - "index_type": "IVF_FLAT", - "metric_type": "COSINE", - "params": {"nlist": 128}, - } - loaded_collection.create_index(vector_field, index) - print(f"{vector_field} of {collection_name} index created") + res = self.milvus_client.insert( + collection_name=collection_name, data=vectors + ) + return res["ids"] else: raise VQLError(500, detail=f"{collection_name} collection does not exist") @@ -147,10 +134,10 @@ def match( collection_name, query_vectors: list, query_field, - output_fields: list, - res_size, - filter_expr, - threshold, + output_fields: list = None, + res_size=10, + filter_expr="", + threshold=0, ): """ Perform a vector similarity search in a specified collection in Milvus. @@ -176,9 +163,6 @@ def match( VQLError: If the collection does not exist. """ if self.is_collection_in(collection_name): - loaded_collection = Collection(collection_name) - if utility.load_state(collection_name) != types.LoadState.Loaded: - loaded_collection.load() search_params = { "metric_type": "COSINE", "ignore_growing": False, @@ -188,14 +172,16 @@ def match( "range_filter": 1, }, } - hits = loaded_collection.search( + hits = self.milvus_client.search( + collection_name=collection_name, data=query_vectors, anns_field=query_field, - param=search_params, + search_params=search_params, limit=res_size, output_fields=output_fields, - expr=filter_expr, + filter=filter_expr, ) + return hits else: raise VQLError(500, detail=f"{collection_name} collection does not exist") @@ -216,9 +202,11 @@ def delete_doc_by_ids(self, collection_name, ids): VQLError: If the collection does not exist. """ if self.is_collection_in(collection_name): - loaded_collection = Collection(collection_name) delete_expr = f"{self.primary_field} in {ids}" - loaded_collection.delete(delete_expr) + res = self.milvus_client.delete( + collection_name=collection_name, filter=delete_expr + ) + return res else: raise VQLError(500, detail=f"{collection_name} collection does not exist") @@ -238,8 +226,7 @@ def delete_doc_by_expr(self, collection_name, expr): VQLError: If the collection does not exist. """ if self.is_collection_in(collection_name): - loaded_collection = Collection(collection_name) - loaded_collection.delete(expr) + self.milvus_client.delete(collection_name=collection_name, filter=expr) else: raise VQLError(500, detail=f"{collection_name} collection does not exist") @@ -264,17 +251,22 @@ def delete_doc_by_expr(self, collection_name, expr): ) data = [ - [str(uuid4())] * 1, - [str(uuid4())] * 1, - # rng.random((1, 512)) - [[1, 2] * 256], + { + "pk": str(uuid4()), + "bot_id": str(uuid4()), + # rng.random((1, 512)) + "vector": [1.0, 2.0] * 256, + } ] - # milvus_handler.drop_collection('test1') - # milvus_handler.make_collection('test1', schema) + milvus_handler.drop_collection("test1") + milvus_handler.make_collection("test1", schema) add_detail = milvus_handler.do_add("test1", data) print(add_detail) - # test_data = - # match_result = milvus_handler.match('test1', [test_data], 'vector', ['pk'], 10, '', 0.65) - # print(match_result) - milvus_handler.primary_field = "pk" - milvus_handler.delete_doc_by_ids("test1", "5b50a621-7745-41fc-87d8-726f7e1e51cf") + print(milvus_handler.milvus_client.describe_index("test1", "vector")) + test_data = [[1.0, 2.0] * 256, [100, 400] * 256] + match_result = milvus_handler.match( + "test1", test_data, "vector", ["pk"], 10, "", 0.65 + ) + print(match_result) + # milvus_handler.primary_field = "pk" + # milvus_handler.delete_doc_by_ids("test1", ["1f764837-b80b-4788-ad8c-7a89924e343b"]) diff --git a/omagent-core/src/omagent_core/handlers/data_handler/video_handler.py b/omagent-core/src/omagent_core/handlers/data_handler/video_handler.py index d45c3c3..6ff668b 100644 --- a/omagent-core/src/omagent_core/handlers/data_handler/video_handler.py +++ b/omagent-core/src/omagent_core/handlers/data_handler/video_handler.py @@ -11,6 +11,7 @@ @registry.register_handler() class VideoHandler(MilvusHandler): + collection_name: str text_encoder: Optional[EncoderBase] = None dim: int = None @@ -20,7 +21,6 @@ class Config: extra = "allow" arbitrary_types_allowed = True - @field_validator("text_encoder", mode="before") @classmethod def init_encoder(cls, text_encoder): @@ -29,12 +29,10 @@ def init_encoder(cls, text_encoder): elif isinstance(text_encoder, dict): return registry.get_encoder(text_encoder.get("name"))(**text_encoder) else: - raise ValueError("index_id must be EncoderBase or Dict") + raise ValueError("text_encoder must be EncoderBase or Dict") def __init__(self, **data: Any) -> None: super().__init__(**data) - self.fields = [] - self.vector_fields = [] self.dim = self.text_encoder.dim @@ -56,20 +54,12 @@ def __init__(self, **data: Any) -> None: name="end_time", dtype=DataType.FLOAT, ) - self.schema = CollectionSchema( + schema = CollectionSchema( fields=[_uid, video_md5, content, content_vector, start_time, end_time], description="video summary vector DB", enable_dynamic_field=True, ) - for each_field in self.schema.fields: - self.fields.append(each_field.name) - if ( - each_field.dtype == DataType.FLOAT_VECTOR - or each_field.dtype == DataType.BINARY_VECTOR - ): - self.vector_fields.append(each_field.name) - self.collection_name = self.index_id - self.make_collection(self.collection_name, self.schema) + self.make_collection(self.collection_name, schema) def text_add(self, video_md5, content, start_time, end_time): @@ -77,15 +67,24 @@ def text_add(self, video_md5, content, start_time, end_time): raise VQLError(500, detail="Missing text_encoder") content_vector = self.text_encoder.infer([content])[0] + # upload_data = [ + # [video_md5], + # [content], + # [content_vector], + # [start_time], + # [end_time], + # ] upload_data = [ - [video_md5], - [content], - [content_vector], - [start_time], - [end_time], + { + "video_md5": video_md5, + "content": content, + "content_vector": content_vector, + "start_time": start_time, + "end_time": end_time, + } ] - add_detail = self.do_add(self.index_id, upload_data) + add_detail = self.do_add(self.collection_name, upload_data) # assert add_detail.succ_count == len(upload_data) def text_match( @@ -114,7 +113,7 @@ def text_match( collection_name=self.collection_name, query_vectors=[content_vector], query_field="content_vector", - output_fields=self.fields, + output_fields=["content", "start_time", "end_time"], res_size=res_size, threshold=threshold, filter_expr=filter_expr, @@ -122,6 +121,7 @@ def text_match( output = [] for match in match_res[0]: - output.append(match.fields) + print(match) + output.append(match["entity"]) return output diff --git a/workflows/video_understanding/config.yml b/workflows/video_understanding/config.yml index 4adaac7..1b2a96e 100644 --- a/workflows/video_understanding/config.yml +++ b/workflows/video_understanding/config.yml @@ -1,8 +1,9 @@ custom_openai_key: #custom_openai_endpoint: http://xxx.com/v1 #The base api endpoint for the custom openai model bing_api_key: -milvus_host_url: -milvus_port: 19530 -milvus_alias: default +# milvus_host_url: +# milvus_db_name: +# milvus_user: +# milvus_password: ovd_endpoint: #The api endpoint for the OVD model ovd_model_id: OmDet-Turbo_tiny_SWIN_T \ No newline at end of file diff --git a/workflows/video_understanding/ltm.json b/workflows/video_understanding/ltm.json index 63f6573..dfb2834 100644 --- a/workflows/video_understanding/ltm.json +++ b/workflows/video_understanding/ltm.json @@ -1,14 +1,15 @@ [ { "name": "VideoHandler", - "index_id": "video_understanding_openai_encoder", + "collection_name": "video_understanding_openai_encoder", "text_encoder": { "name": "OpenaiTextEmbeddingV3", "endpoint": "$", "api_key": "$" }, - "host_url": "$", - "port": "$", - "alias": "$" + "host_url": "$", + "user": "$",