Skip to content

Commit

Permalink
Rename ImageData to MappedImage to have a better readability
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed Jul 12, 2024
1 parent d693e03 commit 0d9ee76
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 72 deletions.
24 changes: 12 additions & 12 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DuplicateValidationResponse
from app.Models.api_response.base import NekoProtocol
from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Models.mapped_image import MappedImage
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import ServiceProvider
from app.Services.vector_db_context import PointNotFoundError
Expand Down Expand Up @@ -131,17 +131,17 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i
logger.warning("Invalid image file from upload request. id: {}", img_id)
raise HTTPException(422, "Cannot open the image file.") from ex

image_data = ImageData(id=img_id,
url=model.url,
thumbnail_url=model.thumbnail_url,
local=model.local,
categories=model.categories,
starred=model.starred,
comments=model.comments,
format=img_type,
index_date=datetime.now())

await services.upload_service.queue_upload_image(image_data, img_bytes, model.skip_ocr, model.local_thumbnail)
mapped_image = MappedImage(id=img_id,
url=model.url,
thumbnail_url=model.thumbnail_url,
local=model.local,
categories=model.categories,
starred=model.starred,
comments=model.comments,
format=img_type,
index_date=datetime.now())

await services.upload_service.queue_upload_image(mapped_image, img_bytes, model.skip_ocr, model.local_thumbnail)
return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id)


Expand Down
6 changes: 3 additions & 3 deletions app/Models/api_response/images_api_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import Field

from app.Models.api_response.base import NekoProtocol
from app.Models.img_data import ImageData
from app.Models.mapped_image import MappedImage


class ImageStatus(str, Enum):
Expand All @@ -16,10 +16,10 @@ class QueryByIdApiResponse(NekoProtocol):
"Warning: If NekoImageGallery is deployed in a cluster, "
"the `in_queue` might not be accurate since the index queue "
"is independent of each service instance.")
img: ImageData | None = Field(description="The mapped image data. Only available when `img_status = mapped`.")
img: MappedImage | None = Field(description="The mapped image data. Only available when `img_status = mapped`.")


class QueryImagesApiResponse(NekoProtocol):
images: list[ImageData] = Field(description="The list of images.")
images: list[MappedImage] = Field(description="The list of images.")
next_page_offset: str | None = Field(description="The offset ID for the next page query. "
"If there are no more images, this field will be null.")
2 changes: 1 addition & 1 deletion app/Models/img_data.py → app/Models/mapped_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel, Field, ConfigDict


class ImageData(BaseModel):
class MappedImage(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra='ignore')

id: UUID = Field(description="The unique ID of the image. The ID is generated from the digest of the image.")
Expand Down
5 changes: 3 additions & 2 deletions app/Models/search_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pydantic import BaseModel
from .img_data import ImageData

from .mapped_image import MappedImage


class SearchResult(BaseModel):
img: ImageData
img: MappedImage
score: float
10 changes: 5 additions & 5 deletions app/Services/index_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fastapi.concurrency import run_in_threadpool

from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Models.mapped_image import MappedImage
from app.Services.lifespan_service import LifespanService
from app.Services.ocr_services import OCRService
from app.Services.transformers_service import TransformersService
Expand All @@ -16,7 +16,7 @@ def __init__(self, ocr_service: OCRService, transformers_service: TransformersSe
self._transformers_service = transformers_service
self._db_context = db_context

def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
def _prepare_image(self, image: Image.Image, image_data: MappedImage, skip_ocr=False):
image_data.width = image.width
image_data.height = image.height
image_data.aspect_ratio = float(image.width) / image.height
Expand All @@ -34,12 +34,12 @@ def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=Fal
image_data.ocr_text = None

# currently, here only need just a simple check
async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool:
async def _is_point_duplicate(self, image_data: list[MappedImage]) -> bool:
image_id_list = [str(item.id) for item in image_data]
result = await self._db_context.validate_ids(image_id_list)
return len(result) != 0

async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False,
async def index_image(self, image: Image.Image, image_data: MappedImage, skip_ocr=False, skip_duplicate_check=False,
background=False):
if not skip_duplicate_check and (await self._is_point_duplicate([image_data])):
raise PointDuplicateError("The uploaded points are contained in the database!", image_data.id)
Expand All @@ -51,7 +51,7 @@ async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=

await self._db_context.insertItems([image_data])

async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData],
async def index_image_batch(self, image: list[Image.Image], image_data: list[MappedImage],
skip_ocr=False, allow_overwrite=False):
if not allow_overwrite and (await self._is_point_duplicate(image_data)):
raise PointDuplicateError("The uploaded points are contained in the database!")
Expand Down
46 changes: 23 additions & 23 deletions app/Services/upload_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from app.Models.api_models.admin_query_params import UploadImageThumbnailMode
from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Models.mapped_image import MappedImage
from app.Services.index_service import IndexService
from app.Services.lifespan_service import LifespanService
from app.Services.storage import StorageService
Expand Down Expand Up @@ -46,44 +46,44 @@ async def _upload_worker(self):
if self._processed_count % 50 == 0:
gc.collect()

async def _upload_task(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
async def _upload_task(self, mapped_img: MappedImage, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
img = Image.open(BytesIO(img_bytes))
logger.info('Start indexing image {}. Local: {}. Size: {}', img_data.id, img_data.local, len(img_bytes))
file_name = f"{img_data.id}.{img_data.format}"
thumb_path = f"thumbnails/{img_data.id}.webp"
logger.info('Start indexing image {}. Local: {}. Size: {}', mapped_img.id, mapped_img.local, len(img_bytes))
file_name = f"{mapped_img.id}.{mapped_img.format}"
thumb_path = f"thumbnails/{mapped_img.id}.webp"
gen_thumb = thumbnail_mode == UploadImageThumbnailMode.ALWAYS or (
thumbnail_mode == UploadImageThumbnailMode.IF_NECESSARY and len(img_bytes) > 1024 * 500)

if img_data.local:
img_data.url = await self._storage_service.active_storage.url(file_name)
if mapped_img.local:
mapped_img.url = await self._storage_service.active_storage.url(file_name)
if gen_thumb:
img_data.thumbnail_url = await self._storage_service.active_storage.url(
f"thumbnails/{img_data.id}.webp")
img_data.local_thumbnail = True
mapped_img.thumbnail_url = await self._storage_service.active_storage.url(
f"thumbnails/{mapped_img.id}.webp")
mapped_img.local_thumbnail = True

await self._index_service.index_image(img, img_data, skip_ocr=skip_ocr, background=True)
logger.success("Image {} indexed.", img_data.id)
await self._index_service.index_image(img, mapped_img, skip_ocr=skip_ocr, background=True)
logger.success("Image {} indexed.", mapped_img.id)

if img_data.local:
logger.info("Start uploading image {} to local storage.", img_data.id)
if mapped_img.local:
logger.info("Start uploading image {} to local storage.", mapped_img.id)
await self._storage_service.active_storage.upload(img_bytes, file_name)
logger.success("Image {} uploaded to local storage.", img_data.id)
logger.success("Image {} uploaded to local storage.", mapped_img.id)
if gen_thumb:
logger.info("Start generate and upload thumbnail for {}.", img_data.id)
logger.info("Start generate and upload thumbnail for {}.", mapped_img.id)
img.thumbnail((256, 256), resample=Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
img.save(img_byte_arr, 'WebP', save_all=True)
await self._storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path)
logger.success("Thumbnail for {} generated and uploaded!", img_data.id)
logger.success("Thumbnail for {} generated and uploaded!", mapped_img.id)

img.close()

async def queue_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
async def queue_upload_image(self, mapped_img: MappedImage, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
self.uploading_ids.add(img_data.id)
await self._queue.put((img_data, img_bytes, skip_ocr, thumbnail_mode))
logger.success("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize())
self.uploading_ids.add(mapped_img.id)
await self._queue.put((mapped_img, img_bytes, skip_ocr, thumbnail_mode))
logger.success("Image {} added to upload queue. Queue Length: {} [+1]", mapped_img.id, self._queue.qsize())

async def assign_image_id(self, img_file: pathlib.Path | io.BytesIO | bytes):
img_id = generate_uuid(img_file)
Expand All @@ -94,9 +94,9 @@ async def assign_image_id(self, img_file: pathlib.Path | io.BytesIO | bytes):
img_id)
return img_id

async def sync_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
async def sync_upload_image(self, mapped_img: MappedImage, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
await self._upload_task(img_data, img_bytes, skip_ocr, thumbnail_mode)
await self._upload_task(mapped_img, img_bytes, skip_ocr, thumbnail_mode)

Check warning on line 99 in app/Services/upload_service.py

View check run for this annotation

Codecov / codecov/patch

app/Services/upload_service.py#L99

Added line #L99 was not covered by tests

def get_queue_size(self):
return self._queue.qsize()
Expand Down
36 changes: 18 additions & 18 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from qdrant_client.models import RecommendStrategy

from app.Models.api_models.search_api_model import SearchModelEnum, SearchBasisEnum
from app.Models.img_data import ImageData
from app.Models.mapped_image import MappedImage
from app.Models.query_params import FilterParams
from app.Models.search_result import SearchResult
from app.Services.lifespan_service import LifespanService
Expand Down Expand Up @@ -50,7 +50,7 @@ async def on_load(self):
logger.warning("Collection not found. Initializing...")
await self.initialize_collection()

async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData:
async def retrieve_by_id(self, image_id: str, with_vectors=False) -> MappedImage:
"""
Retrieve an item from database by id. Will raise PointNotFoundError if the given ID doesn't exist.
:param image_id: The ID to retrieve.
Expand All @@ -65,9 +65,9 @@ async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData:
if len(result) != 1:
logger.error("Point not exist.")
raise PointNotFoundError(image_id)
return self._get_img_data_from_point(result[0])
return self._get_mapped_image_from_point(result[0])

async def retrieve_by_ids(self, image_id: list[str], with_vectors=False) -> list[ImageData]:
async def retrieve_by_ids(self, image_id: list[str], with_vectors=False) -> list[MappedImage]:
"""
Retrieve items from the database by IDs.
An exception is thrown if there are items in the IDs that do not exist in the database.
Expand All @@ -85,7 +85,7 @@ async def retrieve_by_ids(self, image_id: list[str], with_vectors=False) -> list
if len(missing_point_ids) > 0:
logger.error("{} points not exist.", len(missing_point_ids))
raise PointNotFoundError(str(missing_point_ids))
return self._get_img_data_from_points(result)
return self._get_mapped_image_from_point_batch(result)

Check warning on line 88 in app/Services/vector_db_context.py

View check run for this annotation

Codecov / codecov/patch

app/Services/vector_db_context.py#L88

Added line #L88 was not covered by tests

async def validate_ids(self, image_id: list[str]) -> list[str]:
"""
Expand Down Expand Up @@ -144,10 +144,10 @@ async def querySimilar(self,

return [self._get_search_result_from_scored_point(t) for t in result]

async def insertItems(self, items: list[ImageData]):
async def insertItems(self, items: list[MappedImage]):
logger.info("Inserting {} items into Qdrant...", len(items))

points = [self._get_point_from_img_data(t) for t in items]
points = [self._get_point_from_mapped_image(t) for t in items]

response = await self._client.upsert(collection_name=self.collection_name,
wait=True,
Expand All @@ -163,7 +163,7 @@ async def deleteItems(self, ids: list[str]):
)
logger.success("Delete completed! Status: {}", response.status)

async def updatePayload(self, new_data: ImageData):
async def updatePayload(self, new_data: MappedImage):
"""
Update the payload of an existing item in the database.
Warning: This method will not update the vector of the item.
Expand All @@ -175,7 +175,7 @@ async def updatePayload(self, new_data: ImageData):
wait=True)
logger.success("Update completed! Status: {}", response.status)

async def updateVectors(self, new_points: list[ImageData]):
async def updateVectors(self, new_points: list[MappedImage]):
resp = await self._client.update_vectors(collection_name=self.collection_name,
points=[self._get_vector_from_img_data(t) for t in new_points],
)
Expand All @@ -186,15 +186,15 @@ async def scroll_points(self,
count=50,
with_vectors=False,
filter_param: FilterParams | None = None,
) -> tuple[list[ImageData], str]:
) -> tuple[list[MappedImage], str]:
resp, next_id = await self._client.scroll(collection_name=self.collection_name,
limit=count,
offset=from_id,
with_vectors=with_vectors,
scroll_filter=self._get_filters_by_filter_param(filter_param)
)

return [self._get_img_data_from_point(t) for t in resp], next_id
return [self._get_mapped_image_from_point(t) for t in resp], next_id

async def get_counts(self, exact: bool) -> int:
resp = await self._client.count(collection_name=self.collection_name, exact=exact)
Expand All @@ -219,7 +219,7 @@ async def initialize_collection(self):
logger.success("Collection created!")

@classmethod
def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors:
def _get_vector_from_img_data(cls, img_data: MappedImage) -> models.PointVectors:
vector = {}
if img_data.image_vector is not None:
vector[cls.IMG_VECTOR] = img_data.image_vector.tolist()
Expand All @@ -231,15 +231,15 @@ def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors:
)

@classmethod
def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct:
def _get_point_from_mapped_image(cls, img_data: MappedImage) -> models.PointStruct:
return models.PointStruct(
id=str(img_data.id),
payload=img_data.payload,
vector=cls._get_vector_from_img_data(img_data).vector
)

def _get_img_data_from_point(self, point: AVAILABLE_POINT_TYPES) -> ImageData:
return (ImageData
def _get_mapped_image_from_point(self, point: AVAILABLE_POINT_TYPES) -> MappedImage:
return (MappedImage
.from_payload(point.id,
point.payload,
image_vector=numpy.array(point.vector[self.IMG_VECTOR], dtype=numpy.float32)
Expand All @@ -248,11 +248,11 @@ def _get_img_data_from_point(self, point: AVAILABLE_POINT_TYPES) -> ImageData:
if point.vector and self.TEXT_VECTOR in point.vector else None
))

def _get_img_data_from_points(self, points: list[AVAILABLE_POINT_TYPES]) -> list[ImageData]:
return [self._get_img_data_from_point(t) for t in points]
def _get_mapped_image_from_point_batch(self, points: list[AVAILABLE_POINT_TYPES]) -> list[MappedImage]:
return [self._get_mapped_image_from_point(t) for t in points]

Check warning on line 252 in app/Services/vector_db_context.py

View check run for this annotation

Codecov / codecov/patch

app/Services/vector_db_context.py#L252

Added line #L252 was not covered by tests

def _get_search_result_from_scored_point(self, point: models.ScoredPoint) -> SearchResult:
return SearchResult(img=self._get_img_data_from_point(point), score=point.score)
return SearchResult(img=self._get_mapped_image_from_point(point), score=point.score)

@classmethod
def vector_name_for_basis(cls, basis: SearchBasisEnum) -> str:
Expand Down
16 changes: 8 additions & 8 deletions scripts/local_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from app.Models.api_models.admin_query_params import UploadImageThumbnailMode
from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Models.mapped_image import MappedImage
from app.Services.provider import ServiceProvider
from app.util.local_file_utility import glob_local_files

Expand All @@ -18,13 +18,13 @@
async def index_task(file_path: Path, categories: list[str], starred: bool, thumbnail_mode: UploadImageThumbnailMode):
try:
img_id = await services.upload_service.assign_image_id(file_path)
image_data = ImageData(id=img_id,
local=True,
categories=categories,
starred=starred,
format=file_path.suffix[1:], # remove the dot
index_date=datetime.now())
await services.upload_service.sync_upload_image(image_data, file_path.read_bytes(), skip_ocr=False,
mapped_image = MappedImage(id=img_id,
local=True,
categories=categories,
starred=starred,
format=file_path.suffix[1:], # remove the dot
index_date=datetime.now())
await services.upload_service.sync_upload_image(mapped_image, file_path.read_bytes(), skip_ocr=False,
thumbnail_mode=thumbnail_mode)
except PointDuplicateError as ex:
logger.warning("Image {} already exists in the database", file_path)
Expand Down

0 comments on commit 0d9ee76

Please sign in to comment.