From abfc09ae353c031bd878877dc1a4efdebe2944a1 Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Sun, 24 Dec 2023 23:08:13 +0800 Subject: [PATCH] Fix lint --- app/Controllers/admin.py | 9 ++++++--- app/Controllers/search.py | 8 ++++---- app/Services/transformers_service.py | 2 +- app/Services/vector_db_context.py | 30 ++++++++++++++++------------ requirements.txt | 4 ++-- scripts/local_indexing.py | 24 ++++++---------------- 6 files changed, 36 insertions(+), 41 deletions(-) diff --git a/app/Controllers/admin.py b/app/Controllers/admin.py index c5d1635..0014175 100644 --- a/app/Controllers/admin.py +++ b/app/Controllers/admin.py @@ -1,6 +1,9 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends + +from app.Services.authentication import force_admin_token_verify + +admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)]) -admin_router = APIRouter() def add_image_info(): - pass \ No newline at end of file + pass diff --git a/app/Controllers/search.py b/app/Controllers/search.py index 3e9fd56..965feb6 100644 --- a/app/Controllers/search.py +++ b/app/Controllers/search.py @@ -131,14 +131,14 @@ async def randomPick( return SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()) -@searchRouter.get("/recall/{queryId}", description="Recall the query with given queryId") -async def recallQuery(queryId: str): +@searchRouter.get("/recall/{query_id}", description="Recall the query with given queryId") +async def recallQuery(query_id: str): raise NotImplementedError() async def process_advanced_and_combined_search_query(model: Union[AdvancedSearchModel, CombinedSearchModel], basis: Union[SearchBasisParams, SearchCombinedParams], - filter: FilterParams, + filter_param: FilterParams, paging: SearchPagingParams) -> List[SearchResult]: if basis.basis == SearchBasisEnum.ocr: positive_vectors = [transformers_service.get_bert_vector(t) for t in model.criteria] @@ -151,7 +151,7 @@ async def process_advanced_and_combined_search_query(model: Union[AdvancedSearch positive_vectors=positive_vectors, negative_vectors=negative_vectors, mode=model.mode, - filter_param=filter, + filter_param=filter_param, with_vectors=True if isinstance(basis, SearchCombinedParams) else False, top_k=paging.count, skip=paging.skip) diff --git a/app/Services/transformers_service.py b/app/Services/transformers_service.py index dee10bc..d4f580b 100644 --- a/app/Services/transformers_service.py +++ b/app/Services/transformers_service.py @@ -8,7 +8,7 @@ from torch import FloatTensor, no_grad from transformers import CLIPProcessor, CLIPModel, BertTokenizer, BertModel -from app.config import config, environment +from app.config import config class TransformersService: diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index c8765c4..e069e68 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -1,6 +1,7 @@ +from typing import Optional + import numpy from loguru import logger -from typing import Optional from qdrant_client import AsyncQdrantClient from qdrant_client.http import models from qdrant_client.http.models import PointStruct @@ -30,8 +31,7 @@ async def retrieve_by_id(self, id: str, with_vectors=False) -> ImageData: numpy.array(result[0].vector, dtype=numpy.float32) if with_vectors else None) async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR, - top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[ - SearchResult]: + top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]: logger.info("Querying Qdrant... top_k = {}", top_k) result = await self.client.search(collection_name=self.collection_name, query_vector=(query_vector_name, query_vector), @@ -57,7 +57,7 @@ async def querySimilar(self, _strategy = None if mode is None else (RecommendStrategy.AVERAGE_VECTOR if mode == SearchModelEnum.average else RecommendStrategy.BEST_SCORE) # since only combined_search need return vectors, We can define _combined_search_need_vectors like below - _combined_search_need_vectors = [self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR]\ + _combined_search_need_vectors = [self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR] \ if with_vectors else None logger.info("Querying Qdrant... top_k = {}", top_k) result = await self.client.recommend(collection_name=self.collection_name, @@ -71,15 +71,19 @@ async def querySimilar(self, offset=skip, with_payload=True) logger.success("Query completed!") - result_transform = lambda t: SearchResult( - img=ImageData.from_payload( - t.id, - t.payload, - numpy.array(t.vector['image_vector']) if t.vector and 'image_vector' in t.vector else None, - numpy.array(t.vector['text_contain_vector']) if t.vector and 'text_contain_vector' in t.vector else None - ), - score=t.score - ) + + def result_transform(t): + return SearchResult( + img=ImageData.from_payload( + t.id, + t.payload, + numpy.array(t.vector['image_vector']) if t.vector and 'image_vector' in t.vector else None, + numpy.array( + t.vector['text_contain_vector']) if t.vector and 'text_contain_vector' in t.vector else None + ), + score=t.score + ) + return [result_transform(t) for t in result] async def insertItems(self, items: list[ImageData]): diff --git a/requirements.txt b/requirements.txt index 137e69e..ce6253b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,8 @@ pydantic-settings # AI - Manually install cuda-capable pytorch torch torchvision -transformers -pillow +transformers>4.35.2 +pillow>9.3.0 numpy # OCR - you can choose other option if necessary, or completely disable it if you don't need this feature diff --git a/scripts/local_indexing.py b/scripts/local_indexing.py index bb0259e..ee5e6ff 100644 --- a/scripts/local_indexing.py +++ b/scripts/local_indexing.py @@ -1,11 +1,4 @@ - -if __name__ == '__main__': - import sys - - sys.path.insert(1, './') - import argparse -import asyncio from datetime import datetime from pathlib import Path from shutil import copy2 @@ -26,14 +19,14 @@ def parse_args(): return parser.parse_args() -def copy_and_index(filePath: Path) -> ImageData | None: +def copy_and_index(file_path: Path) -> ImageData | None: try: - img = Image.open(filePath) + img = Image.open(file_path) except Exception as e: - logger.error("Error when opening image {}: {}", filePath, e) + logger.error("Error when opening image {}: {}", file_path, e) return None id = uuid4() - img_ext = filePath.suffix + img_ext = file_path.suffix image_ocr_result = None text_contain_vector = None [width, height] = img.size @@ -46,7 +39,7 @@ def copy_and_index(filePath: Path) -> ImageData | None: else: image_ocr_result = None except Exception as e: - logger.error("Error when processing image {}: {}", filePath, e) + logger.error("Error when processing image {}: {}", file_path, e) return None imgdata = ImageData(id=id, url=f'/static/{id}{img_ext}', @@ -59,7 +52,7 @@ def copy_and_index(filePath: Path) -> ImageData | None: ocr_text=image_ocr_result) # copy to static - copy2(filePath, Path(config.static_file.path) / f'{id}{img_ext}') + copy2(file_path, Path(config.static_file.path) / f'{id}{img_ext}') return imgdata @@ -88,8 +81,3 @@ async def main(args): logger.info("Upload {} element to database", len(buffer)) await db_context.insertItems(buffer) logger.success("Indexing completed! {} images indexed", counter) - - -if __name__ == '__main__': - args = parse_args() - asyncio.run(main(args))