Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed Dec 24, 2023
1 parent 6991a7e commit abfc09a
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 41 deletions.
9 changes: 6 additions & 3 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from fastapi import APIRouter
from fastapi import APIRouter, Depends

from app.Services.authentication import force_admin_token_verify

admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)])

admin_router = APIRouter()

def add_image_info():
pass
pass
8 changes: 4 additions & 4 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ async def randomPick(
return SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4())


@searchRouter.get("/recall/{queryId}", description="Recall the query with given queryId")
async def recallQuery(queryId: str):
@searchRouter.get("/recall/{query_id}", description="Recall the query with given queryId")
async def recallQuery(query_id: str):
raise NotImplementedError()


async def process_advanced_and_combined_search_query(model: Union[AdvancedSearchModel, CombinedSearchModel],
basis: Union[SearchBasisParams, SearchCombinedParams],
filter: FilterParams,
filter_param: FilterParams,
paging: SearchPagingParams) -> List[SearchResult]:
if basis.basis == SearchBasisEnum.ocr:
positive_vectors = [transformers_service.get_bert_vector(t) for t in model.criteria]
Expand All @@ -151,7 +151,7 @@ async def process_advanced_and_combined_search_query(model: Union[AdvancedSearch
positive_vectors=positive_vectors,
negative_vectors=negative_vectors,
mode=model.mode,
filter_param=filter,
filter_param=filter_param,
with_vectors=True if isinstance(basis, SearchCombinedParams) else False,
top_k=paging.count,
skip=paging.skip)
Expand Down
2 changes: 1 addition & 1 deletion app/Services/transformers_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import FloatTensor, no_grad
from transformers import CLIPProcessor, CLIPModel, BertTokenizer, BertModel

from app.config import config, environment
from app.config import config


class TransformersService:
Expand Down
30 changes: 17 additions & 13 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

import numpy
from loguru import logger
from typing import Optional
from qdrant_client import AsyncQdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import PointStruct
Expand Down Expand Up @@ -30,8 +31,7 @@ async def retrieve_by_id(self, id: str, with_vectors=False) -> ImageData:
numpy.array(result[0].vector, dtype=numpy.float32) if with_vectors else None)

async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR,
top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[
SearchResult]:
top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]:
logger.info("Querying Qdrant... top_k = {}", top_k)
result = await self.client.search(collection_name=self.collection_name,
query_vector=(query_vector_name, query_vector),
Expand All @@ -57,7 +57,7 @@ async def querySimilar(self,
_strategy = None if mode is None else (RecommendStrategy.AVERAGE_VECTOR if
mode == SearchModelEnum.average else RecommendStrategy.BEST_SCORE)
# since only combined_search need return vectors, We can define _combined_search_need_vectors like below
_combined_search_need_vectors = [self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR]\
_combined_search_need_vectors = [self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR] \
if with_vectors else None
logger.info("Querying Qdrant... top_k = {}", top_k)
result = await self.client.recommend(collection_name=self.collection_name,
Expand All @@ -71,15 +71,19 @@ async def querySimilar(self,
offset=skip,
with_payload=True)
logger.success("Query completed!")
result_transform = lambda t: SearchResult(
img=ImageData.from_payload(
t.id,
t.payload,
numpy.array(t.vector['image_vector']) if t.vector and 'image_vector' in t.vector else None,
numpy.array(t.vector['text_contain_vector']) if t.vector and 'text_contain_vector' in t.vector else None
),
score=t.score
)

def result_transform(t):
return SearchResult(
img=ImageData.from_payload(
t.id,
t.payload,
numpy.array(t.vector['image_vector']) if t.vector and 'image_vector' in t.vector else None,
numpy.array(
t.vector['text_contain_vector']) if t.vector and 'text_contain_vector' in t.vector else None
),
score=t.score
)

return [result_transform(t) for t in result]

async def insertItems(self, items: list[ImageData]):
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ pydantic-settings
# AI - Manually install cuda-capable pytorch
torch
torchvision
transformers
pillow
transformers>4.35.2
pillow>9.3.0
numpy

# OCR - you can choose other option if necessary, or completely disable it if you don't need this feature
Expand Down
24 changes: 6 additions & 18 deletions scripts/local_indexing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@

if __name__ == '__main__':
import sys

sys.path.insert(1, './')

import argparse
import asyncio
from datetime import datetime
from pathlib import Path
from shutil import copy2
Expand All @@ -26,14 +19,14 @@ def parse_args():
return parser.parse_args()


def copy_and_index(filePath: Path) -> ImageData | None:
def copy_and_index(file_path: Path) -> ImageData | None:
try:
img = Image.open(filePath)
img = Image.open(file_path)
except Exception as e:
logger.error("Error when opening image {}: {}", filePath, e)
logger.error("Error when opening image {}: {}", file_path, e)
return None
id = uuid4()
img_ext = filePath.suffix
img_ext = file_path.suffix
image_ocr_result = None
text_contain_vector = None
[width, height] = img.size
Expand All @@ -46,7 +39,7 @@ def copy_and_index(filePath: Path) -> ImageData | None:
else:
image_ocr_result = None
except Exception as e:
logger.error("Error when processing image {}: {}", filePath, e)
logger.error("Error when processing image {}: {}", file_path, e)
return None
imgdata = ImageData(id=id,
url=f'/static/{id}{img_ext}',
Expand All @@ -59,7 +52,7 @@ def copy_and_index(filePath: Path) -> ImageData | None:
ocr_text=image_ocr_result)

# copy to static
copy2(filePath, Path(config.static_file.path) / f'{id}{img_ext}')
copy2(file_path, Path(config.static_file.path) / f'{id}{img_ext}')
return imgdata


Expand Down Expand Up @@ -88,8 +81,3 @@ async def main(args):
logger.info("Upload {} element to database", len(buffer))
await db_context.insertItems(buffer)
logger.success("Indexing completed! {} images indexed", counter)


if __name__ == '__main__':
args = parse_args()
asyncio.run(main(args))

0 comments on commit abfc09a

Please sign in to comment.