From dd8d0d3f1bafb4430c9f3bc758e361f1b4b46918 Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Wed, 26 Jun 2024 20:42:47 +0800 Subject: [PATCH] - feat: Add duplication validate API - fix: Return validation error when querying advance search with no positive criteria - fix: Add response type annotations for some API in Admin router - test: Add test to ensure the id consistency across versions --- app/Controllers/admin.py | 26 ++++++++++--- app/Controllers/search.py | 11 ++---- app/Models/api_models/admin_api_model.py | 12 +++++- app/Models/api_models/search_api_model.py | 9 ++++- app/Models/api_response/admin_api_response.py | 9 +++++ app/util/generate_uuid.py | 6 ++- tests/api/conftest.py | 2 - tests/api/integrate_test.py | 3 +- tests/api/test_upload.py | 38 ++++++++++++++++--- tests/assets/__init__.py | 3 ++ tests/unit/test_image_uuid.py | 19 ++++++++++ 11 files changed, 111 insertions(+), 27 deletions(-) create mode 100644 tests/assets/__init__.py create mode 100644 tests/unit/test_image_uuid.py diff --git a/app/Controllers/admin.py b/app/Controllers/admin.py index a4607fa..97dc420 100644 --- a/app/Controllers/admin.py +++ b/app/Controllers/admin.py @@ -8,16 +8,17 @@ from fastapi import APIRouter, Depends, HTTPException, params, UploadFile, File from loguru import logger -from app.Models.api_models.admin_api_model import ImageOptUpdateModel +from app.Models.api_models.admin_api_model import ImageOptUpdateModel, DuplicateValidationModel from app.Models.api_models.admin_query_params import UploadImageModel -from app.Models.api_response.admin_api_response import ServerInfoResponse, ImageUploadResponse +from app.Models.api_response.admin_api_response import ServerInfoResponse, ImageUploadResponse, \ + DuplicateValidationResponse from app.Models.api_response.base import NekoProtocol from app.Models.img_data import ImageData from app.Services.authentication import force_admin_token_verify from app.Services.provider import ServiceProvider from app.Services.vector_db_context import PointNotFoundError from app.config import config -from app.util.generate_uuid import generate_uuid +from app.util.generate_uuid import generate_uuid, generate_uuid_from_sha1 admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)], tags=["Admin"]) @@ -98,7 +99,7 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id description="Upload image to server. The image will be indexed and stored in the database. If " "local is set to true, the image will be uploaded to local storage.") async def upload_image(image_file: Annotated[UploadFile, File(description="The image to be uploaded.")], - model: Annotated[UploadImageModel, Depends()]): + model: Annotated[UploadImageModel, Depends()]) -> ImageUploadResponse: # generate an ID for the image img_type = None if image_file.content_type.lower() in IMAGE_MIMES: @@ -140,7 +141,22 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i @admin_router.get("/server_info", description="Get server information") -async def server_info(): +async def server_info() -> ServerInfoResponse: return ServerInfoResponse(message="Successfully get server information!", image_count=await services.db_context.get_counts(exact=True), index_queue_length=services.upload_service.get_queue_size()) + + +@admin_router.post("/duplication_validate", + description="Check if an image exists in the server by its SHA1 hash. If the image exists, " + "the image ID will be returned.\n" + "This is helpful for checking if an image is already in the server without " + "uploading the image.") +async def duplication_validate(model: DuplicateValidationModel) -> DuplicateValidationResponse: + ids = [generate_uuid_from_sha1(t) for t in model.hashes] + valid_ids = await services.db_context.validate_ids([str(t) for t in ids]) + exists_matrix = [str(t) in valid_ids or t in services.upload_service.uploading_ids for t in ids] + return DuplicateValidationResponse( + exists=exists_matrix, + entity_ids=[(str(t) if exists else None) for (t, exists) in zip(ids, exists_matrix)], + message="Validation completed.") diff --git a/app/Controllers/search.py b/app/Controllers/search.py index 1bf6321..16b64f9 100644 --- a/app/Controllers/search.py +++ b/app/Controllers/search.py @@ -127,8 +127,6 @@ async def advancedSearch( basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], filter_param: Annotated[FilterParams, Depends(FilterParams)], paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse: - if len(model.criteria) + len(model.negative_criteria) == 0: - raise HTTPException(status_code=422, detail="At least one criteria should be provided.") logger.info("Advanced search request received: {}", model) result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging) return await result_postprocessing( @@ -141,8 +139,6 @@ async def combinedSearch( basis: Annotated[SearchCombinedParams, Depends(SearchCombinedParams)], filter_param: Annotated[FilterParams, Depends(FilterParams)], paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse: - if len(model.criteria) + len(model.negative_criteria) == 0: - raise HTTPException(status_code=422, detail="At least one criteria should be provided.") logger.info("Combined search request received: {}", model) result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging) calculate_and_sort_by_combined_scores(model, basis, result) @@ -166,10 +162,9 @@ async def randomPick( SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4())) -@search_router.get("/recall/{query_id}", description="Recall the query with given queryId") -async def recallQuery(query_id: str): - raise NotImplementedError() - +# @search_router.get("/recall/{query_id}", description="Recall the query with given queryId") +# async def recallQuery(query_id: str): +# raise NotImplementedError() async def process_advanced_and_combined_search_query(model: Union[AdvancedSearchModel, CombinedSearchModel], basis: Union[SearchBasisParams, SearchCombinedParams], diff --git a/app/Models/api_models/admin_api_model.py b/app/Models/api_models/admin_api_model.py index c53fdc7..12037aa 100644 --- a/app/Models/api_models/admin_api_model.py +++ b/app/Models/api_models/admin_api_model.py @@ -1,6 +1,6 @@ -from typing import Optional +from typing import Optional, Annotated -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, StringConstraints class ImageOptUpdateModel(BaseModel): @@ -21,3 +21,11 @@ class ImageOptUpdateModel(BaseModel): def empty(self) -> bool: return all([item is None for item in self.model_dump().values()]) + + +Sha1HashString = Annotated[ + str, StringConstraints(min_length=40, max_length=40, pattern=r"[0-9a-f]+", to_lower=True, strip_whitespace=True)] + + +class DuplicateValidationModel(BaseModel): + hashes: list[Sha1HashString] = Field(description="The SHA1 hash of the image.", min_length=1) diff --git a/app/Models/api_models/search_api_model.py b/app/Models/api_models/search_api_model.py index 0a620ae..5026fc7 100644 --- a/app/Models/api_models/search_api_model.py +++ b/app/Models/api_models/search_api_model.py @@ -19,8 +19,13 @@ class SearchCombinedBasisEnum(str, Enum): class AdvancedSearchModel(BaseModel): - criteria: list[str] = Field([], description="The positive criteria you want to search with", max_length=16) - negative_criteria: list[str] = Field([], description="The negative criteria you want to search with", max_length=16) + criteria: list[str] = Field([], + description="The positive criteria you want to search with", + max_length=16, + min_length=1) + negative_criteria: list[str] = Field([], + description="The negative criteria you want to search with", + max_length=16) mode: SearchModelEnum = Field(SearchModelEnum.average, description="The mode you want to use to combine the criteria.") diff --git a/app/Models/api_response/admin_api_response.py b/app/Models/api_response/admin_api_response.py index 47f4b9d..2a450f9 100644 --- a/app/Models/api_response/admin_api_response.py +++ b/app/Models/api_response/admin_api_response.py @@ -1,5 +1,7 @@ from uuid import UUID +from pydantic import Field + from .base import NekoProtocol @@ -8,5 +10,12 @@ class ServerInfoResponse(NekoProtocol): index_queue_length: int +class DuplicateValidationResponse(NekoProtocol): + entity_ids: list[UUID | None] = Field( + description="The image id for each hash. If the image does not exist in the server, the value will be null.") + exists: list[bool] = Field( + description="Whether the image exists in the server. True if the image exists, False otherwise.") + + class ImageUploadResponse(NekoProtocol): image_id: UUID diff --git a/app/util/generate_uuid.py b/app/util/generate_uuid.py index 0f493ef..28f96c3 100644 --- a/app/util/generate_uuid.py +++ b/app/util/generate_uuid.py @@ -19,4 +19,8 @@ def generate_uuid(file_input: pathlib.Path | io.BytesIO | bytes) -> UUID: else: raise ValueError("Unsupported file type. Must be pathlib.Path or io.BytesIO.") file_hash = hashlib.sha1(file_content).hexdigest() - return uuid5(namespace_uuid, file_hash) + return generate_uuid_from_sha1(file_hash) + + +def generate_uuid_from_sha1(sha1_hash: str) -> UUID: + return uuid5(namespace_uuid, sha1_hash.lower()) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 5a6c767..41dbc10 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -10,8 +10,6 @@ TEST_ACCESS_TOKEN = 'test_token' TEST_ADMIN_TOKEN = 'test_admin_token' -assets_path = Path(__file__).parent / '..' / 'assets' - @pytest.fixture(scope="session") def test_client(tmp_path_factory) -> TestClient: diff --git a/tests/api/integrate_test.py b/tests/api/integrate_test.py index 2976c84..45654b0 100644 --- a/tests/api/integrate_test.py +++ b/tests/api/integrate_test.py @@ -1,6 +1,7 @@ import pytest -from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN, assets_path +from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN +from ..assets import assets_path test_images = {'bsn': ['bsn_0.jpg', 'bsn_1.jpg', 'bsn_2.jpg'], 'cat': ['cat_0.jpg', 'cat_1.jpg'], diff --git a/tests/api/test_upload.py b/tests/api/test_upload.py index 42e454f..81462ce 100644 --- a/tests/api/test_upload.py +++ b/tests/api/test_upload.py @@ -3,9 +3,13 @@ import pytest -from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN, assets_path +from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN +from ..assets import assets_path test_file_path = assets_path / 'test_images' / 'bsn_0.jpg' +test_file_2_path = assets_path / 'test_images' / 'bsn_1.jpg' + +test_file_hashes = ['648351F7CBD472D0CA23EADCCF3B9E619EC9ADDA', 'C5DE90DAC2F75FBDBE48023DF4DE7585A86B2392'] def get_single_img_info(test_client, image_id): @@ -46,15 +50,37 @@ def upload(file): headers={'x-admin-token': TEST_ADMIN_TOKEN}, params={'local': True}) + def validate(hashes): + return test_client.post('/admin/duplication_validate', + json={'hashes': hashes}, + headers={'x-admin-token': TEST_ADMIN_TOKEN}) + with open(test_file_path, 'rb') as f: + # Validate 1# + val_resp = validate(test_file_hashes) + assert val_resp.status_code == 200 + assert val_resp.json()['exists'] == [False, False] + assert val_resp.json()['entity_ids'] == [None, None] + + # Upload resp = upload(f) assert resp.status_code == 200 image_id = resp.json()['image_id'] - resp = upload(f) # The previous image is still in queue - assert resp.status_code == 409 - await wait_for_background_task(1) - resp = upload(f) # The previous image is indexed now - assert resp.status_code == 409 + + for i in range(0, 2): + # Re-upload + resp = upload(f) + assert resp.status_code == 409, i + + # Validate + val_resp = validate(test_file_hashes) + assert val_resp.status_code == 200, i + assert val_resp.json()['exists'] == [True, False], i + assert val_resp.json()['entity_ids'] == [str(image_id), None], i + + # Wait for the image to be indexed + if i == 0: + await wait_for_background_task(1) # cleanup resp = test_client.delete(f'/admin/delete/{image_id}', headers={'x-admin-token': TEST_ADMIN_TOKEN}) diff --git a/tests/assets/__init__.py b/tests/assets/__init__.py new file mode 100644 index 0000000..8381764 --- /dev/null +++ b/tests/assets/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +assets_path = Path(__file__).parent diff --git a/tests/unit/test_image_uuid.py b/tests/unit/test_image_uuid.py new file mode 100644 index 0000000..8c3a41a --- /dev/null +++ b/tests/unit/test_image_uuid.py @@ -0,0 +1,19 @@ +import io +from uuid import UUID + +from app.util.generate_uuid import generate_uuid +from ..assets import assets_path + +BSN_UUID = UUID('b3aff1e9-8085-5300-8e06-37b522384659') # To test consistency of UUID across versions + + +def test_uuid_consistency(): + file_path = assets_path / 'test_images' / 'bsn_0.jpg' + with open(file_path, 'rb') as f: + file_content = f.read() + + uuid1 = generate_uuid(file_path) + uuid2 = generate_uuid(io.BytesIO(file_content)) + uuid3 = generate_uuid(file_content) + + assert uuid1 == uuid2 == uuid3 == BSN_UUID