From 7579792cb35575eb3bc6b941255399bb929d4775 Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Wed, 8 May 2024 01:22:13 +0800 Subject: [PATCH] Move 'create_collection' to vector db service --- app/Services/vector_db_context.py | 18 ++++++++++++++++++ main.py | 2 +- scripts/qdrant_create_collection.py | 16 ++++------------ 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index 21c392a..36fc7f8 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -182,6 +182,24 @@ async def get_counts(self, exact: bool) -> int: resp = await self._client.count(collection_name=self.collection_name, exact=exact) return resp.count + async def check_collection(self) -> bool: + resp = await self._client.get_collections() + resp = [t.name for t in resp.collections] + return self.collection_name in resp + + async def initialize_collection(self): + if await self.check_collection(): + logger.warning("Collection already exists. Skip initialization.") + return + logger.info("Initializing database, collection name: {}", self.collection_name) + vectors_config = { + self.IMG_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE), + self.TEXT_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE) + } + await self._client.create_collection(collection_name=self.collection_name, + vectors_config=vectors_config) + logger.success("Collection created!") + @classmethod def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors: vector = {} diff --git a/main.py b/main.py index c1532a3..84c60ac 100644 --- a/main.py +++ b/main.py @@ -50,7 +50,7 @@ def parse_args(): from scripts import qdrant_create_collection from app.config import config - qdrant_create_collection.create_coll(config.qdrant.host, config.qdrant.port, config.qdrant.coll) + asyncio.run(qdrant_create_collection.main()) elif args.migrate_from_version is not None: from scripts import db_migrations diff --git a/scripts/qdrant_create_collection.py b/scripts/qdrant_create_collection.py index df7d89c..3095485 100644 --- a/scripts/qdrant_create_collection.py +++ b/scripts/qdrant_create_collection.py @@ -1,14 +1,6 @@ -from qdrant_client import qdrant_client, models +from app.Services.vector_db_context import VectorDbContext -def create_coll(host, port, name): - client = qdrant_client.QdrantClient(host=host, port=port) - # create or update - print("Creating collection") - vectors_config = { - "image_vector": models.VectorParams(size=768, distance=models.Distance.COSINE), - "text_contain_vector": models.VectorParams(size=768, distance=models.Distance.COSINE) - } - client.create_collection(collection_name=name, - vectors_config=vectors_config) - print("Collection created") +async def main(): + context = VectorDbContext() + await context.initialize_collection()