Skip to content

Commit

Permalink
Merge pull request #42 from hv0905/images-api
Browse files Browse the repository at this point in the history
Add images API
  • Loading branch information
hv0905 authored Jul 9, 2024
2 parents a893a8e + 547429e commit 6468433
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 11 deletions.
41 changes: 41 additions & 0 deletions app/Controllers/images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Annotated
from uuid import UUID

from fastapi import APIRouter, Depends, Path, HTTPException, Query

from app.Models.api_response.images_api_response import QueryByIdApiResponse, ImageStatus, QueryImagesApiResponse
from app.Models.query_params import FilterParams
from app.Services.authentication import force_access_token_verify
from app.Services.provider import ServiceProvider
from app.Services.vector_db_context import PointNotFoundError
from app.config import config

images_router = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None),
tags=["Images"])

services: ServiceProvider | None = None # The service provider will be injected in the webapp initialize


@images_router.get("/id/{image_id}", description="Query the image info with the given image ID. \n"
"This can also be used to check the status"
" of an image in the index queue.")
async def query_image_by_id(image_id: Annotated[UUID, Path(description="The id of the image you want to query.")]):
try:
return QueryByIdApiResponse(img=await services.db_context.retrieve_by_id(str(image_id)),
img_status=ImageStatus.MAPPED,
message="Success query the image with the given ID.")
except PointNotFoundError as ex:
if services.upload_service and image_id in services.upload_service.uploading_ids:
return QueryByIdApiResponse(img_status=ImageStatus.IN_QUEUE, message="The image is in the indexing queue.")
raise HTTPException(404, "Cannot find the image with the given ID.") from ex


@images_router.get("/", description="Query images in order of ID.")
async def scroll_images(filter_param: Annotated[FilterParams, Depends()],
prev_offset_id: Annotated[UUID, Query(description="The previous offset image ID.")] = None,
count: Annotated[int, Query(ge=1, le=100, description="The number of images to query.")] = 15):
# validate the offset ID
if prev_offset_id is not None and len(await services.db_context.validate_ids([str(prev_offset_id)])) == 0:
raise HTTPException(404, "The previous offset ID is invalid.")
images, offset = await services.db_context.scroll_points(str(prev_offset_id), count, filter_param=filter_param)
return QueryImagesApiResponse(images=images, next_page_offset=offset, message="Success query images.")
5 changes: 2 additions & 3 deletions app/Models/api_models/admin_query_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


class UploadImageThumbnailMode(str, Enum):
DEFAULT = "default"
IF_NECESSARY = "if_necessary"
ALWAYS = "always"
NEVER = "never"
Expand All @@ -30,7 +29,7 @@ def __init__(self,
description="When set to true, the image will be uploaded to local storage. "
"Otherwise, it will only be indexed in the database."),
local_thumbnail: UploadImageThumbnailMode =
Query(default=UploadImageThumbnailMode.DEFAULT,
Query(default=None,
description="Whether to generate thumbnail locally. Possible values:\n"
"- `if_necessary`: Only generate thumbnail if the image is larger than 500KB. "
"This is the default value if `local=True`\n"
Expand All @@ -43,7 +42,7 @@ def __init__(self,
self.starred = starred
self.local = local
self.skip_ocr = skip_ocr
self.local_thumbnail = local_thumbnail if local_thumbnail is not UploadImageThumbnailMode.DEFAULT else (
self.local_thumbnail = local_thumbnail if (local_thumbnail is not None) else (
UploadImageThumbnailMode.IF_NECESSARY if local else UploadImageThumbnailMode.NEVER)
if not self.url and not self.local:
raise HTTPException(422, "A correspond url must be provided for a non-local image.")
25 changes: 25 additions & 0 deletions app/Models/api_response/images_api_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from enum import Enum

from pydantic import Field

from app.Models.api_response.base import NekoProtocol
from app.Models.img_data import ImageData


class ImageStatus(str, Enum):
MAPPED = "mapped"
IN_QUEUE = "in_queue"


class QueryByIdApiResponse(NekoProtocol):
img_status: ImageStatus = Field(description="The status of the image.\n"
"Warning: If NekoImageGallery is deployed in a cluster, "
"the `in_queue` might not be accurate since the index queue "
"is independent of each service instance.")
img: ImageData | None = Field(description="The mapped image data. Only available when `img_status = mapped`.")


class QueryImagesApiResponse(NekoProtocol):
images: list[ImageData] = Field(description="The list of images.")
next_page_offset: str | None = Field(description="The offset ID for the next page query. "
"If there are no more images, this field will be null.")
7 changes: 5 additions & 2 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,14 @@ async def updateVectors(self, new_points: list[ImageData]):
async def scroll_points(self,
from_id: str | None = None,
count=50,
with_vectors=False) -> tuple[list[ImageData], str]:
with_vectors=False,
filter_param: FilterParams | None = None,
) -> tuple[list[ImageData], str]:
resp, next_id = await self._client.scroll(collection_name=self.collection_name,
limit=count,
offset=from_id,
with_vectors=with_vectors
with_vectors=with_vectors,
scroll_filter=self._get_filters_by_filter_param(filter_param)
)

return [self._get_img_data_from_point(t) for t in resp], next_id
Expand Down
3 changes: 3 additions & 0 deletions app/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import app
import app.Controllers.admin as admin_controller
import app.Controllers.images as images_controller
import app.Controllers.search as search_controller
from app.Services.authentication import permissive_access_token_verify, permissive_admin_token_verify
from app.Services.provider import ServiceProvider
Expand All @@ -26,6 +27,7 @@ async def lifespan(_: FastAPI):

search_controller.services = provider
admin_controller.services = provider
images_controller.services = provider
yield

await provider.onexit()
Expand All @@ -44,6 +46,7 @@ async def lifespan(_: FastAPI):
)

app.include_router(search_controller.search_router, prefix="/search")
app.include_router(images_controller.images_router, prefix="/images")
if config.admin_api_enable:
app.include_router(admin_controller.admin_router, prefix="/admin")

Expand Down
11 changes: 9 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uvicorn

import app
from app.Models.api_models.admin_query_params import UploadImageThumbnailMode

parser = typer.Typer(name=app.__title__,
epilog="Build with ♥ By EdgeNeko. Github: "
Expand Down Expand Up @@ -36,7 +37,7 @@ def server(ctx: typer.Context,
"""
Ciallo~ Welcome to NekoImageGallery Server.
- Website: https://image-insights.edgneko.com
- Website: https://image-insights.edgeneko.com
- Repository & Issue tracker: https://github.com/hv0905/NekoImageGallery
Expand Down Expand Up @@ -77,6 +78,12 @@ def local_index(
help="Directories you want to index.")],
categories: Annotated[Optional[list[str]], typer.Option(help="Categories for the indexed images.")] = None,
starred: Annotated[bool, typer.Option(help="Whether the indexed images are starred.")] = False,
thumbnail_mode: Annotated[
UploadImageThumbnailMode, typer.Option(
help="Whether to generate thumbnail for images. Possible values:\n"
"- `if_necessary`:(Recommended) Only generate thumbnail if the image is larger than 500KB.\n"
"- `always`: Always generate thumbnail.\n"
"- `never`: Never generate thumbnail.")] = UploadImageThumbnailMode.IF_NECESSARY
):
"""
Index all the images in the specified directory.
Expand All @@ -85,7 +92,7 @@ def local_index(
from scripts import local_indexing
if categories is None:
categories = []
asyncio.run(local_indexing.main(target_dir, categories, starred))
asyncio.run(local_indexing.main(target_dir, categories, starred, thumbnail_mode))


@parser.command('local-create-thumbnail', deprecated=True)
Expand Down
9 changes: 5 additions & 4 deletions scripts/local_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
services: ServiceProvider | None = None


async def index_task(file_path: Path, categories: list[str], starred: bool):
async def index_task(file_path: Path, categories: list[str], starred: bool, thumbnail_mode: UploadImageThumbnailMode):
try:
img_id = await services.upload_service.assign_image_id(file_path)
image_data = ImageData(id=img_id,
Expand All @@ -25,15 +25,16 @@ async def index_task(file_path: Path, categories: list[str], starred: bool):
format=file_path.suffix[1:], # remove the dot
index_date=datetime.now())
await services.upload_service.sync_upload_image(image_data, file_path.read_bytes(), skip_ocr=False,
thumbnail_mode=UploadImageThumbnailMode.IF_NECESSARY)
thumbnail_mode=thumbnail_mode)
except PointDuplicateError as ex:
logger.warning("Image {} already exists in the database", file_path)
except PIL.UnidentifiedImageError as e:
logger.error("Error when processing image {}: {}", file_path, e)


@logger.catch()
async def main(root_directory: list[Path], categories: list[str], starred: bool):
async def main(root_directory: list[Path], categories: list[str], starred: bool,
thumbnail_mode: UploadImageThumbnailMode):
global services
services = ServiceProvider()
await services.onload()
Expand All @@ -47,6 +48,6 @@ async def main(root_directory: list[Path], categories: list[str], starred: bool)
for idx, item in enumerate(progress.track(files, description="Indexing...")):
logger.info("[{} / {}] Indexing {}", idx + 1, len(files), str(item))

await index_task(item, categories, starred)
await index_task(item, categories, starred, thumbnail_mode)

logger.success("Indexing completed!")
32 changes: 32 additions & 0 deletions tests/api/test_search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import itertools
import uuid

import pytest_asyncio

from .conftest import check_local_dir_empty
Expand Down Expand Up @@ -92,3 +95,32 @@ def test_search_filters(test_client, img_ids):
resp = test_client.get("/search/text/cat", params={'starred': True})
assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] == img_ids['bsn'][0]


def test_images_query_by_id(test_client, img_ids):
resp = test_client.get(f"/images/id/{img_ids['bsn'][0]}")
assert resp.status_code == 200
assert resp.json()['img']['id'] == img_ids['bsn'][0]


def test_images_query_not_exist(test_client, img_ids):
resp = test_client.get(f"/images/id/{uuid.uuid4()}")
assert resp.status_code == 404


def test_images_query_scroll(test_client, img_ids):
resp = test_client.get("/images/", params={'count': 50})
assert resp.status_code == 200
resp_imgs = resp.json()['images']
all_images_id = list(itertools.chain(*img_ids.values()))
for item in resp_imgs:
assert item['id'] in all_images_id

paging_test = test_client.get(f'/images',
params={'prev_offset_id': resp_imgs[len(resp_imgs) // 2]['id']})
assert paging_test.status_code == 200
assert paging_test.json()['images'][0]['id'] == resp_imgs[len(resp_imgs) // 2]['id']

no_exist_test = test_client.get(f'/images',
params={'prev_offset_id': uuid.uuid4()})
assert no_exist_test.status_code == 404

0 comments on commit 6468433

Please sign in to comment.