Skip to content

Commit

Permalink
Move 'create_collection' to vector db service
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed May 7, 2024
1 parent fa1d896 commit 7579792
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
18 changes: 18 additions & 0 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,24 @@ async def get_counts(self, exact: bool) -> int:
resp = await self._client.count(collection_name=self.collection_name, exact=exact)
return resp.count

async def check_collection(self) -> bool:
resp = await self._client.get_collections()
resp = [t.name for t in resp.collections]
return self.collection_name in resp

async def initialize_collection(self):
if await self.check_collection():
logger.warning("Collection already exists. Skip initialization.")
return
logger.info("Initializing database, collection name: {}", self.collection_name)
vectors_config = {
self.IMG_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE),
self.TEXT_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE)
}
await self._client.create_collection(collection_name=self.collection_name,
vectors_config=vectors_config)
logger.success("Collection created!")

@classmethod
def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors:
vector = {}
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def parse_args():
from scripts import qdrant_create_collection
from app.config import config

qdrant_create_collection.create_coll(config.qdrant.host, config.qdrant.port, config.qdrant.coll)
asyncio.run(qdrant_create_collection.main())

elif args.migrate_from_version is not None:
from scripts import db_migrations
Expand Down
16 changes: 4 additions & 12 deletions scripts/qdrant_create_collection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
from qdrant_client import qdrant_client, models
from app.Services.vector_db_context import VectorDbContext


def create_coll(host, port, name):
client = qdrant_client.QdrantClient(host=host, port=port)
# create or update
print("Creating collection")
vectors_config = {
"image_vector": models.VectorParams(size=768, distance=models.Distance.COSINE),
"text_contain_vector": models.VectorParams(size=768, distance=models.Distance.COSINE)
}
client.create_collection(collection_name=name,
vectors_config=vectors_config)
print("Collection created")
async def main():
context = VectorDbContext()
await context.initialize_collection()

0 comments on commit 7579792

Please sign in to comment.