diff --git a/app/Models/img_data.py b/app/Models/img_data.py index c0a976c..f82ba47 100644 --- a/app/Models/img_data.py +++ b/app/Models/img_data.py @@ -3,7 +3,7 @@ from uuid import UUID from numpy import ndarray -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, computed_field class ImageData(BaseModel): @@ -19,10 +19,19 @@ class ImageData(BaseModel): aspect_ratio: Optional[float] = None starred: Optional[bool] = False categories: Optional[list[str]] = [] + local: Optional[bool] = False + + @computed_field() + @property + def ocr_text_lower(self) -> str | None: + if self.ocr_text is None: + return None + return self.ocr_text.lower() + @property def payload(self): - result = self.model_dump(exclude={'image_vector', 'text_contain_vector', 'id', 'index_date'}) + result = self.model_dump(exclude={'id', 'index_date'}) # Qdrant database cannot accept datetime object, so we have to convert it to string result['index_date'] = self.index_date.isoformat() return result diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index 5615ea9..b4cb497 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -4,7 +4,6 @@ from loguru import logger from qdrant_client import AsyncQdrantClient from qdrant_client.http import models -from qdrant_client.http.models import PointStruct from qdrant_client.models import RecommendStrategy from app.Models.api_model import SearchModelEnum, SearchBasisEnum @@ -25,32 +24,31 @@ class VectorDbContext: TEXT_VECTOR = "text_contain_vector" def __init__(self): - self.client = AsyncQdrantClient(host=config.qdrant.host, port=config.qdrant.port, - grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key, - prefer_grpc=config.qdrant.prefer_grpc) + self._client = AsyncQdrantClient(host=config.qdrant.host, port=config.qdrant.port, + grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key, + prefer_grpc=config.qdrant.prefer_grpc) self.collection_name = config.qdrant.coll async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData: logger.info("Retrieving item {} from database...", image_id) - result = await self.client.retrieve(collection_name=self.collection_name, ids=[image_id], with_payload=True, - with_vectors=with_vectors) + result = await self._client.retrieve(collection_name=self.collection_name, ids=[image_id], with_payload=True, + with_vectors=with_vectors) if len(result) != 1: logger.error("Point not exist.") raise PointNotFoundError(image_id) - return ImageData.from_payload(result[0].id, result[0].payload, - numpy.array(result[0].vector, dtype=numpy.float32) if with_vectors else None) + return self._get_img_data_from_point(result[0]) async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR, top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]: logger.info("Querying Qdrant... top_k = {}", top_k) - result = await self.client.search(collection_name=self.collection_name, - query_vector=(query_vector_name, query_vector), - query_filter=self.getFiltersByFilterParam(filter_param), - limit=top_k, - offset=skip, - with_payload=True) + result = await self._client.search(collection_name=self.collection_name, + query_vector=(query_vector_name, query_vector), + query_filter=self.getFiltersByFilterParam(filter_param), + limit=top_k, + offset=skip, + with_payload=True) logger.success("Query completed!") - return [SearchResult(img=ImageData.from_payload(t.id, t.payload), score=t.score) for t in result] + return [self._get_search_result_from_scored_point(t) for t in result] async def querySimilar(self, query_vector_name: str = IMG_VECTOR, @@ -70,61 +68,37 @@ async def querySimilar(self, _combined_search_need_vectors = [ self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR] if with_vectors else None logger.info("Querying Qdrant... top_k = {}", top_k) - result = await self.client.recommend(collection_name=self.collection_name, - using=query_vector_name, - positive=_positive_vectors, - negative=_negative_vectors, - strategy=_strategy, - with_vectors=_combined_search_need_vectors, - query_filter=self.getFiltersByFilterParam(filter_param), - limit=top_k, - offset=skip, - with_payload=True) + result = await self._client.recommend(collection_name=self.collection_name, + using=query_vector_name, + positive=_positive_vectors, + negative=_negative_vectors, + strategy=_strategy, + with_vectors=_combined_search_need_vectors, + query_filter=self.getFiltersByFilterParam(filter_param), + limit=top_k, + offset=skip, + with_payload=True) logger.success("Query completed!") - def result_transform(t): - return SearchResult( - img=ImageData.from_payload( - t.id, - t.payload, - numpy.array(t.vector['image_vector']) if t.vector and 'image_vector' in t.vector else None, - numpy.array( - t.vector['text_contain_vector']) if t.vector and 'text_contain_vector' in t.vector else None - ), - score=t.score - ) - - return [result_transform(t) for t in result] + return [self._get_search_result_from_scored_point(t) for t in result] async def insertItems(self, items: list[ImageData]): logger.info("Inserting {} items into Qdrant...", len(items)) - def get_point(img_data): - vector = { - self.IMG_VECTOR: img_data.image_vector.tolist(), - } - if img_data.text_contain_vector is not None: - vector[self.TEXT_VECTOR] = img_data.text_contain_vector.tolist() - return PointStruct( - id=str(img_data.id), - vector=vector, - payload=img_data.payload - ) - - points = [get_point(t) for t in items] - - response = await self.client.upsert(collection_name=self.collection_name, - wait=True, - points=points) + points = [self._get_point_from_img_data(t) for t in items] + + response = await self._client.upsert(collection_name=self.collection_name, + wait=True, + points=points) logger.success("Insert completed! Status: {}", response.status) async def deleteItems(self, ids: list[str]): logger.info("Deleting {} items from Qdrant...", len(ids)) - response = await self.client.delete(collection_name=self.collection_name, - points_selector=models.PointIdsList( - points=ids - ), - ) + response = await self._client.delete(collection_name=self.collection_name, + points_selector=models.PointIdsList( + points=ids + ), + ) logger.success("Delete completed! Status: {}", response.status) async def updatePayload(self, new_data: ImageData): @@ -133,12 +107,65 @@ async def updatePayload(self, new_data: ImageData): Warning: This method will not update the vector of the item. :param new_data: The new data to update. """ - response = await self.client.set_payload(collection_name=self.collection_name, - payload=new_data.payload, - points=[str(new_data.id)], - wait=True) + response = await self._client.set_payload(collection_name=self.collection_name, + payload=new_data.payload, + points=[str(new_data.id)], + wait=True) logger.success("Update completed! Status: {}", response.status) + async def updateVectors(self, new_points: list[ImageData]): + resp = await self._client.update_vectors(collection_name=self.collection_name, + points=[self._get_vector_from_img_data(t) for t in new_points], + ) + logger.success("Update vectors completed! Status: {}", resp.status) + + async def scroll_points(self, + from_id: str | None = None, + count=50, + with_vectors=False) -> tuple[list[ImageData], str]: + resp, next_id = await self._client.scroll(collection_name=self.collection_name, + limit=count, + offset=from_id, + with_vectors=with_vectors + ) + + return [self._get_img_data_from_point(t) for t in resp], next_id + + @classmethod + def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors: + vector = {} + if img_data.image_vector is not None: + vector[cls.IMG_VECTOR] = img_data.image_vector.tolist() + if img_data.text_contain_vector is not None: + vector[cls.TEXT_VECTOR] = img_data.text_contain_vector.tolist() + return models.PointVectors( + id=str(img_data.id), + vector=vector + ) + + @classmethod + def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct: + return models.PointStruct( + id=str(img_data.id), + payload=img_data.payload, + vector=cls._get_vector_from_img_data(img_data).vector + ) + + @classmethod + def _get_img_data_from_point(cls, point: models.Record | models.ScoredPoint | models.PointStruct) -> ImageData: + return (ImageData + .from_payload(point.id, + point.payload, + image_vector=numpy.array(point.vector[cls.IMG_VECTOR], dtype=numpy.float32) + if point.vector and cls.IMG_VECTOR in point.vector else None, + text_contain_vector=numpy.array(point.vector[cls.TEXT_VECTOR], dtype=numpy.float32) + if point.vector and cls.TEXT_VECTOR in point.vector else None + )) + + @classmethod + def _get_search_result_from_scored_point(cls, point: models.ScoredPoint) -> SearchResult: + return SearchResult(img=cls._get_img_data_from_point(point), score=point.score) + @classmethod def getVectorByBasis(cls, basis: SearchBasisEnum) -> str: match basis: diff --git a/main.py b/main.py index 22ae42f..fe3c6d8 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import argparse +import asyncio import collections import uvicorn @@ -13,6 +14,9 @@ def parse_args(): help="Initialize qdrant database using connection settings in " "config.py. When this flag is set, will not" "start the server.") + parser.add_argument('--migrate-db', dest="migrate_from_version", type=int, + help="Migrate qdrant database using connection settings in config from version specified." + "When this flag is set, will not start the server.") parser.add_argument('--local-index', dest="local_index_target_dir", type=str, help="Index all the images in this directory and copy them to " "static folder set in config.py. When this flag is set, " @@ -39,17 +43,19 @@ def parse_args(): qdrant_create_collection.create_coll( collections.namedtuple('Options', ['host', 'port', 'name'])(config.qdrant.host, config.qdrant.port, config.qdrant.coll)) + elif args.migrate_from_version is not None: + from scripts import db_migrations + + asyncio.run(db_migrations.migrate(args.migrate_from_version)) elif args.local_index_target_dir is not None: from app.config import environment environment.local_indexing = True from scripts import local_indexing - import asyncio asyncio.run(local_indexing.main(args)) elif args.local_create_thumbnail: from scripts import local_create_thumbnail - import asyncio asyncio.run(local_create_thumbnail.main()) else: diff --git a/scripts/db_migrations.py b/scripts/db_migrations.py new file mode 100644 index 0000000..86805a9 --- /dev/null +++ b/scripts/db_migrations.py @@ -0,0 +1,40 @@ +from loguru import logger + +from app.Services import db_context, transformers_service + +CURRENT_VERSION = 2 + + +async def migrate_v1_v2(): + logger.info("Migrating from v1 to v2...") + next_id = None + count = 0 + while True: + points, next_id = await db_context.scroll_points(next_id, count=100) + for point in points: + count += 1 + logger.info("[{}] Migrating point {}", count, point.id) + if point.url.startswith('/'): + # V1 database assuming all image with '/' as begins is a local image, + # v2 migrate to a more strict approach + point.local = True + await db_context.updatePayload(point) # This will also store ocr_text_lower field, if present + if point.ocr_text is not None: + point.text_contain_vector = transformers_service.get_bert_vector(point.ocr_text_lower) + + logger.info("Updating vectors...") + # Update vectors for this group of points + await db_context.updateVectors([t for t in points if t.text_contain_vector is not None]) + if next_id is None: + break + + +async def migrate(from_version: int): + match from_version: + case 1: + await migrate_v1_v2() + case 2: + logger.info("Already up to date.") + pass + case _: + raise Exception(f"Unknown version {from_version}") diff --git a/scripts/local_indexing.py b/scripts/local_indexing.py index 118f5d1..4f6853f 100644 --- a/scripts/local_indexing.py +++ b/scripts/local_indexing.py @@ -50,7 +50,8 @@ def copy_and_index(file_path: Path) -> ImageData | None: width=width, height=height, aspect_ratio=float(width) / height, - ocr_text=image_ocr_result) + ocr_text=image_ocr_result, + local=True) # copy to static copy2(file_path, Path(config.static_file.path) / f'{image_id}{img_ext}')