From c7de8011d2feac651ed2403608bcb42f8847dae2 Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Tue, 10 Sep 2024 21:03:00 +0800 Subject: [PATCH] add hybrid_search for MilvusClient (#2259) Signed-off-by: zhenshan.cao --- examples/hybrid_search.py | 2 + examples/milvus_client/hybrid_search.py | 75 +++++++++++++++++++++++++ pymilvus/milvus_client/milvus_client.py | 72 ++++++++++++++++++++++++ 3 files changed, 149 insertions(+) create mode 100644 examples/milvus_client/hybrid_search.py diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py index 02e85343a..6a13045f0 100644 --- a/examples/hybrid_search.py +++ b/examples/hybrid_search.py @@ -16,6 +16,8 @@ has = utility.has_collection("hello_milvus") print(f"Does collection hello_milvus exist in Milvus: {has}") +if has: + utility.drop_collection("hello_milvus") fields = [ FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), diff --git a/examples/milvus_client/hybrid_search.py b/examples/milvus_client/hybrid_search.py new file mode 100644 index 000000000..28ae0b309 --- /dev/null +++ b/examples/milvus_client/hybrid_search.py @@ -0,0 +1,75 @@ +import numpy as np +from pymilvus import ( + MilvusClient, + DataType, + AnnSearchRequest, RRFRanker, WeightedRanker, +) + +fmt = "\n=== {:30} ===\n" +search_latency_fmt = "search latency = {:.4f}s" +num_entities, dim = 3000, 8 + +collection_name = "hello_milvus" +milvus_client = MilvusClient("http://localhost:19530") + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) + +schema = milvus_client.create_schema(auto_id=False, description="hello_milvus is the simplest demo to introduce the APIs") +schema.add_field("pk", DataType.VARCHAR, is_primary=True, max_length=100) +schema.add_field("random", DataType.DOUBLE) +schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim) +schema.add_field("embeddings2", DataType.FLOAT_VECTOR, dim=dim) + +index_params = milvus_client.prepare_index_params() +index_params.add_index(field_name = "embeddings", index_type = "IVF_FLAT", metric_type="L2", nlist=128) +index_params.add_index(field_name = "embeddings2",index_type = "IVF_FLAT", metric_type="L2", nlist=128) + +print(fmt.format("Create collection `hello_milvus`")) + +milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong") + +print(fmt.format("Start inserting entities")) +rng = np.random.default_rng(seed=19530) +entities = [ + # provide the pk field because `auto_id` is set to False + [str(i) for i in range(num_entities)], + rng.random(num_entities).tolist(), # field random, only supports list + rng.random((num_entities, dim)), # field embeddings, supports numpy.ndarray and list + rng.random((num_entities, dim)), # field embeddings2, supports numpy.ndarray and list +] + +rows = [ {"pk": entities[0][i], "random": entities[1][i], "embeddings": entities[2][i], "embeddings2": entities[3][i]} for i in range (num_entities)] + +insert_result = milvus_client.insert(collection_name, rows) + + +print(fmt.format("Start loading")) +milvus_client.load_collection(collection_name) + +field_names = ["embeddings", "embeddings2"] +field_names = ["embeddings"] + +req_list = [] +nq = 1 +default_limit = 5 +vectors_to_search = [] + +for i in range(len(field_names)): + # 4. generate search data + vectors_to_search = rng.random((nq, dim)) + search_param = { + "data": vectors_to_search, + "anns_field": field_names[i], + "param": {"metric_type": "L2"}, + "limit": default_limit, + "expr": "random > 0.5"} + req = AnnSearchRequest(**search_param) + req_list.append(req) + +print("rank by RRFRanker") +hybrid_res = milvus_client.hybrid_search(collection_name, req_list, RRFRanker(), default_limit, output_fields=["random"]) +for hits in hybrid_res: + for hit in hits: + print(f" hybrid search hit: {hit}") diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index d0fbb798f..0047c8ea3 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Union from uuid import uuid4 +from pymilvus.client.abstract import AnnSearchRequest, BaseRanker from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL from pymilvus.client.types import ( ExceptionsMessage, @@ -282,6 +283,77 @@ def upsert( } ) + def hybrid_search( + self, + collection_name: str, + reqs: List[AnnSearchRequest], + ranker: BaseRanker, + limit: int = 10, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + **kwargs, + ) -> List[List[dict]]: + """Conducts multi vector similarity search with a rerank for rearrangement. + + Args: + collection_name(``string``): The name of collection. + reqs (``List[AnnSearchRequest]``): The vector search requests. + ranker (``BaseRanker``): The ranker for rearrange nummer of limit results. + limit (``int``): The max number of returned record, also known as `topk`. + + partition_names (``List[str]``, optional): The names of partitions to search on. + output_fields (``List[str]``, optional): + The name of fields to return in the search result. Can only get scalar fields. + round_decimal (``int``, optional): + The specified number of decimal places of returned distance. + Defaults to -1 means no round to returned distance. + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. + If timeout is set to None, the client keeps waiting until the server + responds or an error occurs. + **kwargs (``dict``): Optional search params + + * *offset* (``int``, optinal) + offset for pagination. + + * *consistency_level* (``str/int``, optional) + Which consistency level to use when searching in the collection. + + Options of consistency level: Strong, Bounded, Eventually, Session, Customized. + + Note: this parameter overwrites the same one specified when creating collection, + if no consistency level was specified, search will use the + consistency level when you create the collection. + + Returns: + List[List[dict]]: A nested list of dicts containing the result data. + + Raises: + MilvusException: If anything goes wrong + """ + + conn = self._get_connection() + try: + res = conn.hybrid_search( + collection_name, + reqs, + ranker, + limit=limit, + partition_names=partition_names, + output_fields=output_fields, + timeout=timeout, + **kwargs, + ) + except Exception as ex: + logger.error("Failed to hybrid search collection: %s", collection_name) + raise ex from ex + + ret = [] + for hits in res: + ret.append([hit.to_dict() for hit in hits]) + + return ExtraList(ret, extra=construct_cost_extra(res.cost)) + def search( self, collection_name: str,