Skip to content

Commit

Permalink
refactor: Qdrant points upload implementation (run-llama#10462)
Browse files Browse the repository at this point in the history
refactor: Qdrant upload
  • Loading branch information
Anush008 authored Feb 5, 2024
1 parent 61011d7 commit 703c346
Showing 1 changed file with 34 additions and 13 deletions.
47 changes: 34 additions & 13 deletions llama_index/vector_stores/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ class QdrantVectorStore(BasePydanticVectorStore):
Args:
collection_name: (str): name of the Qdrant collection
client (Optional[Any]): QdrantClient instance from `qdrant-client` package
aclient (Optional[Any]): AsyncQdrantClient instance from `qdrant-client` package
url (Optional[str]): url of the Qdrant instance
api_key (Optional[str]): API key for authenticating with Qdrant
batch_size (int): number of points to upload in a single request to Qdrant. Defaults to 64
parallel (int): number of parallel processes to use during upload. Defaults to 1
max_retries (int): maximum number of retries in case of a failure. Defaults to 3
client_kwargs (Optional[dict]): additional kwargs for QdrantClient and AsyncQdrantClient
enable_hybrid (bool): whether to enable hybrid search using dense and sparse vectors
sparse_doc_fn (Optional[SparseEncoderCallable]): function to encode sparse vectors
sparse_query_fn (Optional[SparseEncoderCallable]): function to encode sparse queries
hybrid_fusion_fn (Optional[HybridFusionCallable]): function to fuse hybrid search results
"""

stores_text: bool = True
Expand All @@ -57,6 +68,8 @@ class QdrantVectorStore(BasePydanticVectorStore):
url: Optional[str]
api_key: Optional[str]
batch_size: int
parallel: int
max_retries: int
client_kwargs: dict = Field(default_factory=dict)
enable_hybrid: bool

Expand All @@ -74,7 +87,9 @@ def __init__(
aclient: Optional[Any] = None,
url: Optional[str] = None,
api_key: Optional[str] = None,
batch_size: int = 100,
batch_size: int = 64,
parallel: int = 1,
max_retries: int = 3,
client_kwargs: Optional[dict] = None,
enable_hybrid: bool = False,
sparse_doc_fn: Optional[SparseEncoderCallable] = None,
Expand Down Expand Up @@ -138,6 +153,8 @@ def __init__(
url=url,
api_key=api_key,
batch_size=batch_size,
parallel=parallel,
max_retries=max_retries,
client_kwargs=client_kwargs or {},
enable_hybrid=enable_hybrid,
)
Expand Down Expand Up @@ -227,12 +244,14 @@ def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:

points, ids = self._build_points(nodes)

# batch upsert the points into Qdrant collection to avoid large payloads
for points_batch in iter_batch(points, self.batch_size):
self._client.upsert(
collection_name=self.collection_name,
points=points_batch,
)
self._client.upload_points(
collection_name=self.collection_name,
points=points,
batch_size=self.batch_size,
parallel=self.parallel,
max_retries=self.max_retries,
wait=True,
)

return ids

Expand All @@ -259,12 +278,14 @@ async def async_add(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:

points, ids = self._build_points(nodes)

# batch upsert the points into Qdrant collection to avoid large payloads
for points_batch in iter_batch(points, self.batch_size):
await self._aclient.upsert(
collection_name=self.collection_name,
points=points_batch,
)
await self._aclient.upload_points(
collection_name=self.collection_name,
points=points,
batch_size=self.batch_size,
parallel=self.parallel,
max_retries=self.max_retries,
wait=True,
)

return ids

Expand Down

0 comments on commit 703c346

Please sign in to comment.