Skip to content

Commit

Permalink
Add aspect_ratio and min_wh filter feature
Browse files Browse the repository at this point in the history
hv0905 committed Dec 23, 2023
1 parent e9743c7 commit a0c18aa
Showing 4 changed files with 120 additions and 26 deletions.
35 changes: 20 additions & 15 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@

from app.Models.api_model import AdvancedSearchModel, SearchBasisEnum
from app.Models.api_response.search_api_response import SearchApiResponse
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Services import db_context
from app.Services import transformers_service
from app.Services.authentication import force_access_token_verify
@@ -17,16 +18,6 @@
searchRouter = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None))


class SearchPagingParams:
def __init__(
self,
count: Annotated[int, Query(ge=1, le=100, description="The number of results you want to get.")] = 10,
skip: Annotated[int, Query(ge=0, description="The number of results you want to skip.")] = 0
):
self.count = count
self.skip = skip


class SearchBasisParams:
def __init__(self,
basis: Annotated[SearchBasisEnum, Query(
@@ -41,13 +32,15 @@ async def textSearch(
prompt: Annotated[
str, Path(min_length=3, max_length=100, description="The image prompt text you want to search.")],
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
logger.info("Text search request received, prompt: {}", prompt)
text_vector = transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \
else transformers_service.get_bert_vector(prompt)
results = await db_context.querySearch(text_vector,
query_vector_name=db_context.getVectorByBasis(basis.basis),
filter_param=filter_param,
top_k=paging.count,
skip=paging.skip)
return SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4())
@@ -57,12 +50,17 @@ async def textSearch(
async def imageSearch(
image: Annotated[bytes, File(max_length=10 * 1024 * 1024, media_type="image/*",
description="The image you want to search.")],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
fakefile = BytesIO(image)
img = Image.open(fakefile)
logger.info("Image search request received")
image_vector = transformers_service.get_image_vector(img)
results = await db_context.querySearch(image_vector, top_k=paging.count, skip=paging.skip)
results = await db_context.querySearch(image_vector,
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param)
return SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4())


@@ -72,12 +70,14 @@ async def imageSearch(
async def similarWith(
id: Annotated[UUID, Path(description="The id of the image you want to search.")],
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
logger.info("Similar search request received, id: {}", id)
results = await db_context.querySimilar(str(id),
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param,
query_vector_name=db_context.getVectorByBasis(basis.basis))
return SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4())

@@ -86,6 +86,7 @@ async def similarWith(
async def advancedSearch(
model: AdvancedSearchModel,
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
if len(model.criteria) + len(model.negative_criteria) == 0:
raise ValueError("At least one criteria should be provided.")
@@ -98,16 +99,20 @@ async def advancedSearch(
negative_vectors = [transformers_service.get_text_vector(t) for t in model.negative_criteria]
result = await db_context.queryAdvanced(positive_vectors, negative_vectors,
db_context.getVectorByBasis(basis.basis), model.mode,
filter_param=filter_param,
top_k=paging.count,
skip=paging.skip)
skip=paging.skip
)
return SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4())


@searchRouter.get("/random", description="Get random images")
async def randomPick(paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
async def randomPick(
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
logger.info("Random pick request received")
random_vector = transformers_service.get_random_vector()
result = await db_context.querySearch(random_vector, top_k=paging.count)
result = await db_context.querySearch(random_vector, top_k=paging.count, filter_param=filter_param)
return SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4())


35 changes: 35 additions & 0 deletions app/Models/query_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Annotated

from fastapi.params import Query


class SearchPagingParams:
def __init__(
self,
count: Annotated[int, Query(ge=1, le=100, description="The number of results you want to get.")] = 10,
skip: Annotated[int, Query(ge=0, description="The number of results you want to skip.")] = 0
):
self.count = count
self.skip = skip


class FilterParams:
def __init__(
self,
preferred_ratio: Annotated[
float | None, Query(gt=0, description="The preferred aspect ratio of the image.")] = None,
ratio_tolerance: Annotated[
float, Query(gt=0, lt=1, description="The tolerance of the aspect ratio.")] = 0.1,
min_width: Annotated[int | None, Query(gt=0, description="The minimum width of the image.")] = None,
min_height: Annotated[int | None, Query(gt=0, description="The minimum height of the image.")] = None):
self.preferred_ratio = preferred_ratio
self.ratio_tolerance = ratio_tolerance
self.min_width = min_width
self.min_height = min_height

if self.preferred_ratio:
self.min_ratio = self.preferred_ratio * (1 - self.ratio_tolerance)
self.max_ratio = self.preferred_ratio * (1 + self.ratio_tolerance)
else:
self.min_ratio = None
self.max_ratio = None
25 changes: 17 additions & 8 deletions app/Services/authentication.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated

from fastapi import HTTPException
from fastapi.params import Header
from fastapi.params import Header, Depends

from app.config import config

@@ -12,14 +12,23 @@ def verify_access_token(token: str | None) -> bool:
return token is not None and token == config.access_token


def force_access_token_verify(
x_access_token: Annotated[str | None, Header(
description="Access token set in configuration (if access_protected is enabled)")] = None):
if not verify_access_token(x_access_token):
raise HTTPException(status_code=401, detail="Access token is not present or invalid.")


def permissive_access_token_verify(
x_access_token: Annotated[str | None, Header(
description="Access token set in configuration (if access_protected is enabled)")] = None) -> bool:
return verify_access_token(x_access_token)


def force_access_token_verify(token_passed: Annotated[bool, Depends(permissive_access_token_verify)]):
if not token_passed:
raise HTTPException(status_code=401, detail="Access token is not present or invalid.")


def permissive_admin_token_verify(
x_admin_token: Annotated[str | None, Header(
description="Admin token set in configuration (if admin_api_enable is enabled)")] = None) -> bool:
return x_admin_token == config.admin_token


def force_admin_token_verify(token_passed: Annotated[bool, Depends(permissive_admin_token_verify)]):
if not token_passed:
raise HTTPException(status_code=401, detail="Admin token is not present or invalid.")
51 changes: 48 additions & 3 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import numpy
from loguru import logger
from qdrant_client import AsyncQdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import PointStruct
from qdrant_client.models import RecommendStrategy

from app.Models.api_model import SearchModelEnum, SearchBasisEnum
from app.Models.img_data import ImageData
from app.Models.query_params import FilterParams
from app.Models.search_result import SearchResult
from app.config import config

@@ -26,23 +28,27 @@ async def retrieve_by_id(self, id: str, with_vectors=False) -> ImageData:
return ImageData.from_payload(result[0].id, result[0].payload,
numpy.array(result[0].vector, dtype=numpy.float32) if with_vectors else None)

async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR, top_k=10, skip=0) -> list[
async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR,
top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[
SearchResult]:
logger.info("Querying Qdrant... top_k = {}", top_k)
result = await self.client.search(collection_name=self.collection_name,
query_vector=(query_vector_name, query_vector),
filters=self.getFiltersByFilterParam(filter_param),
limit=top_k,
offset=skip,
with_payload=True)
logger.success("Query completed!")
return [SearchResult(img=ImageData.from_payload(t.id, t.payload), score=t.score) for t in result]

async def querySimilar(self, id: str, query_vector_name: str = IMG_VECTOR, top_k=10, skip=0) -> list[SearchResult]:
async def querySimilar(self, id: str, query_vector_name: str = IMG_VECTOR,
top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]:
logger.info("Querying Qdrant... top_k = {}", top_k)
result = await self.client.recommend(collection_name=self.collection_name,
positive=[id],
negative=[],
using=query_vector_name,
query_filter=self.getFiltersByFilterParam(filter_param),
limit=top_k,
offset=skip,
with_vectors=False,
@@ -52,12 +58,13 @@ async def querySimilar(self, id: str, query_vector_name: str = IMG_VECTOR, top_k

async def queryAdvanced(self, positive_vectors: list[numpy.ndarray], negative_vectors: list[numpy.ndarray],
query_vector_name: str = IMG_VECTOR, mode: SearchModelEnum = SearchModelEnum.average,
top_k=10, skip=0) -> list[SearchResult]:
top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]:
logger.info("Querying Qdrant... top_k = {}", top_k)
result = await self.client.recommend(collection_name=self.collection_name,
using=query_vector_name,
positive=[t.tolist() for t in positive_vectors],
negative=[t.tolist() for t in negative_vectors],
query_filter=self.getFiltersByFilterParam(filter_param),
limit=top_k,
offset=skip,
strategy=
@@ -111,3 +118,41 @@ def getVectorByBasis(cls, basis: SearchBasisEnum) -> str:
return cls.TEXT_VECTOR
case _:
raise ValueError("Invalid basis")

@staticmethod
def getFiltersByFilterParam(filter_param: FilterParams | None) -> models.Filter | None:
if filter_param is None:
return None

filters = []
if filter_param.min_width is not None:
filters.append(models.FieldCondition(
key="width",
range=models.Range(
gte=filter_param.min_width
)
))

if filter_param.min_height is not None:
filters.append(models.FieldCondition(
key="height",
range=models.Range(
gte=filter_param.min_height
)
))

if filter_param.min_ratio is not None:
filters.append(models.FieldCondition(
key="aspect_ratio",
range=models.Range(
gte=filter_param.min_ratio,
lte=filter_param.max_ratio
)
))

if len(filters) > 0:
return models.Filter(
must=filters
)
else:
return None

0 comments on commit a0c18aa

Please sign in to comment.