Skip to content

Commit

Permalink
Merge pull request #32 from hv0905/thumbnail
Browse files Browse the repository at this point in the history
Allow setting thumbnail mode when uploading
  • Loading branch information
hv0905 authored Jun 4, 2024
2 parents 799d6e1 + d87b2c0 commit 43d30ab
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
Expand Down
23 changes: 12 additions & 11 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ async def delete_image(
await services.db_context.deleteItems([str(point.id)])
logger.success("Image {} deleted from database.", point.id)

if point.local and config.storage.method.enabled: # local image
image_files = [itm[0] async for itm in services.storage_service.active_storage.list_files("", f"{point.id}.*")]
assert len(image_files) <= 1

if not image_files:
logger.warning("Image {} is a local image but not found in static folder.", point.id)
else:
await services.storage_service.active_storage.move(image_files[0], f"_deleted/{image_files[0].name}")
logger.success("Image {} removed.", image_files[0].name)
if point.thumbnail_url is not None:
if config.storage.method.enabled: # local image
if point.local:
image_files = [itm[0] async for itm in
services.storage_service.active_storage.list_files("", f"{point.id}.*")]
assert len(image_files) <= 1
if not image_files:
logger.warning("Image {} is a local image but not found in static folder.", point.id)
else:
await services.storage_service.active_storage.move(image_files[0], f"_deleted/{image_files[0].name}")
logger.success("Image {} removed.", image_files[0].name)
if point.thumbnail_url is not None and (point.local or point.local_thumbnail):
thumbnail_file = PurePath(f"thumbnails/{point.id}.webp")
if await services.storage_service.active_storage.is_exist(thumbnail_file):
await services.storage_service.active_storage.delete(thumbnail_file)
Expand Down Expand Up @@ -119,7 +120,7 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i
format=img_type,
index_date=datetime.now())

await services.upload_service.upload_image(image, image_data, img_bytes, model.skip_ocr)
await services.upload_service.upload_image(image, image_data, img_bytes, model.skip_ocr, model.local_thumbnail)
return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id)


Expand Down
12 changes: 7 additions & 5 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,17 @@ def __init__(self,


async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse:
if not config.storage.method.enabled:
return resp
for item in resp.result:
if item.img.local and config.storage.method.enabled:
if item.img.local:
img_extension = item.img.format or item.img.url.split('.')[-1]
img_remote_filename = f"{item.img.id}.{img_extension}"
item.img.url = await services.storage_service.active_storage.presign_url(img_remote_filename)
if item.img.thumbnail_url is not None:
thumbnail_remote_filename = f"thumbnails/{item.img.id}.webp"
item.img.thumbnail_url = await services.storage_service.active_storage.presign_url(
thumbnail_remote_filename)
if item.img.thumbnail_url is not None and (item.img.local or item.img.local_thumbnail):
thumbnail_remote_filename = f"thumbnails/{item.img.id}.webp"
item.img.thumbnail_url = await services.storage_service.active_storage.presign_url(
thumbnail_remote_filename)
return resp


Expand Down
24 changes: 22 additions & 2 deletions app/Models/api_models/admin_query_params.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,49 @@
from enum import Enum
from typing import Optional

from fastapi import Query, HTTPException


class UploadImageThumbnailMode(str, Enum):
DEFAULT = "default"
IF_NECESSARY = "if_necessary"
ALWAYS = "always"
NEVER = "never"


class UploadImageModel:
def __init__(self,
url: Optional[str] = Query(None,
description="The image's url. If the image is local, this field will be "
"ignored. Otherwise it is required."),
thumbnail_url: Optional[str] = Query(None,
description="The image's thumbnail url. If the image is local, "
"this field will be ignored."),
description="The image's thumbnail url. If the image is local "
"or local_thumbnail's value is always, "
"this field will be ignored. Currently setting a "
"external thumbnail for a local image is "
"unsupported due to compatibility issues."),
categories: Optional[str] = Query(None,
description="The categories of the image. The entries should be "
"seperated by comma."),
starred: bool = Query(False, description="If the image is starred."),
local: bool = Query(False,
description="When set to true, the image will be uploaded to local storage. "
"Otherwise, it will only be indexed in the database."),
local_thumbnail: UploadImageThumbnailMode =
Query(default=UploadImageThumbnailMode.DEFAULT,
description="Whether to generate thumbnail locally. Possible values:\n"
"- `if_necessary`: Only generate thumbnail if the image is larger than 500KB. "
"This is the default value if `local=True`\n"
" - `always`: Always generate thumbnail.\n"
" - `never`: Never generate thumbnail. This is the default value if `local=False`."),
skip_ocr: bool = Query(False, description="Whether to skip the OCR process.")):
self.url = url
self.thumbnail_url = thumbnail_url
self.categories = [t.strip() for t in categories.split(',') if t.strip()] if categories else None
self.starred = starred
self.local = local
self.skip_ocr = skip_ocr
self.local_thumbnail = local_thumbnail if local_thumbnail is not UploadImageThumbnailMode.DEFAULT else (
UploadImageThumbnailMode.IF_NECESSARY if local else UploadImageThumbnailMode.NEVER)
if not self.url and not self.local:
raise HTTPException(422, "A correspond url must be provided for a non-local image.")
1 change: 1 addition & 0 deletions app/Models/img_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ImageData(BaseModel):
starred: Optional[bool] = False
categories: Optional[list[str]] = []
local: Optional[bool] = False
local_thumbnail: Optional[bool] = False
format: Optional[str] = None # required for s3 local storage

@computed_field()
Expand Down
36 changes: 22 additions & 14 deletions app/Services/upload_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from PIL import Image
from loguru import logger

from app.Models.api_models.admin_query_params import UploadImageThumbnailMode
from app.Models.img_data import ImageData
from app.Services.index_service import IndexService
from app.Services.storage import StorageService
Expand All @@ -25,9 +26,9 @@ def __init__(self, storage_service: StorageService, db_context: VectorDbContext,

async def _upload_worker(self):
while True:
img, img_data, img_bytes, skip_ocr = await self._queue.get()
img, img_data, *args = await self._queue.get()
try:
await self._upload_task(img, img_data, img_bytes, skip_ocr)
await self._upload_task(img, img_data, *args)
logger.success("Image {} uploaded and indexed. Queue Length: {} [-1]", img_data.id, self._queue.qsize())
except Exception as ex:
logger.error("Error occurred while uploading image {}", img_data.id)
Expand All @@ -38,15 +39,20 @@ async def _upload_worker(self):
if self._processed_count % 50 == 0:
gc.collect()

async def _upload_task(self, img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool):
async def _upload_task(self, img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
logger.info('Start indexing image {}. Local: {}. Size: {}', img_data.id, img_data.local, len(img_bytes))
file_name = f"{img_data.id}.{img_data.format}"
thumb_path = f"thumbnails/{img_data.id}.webp"
gen_thumb = thumbnail_mode == UploadImageThumbnailMode.ALWAYS or (
thumbnail_mode == UploadImageThumbnailMode.IF_NECESSARY and len(img_bytes) > 1024 * 500)

if img_data.local:
img_data.url = await self._storage_service.active_storage.url(file_name)
if len(img_bytes) > 1024 * 500:
img_data.thumbnail_url = await self._storage_service.active_storage.url(
f"thumbnails/{img_data.id}.webp")
if gen_thumb:
img_data.thumbnail_url = await self._storage_service.active_storage.url(
f"thumbnails/{img_data.id}.webp")
img_data.local_thumbnail = True

await self._index_service.index_image(img, img_data, skip_ocr=skip_ocr, background=True)
logger.success("Image {} indexed.", img_data.id)
Expand All @@ -55,17 +61,19 @@ async def _upload_task(self, img: Image.Image, img_data: ImageData, img_bytes: b
logger.info("Start uploading image {} to local storage.", img_data.id)
await self._storage_service.active_storage.upload(img_bytes, file_name)
logger.success("Image {} uploaded to local storage.", img_data.id)
if len(img_bytes) > 1024 * 500:
img.thumbnail((256, 256), resample=Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
img.save(img_byte_arr, 'WebP')
await self._storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path)
logger.success("Thumbnail for {} generated and uploaded!", img_data.id)
if gen_thumb:
logger.info("Start generate and upload thumbnail for {}.", img_data.id)
img.thumbnail((256, 256), resample=Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
img.save(img_byte_arr, 'WebP')
await self._storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path)
logger.success("Thumbnail for {} generated and uploaded!", img_data.id)

img.close()

async def upload_image(self, img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool):
await self._queue.put((img, img_data, img_bytes, skip_ocr))
async def upload_image(self, img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
await self._queue.put((img, img_data, img_bytes, skip_ocr, thumbnail_mode))
logger.success("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize())

def get_queue_size(self):
Expand Down
1 change: 1 addition & 0 deletions scripts/local_create_thumbnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def main():

# update payload
imgdata.thumbnail_url = await services.storage_service.active_storage.url(f'thumbnails/{str(image_id)}.webp')
imgdata.local_thumbnail = True
await services.db_context.updatePayload(imgdata)
logger.success("Payload for {} updated!", image_id)

Expand Down
28 changes: 28 additions & 0 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import importlib
from pathlib import Path

import pytest
from fastapi.testclient import TestClient
Expand All @@ -25,6 +27,32 @@ def test_client(tmp_path_factory) -> TestClient:
yield client


@pytest.fixture()
def check_local_dir_empty():
yield
dir = Path(config.config.storage.local.path)
files = [f for f in dir.glob('*.*') if f.is_file()]
assert len(files) == 0

thumbnail_dir = dir / 'thumbnails'
if thumbnail_dir.exists():
thumbnail_files = [f for f in thumbnail_dir.glob('*.*') if f.is_file()]
assert len(thumbnail_files) == 0


@pytest.fixture()
def wait_for_background_task(test_client):
async def func(expected_image_count):
while True:
resp = test_client.get('/admin/server_info', headers={'x-admin-token': TEST_ADMIN_TOKEN})
if resp.json()['image_count'] >= expected_image_count:
break
await asyncio.sleep(0.2)
assert resp.json()['index_queue_length'] == 0

return func


@pytest.fixture
def anyio_backend():
return 'asyncio'
16 changes: 8 additions & 8 deletions tests/api/integrate_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from pathlib import Path

import pytest
Expand All @@ -13,7 +12,7 @@


@pytest.mark.asyncio
async def test_integrate(test_client):
async def test_integrate(test_client, check_local_dir_empty, wait_for_background_task):
credentials = {'x-admin-token': TEST_ADMIN_TOKEN, 'x-access-token': TEST_ACCESS_TOKEN}
resp = test_client.get("/", headers=credentials)
assert resp.status_code == 200
Expand All @@ -32,13 +31,8 @@ async def test_integrate(test_client):

print('Waiting for images to be processed...')

while True:
resp = test_client.get('/admin/server_info', headers=credentials)
if resp.json()['image_count'] >= 7:
break
await asyncio.sleep(1)
await wait_for_background_task(sum(len(v) for v in test_images.values()))

assert resp.json()['index_queue_length'] == 0
resp = test_client.get('/search/text/hatsune+miku',
headers=credentials)
assert resp.status_code == 200
Expand Down Expand Up @@ -80,3 +74,9 @@ async def test_integrate(test_client):
resp = test_client.get("/search/text/cat", params={'categories': 'bsn'}, headers=credentials)
assert resp.status_code == 200
assert len(resp.json()['result']) == 0

# cleanup
for img_cls in test_images.keys():
for img_id in img_ids[img_cls]:
resp = test_client.delete(f"/admin/delete/{img_id}", headers=credentials)
assert resp.status_code == (404 if img_id == img_ids['bsn'][0] else 200)
68 changes: 67 additions & 1 deletion tests/api/test_upload.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import io
import random

from .conftest import TEST_ADMIN_TOKEN
import pytest

from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN


def test_upload_bad_img_file(test_client):
Expand All @@ -24,3 +26,67 @@ def test_upload_unsupported_types(test_client):
headers={'x-admin-token': TEST_ADMIN_TOKEN},
params={'local': True})
assert resp.status_code == 415


TEST_FAKE_URL = 'fake-url'
TEST_FAKE_THUMBNAIL_URL = 'fake-thumbnail-url'

TEST_PARAMS = [
(True, {'local': True}, True, 'local'),
(True, {'local': True, 'local_thumbnail': 'never'}, True, 'none'),
(False, {'local': True, 'local_thumbnail': 'always'}, True, 'local'),
(False, {'local': True}, True, 'none'),
(False, {'local': False, 'url': TEST_FAKE_URL, 'thumbnail_url': TEST_FAKE_THUMBNAIL_URL}, False, 'fake'),
(False, {'local': False, 'url': TEST_FAKE_URL, 'local_thumbnail': 'always'}, False, 'local'),
(False, {'local': False, 'url': TEST_FAKE_URL}, False, 'none'),
]


@pytest.mark.parametrize('add_trailing_bytes,params,expect_local_url,expect_thumbnail_mode', TEST_PARAMS)
@pytest.mark.asyncio
async def test_upload_auto_local_thumbnail(test_client, check_local_dir_empty, wait_for_background_task,
add_trailing_bytes, params, expect_local_url, expect_thumbnail_mode):
with open('tests/assets/test_images/bsn_0.jpg', 'rb') as f:
img_bytes = f.read()
# append 500KB to the image, to make it large enough to generate a thumbnail
if add_trailing_bytes:
img_bytes += bytearray(random.getrandbits(8) for _ in range(1024 * 500))
f_patched = io.BytesIO(img_bytes)
f_patched.name = 'bsn_0.jpg'
else:
f_patched = f
resp = test_client.post('/admin/upload',
files={'image_file': f_patched},
headers={'x-admin-token': TEST_ADMIN_TOKEN},
params=params)
assert resp.status_code == 200
id = resp.json()['image_id']
await wait_for_background_task(1)

query = test_client.get('/search/random', headers={'x-access-token': TEST_ACCESS_TOKEN})
assert query.status_code == 200
assert query.json()['result'][0]['img']['id'] == id

if expect_local_url:
assert query.json()['result'][0]['img']['url'].startswith(f'/static/{id}.')
img_request = test_client.get(query.json()['result'][0]['img']['url'])
assert img_request.status_code == 200
else:
assert query.json()['result'][0]['img']['url'] == TEST_FAKE_URL

match expect_thumbnail_mode:
case 'local':
assert query.json()['result'][0]['img']['thumbnail_url'] == f'/static/thumbnails/{id}.webp'

thumbnail_request = test_client.get(query.json()['result'][0]['img']['thumbnail_url'])
assert thumbnail_request.status_code == 200
# IDK why starlette doesn't return the correct content type, but it works on the browser anyway
# assert thumbnail_request.headers['Content-Type'] == 'image/webp'
case 'fake':
assert query.json()['result'][0]['img']['thumbnail_url'] == TEST_FAKE_THUMBNAIL_URL
case 'none':
assert query.json()['result'][0]['img']['thumbnail_url'] is None

# cleanup
resp = test_client.delete(f'/admin/delete/{id}', headers={'x-admin-token': TEST_ADMIN_TOKEN})
assert resp.status_code == 200

0 comments on commit 43d30ab

Please sign in to comment.