Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add duplication validate API #38

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from fastapi import APIRouter, Depends, HTTPException, params, UploadFile, File
from loguru import logger

from app.Models.api_models.admin_api_model import ImageOptUpdateModel
from app.Models.api_models.admin_api_model import ImageOptUpdateModel, DuplicateValidationModel
from app.Models.api_models.admin_query_params import UploadImageModel
from app.Models.api_response.admin_api_response import ServerInfoResponse, ImageUploadResponse
from app.Models.api_response.admin_api_response import ServerInfoResponse, ImageUploadResponse, \
DuplicateValidationResponse
from app.Models.api_response.base import NekoProtocol
from app.Models.img_data import ImageData
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import ServiceProvider
from app.Services.vector_db_context import PointNotFoundError
from app.config import config
from app.util.generate_uuid import generate_uuid
from app.util.generate_uuid import generate_uuid, generate_uuid_from_sha1

admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)], tags=["Admin"])

Expand Down Expand Up @@ -98,7 +99,7 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id
description="Upload image to server. The image will be indexed and stored in the database. If "
"local is set to true, the image will be uploaded to local storage.")
async def upload_image(image_file: Annotated[UploadFile, File(description="The image to be uploaded.")],
model: Annotated[UploadImageModel, Depends()]):
model: Annotated[UploadImageModel, Depends()]) -> ImageUploadResponse:
# generate an ID for the image
img_type = None
if image_file.content_type.lower() in IMAGE_MIMES:
Expand Down Expand Up @@ -140,7 +141,22 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i


@admin_router.get("/server_info", description="Get server information")
async def server_info():
async def server_info() -> ServerInfoResponse:
return ServerInfoResponse(message="Successfully get server information!",
image_count=await services.db_context.get_counts(exact=True),
index_queue_length=services.upload_service.get_queue_size())


@admin_router.post("/duplication_validate",
description="Check if an image exists in the server by its SHA1 hash. If the image exists, "
"the image ID will be returned.\n"
"This is helpful for checking if an image is already in the server without "
"uploading the image.")
async def duplication_validate(model: DuplicateValidationModel) -> DuplicateValidationResponse:
ids = [generate_uuid_from_sha1(t) for t in model.hashes]
valid_ids = await services.db_context.validate_ids([str(t) for t in ids])
exists_matrix = [str(t) in valid_ids or t in services.upload_service.uploading_ids for t in ids]
return DuplicateValidationResponse(
exists=exists_matrix,
entity_ids=[(str(t) if exists else None) for (t, exists) in zip(ids, exists_matrix)],
message="Validation completed.")
11 changes: 3 additions & 8 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ async def advancedSearch(
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
if len(model.criteria) + len(model.negative_criteria) == 0:
raise HTTPException(status_code=422, detail="At least one criteria should be provided.")
logger.info("Advanced search request received: {}", model)
result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging)
return await result_postprocessing(
Expand All @@ -141,8 +139,6 @@ async def combinedSearch(
basis: Annotated[SearchCombinedParams, Depends(SearchCombinedParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
if len(model.criteria) + len(model.negative_criteria) == 0:
raise HTTPException(status_code=422, detail="At least one criteria should be provided.")
logger.info("Combined search request received: {}", model)
result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging)
calculate_and_sort_by_combined_scores(model, basis, result)
Expand All @@ -166,10 +162,9 @@ async def randomPick(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


@search_router.get("/recall/{query_id}", description="Recall the query with given queryId")
async def recallQuery(query_id: str):
raise NotImplementedError()

# @search_router.get("/recall/{query_id}", description="Recall the query with given queryId")
# async def recallQuery(query_id: str):
# raise NotImplementedError()

async def process_advanced_and_combined_search_query(model: Union[AdvancedSearchModel, CombinedSearchModel],
basis: Union[SearchBasisParams, SearchCombinedParams],
Expand Down
12 changes: 10 additions & 2 deletions app/Models/api_models/admin_api_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional
from typing import Optional, Annotated

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, StringConstraints


class ImageOptUpdateModel(BaseModel):
Expand All @@ -21,3 +21,11 @@ class ImageOptUpdateModel(BaseModel):

def empty(self) -> bool:
return all([item is None for item in self.model_dump().values()])


Sha1HashString = Annotated[
str, StringConstraints(min_length=40, max_length=40, pattern=r"[0-9a-f]+", to_lower=True, strip_whitespace=True)]


class DuplicateValidationModel(BaseModel):
hashes: list[Sha1HashString] = Field(description="The SHA1 hash of the image.", min_length=1)
9 changes: 7 additions & 2 deletions app/Models/api_models/search_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ class SearchCombinedBasisEnum(str, Enum):


class AdvancedSearchModel(BaseModel):
criteria: list[str] = Field([], description="The positive criteria you want to search with", max_length=16)
negative_criteria: list[str] = Field([], description="The negative criteria you want to search with", max_length=16)
criteria: list[str] = Field([],
description="The positive criteria you want to search with",
max_length=16,
min_length=1)
negative_criteria: list[str] = Field([],
description="The negative criteria you want to search with",
max_length=16)
mode: SearchModelEnum = Field(SearchModelEnum.average,
description="The mode you want to use to combine the criteria.")

Expand Down
9 changes: 9 additions & 0 deletions app/Models/api_response/admin_api_response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from uuid import UUID

from pydantic import Field

from .base import NekoProtocol


Expand All @@ -8,5 +10,12 @@ class ServerInfoResponse(NekoProtocol):
index_queue_length: int


class DuplicateValidationResponse(NekoProtocol):
entity_ids: list[UUID | None] = Field(
description="The image id for each hash. If the image does not exist in the server, the value will be null.")
exists: list[bool] = Field(
description="Whether the image exists in the server. True if the image exists, False otherwise.")


class ImageUploadResponse(NekoProtocol):
image_id: UUID
6 changes: 5 additions & 1 deletion app/util/generate_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,9 @@
file_content = file_input
else:
raise ValueError("Unsupported file type. Must be pathlib.Path or io.BytesIO.")
file_hash = hashlib.sha1(file_content).hexdigest()

Check failure

Code scanning / CodeQL

Use of a broken or weak cryptographic hashing algorithm on sensitive data High

Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
return uuid5(namespace_uuid, file_hash)
return generate_uuid_from_sha1(file_hash)


def generate_uuid_from_sha1(sha1_hash: str) -> UUID:
return uuid5(namespace_uuid, sha1_hash.lower())
2 changes: 0 additions & 2 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
TEST_ACCESS_TOKEN = 'test_token'
TEST_ADMIN_TOKEN = 'test_admin_token'

assets_path = Path(__file__).parent / '..' / 'assets'


@pytest.fixture(scope="session")
def test_client(tmp_path_factory) -> TestClient:
Expand Down
3 changes: 2 additions & 1 deletion tests/api/integrate_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN, assets_path
from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN
from ..assets import assets_path

test_images = {'bsn': ['bsn_0.jpg', 'bsn_1.jpg', 'bsn_2.jpg'],
'cat': ['cat_0.jpg', 'cat_1.jpg'],
Expand Down
38 changes: 32 additions & 6 deletions tests/api/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

import pytest

from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN, assets_path
from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN
from ..assets import assets_path

test_file_path = assets_path / 'test_images' / 'bsn_0.jpg'
test_file_2_path = assets_path / 'test_images' / 'bsn_1.jpg'

test_file_hashes = ['648351F7CBD472D0CA23EADCCF3B9E619EC9ADDA', 'C5DE90DAC2F75FBDBE48023DF4DE7585A86B2392']


def get_single_img_info(test_client, image_id):
Expand Down Expand Up @@ -46,15 +50,37 @@ def upload(file):
headers={'x-admin-token': TEST_ADMIN_TOKEN},
params={'local': True})

def validate(hashes):
return test_client.post('/admin/duplication_validate',
json={'hashes': hashes},
headers={'x-admin-token': TEST_ADMIN_TOKEN})

with open(test_file_path, 'rb') as f:
# Validate 1#
val_resp = validate(test_file_hashes)
assert val_resp.status_code == 200
assert val_resp.json()['exists'] == [False, False]
assert val_resp.json()['entity_ids'] == [None, None]

# Upload
resp = upload(f)
assert resp.status_code == 200
image_id = resp.json()['image_id']
resp = upload(f) # The previous image is still in queue
assert resp.status_code == 409
await wait_for_background_task(1)
resp = upload(f) # The previous image is indexed now
assert resp.status_code == 409

for i in range(0, 2):
# Re-upload
resp = upload(f)
assert resp.status_code == 409, i

# Validate
val_resp = validate(test_file_hashes)
assert val_resp.status_code == 200, i
assert val_resp.json()['exists'] == [True, False], i
assert val_resp.json()['entity_ids'] == [str(image_id), None], i

# Wait for the image to be indexed
if i == 0:
await wait_for_background_task(1)

# cleanup
resp = test_client.delete(f'/admin/delete/{image_id}', headers={'x-admin-token': TEST_ADMIN_TOKEN})
Expand Down
3 changes: 3 additions & 0 deletions tests/assets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pathlib import Path

assets_path = Path(__file__).parent
19 changes: 19 additions & 0 deletions tests/unit/test_image_uuid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import io
from uuid import UUID

from app.util.generate_uuid import generate_uuid
from ..assets import assets_path

BSN_UUID = UUID('b3aff1e9-8085-5300-8e06-37b522384659') # To test consistency of UUID across versions


def test_uuid_consistency():
file_path = assets_path / 'test_images' / 'bsn_0.jpg'
with open(file_path, 'rb') as f:
file_content = f.read()

uuid1 = generate_uuid(file_path)
uuid2 = generate_uuid(io.BytesIO(file_content))
uuid3 = generate_uuid(file_content)

assert uuid1 == uuid2 == uuid3 == BSN_UUID