diff --git a/llama_index/vector_stores/qdrant.py b/llama_index/vector_stores/qdrant.py index a6cf0b88d6063..16dc728f120be 100644 --- a/llama_index/vector_stores/qdrant.py +++ b/llama_index/vector_stores/qdrant.py @@ -47,6 +47,17 @@ class QdrantVectorStore(BasePydanticVectorStore): Args: collection_name: (str): name of the Qdrant collection client (Optional[Any]): QdrantClient instance from `qdrant-client` package + aclient (Optional[Any]): AsyncQdrantClient instance from `qdrant-client` package + url (Optional[str]): url of the Qdrant instance + api_key (Optional[str]): API key for authenticating with Qdrant + batch_size (int): number of points to upload in a single request to Qdrant. Defaults to 64 + parallel (int): number of parallel processes to use during upload. Defaults to 1 + max_retries (int): maximum number of retries in case of a failure. Defaults to 3 + client_kwargs (Optional[dict]): additional kwargs for QdrantClient and AsyncQdrantClient + enable_hybrid (bool): whether to enable hybrid search using dense and sparse vectors + sparse_doc_fn (Optional[SparseEncoderCallable]): function to encode sparse vectors + sparse_query_fn (Optional[SparseEncoderCallable]): function to encode sparse queries + hybrid_fusion_fn (Optional[HybridFusionCallable]): function to fuse hybrid search results """ stores_text: bool = True @@ -57,6 +68,8 @@ class QdrantVectorStore(BasePydanticVectorStore): url: Optional[str] api_key: Optional[str] batch_size: int + parallel: int + max_retries: int client_kwargs: dict = Field(default_factory=dict) enable_hybrid: bool @@ -74,7 +87,9 @@ def __init__( aclient: Optional[Any] = None, url: Optional[str] = None, api_key: Optional[str] = None, - batch_size: int = 100, + batch_size: int = 64, + parallel: int = 1, + max_retries: int = 3, client_kwargs: Optional[dict] = None, enable_hybrid: bool = False, sparse_doc_fn: Optional[SparseEncoderCallable] = None, @@ -138,6 +153,8 @@ def __init__( url=url, api_key=api_key, batch_size=batch_size, + parallel=parallel, + max_retries=max_retries, client_kwargs=client_kwargs or {}, enable_hybrid=enable_hybrid, ) @@ -227,12 +244,14 @@ def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]: points, ids = self._build_points(nodes) - # batch upsert the points into Qdrant collection to avoid large payloads - for points_batch in iter_batch(points, self.batch_size): - self._client.upsert( - collection_name=self.collection_name, - points=points_batch, - ) + self._client.upload_points( + collection_name=self.collection_name, + points=points, + batch_size=self.batch_size, + parallel=self.parallel, + max_retries=self.max_retries, + wait=True, + ) return ids @@ -259,12 +278,14 @@ async def async_add(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]: points, ids = self._build_points(nodes) - # batch upsert the points into Qdrant collection to avoid large payloads - for points_batch in iter_batch(points, self.batch_size): - await self._aclient.upsert( - collection_name=self.collection_name, - points=points_batch, - ) + await self._aclient.upload_points( + collection_name=self.collection_name, + points=points, + batch_size=self.batch_size, + parallel=self.parallel, + max_retries=self.max_retries, + wait=True, + ) return ids