From c32090333212f91d4d1315ce3c3a0497cd315756 Mon Sep 17 00:00:00 2001 From: Bhargav Suryadevara Date: Fri, 13 Oct 2023 12:44:31 -0500 Subject: [PATCH] Add a Vector Database Service to allow stages to read and write to VDBs (#1225) Closes #1177 ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - Bhargav Suryadevara (https://github.com/bsuryadevara) Approvers: - Michael Demoret (https://github.com/mdemoret-nv) URL: https://github.com/nv-morpheus/Morpheus/pull/1225 --- docker/conda/environments/cuda11.8_dev.yml | 2 + docs/source/conf.py | 1 + morpheus/service/__init__.py | 13 + morpheus/service/milvus_client.py | 268 ++++++++ morpheus/service/milvus_vector_db_service.py | 600 ++++++++++++++++++ morpheus/service/vector_db_service.py | 323 ++++++++++ morpheus/stages/output/write_to_vector_db.py | 123 ++++ morpheus/utils/vector_db_service_utils.py | 55 ++ tests/conftest.py | 28 + tests/test_milvus_vector_db_service.py | 428 +++++++++++++ ...st_milvus_write_to_vector_db_stage_pipe.py | 127 ++++ .../milvus_idx_part_collection_conf.json | 3 + .../milvus_simple_collection_conf.json | 3 + 13 files changed, 1974 insertions(+) create mode 100644 morpheus/service/__init__.py create mode 100644 morpheus/service/milvus_client.py create mode 100644 morpheus/service/milvus_vector_db_service.py create mode 100644 morpheus/service/vector_db_service.py create mode 100644 morpheus/stages/output/write_to_vector_db.py create mode 100644 morpheus/utils/vector_db_service_utils.py create mode 100644 tests/test_milvus_vector_db_service.py create mode 100755 tests/test_milvus_write_to_vector_db_stage_pipe.py create mode 100644 tests/tests_data/service/milvus_idx_part_collection_conf.json create mode 100644 tests/tests_data/service/milvus_simple_collection_conf.json diff --git a/docker/conda/environments/cuda11.8_dev.yml b/docker/conda/environments/cuda11.8_dev.yml index 7c1a034038..4a7dc39688 100644 --- a/docker/conda/environments/cuda11.8_dev.yml +++ b/docker/conda/environments/cuda11.8_dev.yml @@ -110,3 +110,5 @@ dependencies: # Add additional dev dependencies here - databricks-connect - pytest-kafka==0.6.0 + - pymilvus==2.3.1 + - milvus diff --git a/docs/source/conf.py b/docs/source/conf.py index adb924692f..ff3f98be38 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -169,6 +169,7 @@ "morpheus.cli.commands", # Dont document the CLI in Sphinx "nvtabular", "pandas", + "pymilvus", "tensorrt", "torch", "tqdm", diff --git a/morpheus/service/__init__.py b/morpheus/service/__init__.py new file mode 100644 index 0000000000..ce94db52fa --- /dev/null +++ b/morpheus/service/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/morpheus/service/milvus_client.py b/morpheus/service/milvus_client.py new file mode 100644 index 0000000000..ff2956a93c --- /dev/null +++ b/morpheus/service/milvus_client.py @@ -0,0 +1,268 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +from pymilvus import Collection +from pymilvus import DataType +from pymilvus import MilvusClient as PyMilvusClient +from pymilvus.orm.mutation import MutationResult + +# Milvus data type mapping dictionary +MILVUS_DATA_TYPE_MAP = { + "int8": DataType.INT8, + "int16": DataType.INT16, + "int32": DataType.INT32, + "int64": DataType.INT64, + "bool": DataType.BOOL, + "float": DataType.FLOAT, + "double": DataType.DOUBLE, + "binary_vector": DataType.BINARY_VECTOR, + "float_vector": DataType.FLOAT_VECTOR, + "string": DataType.STRING, + "varchar": DataType.VARCHAR, + "json": DataType.JSON, +} + + +def handle_exceptions(func_name: str, error_message: str) -> typing.Callable: + """ + Decorator function to handle exceptions and log errors. + + Parameters + ---------- + func_name : str + Name of the func being decorated. + error_message : str + Error message to log in case of an exception. + + Returns + ------- + typing.Callable + Decorated function. + """ + + def decorator(func): + + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as ex: + raise RuntimeError(f"{error_message} - Failed to execute {func_name}") from ex + + return wrapper + + return decorator + + +class MilvusClient(PyMilvusClient): + """ + Extension of the `MilvusClient` class with custom functions. + + Parameters + ---------- + uri : str + URI for connecting to Milvus server. + user : str + User name for authentication. + password : str + Password for authentication. + db_name : str + Name of the Milvus database. + token : str + Token for authentication. + **kwargs : dict[str, typing.Any] + Additional keyword arguments for the MilvusClient constructor. + """ + + def __init__(self, uri: str, user: str, password: str, db_name: str, token: str, **kwargs: dict[str, typing.Any]): + super().__init__(uri=uri, user=user, password=password, db_name=db_name, token=token, **kwargs) + + @handle_exceptions("has_collection", "Error checking collection existence") + def has_collection(self, collection_name: str) -> bool: + """ + Check if a collection exists in the database. + + Parameters + ---------- + collection_name : str + Name of the collection to check. + + Returns + ------- + bool + True if the collection exists, False otherwise. + """ + conn = self._get_connection() + return conn.has_collection(collection_name) + + @handle_exceptions("create_partition", "Error creating partition") + def create_partition(self, collection_name: str, partition_name: str, timeout: float = 1.0) -> None: + """ + Create a partition within a collection. + + Parameters + ---------- + collection_name : str + Name of the collection. + partition_name : str + Name of the partition to create. + timeout : float, optional + Timeout for the operation in seconds (default is 1.0). + """ + conn = self._get_connection() + conn.create_partition(collection_name=collection_name, partition_name=partition_name, timeout=timeout) + + @handle_exceptions("load_collection", "Error loading collection") + def load_collection(self, collection_name: str) -> None: + """ + Load a collection into memory. + + Parameters + ---------- + collection_name : str + Name of the collection to load. + """ + conn = self._get_connection() + conn.load_collection(collection_name=collection_name) + + @handle_exceptions("release_collection", "Error releasing collection") + def release_collection(self, collection_name: str) -> None: + """ + Release a loaded collection from memory. + + Parameters + ---------- + collection_name : str + Name of the collection to release. + """ + conn = self._get_connection() + conn.release_collection(collection_name=collection_name) + + @handle_exceptions("upsert", "Error upserting collection entities") + def upsert(self, collection_name: str, entities: list, **kwargs: dict[str, typing.Any]) -> MutationResult: + """ + Upsert entities into a collection. + + Parameters + ---------- + collection_name : str + Name of the collection to upsert into. + entities : list + List of entities to upsert. + **kwargs : dict[str, typing.Any] + Additional keyword arguments for the upsert operation. + + Returns + ------- + MutationResult + Result of the upsert operation. + """ + conn = self._get_connection() + return conn.upsert(collection_name=collection_name, entities=entities, **kwargs) + + @handle_exceptions("delete_by_expr", "Error deleting collection entities") + def delete_by_expr(self, collection_name: str, expression: str, **kwargs: dict[str, typing.Any]) -> MutationResult: + """ + Delete entities from a collection using an expression. + + Parameters + ---------- + collection_name : str + Name of the collection to delete from. + expression : str + Deletion expression. + **kwargs : dict[str, typing.Any] + Additional keyword arguments for the delete operation. + + Returns + ------- + MutationResult + Returns result of delete operation. + """ + conn = self._get_connection() + return conn.delete(collection_name=collection_name, expression=expression, **kwargs) + + @handle_exceptions("has_partition", "Error checking partition existence") + def has_partition(self, collection_name: str, partition_name: str) -> bool: + """ + Check if a partition exists within a collection. + + Parameters + ---------- + collection_name : str + Name of the collection. + partition_name : str + Name of the partition to check. + + Returns + ------- + bool + True if the partition exists, False otherwise. + """ + conn = self._get_connection() + return conn.has_partition(collection_name=collection_name, partition_name=partition_name) + + @handle_exceptions("drop_partition", "Error dropping partition") + def drop_partition(self, collection_name: str, partition_name: str) -> None: + """ + Drop a partition from a collection. + + Parameters + ---------- + collection_name : str + Name of the collection. + partition_name : str + Name of the partition to drop. + """ + conn = self._get_connection() + conn.drop_partition(collection_name=collection_name, partition_name=partition_name) + + @handle_exceptions("drop_index", "Error dropping index") + def drop_index(self, collection_name: str, field_name: str, index_name: str) -> None: + """ + Drop an index from a collection. + + Parameters + ---------- + collection_name : str + Name of the collection. + field_name : str + Name of the field associated with the index. + index_name : str + Name of the index to drop. + """ + conn = self._get_connection() + conn.drop_index(collection_name=collection_name, field_name=field_name, index_name=index_name) + + @handle_exceptions("get_collection", "Error getting collection object") + def get_collection(self, collection_name: str, **kwargs: dict[str, typing.Any]) -> Collection: + """ + Returns `Collection` object associated with the given collection name. + + Parameters + ---------- + collection_name : str + Name of the collection to delete from. + **kwargs : dict[str, typing.Any] + Additional keyword arguments to get Collection instance. + + Returns + ------- + Collection + Returns pymilvus Collection instance. + """ + collection = Collection(name=collection_name, using=self._using, **kwargs) + + return collection diff --git a/morpheus/service/milvus_vector_db_service.py b/morpheus/service/milvus_vector_db_service.py new file mode 100644 index 0000000000..18ae5dd4a2 --- /dev/null +++ b/morpheus/service/milvus_vector_db_service.py @@ -0,0 +1,600 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import threading +import time +import typing + +import pandas as pd +import pymilvus +from pymilvus.orm.mutation import MutationResult + +import cudf + +from morpheus.service.milvus_client import MILVUS_DATA_TYPE_MAP +from morpheus.service.milvus_client import MilvusClient +from morpheus.service.vector_db_service import VectorDBService + +logger = logging.getLogger(__name__) + + +def with_collection_lock(func: typing.Callable) -> typing.Callable: + """ + A decorator to synchronize access to a collection with a lock. This decorator ensures that operations on a + specific collection within the Milvus Vector Database are synchronized by acquiring and + releasing a collection-specific lock. + + Parameters + ---------- + func : Callable + The function to be wrapped with the lock. + + Returns + ------- + Callable + The wrapped function with the lock acquisition logic. + """ + + def wrapper(self, name, *args, **kwargs): + collection_lock = MilvusVectorDBService.get_collection_lock(name) + with collection_lock: + logger.debug("Acquiring lock for collection: %s", name) + result = func(self, name, *args, **kwargs) + logger.debug("Releasing lock for collection: %s", name) + return result + + return wrapper + + +class MilvusVectorDBService(VectorDBService): + """ + Service class for Milvus Vector Database implementation. This class provides functions for interacting + with a Milvus vector database. + + Parameters + ---------- + host : str + The hostname or IP address of the Milvus server. + port : str + The port number for connecting to the Milvus server. + alias : str, optional + Alias for the Milvus connection, by default "default". + **kwargs : dict + Additional keyword arguments specific to the Milvus connection configuration. + """ + + _collection_locks = {} + _cleanup_interval = 600 # 10mins + _last_cleanup_time = time.time() + + def __init__(self, + uri: str, + user: str = "", + password: str = "", + db_name: str = "", + token: str = "", + **kwargs: dict[str, typing.Any]): + + self._client = MilvusClient(uri=uri, user=user, password=password, db_name=db_name, token=token, **kwargs) + + def has_store_object(self, name: str) -> bool: + """ + Check if a collection exists in the Milvus vector database. + + Parameters + ---------- + name : str + Name of the collection to check. + + Returns + ------- + bool + True if the collection exists, False otherwise. + """ + return self._client.has_collection(collection_name=name) + + def list_store_objects(self, **kwargs: dict[str, typing.Any]) -> list[str]: + """ + List the names of all collections in the Milvus vector database. + + Returns + ------- + list[str] + A list of collection names. + """ + return self._client.list_collections(**kwargs) + + def _create_schema_field(self, field_conf: dict) -> pymilvus.FieldSchema: + + name = field_conf.pop("name") + dtype = field_conf.pop("dtype") + + dtype = MILVUS_DATA_TYPE_MAP[dtype.lower()] + + field_schema = pymilvus.FieldSchema(name=name, dtype=dtype, **field_conf) + + return field_schema + + @with_collection_lock + def create(self, name: str, overwrite: bool = False, **kwargs: dict[str, typing.Any]): + """ + Create a collection in the Milvus vector database with the specified name and configuration. This method + creates a new collection in the Milvus vector database with the provided name and configuration options. + If the collection already exists, it can be overwritten if the `overwrite` parameter is set to True. + + Parameters + ---------- + name : str + Name of the collection to be created. + overwrite : bool, optional + If True, the collection will be overwritten if it already exists, by default False. + **kwargs : dict + Additional keyword arguments containing collection configuration. + + Raises + ------ + ValueError + If the provided schema fields configuration is empty. + """ + logger.debug("Creating collection: %s, overwrite=%s, kwargs=%s", name, overwrite, kwargs) + # Preserve original configuration. + kwargs = copy.deepcopy(kwargs) + + collection_conf = kwargs.get("collection_conf") + auto_id = collection_conf.get("auto_id", False) + index_conf = collection_conf.get("index_conf", None) + partition_conf = collection_conf.get("partition_conf", None) + + schema_conf = collection_conf.get("schema_conf") + schema_fields_conf = schema_conf.pop("schema_fields") + + index_param = {} + + if not self.has_store_object(name) or overwrite: + if overwrite and self.has_store_object(name): + self.drop(name) + + if len(schema_fields_conf) == 0: + raise ValueError("Cannot create collection as provided empty schema_fields configuration") + + schema_fields = [self._create_schema_field(field_conf=field_conf) for field_conf in schema_fields_conf] + + schema = pymilvus.CollectionSchema(fields=schema_fields, **schema_conf) + + if index_conf: + field_name = index_conf.pop("field_name") + metric_type = index_conf.pop("metric_type") + index_param = self._client.prepare_index_params(field_name=field_name, + metric_type=metric_type, + **index_conf) + + self._client.create_collection_with_schema(collection_name=name, + schema=schema, + index_param=index_param, + auto_id=auto_id, + shards_num=collection_conf.get("shards", 2), + consistency_level=collection_conf.get( + "consistency_level", "Strong")) + + if partition_conf: + timeout = partition_conf.get("timeout", 1.0) + # Iterate over each partition configuration + for part in partition_conf["partitions"]: + self._client.create_partition(collection_name=name, partition_name=part["name"], timeout=timeout) + + @with_collection_lock + def insert(self, name: str, data: list[list] | list[dict], **kwargs: dict[str, + typing.Any]) -> dict[str, typing.Any]: + """ + Insert a collection specific data in the Milvus vector database. + + Parameters + ---------- + name : str + Name of the collection to be inserted. + data : list[list] | list[dict] + Data to be inserted in the collection. + **kwargs : dict[str, typing.Any] + Additional keyword arguments containing collection configuration. + + Returns + ------- + dict + Returns response content as a dictionary. + + Raises + ------ + RuntimeError + If the collection not exists exists. + """ + + return self._collection_insert(name, data, **kwargs) + + def _collection_insert(self, name: str, data: list[list] | list[dict], + **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + + if not self.has_store_object(name): + raise RuntimeError(f"Collection {name} doesn't exist.") + + collection = None + try: + collection_conf = kwargs.get("collection_conf", {}) + partition_name = collection_conf.get("partition_name", "_default") + + collection = self._client.get_collection(collection_name=name, **collection_conf) + result = collection.insert(data, partition_name=partition_name) + collection.flush() + finally: + collection.release() + + result_dict = { + "primary_keys": result.primary_keys, + "insert_count": result.insert_count, + "delete_count": result.delete_count, + "upsert_count": result.upsert_count, + "timestamp": result.timestamp, + "succ_count": result.succ_count, + "err_count": result.err_count, + "succ_index": result.succ_index, + "err_index": result.err_index + } + + return result_dict + + @with_collection_lock + def insert_dataframe(self, + name: str, + df: typing.Union[cudf.DataFrame, pd.DataFrame], + **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + """ + Converts dataframe to rows and insert to a collection in the Milvus vector database. + + Parameters + ---------- + name : str + Name of the collection to be inserted. + df : typing.Union[cudf.DataFrame, pd.DataFrame] + Dataframe to be inserted in the collection. + **kwargs : dict[str, typing.Any] + Additional keyword arguments containing collection configuration. + + Returns + ------- + dict + Returns response content as a dictionary. + + Raises + ------ + RuntimeError + If the collection not exists exists. + """ + if not self.has_store_object(name): + raise RuntimeError(f"Collection {name} doesn't exist.") + + if isinstance(df, cudf.DataFrame): + df = df.to_pandas() + + dict_of_rows = df.to_dict(orient='records') + + return self._collection_insert(name, dict_of_rows, **kwargs) + + @with_collection_lock + def search(self, name: str, query: str = None, **kwargs: dict[str, typing.Any]) -> typing.Any: + """ + Search for data in a collection in the Milvus vector database. + + This method performs a search operation in the specified collection/partition in the Milvus vector database. + + Parameters + ---------- + name : str + Name of the collection to search within. + query : str, optional + The search query, which can be a filter expression, by default None. + **kwargs : dict + Additional keyword arguments for the search operation. + + Returns + ------- + typing.Any + The search result, which can vary depending on the query and options. + + Raises + ------ + RuntimeError + If an error occurs during the search operation. + If query argument is `None` and `data` keyword argument doesn't exist. + If `data` keyword arguement is `None`. + """ + + logger.debug("Searching in collection: %s, query=%s, kwargs=%s", name, query, kwargs) + + try: + self._client.load_collection(collection_name=name) + if query is not None: + result = self._client.query(collection_name=name, filter=query, **kwargs) + else: + if "data" not in kwargs: + raise RuntimeError("The search operation requires that search vectors be " + + "provided as a keyword argument 'data'") + if kwargs["data"] is None: + raise RuntimeError("Argument 'data' cannot be None") + + data = kwargs.pop("data") + + result = self._client.search(collection_name=name, data=data, **kwargs) + return result + + except pymilvus.exceptions.MilvusException as exec_info: + raise RuntimeError(f"Unable to perform serach: {exec_info}") from exec_info + + finally: + self._client.release_collection(collection_name=name) + + @with_collection_lock + def update(self, name: str, data: list[typing.Any], **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + """ + Update data in the vector database. + + Parameters + ---------- + name : str + Name of the resource. + data : list[typing.Any] + Data to be updated in the collection. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to upsert operation. + + Returns + ------- + dict[str, typing.Any] + Returns result of the updated operation stats. + """ + + if not isinstance(data, list): + raise RuntimeError("Data is not of type list.") + + result = self._client.upsert(collection_name=name, entities=data, **kwargs) + + return self._convert_mutation_result_to_dict(result=result) + + @with_collection_lock + def delete_by_keys(self, name: str, keys: int | str | list, **kwargs: dict[str, typing.Any]) -> typing.Any: + """ + Delete vectors by keys from the resource. + + Parameters + ---------- + name : str + Name of the resource. + keys : int | str | list + Primary keys to delete vectors. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + typing.Any + Returns result of the given keys that are delete from the collection. + """ + + result = self._client.delete(collection_name=name, pks=keys, **kwargs) + + return result + + @with_collection_lock + def delete(self, name: str, expr: str, **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + """ + Delete vectors from the resource using expressions. + + Parameters + ---------- + name : str + Name of the resource. + expr : str + Delete expression. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + dict[str, typing.Any] + Returns result of the given keys that are delete from the collection. + """ + + result = self._client.delete_by_expr(collection_name=name, expression=expr, **kwargs) + + return self._convert_mutation_result_to_dict(result=result) + + @with_collection_lock + def retrieve_by_keys(self, name: str, keys: int | str | list, **kwargs: dict[str, typing.Any]) -> list[typing.Any]: + """ + Retrieve the inserted vectors using their primary keys from the Collection. + + Parameters + ---------- + name : str + Name of the collection. + keys : int | str | list + Primary keys to get vectors for. Depending on pk_field type it can be int or str + or a list of either. + **kwargs : dict[str, typing.Any] + Additional keyword arguments for the retrieval operation. + + Returns + ------- + list[typing.Any] + Returns result rows of the given keys from the collection. + """ + + result = None + + try: + self._client.load_collection(collection_name=name) + result = self._client.get(collection_name=name, ids=keys, **kwargs) + except pymilvus.exceptions.MilvusException as exec_info: + raise RuntimeError(f"Unable to perform serach: {exec_info}") from exec_info + + finally: + self._client.release_collection(collection_name=name) + + return result + + def count(self, name: str, **kwargs: dict[str, typing.Any]) -> int: + """ + Returns number of rows/entities in the given collection. + + Parameters + ---------- + name : str + Name of the collection. + **kwargs : dict[str, typing.Any] + Additional keyword arguments for the count operation. + + Returns + ------- + int + Returns number of entities in the collection. + """ + + return self._client.num_entities(collection_name=name, **kwargs) + + def drop(self, name: str, **kwargs: dict[str, typing.Any]) -> None: + """ + Drop a collection, index, or partition in the Milvus vector database. + + This method allows you to drop a collection, an index within a collection, + or a specific partition within a collection in the Milvus vector database. + + Parameters + ---------- + name : str + Name of the collection, index, or partition to be dropped. + **kwargs : dict + Additional keyword arguments for specifying the type and partition name (if applicable). + + Notes on Expected Keyword Arguments: + ------------------------------------ + - 'resource' (str, optional): + Specifies the type of resource to drop. Possible values: 'collection' (default), 'index', 'partition'. + + - 'partition_name' (str, optional): + Required when dropping a specific partition within a collection. Specifies the partition name to be dropped. + + - 'field_name' (str, optional): + Required when dropping an index within a collection. Specifies the field name for which the index is created. + + - 'index_name' (str, optional): + Required when dropping an index within a collection. Specifies the name of the index to be dropped. + + Raises + ------ + ValueError + If mandatory arguments are missing or if the provided 'resource' value is invalid. + """ + + logger.debug("Dropping collection: %s, kwargs=%s", name, kwargs) + + if self.has_store_object(name): + resource = kwargs.get("resource", "collection") + if resource == "collection": + self._client.drop_collection(collection_name=name) + elif resource == "partition": + if "partition_name" not in kwargs: + raise ValueError("Mandatory argument 'partition_name' is required when resource='partition'") + partition_name = kwargs["partition_name"] + if self._client.has_partition(collection_name=name, partition_name=partition_name): + self._client.drop_partition(collection_name=name, partition_name=partition_name) + elif resource == "index": + if "field_name" in kwargs and "index_name" in kwargs: + self._client.drop_index(collection_name=name, + field_name=kwargs["field_name"], + index_name=kwargs["index_name"]) + else: + raise ValueError( + "Mandatory arguments 'field_name' and 'index_name' are required when resource='index'") + + def describe(self, name: str, **kwargs: dict[str, typing.Any]) -> dict: + """ + Describe the collection in the vector database. + + Parameters + ---------- + name : str + Name of the collection. + **kwargs : dict[str, typing.Any] + Additional keyword arguments specific to the Milvus vector database. + + Returns + ------- + dict + Returns collection information. + """ + + return self._client.describe_collection(collection_name=name, **kwargs) + + def close(self) -> None: + """ + Close the connection to the Milvus vector database. + + This method disconnects from the Milvus vector database by removing the connection. + + """ + self._client.close() + + def _convert_mutation_result_to_dict(self, result: MutationResult) -> dict[str, typing.Any]: + result_dict = { + "insert_count": result.insert_count, + "delete_count": result.delete_count, + "upsert_count": result.upsert_count, + "timestamp": result.timestamp, + "succ_count": result.succ_count, + "err_count": result.err_count + } + return result_dict + + @classmethod + def get_collection_lock(cls, name: str) -> threading.Lock: + """ + Get a lock for a given collection name. + + Parameters + ---------- + name : str + Name of the collection for which to acquire the lock. + + Returns + ------- + threading.Lock + A thread lock specific to the given collection name. + """ + + current_time = time.time() + + if name not in cls._collection_locks: + cls._collection_locks[name] = {"lock": threading.Lock(), "last_used": current_time} + else: + cls._collection_locks[name]["last_used"] = current_time + + if (current_time - cls._last_cleanup_time) >= cls._cleanup_interval: + for lock_name, lock_info in cls._collection_locks.copy().items(): + last_used = lock_info["last_used"] + if current_time - last_used >= cls._cleanup_interval: + logger.debug("Cleaning up lock for collection: %s", lock_name) + del cls._collection_locks[lock_name] + cls._last_cleanup_time = current_time + + return cls._collection_locks[name]["lock"] diff --git a/morpheus/service/vector_db_service.py b/morpheus/service/vector_db_service.py new file mode 100644 index 0000000000..650c7860e8 --- /dev/null +++ b/morpheus/service/vector_db_service.py @@ -0,0 +1,323 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import typing +from abc import ABC +from abc import abstractmethod + +import pandas as pd + +import cudf + +logger = logging.getLogger(__name__) + + +class VectorDBService(ABC): + """ + Class used for vectorstore specific implementation. + """ + + @abstractmethod + def insert(self, name: str, data: list[list] | list[dict], **kwargs: dict[str, typing.Any]) -> dict: + """ + Insert data into the vector database. + + Parameters + ---------- + name : str + Name of the resource. + data : list[list] | list[dict] + Data to be inserted into the resource. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + dict + Returns response content as a dictionary. + """ + + pass + + @abstractmethod + def insert_dataframe(self, + name: str, + df: typing.Union[cudf.DataFrame, pd.DataFrame], + **kwargs: dict[str, typing.Any]) -> dict: + """ + Converts dataframe to rows and insert into the vector database resource. + + Parameters + ---------- + name : str + Name of the resource to be inserted. + df : typing.Union[cudf.DataFrame, pd.DataFrame] + Dataframe to be inserted. + **kwargs : dict[str, typing.Any] + Additional keyword arguments containing collection configuration. + + Returns + ------- + dict + Returns response content as a dictionary. + + Raises + ------ + RuntimeError + If the resource not exists exists. + """ + pass + + @abstractmethod + def search(self, name: str, query: str = None, **kwargs: dict[str, typing.Any]) -> typing.Any: + """ + Search for content in the vector database. + + Parameters + ---------- + name : str + Name of the resource. + query : str, default None + Query to execute on the given resource. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + typing.Any + Returns search results. + """ + + pass + + @abstractmethod + def drop(self, name: str, **kwargs: dict[str, typing.Any]) -> None: + """ + Drop resources from the vector database. + + Parameters + ---------- + name : str + Name of the resource. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + """ + + pass + + @abstractmethod + def update(self, name: str, data: list[typing.Any], **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + """ + Update data in the vector database. + + Parameters + ---------- + name : str + Name of the resource. + data : list[typing.Any] + Data to be updated in the resource. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + dict[str, typing.Any] + Returns result of the updated operation stats. + """ + + pass + + @abstractmethod + def delete(self, name: str, expr: str, **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + """ + Delete data in the vector database. + + Parameters + ---------- + name : str + Name of the resource. + expr : typing.Any + Delete expression. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + dict[str, typing.Any] + Returns result of the delete operation stats. + """ + + pass + + @abstractmethod + def create(self, name: str, overwrite: bool = False, **kwargs: dict[str, typing.Any]) -> None: + """ + Create resources in the vector database. + + Parameters + ---------- + name : str + Name of the resource. + overwrite : bool, default False + Whether to overwrite the resource if it already exists. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + """ + + pass + + @abstractmethod + def describe(self, name: str, **kwargs: dict[str, typing.Any]) -> dict: + """ + Describe resource in the vector database. + + Parameters + ---------- + name : str + Name of the resource. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + dict + Returns resource information. + """ + + pass + + @abstractmethod + def close(self) -> None: + """ + Close connection to the vector database. + """ + + pass + + @abstractmethod + def has_store_object(self, name: str) -> bool: + """ + Check if a resource exists in the vector database. + + Parameters + ---------- + name : str + Name of the resource. + + Returns + ------- + bool + Returns True if resource exists in the vector database, otherwise False. + """ + + pass + + @abstractmethod + def list_store_objects(self, **kwargs: dict[str, typing.Any]) -> list[str]: + """ + List existing resources in the vector database. + + Parameters + ---------- + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + list[str] + Returns available resouce names in the vector database. + """ + + pass + + # pylint: disable=unused-argument + def transform(self, data: typing.Any, **kwargs: dict[str, typing.Any]) -> typing.Any: + """ + Transform data according to the specific vector database implementation. + + Parameters + ---------- + data : typing.Any + Data to be updated in the resource. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + typing.Any + Returns transformed data as per the implementation. + """ + return data + + @abstractmethod + def retrieve_by_keys(self, name: str, keys: int | str | list, **kwargs: dict[str, typing.Any]) -> list[typing.Any]: + """ + Retrieve the inserted vectors using keys from the resource. + + Parameters + ---------- + name : str + Name of the resource. + keys : typing.Any + Primary keys to get vectors. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + list[typing.Any] + Returns rows of the given keys that exists in the resource. + """ + pass + + @abstractmethod + def delete_by_keys(self, name: str, keys: int | str | list, **kwargs: dict[str, typing.Any]) -> typing.Any: + """ + Delete vectors by keys from the resource. + + Parameters + ---------- + name : str + Name of the resource. + keys : int | str | list + Primary keys to delete vectors. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + typing.Any + Returns vectors of the given keys that are delete from the resource. + """ + pass + + @abstractmethod + def count(self, name: str, **kwargs: dict[str, typing.Any]) -> int: + """ + Returns number of rows/entities in the given resource. + + Parameters + ---------- + name : str + Name of the resource. + **kwargs : dict[str, typing.Any] + Extra keyword arguments specific to the vector database implementation. + + Returns + ------- + int + Returns number of rows/entities in the given resource. + """ + pass diff --git a/morpheus/stages/output/write_to_vector_db.py b/morpheus/stages/output/write_to_vector_db.py new file mode 100644 index 0000000000..3fef0fcd66 --- /dev/null +++ b/morpheus/stages/output/write_to_vector_db.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import typing + +import mrc +from mrc.core import operators as ops + +from morpheus.config import Config +from morpheus.messages import ControlMessage +from morpheus.pipeline.single_port_stage import SinglePortStage +from morpheus.pipeline.stream_pair import StreamPair +from morpheus.service.vector_db_service import VectorDBService +from morpheus.utils.vector_db_service_utils import VectorDBServiceFactory + +logger = logging.getLogger(__name__) + + +class WriteToVectorDBStage(SinglePortStage): + """ + Writes messages to a Vector Database. + + Parameters + ---------- + config : `morpheus.config.Config` + Pipeline configuration instance. + resource_name : str + The name of the resource managed by this instance. + resource_conf : dict + Additional resource configuration when performing vector database writes. + service : typing.Union[str, VectorDBService] + Either the name of the vector database service to use or an instance of VectorDBService + for managing the resource. + **service_kwargs : dict[str, typing.Any] + Additional keyword arguments to pass when creating a VectorDBService instance. + + Raises + ------ + ValueError + If `service` is not a valid string (service name) or an instance of VectorDBService. + """ + + def __init__(self, + config: Config, + resource_name: str, + service: typing.Union[str, VectorDBService], + **service_kwargs: dict[str, typing.Any]): + + super().__init__(config) + + self._resource_name = resource_name + self._resource_kwargs = {} + + if "resource_kwargs" in service_kwargs: + self._resource_kwargs = service_kwargs.pop("resource_kwargs") + + if isinstance(service, str): + # If service is a string, assume it's the service name + self._service: VectorDBService = VectorDBServiceFactory.create_instance(service_name=service, + **service_kwargs) + elif isinstance(service, VectorDBService): + # If service is an instance of VectorDBService, use it directly + self._service: VectorDBService = service + else: + raise ValueError("service must be a string (service name) or an instance of VectorDBService") + + @property + def name(self) -> str: + return "to-vector-db" + + def accepted_types(self) -> typing.Tuple: + """ + Returns accepted input types for this stage. + + Returns + ------- + typing.Tuple(`morpheus.pipeline.messages.MessageMeta`, ) + Accepted input types. + + """ + return (ControlMessage, ) + + def supports_cpp_node(self): + """Indicates whether this stage supports a C++ node.""" + return False + + def on_completed(self): + # Close vector database service connection + self._service.close() + + def _build_single(self, builder: mrc.Builder, input_stream: StreamPair) -> StreamPair: + + stream = input_stream[0] + + def on_data(ctrl_msg: ControlMessage) -> ControlMessage: + # Insert entries in the dataframe to vector database. + result = self._service.insert_dataframe(name=self._resource_name, + df=ctrl_msg.payload().df, + **self._resource_kwargs) + + ctrl_msg.set_metadata("insert_response", result) + + return ctrl_msg + + to_vector_db = builder.make_node(self.unique_name, ops.map(on_data), ops.on_completed(self.on_completed)) + + builder.make_edge(stream, to_vector_db) + stream = to_vector_db + + # Return input unchanged to allow passthrough + return stream, input_stream[1] diff --git a/morpheus/utils/vector_db_service_utils.py b/morpheus/utils/vector_db_service_utils.py new file mode 100644 index 0000000000..fa9145941f --- /dev/null +++ b/morpheus/utils/vector_db_service_utils.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import typing + + +class VectorDBServiceFactory: + """ + Factory for creating instances of vector database service classes. This factory allows dynamically + creating instances of vector database service classes based on the provided service name. + Each service name corresponds to a specific implementation class. + + Parameters + ---------- + service_name : str + The name of the vector database service to create. + *args : typing.Any + Variable-length argument list to pass to the service constructor. + **kwargs : dict[str, typing.Any] + Arbitrary keyword arguments to pass to the service constructor. + + Returns + ------- + An instance of the specified vector database service class. + + Raises + ------ + ValueError + If the specified service name is not found or does not correspond to a valid service class. + """ + + @classmethod + def create_instance(cls, service_name: str, *args: typing.Any, **kwargs: dict[str, typing.Any]): + try: + module_name = f"morpheus.service.{service_name}_vector_db_service" + module = importlib.import_module(module_name) + class_name = f"{service_name.capitalize()}VectorDBService" + class_ = getattr(module, class_name) + instance = class_(*args, **kwargs) + return instance + except (ModuleNotFoundError, AttributeError) as exc: + raise ValueError(f"Service {service_name} not found. Ensure that the corresponding service class," + + f"such as {module_name}, has been implemented.") from exc diff --git a/tests/conftest.py b/tests/conftest.py index e3df4500c5..425f604701 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -937,3 +937,31 @@ def filter_probs_df(dataset, use_cpp: bool): that as well, while excluding the combination of C++ execution and Pandas dataframes. """ yield dataset["filter_probs.csv"] + + +@pytest.fixture(scope="session") +def milvus_server_uri(): + """ + Pytest fixture to start and stop a Milvus server and provide its URI for testing. + + This fixture starts a Milvus server, retrieves its URI (Uniform Resource Identifier), and provides + the URI as a yield value to the tests using this fixture. After all tests in the module are + completed, the Milvus server is stopped. + """ + from milvus import default_server + + logger = logging.getLogger(f"morpheus.{__name__}") + try: + default_server.start() + host = "127.0.0.1" + port = default_server.listen_port + uri = f"http://{host}:{port}" + + yield uri + except Exception as exec_inf: + logger.error("Error in starting Milvus server: %s", exec_inf) + finally: + try: + default_server.stop() + except Exception as exec_inf: + logger.error("Error in stopping Milvus server: %s", exec_inf) diff --git a/tests/test_milvus_vector_db_service.py b/tests/test_milvus_vector_db_service.py new file mode 100644 index 0000000000..28a3d646ed --- /dev/null +++ b/tests/test_milvus_vector_db_service.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import concurrent.futures +import json +import os + +import numpy as np +import pytest + +from _utils import TEST_DIRS +from morpheus.service.milvus_client import MILVUS_DATA_TYPE_MAP +from morpheus.service.milvus_vector_db_service import MilvusVectorDBService + + +@pytest.fixture(scope="module", name="milvus_service_fixture") +def milvus_service(milvus_server_uri: str): + service = MilvusVectorDBService(uri=milvus_server_uri) + yield service + + +def load_json(filename): + conf_filepath = os.path.join(TEST_DIRS.tests_data_dir, "service", filename) + + with open(conf_filepath, 'r', encoding="utf-8") as json_file: + collection_config = json.load(json_file) + + return collection_config + + +@pytest.fixture(scope="module", name="data_fixture") +def data(): + inital_data = [{"id": i, "embedding": [i / 10.0] * 10, "age": 25 + i} for i in range(10)] + yield inital_data + + +@pytest.fixture(scope="module", name="idx_part_collection_config_fixture") +def idx_part_collection_config(): + collection_config = load_json(filename="milvus_idx_part_collection_conf.json") + yield collection_config + + +@pytest.fixture(scope="module", name="simple_collection_config_fixture") +def simple_collection_config(): + collection_config = load_json(filename="milvus_simple_collection_conf.json") + yield collection_config + + +@pytest.mark.slow +def test_list_store_objects(milvus_service_fixture: MilvusVectorDBService): + # List all collections in the Milvus server. + collections = milvus_service_fixture.list_store_objects() + assert isinstance(collections, list) + + +@pytest.mark.slow +def test_has_store_object(milvus_service_fixture: MilvusVectorDBService): + # Check if a non-existing collection exists in the Milvus server. + collection_name = "non_existing_collection" + assert not milvus_service_fixture.has_store_object(collection_name) + + +@pytest.mark.slow +def test_create_and_drop_collection(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict): + # Create a collection and check if it exists. + collection_name = "test_collection" + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + assert milvus_service_fixture.has_store_object(collection_name) + + # Drop the collection and check if it no longer exists. + milvus_service_fixture.drop(collection_name) + assert not milvus_service_fixture.has_store_object(collection_name) + + +@pytest.mark.slow +def test_insert_and_retrieve_by_keys(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict, + data_fixture: list[dict]): + # Create a collection. + collection_name = "test_insert_collection" + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + + # Insert data into the collection. + response = milvus_service_fixture.insert(collection_name, data_fixture) + assert response["insert_count"] == len(data_fixture) + + # Retrieve inserted data by primary keys. + keys_to_retrieve = [2, 4, 6] + retrieved_data = milvus_service_fixture.retrieve_by_keys(collection_name, keys_to_retrieve) + assert len(retrieved_data) == len(keys_to_retrieve) + + # Clean up the collection. + milvus_service_fixture.drop(collection_name) + + +@pytest.mark.slow +def test_search(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict, + data_fixture: list[dict]): + # Create a collection. + collection_name = "test_search_collection" + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + + # Insert data into the collection. + milvus_service_fixture.insert(collection_name, data_fixture) + + # Define a search query. + query = "age==26 or age==27" + + # Perform a search in the collection. + search_result = milvus_service_fixture.search(collection_name, query) + assert len(search_result) == 2 + assert search_result[0]["age"] in [26, 27] + assert search_result[1]["age"] in [26, 27] + + # Clean up the collection. + milvus_service_fixture.drop(collection_name) + + +@pytest.mark.slow +def test_search_with_data(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict, + data_fixture: list[dict]): + # Create a collection. + collection_name = "test_search_with_data_collection" + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + + # Insert data to the collection. + milvus_service_fixture.insert(collection_name, data_fixture) + + rng = np.random.default_rng(seed=100) + search_vec = rng.random((1, 10)) + + # Define a search filter. + fltr = "age==26 or age==27" + + # Perform a search in the collection. + search_result = milvus_service_fixture.search(collection_name, + data=search_vec, + filter=fltr, + output_fields=["id", "age"]) + + assert len(search_result[0]) == 2 + assert search_result[0][0]["entity"]["age"] in [26, 27] + assert search_result[0][1]["entity"]["age"] in [26, 27] + assert len(search_result[0][0]["entity"].keys()) == 2 + assert sorted(list(search_result[0][0]["entity"].keys())) == ["age", "id"] + + # Clean up the collection. + milvus_service_fixture.drop(collection_name) + + +@pytest.mark.slow +def test_count(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict, + data_fixture: list[dict]): + # Create a collection. + collection_name = "test_count_collection" + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + + # Insert data into the collection. + milvus_service_fixture.insert(collection_name, data_fixture) + + # Get the count of entities in the collection. + count = milvus_service_fixture.count(collection_name) + assert count == len(data_fixture) + + # Clean up the collection. + milvus_service_fixture.drop(collection_name) + + +@pytest.mark.slow +def test_overwrite_collection_on_create(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict, + data_fixture: list[dict]): + # Create a collection. + collection_name = "test_overwrite_collection" + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + + # Insert data to the collection. + response1 = milvus_service_fixture.insert(collection_name, data_fixture) + assert response1["insert_count"] == len(data_fixture) + + # Create the same collection again with overwrite=True. + milvus_service_fixture.create(collection_name, overwrite=True, **idx_part_collection_config_fixture) + + # Insert different data into the collection. + data2 = [{"id": i, "embedding": [i / 10] * 10, "age": 26 + i} for i in range(10)] + + response2 = milvus_service_fixture.insert(collection_name, data2) + assert response2["insert_count"] == len(data2) + + # Retrieve the data from the collection and check if it matches the second set of data. + retrieved_data = milvus_service_fixture.retrieve_by_keys(collection_name, list(range(10))) + for i in range(10): + assert retrieved_data[i]["age"] == data2[i]["age"] + + # Clean up the collection. + milvus_service_fixture.drop(collection_name) + + +@pytest.mark.slow +def test_insert_into_partition(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict, + data_fixture: list[dict]): + # Create a collection with a partition. + collection_name = "test_partition_collection" + partition_name = idx_part_collection_config_fixture["collection_conf"]["partition_conf"]["partitions"][0]["name"] + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + + # Insert data into the specified partition. + response = milvus_service_fixture.insert(collection_name, + data_fixture, + collection_conf={"partition_name": partition_name}) + assert response["insert_count"] == len(data_fixture) + + # Retrieve inserted data by primary keys. + keys_to_retrieve = [2, 4, 6] + retrieved_data = milvus_service_fixture.retrieve_by_keys(collection_name, + keys_to_retrieve, + partition_names=[partition_name]) + assert len(retrieved_data) == len(keys_to_retrieve) + + retrieved_data_default_part = milvus_service_fixture.retrieve_by_keys(collection_name, + keys_to_retrieve, + partition_names=["_default"]) + assert len(retrieved_data_default_part) == 0 + assert len(retrieved_data_default_part) != len(keys_to_retrieve) + + # Raises error if resource is partition and not passed partition name. + with pytest.raises(ValueError, match="Mandatory argument 'partition_name' is required when resource='partition'"): + milvus_service_fixture.drop(name=collection_name, resource="partition") + + # Clean up the partition + milvus_service_fixture.drop(name=collection_name, resource="partition", partition_name=partition_name) + + # Raises error if resource is index and not passed partition name. + with pytest.raises(ValueError, + match="Mandatory arguments 'field_name' and 'index_name' are required when resource='index'"): + milvus_service_fixture.drop(name=collection_name, resource="index") + + milvus_service_fixture.drop(name=collection_name, + resource="index", + field_name="embedding", + index_name="_default_idx_") + + retrieved_data_after_part_drop = milvus_service_fixture.retrieve_by_keys(collection_name, keys_to_retrieve) + assert len(retrieved_data_after_part_drop) == 0 + + # Clean up the collection. + milvus_service_fixture.drop(collection_name) + + +@pytest.mark.slow +def test_update(milvus_service_fixture: MilvusVectorDBService, + simple_collection_config_fixture: dict, + data_fixture: list[dict]): + collection_name = "test_update_collection" + + # Create a collection with the specified schema configuration. + milvus_service_fixture.create(collection_name, **simple_collection_config_fixture) + + # Insert data to the collection. + milvus_service_fixture.insert(collection_name, data_fixture) + + # Use updated data to test the update/upsert functionality. + updated_data = [{ + "type": MILVUS_DATA_TYPE_MAP["int64"], "name": "id", "values": list(range(5, 12)) + }, + { + "type": MILVUS_DATA_TYPE_MAP["float_vector"], + "name": "embedding", + "values": [[i / 5.0] * 10 for i in range(5, 12)] + }, { + "type": MILVUS_DATA_TYPE_MAP["int64"], "name": "age", "values": [25 + i for i in range(5, 12)] + }] + + # Apply update/upsert on updated_data. + result_dict = milvus_service_fixture.update(collection_name, updated_data) + + assert result_dict["upsert_count"] == 7 + assert result_dict["insert_count"] == 7 + assert result_dict["succ_count"] == 7 + + # Clean up the collection. + milvus_service_fixture.drop(collection_name) + + +@pytest.mark.slow +def test_delete_by_keys(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict, + data_fixture: list[dict]): + # Create a collection. + collection_name = "test_delete_by_keys_collection" + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + + # Insert data into the collection. + milvus_service_fixture.insert(collection_name, data_fixture) + + # Delete data by keys from the collection. + keys_to_delete = [2, 4, 6] + response = milvus_service_fixture.delete_by_keys(collection_name, keys_to_delete) + assert response == keys_to_delete + + # Clean up the collection. + milvus_service_fixture.drop(collection_name) + + +@pytest.mark.slow +def test_delete(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict, + data_fixture: list[dict]): + # Create a collection. + collection_name = "test_delete_collection" + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + + # Insert data into the collection. + milvus_service_fixture.insert(collection_name, data_fixture) + + # Delete expression. + delete_expr = "id in [0,1]" + + # Delete data from the collection using the expression. + delete_response = milvus_service_fixture.delete(collection_name, delete_expr) + assert delete_response["delete_count"] == 2 + + response = milvus_service_fixture.search(collection_name, query="id > 0") + assert len(response) == len(data_fixture) - 2 + + for item in response: + assert item["id"] > 1 + + # Clean up the collection. + milvus_service_fixture.drop(collection_name) + + +@pytest.mark.slow +def test_single_instance_with_collection_lock(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict, + data_fixture: list[dict]): + + # Create a collection. + collection_name = "test_insert_and_search_order_with_collection_lock" + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + + filter_query = "age == 26 or age == 27" + search_vec = np.random.random((1, 10)) + execution_order = [] + + def insert_data(): + result = milvus_service_fixture.insert(collection_name, data_fixture) + assert result['insert_count'] == len(data_fixture) + execution_order.append("Insert Executed") + + def search_data(): + result = milvus_service_fixture.search(collection_name, data=search_vec, filter=filter_query) + execution_order.append("Search Executed") + assert isinstance(result, list) + + def count_entities(): + milvus_service_fixture.count(collection_name) + execution_order.append("Count Collection Entities Executed") + + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + executor.submit(insert_data) + executor.submit(search_data) + executor.submit(count_entities) + + # Assert the execution order + assert execution_order == ["Count Collection Entities Executed", "Insert Executed", "Search Executed"] + + +@pytest.mark.slow +def test_multi_instance_with_collection_lock(milvus_service_fixture: MilvusVectorDBService, + idx_part_collection_config_fixture: dict, + data_fixture: list[dict], + milvus_server_uri: str): + + milvus_service_2 = MilvusVectorDBService(uri=milvus_server_uri) + + collection_name = "test_insert_and_search_order_with_collection_lock" + filter_query = "age == 26 or age == 27" + search_vec = np.random.random((1, 10)) + + execution_order = [] + + def create_collection(): + milvus_service_fixture.create(collection_name, **idx_part_collection_config_fixture) + execution_order.append("Create Executed") + + def insert_data(): + result = milvus_service_2.insert(collection_name, data_fixture) + assert result['insert_count'] == len(data_fixture) + execution_order.append("Insert Executed") + + def search_data(): + result = milvus_service_fixture.search(collection_name, data=search_vec, filter=filter_query) + execution_order.append("Search Executed") + assert isinstance(result, list) + + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + executor.submit(create_collection) + executor.submit(insert_data) + executor.submit(search_data) + + # Assert the execution order + assert execution_order == ["Create Executed", "Insert Executed", "Search Executed"] + + +def test_get_collection_lock(): + collection_name = "test_collection_lock" + lock = MilvusVectorDBService.get_collection_lock(collection_name) + assert "lock" == type(lock).__name__ + assert collection_name in MilvusVectorDBService._collection_locks diff --git a/tests/test_milvus_write_to_vector_db_stage_pipe.py b/tests/test_milvus_write_to_vector_db_stage_pipe.py new file mode 100755 index 0000000000..e065e54307 --- /dev/null +++ b/tests/test_milvus_write_to_vector_db_stage_pipe.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import random + +import pytest + +import cudf + +from _utils import TEST_DIRS +from morpheus.config import Config +from morpheus.messages import ControlMessage +from morpheus.modules import to_control_message # noqa: F401 # pylint: disable=unused-import +from morpheus.pipeline import LinearPipeline +from morpheus.service.milvus_vector_db_service import MilvusVectorDBService +from morpheus.stages.general.linear_modules_stage import LinearModulesStage +from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage +from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage +from morpheus.stages.output.write_to_vector_db import WriteToVectorDBStage +from morpheus.utils.module_ids import MORPHEUS_MODULE_NAMESPACE +from morpheus.utils.module_ids import TO_CONTROL_MESSAGE + + +@pytest.fixture(scope="function", name="milvus_service_fixture") +def milvus_service(milvus_server_uri: str): + service = MilvusVectorDBService(uri=milvus_server_uri) + yield service + + +def get_test_df(num_input_rows): + + df = cudf.DataFrame({ + "id": list(range(num_input_rows)), + "age": [random.randint(20, 40) for i in range(num_input_rows)], + "embedding": [[random.random() for _ in range(10)] for _ in range(num_input_rows)] + }) + + return df + + +def create_milvus_collection(collection_name: str, conf_file: str, service: MilvusVectorDBService): + + conf_filepath = os.path.join(TEST_DIRS.tests_data_dir, "service", conf_file) + + with open(conf_filepath, 'r', encoding="utf-8") as json_file: + collection_config = json.load(json_file) + + service.create(name=collection_name, overwrite=True, **collection_config) + + +@pytest.mark.slow +@pytest.mark.use_cpp +@pytest.mark.parametrize("use_instance, num_input_rows, expected_num_output_rows", [(True, 5, 5), (False, 5, 5)]) +def test_write_to_vector_db_stage_pipe(milvus_service_fixture: MilvusVectorDBService, + milvus_server_uri: str, + use_instance: bool, + config: Config, + num_input_rows: int, + expected_num_output_rows: int): + + collection_name = "test_stage_insert_collection" + + # Create milvus collection using config file. + create_milvus_collection(collection_name, "milvus_idx_part_collection_conf.json", milvus_service_fixture) + df = get_test_df(num_input_rows) + + to_cm_module_config = { + "module_id": TO_CONTROL_MESSAGE, "module_name": "to_control_message", "namespace": MORPHEUS_MODULE_NAMESPACE + } + + pipe = LinearPipeline(config) + pipe.set_source(InMemorySourceStage(config, [df])) + pipe.add_stage( + LinearModulesStage(config, + to_cm_module_config, + input_port_name="input", + output_port_name="output", + output_type=ControlMessage)) + + # Provide partition name to insert data into the partition otherwise goes to '_default' partition. + resource_kwargs = {"collection_conf": {"partition_name": "age_partition"}} + + if use_instance: + # Instantiate stage with service instance and insert options. + write_to_vdb_stage = WriteToVectorDBStage(config, + resource_name=collection_name, + service=milvus_service_fixture, + resource_kwargs=resource_kwargs) + else: + # Instantiate stage with service name, uri and insert options. + write_to_vdb_stage = WriteToVectorDBStage(config, + resource_name=collection_name, + service="milvus", + uri=milvus_server_uri, + resource_kwargs=resource_kwargs) + + pipe.add_stage(write_to_vdb_stage) + sink_stage = pipe.add_stage(InMemorySinkStage(config)) + pipe.run() + + messages = sink_stage.get_messages() + + assert len(messages) == 1 + assert isinstance(messages[0], ControlMessage) + assert messages[0].has_metadata("insert_response") + + # Insert entities response as a dictionary. + response = messages[0].get_metadata("insert_response") + + assert response["insert_count"] == expected_num_output_rows + assert response["succ_count"] == expected_num_output_rows + assert response["err_count"] == 0 diff --git a/tests/tests_data/service/milvus_idx_part_collection_conf.json b/tests/tests_data/service/milvus_idx_part_collection_conf.json new file mode 100644 index 0000000000..4d652055eb --- /dev/null +++ b/tests/tests_data/service/milvus_idx_part_collection_conf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed114c4945f46cd7748ff96c868ddc02940aa3fa4dcc3857a23225f8943fe292 +size 1057 diff --git a/tests/tests_data/service/milvus_simple_collection_conf.json b/tests/tests_data/service/milvus_simple_collection_conf.json new file mode 100644 index 0000000000..399463cfeb --- /dev/null +++ b/tests/tests_data/service/milvus_simple_collection_conf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4d5984904a207dad3a58515b1327f309542b80729a6172b02399ad67a6a83cd +size 766