Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic implementation of storage service #14

Merged
merged 8 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import PurePath
from typing import Annotated
from uuid import UUID

Expand All @@ -8,10 +9,9 @@
from app.Models.api_response.admin_api_response import ServerInfoResponse
from app.Models.api_response.base import NekoProtocol
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import db_context
from app.Services.provider import db_context, storage_service
from app.Services.vector_db_context import PointNotFoundError
from app.config import config
from app.util import directories

admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)], tags=["Admin"])

Expand All @@ -23,28 +23,24 @@ async def delete_image(
image_id: Annotated[UUID, params.Path(description="The id of the image you want to delete.")]) -> NekoProtocol:
try:
point = await db_context.retrieve_by_id(str(image_id))
except PointNotFoundError:
raise HTTPException(404, "Cannot find the image with the given ID.")

except PointNotFoundError as ex:
raise HTTPException(404, "Cannot find the image with the given ID.") from ex
await db_context.deleteItems([str(point.id)])
logger.success("Image {} deleted from database.", point.id)

if point.url.startswith('/') and config.static_file.enable: # local image
image_files = list(directories.static_dir.glob(f"{point.id}.*"))
if point.local and config.storage.method.enabled: # local image
image_files = [itm[0] async for itm in storage_service.active_storage.list_files("", f"{point.id}.*")]
assert len(image_files) <= 1

if not image_files:
logger.warning("Image {} is a local image but not found in static folder.", point.id)
else:
directories.deleted_dir.mkdir(parents=True, exist_ok=True)

image_files[0].rename(directories.deleted_dir / image_files[0].name)
logger.success("Local image {} removed.", image_files[0].name)

await storage_service.active_storage.move(image_files[0], f"_deleted/{image_files[0].name}")
logger.success("Image {} removed.", image_files[0].name)
if point.thumbnail_url is not None:
thumbnail_file = directories.thumbnails_dir / f"{point.id}.webp"
if thumbnail_file.is_file():
thumbnail_file.unlink()
thumbnail_file = PurePath(f"thumbnails/{point.id}.webp")
if await storage_service.active_storage.is_exist(thumbnail_file):
await storage_service.active_storage.delete(thumbnail_file)
logger.success("Thumbnail {} removed.", thumbnail_file.name)
else:
logger.warning("Thumbnail {} not found.", thumbnail_file.name)
Expand All @@ -59,8 +55,8 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id
raise HTTPException(422, "Nothing to update.")
try:
point = await db_context.retrieve_by_id(str(image_id))
except PointNotFoundError:
raise HTTPException(404, "Cannot find the image with the given ID.")
except PointNotFoundError as ex:
raise HTTPException(404, "Cannot find the image with the given ID.") from ex

if model.starred is not None:
point.starred = model.starred
Expand Down
11 changes: 10 additions & 1 deletion app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,23 @@
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
from app.Services.authentication import force_access_token_verify
from app.Services.provider import db_context, transformers_service
from app.Services.provider import db_context, transformers_service, storage_service
from app.config import config
from app.util.calculate_vectors_cosine import calculate_vectors_cosine

searchRouter = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None),
tags=["Search"])


async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse:
for item in resp.result:
if item.img.local and config.storage.method.enabled:
item.img.url = await storage_service.active_storage.get_image_url(item.img)
if item.img.thumbnail_url is not None:
item.img.thumbnail_url = await storage_service.active_storage.get_url(item.img.thumbnail_url)
return resp


class SearchBasisParams:
def __init__(self,
basis: Annotated[SearchBasisEnum, Query(
Expand Down
1 change: 1 addition & 0 deletions app/Models/img_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ImageData(BaseModel):
starred: Optional[bool] = False
categories: Optional[list[str]] = []
local: Optional[bool] = False
type: Optional[str] = None # required for s3 local storage

@computed_field()
@property
Expand Down
5 changes: 5 additions & 0 deletions app/Services/provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from loguru import logger
from .index_service import IndexService
from .storage import StorageService
from .transformers_service import TransformersService
from .vector_db_context import VectorDbContext
from ..config import config, environment
Expand Down Expand Up @@ -27,5 +29,8 @@
from .ocr_services import DisabledOCRService

ocr_service = DisabledOCRService()
logger.info(f"OCR service '{type(ocr_service).__name__}' initialized.")

index_service = IndexService(ocr_service, transformers_service, db_context)
storage_service = StorageService()
logger.info(f"Storage service '{type(storage_service.active_storage).__name__}' initialized.")
20 changes: 20 additions & 0 deletions app/Services/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from app.Services.storage.local_storage import LocalStorage
from app.Services.storage.s3_compatible_storage import S3Storage
from app.config import config, StorageMode


class StorageService:
def __init__(self):
self.local_storage = LocalStorage()
self.active_storage = None
match config.storage.method:
case StorageMode.LOCAL:
self.active_storage = self.local_storage
case StorageMode.S3:
self.active_storage = S3Storage()
case StorageMode.DISABLED:
return
case _:
raise NotImplementedError(f"Storage method {config.storage.method} not implemented. "
f"Available methods: local, s3")
self.active_storage.pre_check()
153 changes: 153 additions & 0 deletions app/Services/storage/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import abc
import os
from typing import TypeVar, Generic, TypeAlias, Optional, AsyncGenerator

from app.Models.img_data import ImageData

FileMetaDataT = TypeVar('FileMetaDataT')

PathLikeType: TypeAlias = str | os.PathLike
LocalFilePathType: TypeAlias = PathLikeType | bytes
RemoteFilePathType: TypeAlias = PathLikeType
LocalFileMetaDataType: TypeAlias = FileMetaDataT
RemoteFileMetaDataType: TypeAlias = FileMetaDataT


class BaseStorage(abc.ABC, Generic[FileMetaDataT]):
def __init__(self):
self.static_dir: os.PathLike
self.thumbnails_dir: os.PathLike
self.deleted_dir: os.PathLike
self.file_metadata: FileMetaDataT

@abc.abstractmethod
def pre_check(self):
raise NotImplementedError

@abc.abstractmethod
async def is_exist(self,
remote_file: RemoteFilePathType) -> bool:
"""
Check if a remote_file exists.
:param remote_file: The file path relative to static_dir
:return: True if the file exists, False otherwise
"""
raise NotImplementedError

@abc.abstractmethod
async def size(self,
remote_file: RemoteFilePathType) -> int:
"""
Get the size of a file in static_dir
:param remote_file: The file path relative to static_dir
:return: file's size
"""
raise NotImplementedError

@abc.abstractmethod
async def url(self,
remote_file: RemoteFilePathType) -> str:
"""
Get the original URL of a file in static_dir.
This url will be placed in the payload field of the qdrant.
:param remote_file: The file path relative to static_dir
:return: file's "original URL"
"""
raise NotImplementedError

@abc.abstractmethod
async def presign_url(self,
remote_file: RemoteFilePathType,
expire_second: int = 3600) -> str:
"""
Get the presign URL of a file in static_dir.
:param remote_file: The file path relative to static_dir
:param expire_second: Valid time for presign url
:return: file's "presign URL"
"""
raise NotImplementedError

@abc.abstractmethod
async def fetch(self,
remote_file: RemoteFilePathType) -> bytes:
"""
Fetch a file from static_dir
:param remote_file: The file path relative to static_dir
:return: file's content
"""
raise NotImplementedError

@abc.abstractmethod
async def upload(self,
local_file: "LocalFilePathType",
remote_file: RemoteFilePathType) -> None:
"""
Move a local picture file to the static_dir.
:param local_file: The absolute path to the local file or bytes.
:param remote_file: The file path relative to static_dir
"""
raise NotImplementedError

@abc.abstractmethod
async def copy(self,
old_remote_file: RemoteFilePathType,
new_remote_file: RemoteFilePathType) -> None:
"""
Copy a file in static_dir.
:param old_remote_file: The file path relative to static_dir
:param new_remote_file: The file path relative to static_dir
"""
raise NotImplementedError

@abc.abstractmethod
async def move(self,
old_remote_file: RemoteFilePathType,
new_remote_file: RemoteFilePathType) -> None:
"""
Move a file in static_dir.
:param old_remote_file: The file path relative to static_dir
:param new_remote_file: The file path relative to static_dir
"""
raise NotImplementedError

@abc.abstractmethod
async def delete(self,
remote_file: RemoteFilePathType) -> None:
"""
Move a file in static_dir.
:param remote_file: The file path relative to static_dir
"""
raise NotImplementedError

@abc.abstractmethod
async def list_files(self,
path: RemoteFilePathType,
pattern: Optional[str] = "*",
batch_max_files: Optional[int] = None,
valid_extensions: Optional[set[str]] = None) \
-> AsyncGenerator[list[RemoteFilePathType], None]:
"""
Asynchronously generates a list of files from a given base directory path that match a specified pattern and set
of file extensions.

:param path: The relative base directory path from which relative to static_dir to start listing files.
:param pattern: A glob pattern to filter files based on their names. Defaults to '*' which selects all files.
:param batch_max_files: The maximum number of files to return. If None, all matching files are returned.
:param valid_extensions: An extra set of file extensions to include (e.g., {".jpg", ".png"}).
If None, files are not filtered by extension.
:return: An asynchronous generator yielding lists of RemoteFilePathType objects representing the matching files.

Usage example:
async for batch in list_files(base_path=".", pattern="*", max_files=100, valid_extensions={".jpg", ".png"}):
print(f"Batch: {batch}")
"""
raise NotImplementedError

@abc.abstractmethod
async def update_metadata(self,
local_file_metadata: LocalFileMetaDataType,
remote_file_metadata: RemoteFileMetaDataType) -> None:
raise NotImplementedError

async def get_image_url(self, img: ImageData) -> str:
return img.url
30 changes: 30 additions & 0 deletions app/Services/storage/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class StorageExtension(Exception):
pass


class LocalFileNotFoundError(StorageExtension):
pass


class LocalFileExistsError(StorageExtension):
pass


class LocalFilePermissionError(StorageExtension):
pass


class RemoteFileNotFoundError(StorageExtension):
pass


class RemoteFileExistsError(StorageExtension):
pass


class RemoteFilePermissionError(StorageExtension):
pass


class RemoteConnectError(StorageExtension):
pass
Loading
Loading