diff --git a/app/Controllers/images.py b/app/Controllers/images.py new file mode 100644 index 0000000..d16be2c --- /dev/null +++ b/app/Controllers/images.py @@ -0,0 +1,41 @@ +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, Path, HTTPException, Query + +from app.Models.api_response.images_api_response import QueryByIdApiResponse, ImageStatus, QueryImagesApiResponse +from app.Models.query_params import FilterParams +from app.Services.authentication import force_access_token_verify +from app.Services.provider import ServiceProvider +from app.Services.vector_db_context import PointNotFoundError +from app.config import config + +images_router = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None), + tags=["Images"]) + +services: ServiceProvider | None = None # The service provider will be injected in the webapp initialize + + +@images_router.get("/id/{image_id}", description="Query the image info with the given image ID. \n" + "This can also be used to check the status" + " of an image in the index queue.") +async def query_image_by_id(image_id: Annotated[UUID, Path(description="The id of the image you want to query.")]): + try: + return QueryByIdApiResponse(img=await services.db_context.retrieve_by_id(str(image_id)), + img_status=ImageStatus.MAPPED, + message="Success query the image with the given ID.") + except PointNotFoundError as ex: + if services.upload_service and image_id in services.upload_service.uploading_ids: + return QueryByIdApiResponse(img_status=ImageStatus.IN_QUEUE, message="The image is in the indexing queue.") + raise HTTPException(404, "Cannot find the image with the given ID.") from ex + + +@images_router.get("/", description="Query images in order of ID.") +async def scroll_images(filter_param: Annotated[FilterParams, Depends()], + prev_offset_id: Annotated[UUID, Query(description="The previous offset image ID.")] = None, + count: Annotated[int, Query(ge=1, le=100, description="The number of images to query.")] = 15): + # validate the offset ID + if prev_offset_id is not None and len(await services.db_context.validate_ids([str(prev_offset_id)])) == 0: + raise HTTPException(404, "The previous offset ID is invalid.") + images, offset = await services.db_context.scroll_points(str(prev_offset_id), count, filter_param=filter_param) + return QueryImagesApiResponse(images=images, next_page_offset=offset, message="Success query images.") diff --git a/app/Models/api_models/admin_query_params.py b/app/Models/api_models/admin_query_params.py index e63cdb1..eba8dd4 100644 --- a/app/Models/api_models/admin_query_params.py +++ b/app/Models/api_models/admin_query_params.py @@ -5,7 +5,6 @@ class UploadImageThumbnailMode(str, Enum): - DEFAULT = "default" IF_NECESSARY = "if_necessary" ALWAYS = "always" NEVER = "never" @@ -30,7 +29,7 @@ def __init__(self, description="When set to true, the image will be uploaded to local storage. " "Otherwise, it will only be indexed in the database."), local_thumbnail: UploadImageThumbnailMode = - Query(default=UploadImageThumbnailMode.DEFAULT, + Query(default=None, description="Whether to generate thumbnail locally. Possible values:\n" "- `if_necessary`: Only generate thumbnail if the image is larger than 500KB. " "This is the default value if `local=True`\n" @@ -43,7 +42,7 @@ def __init__(self, self.starred = starred self.local = local self.skip_ocr = skip_ocr - self.local_thumbnail = local_thumbnail if local_thumbnail is not UploadImageThumbnailMode.DEFAULT else ( + self.local_thumbnail = local_thumbnail if (local_thumbnail is not None) else ( UploadImageThumbnailMode.IF_NECESSARY if local else UploadImageThumbnailMode.NEVER) if not self.url and not self.local: raise HTTPException(422, "A correspond url must be provided for a non-local image.") diff --git a/app/Models/api_response/images_api_response.py b/app/Models/api_response/images_api_response.py new file mode 100644 index 0000000..875d19a --- /dev/null +++ b/app/Models/api_response/images_api_response.py @@ -0,0 +1,25 @@ +from enum import Enum + +from pydantic import Field + +from app.Models.api_response.base import NekoProtocol +from app.Models.img_data import ImageData + + +class ImageStatus(str, Enum): + MAPPED = "mapped" + IN_QUEUE = "in_queue" + + +class QueryByIdApiResponse(NekoProtocol): + img_status: ImageStatus = Field(description="The status of the image.\n" + "Warning: If NekoImageGallery is deployed in a cluster, " + "the `in_queue` might not be accurate since the index queue " + "is independent of each service instance.") + img: ImageData | None = Field(description="The mapped image data. Only available when `img_status = mapped`.") + + +class QueryImagesApiResponse(NekoProtocol): + images: list[ImageData] = Field(description="The list of images.") + next_page_offset: str | None = Field(description="The offset ID for the next page query. " + "If there are no more images, this field will be null.") diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index 06fa7e6..fe1c22b 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -184,11 +184,14 @@ async def updateVectors(self, new_points: list[ImageData]): async def scroll_points(self, from_id: str | None = None, count=50, - with_vectors=False) -> tuple[list[ImageData], str]: + with_vectors=False, + filter_param: FilterParams | None = None, + ) -> tuple[list[ImageData], str]: resp, next_id = await self._client.scroll(collection_name=self.collection_name, limit=count, offset=from_id, - with_vectors=with_vectors + with_vectors=with_vectors, + scroll_filter=self._get_filters_by_filter_param(filter_param) ) return [self._get_img_data_from_point(t) for t in resp], next_id diff --git a/app/webapp.py b/app/webapp.py index dc3c7cb..3aec3a4 100644 --- a/app/webapp.py +++ b/app/webapp.py @@ -10,6 +10,7 @@ import app import app.Controllers.admin as admin_controller +import app.Controllers.images as images_controller import app.Controllers.search as search_controller from app.Services.authentication import permissive_access_token_verify, permissive_admin_token_verify from app.Services.provider import ServiceProvider @@ -26,6 +27,7 @@ async def lifespan(_: FastAPI): search_controller.services = provider admin_controller.services = provider + images_controller.services = provider yield await provider.onexit() @@ -44,6 +46,7 @@ async def lifespan(_: FastAPI): ) app.include_router(search_controller.search_router, prefix="/search") +app.include_router(images_controller.images_router, prefix="/images") if config.admin_api_enable: app.include_router(admin_controller.admin_router, prefix="/admin") diff --git a/main.py b/main.py index 9b12198..eaa2217 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ import uvicorn import app +from app.Models.api_models.admin_query_params import UploadImageThumbnailMode parser = typer.Typer(name=app.__title__, epilog="Build with ♥ By EdgeNeko. Github: " @@ -36,7 +37,7 @@ def server(ctx: typer.Context, """ Ciallo~ Welcome to NekoImageGallery Server. - - Website: https://image-insights.edgneko.com + - Website: https://image-insights.edgeneko.com - Repository & Issue tracker: https://github.com/hv0905/NekoImageGallery @@ -77,6 +78,12 @@ def local_index( help="Directories you want to index.")], categories: Annotated[Optional[list[str]], typer.Option(help="Categories for the indexed images.")] = None, starred: Annotated[bool, typer.Option(help="Whether the indexed images are starred.")] = False, + thumbnail_mode: Annotated[ + UploadImageThumbnailMode, typer.Option( + help="Whether to generate thumbnail for images. Possible values:\n" + "- `if_necessary`:(Recommended) Only generate thumbnail if the image is larger than 500KB.\n" + "- `always`: Always generate thumbnail.\n" + "- `never`: Never generate thumbnail.")] = UploadImageThumbnailMode.IF_NECESSARY ): """ Index all the images in the specified directory. @@ -85,7 +92,7 @@ def local_index( from scripts import local_indexing if categories is None: categories = [] - asyncio.run(local_indexing.main(target_dir, categories, starred)) + asyncio.run(local_indexing.main(target_dir, categories, starred, thumbnail_mode)) @parser.command('local-create-thumbnail', deprecated=True) diff --git a/scripts/local_indexing.py b/scripts/local_indexing.py index 8392990..11c086e 100644 --- a/scripts/local_indexing.py +++ b/scripts/local_indexing.py @@ -15,7 +15,7 @@ services: ServiceProvider | None = None -async def index_task(file_path: Path, categories: list[str], starred: bool): +async def index_task(file_path: Path, categories: list[str], starred: bool, thumbnail_mode: UploadImageThumbnailMode): try: img_id = await services.upload_service.assign_image_id(file_path) image_data = ImageData(id=img_id, @@ -25,7 +25,7 @@ async def index_task(file_path: Path, categories: list[str], starred: bool): format=file_path.suffix[1:], # remove the dot index_date=datetime.now()) await services.upload_service.sync_upload_image(image_data, file_path.read_bytes(), skip_ocr=False, - thumbnail_mode=UploadImageThumbnailMode.IF_NECESSARY) + thumbnail_mode=thumbnail_mode) except PointDuplicateError as ex: logger.warning("Image {} already exists in the database", file_path) except PIL.UnidentifiedImageError as e: @@ -33,7 +33,8 @@ async def index_task(file_path: Path, categories: list[str], starred: bool): @logger.catch() -async def main(root_directory: list[Path], categories: list[str], starred: bool): +async def main(root_directory: list[Path], categories: list[str], starred: bool, + thumbnail_mode: UploadImageThumbnailMode): global services services = ServiceProvider() await services.onload() @@ -47,6 +48,6 @@ async def main(root_directory: list[Path], categories: list[str], starred: bool) for idx, item in enumerate(progress.track(files, description="Indexing...")): logger.info("[{} / {}] Indexing {}", idx + 1, len(files), str(item)) - await index_task(item, categories, starred) + await index_task(item, categories, starred, thumbnail_mode) logger.success("Indexing completed!") diff --git a/tests/api/test_search.py b/tests/api/test_search.py index 3b21517..93aded8 100644 --- a/tests/api/test_search.py +++ b/tests/api/test_search.py @@ -1,3 +1,6 @@ +import itertools +import uuid + import pytest_asyncio from .conftest import check_local_dir_empty @@ -92,3 +95,32 @@ def test_search_filters(test_client, img_ids): resp = test_client.get("/search/text/cat", params={'starred': True}) assert resp.status_code == 200 assert resp.json()['result'][0]['img']['id'] == img_ids['bsn'][0] + + +def test_images_query_by_id(test_client, img_ids): + resp = test_client.get(f"/images/id/{img_ids['bsn'][0]}") + assert resp.status_code == 200 + assert resp.json()['img']['id'] == img_ids['bsn'][0] + + +def test_images_query_not_exist(test_client, img_ids): + resp = test_client.get(f"/images/id/{uuid.uuid4()}") + assert resp.status_code == 404 + + +def test_images_query_scroll(test_client, img_ids): + resp = test_client.get("/images/", params={'count': 50}) + assert resp.status_code == 200 + resp_imgs = resp.json()['images'] + all_images_id = list(itertools.chain(*img_ids.values())) + for item in resp_imgs: + assert item['id'] in all_images_id + + paging_test = test_client.get(f'/images', + params={'prev_offset_id': resp_imgs[len(resp_imgs) // 2]['id']}) + assert paging_test.status_code == 200 + assert paging_test.json()['images'][0]['id'] == resp_imgs[len(resp_imgs) // 2]['id'] + + no_exist_test = test_client.get(f'/images', + params={'prev_offset_id': uuid.uuid4()}) + assert no_exist_test.status_code == 404