diff --git a/app/Controllers/admin.py b/app/Controllers/admin.py index cd3995b..2e9956f 100644 --- a/app/Controllers/admin.py +++ b/app/Controllers/admin.py @@ -6,8 +6,8 @@ from app.Models.admin_api_model import ImageOptUpdateModel from app.Models.api_response.base import NekoProtocol -from app.Services import db_context from app.Services.authentication import force_admin_token_verify +from app.Services.provider import db_context from app.Services.vector_db_context import PointNotFoundError from app.config import config from app.util import directories diff --git a/app/Controllers/search.py b/app/Controllers/search.py index fa7904f..00fc860 100644 --- a/app/Controllers/search.py +++ b/app/Controllers/search.py @@ -11,9 +11,8 @@ from app.Models.api_response.search_api_response import SearchApiResponse from app.Models.query_params import SearchPagingParams, FilterParams from app.Models.search_result import SearchResult -from app.Services import db_context -from app.Services import transformers_service from app.Services.authentication import force_access_token_verify +from app.Services.provider import db_context, transformers_service from app.config import config from app.util.calculate_vectors_cosine import calculate_vectors_cosine diff --git a/app/Models/img_data.py b/app/Models/img_data.py index 4e6b9f8..f82ba47 100644 --- a/app/Models/img_data.py +++ b/app/Models/img_data.py @@ -18,6 +18,7 @@ class ImageData(BaseModel): height: Optional[int] = None aspect_ratio: Optional[float] = None starred: Optional[bool] = False + categories: Optional[list[str]] = [] local: Optional[bool] = False @computed_field() @@ -27,6 +28,7 @@ def ocr_text_lower(self) -> str | None: return None return self.ocr_text.lower() + @property def payload(self): result = self.model_dump(exclude={'id', 'index_date'}) diff --git a/app/Services/__init__.py b/app/Services/__init__.py index 9bfd075..e69de29 100644 --- a/app/Services/__init__.py +++ b/app/Services/__init__.py @@ -1,28 +0,0 @@ -from .transformers_service import TransformersService -from .vector_db_context import VectorDbContext -from ..config import config, environment - -transformers_service = TransformersService() -db_context = VectorDbContext() -ocr_service = None - -if environment.local_indexing: - match config.ocr_search.ocr_module: - case "easyocr": - from .ocr_services import EasyOCRService - - ocr_service = EasyOCRService() - case "easypaddleocr": - from .ocr_services import EasyPaddleOCRService - - ocr_service = EasyPaddleOCRService() - case "paddleocr": - from .ocr_services import PaddleOCRService - - ocr_service = PaddleOCRService() - case _: - raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.") -else: - from .ocr_services import DisabledOCRService - - ocr_service = DisabledOCRService() diff --git a/app/Services/authentication.py b/app/Services/authentication.py index 953ec1e..e8143b7 100644 --- a/app/Services/authentication.py +++ b/app/Services/authentication.py @@ -7,9 +7,7 @@ def verify_access_token(token: str | None) -> bool: - if not config.access_protected: - return True - return token is not None and token == config.access_token + return (not config.access_protected) or (token is not None and token == config.access_token) def permissive_access_token_verify( diff --git a/app/Services/index_service.py b/app/Services/index_service.py new file mode 100644 index 0000000..b9c491d --- /dev/null +++ b/app/Services/index_service.py @@ -0,0 +1,38 @@ +from PIL import Image + +from app.Models.img_data import ImageData +from app.Services.ocr_services import OCRService +from app.Services.transformers_service import TransformersService +from app.Services.vector_db_context import VectorDbContext +from app.config import config + + +class IndexService: + def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext): + self._ocr_service = ocr_service + self._transformers_service = transformers_service + self._db_context = db_context + + def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False): + image_data.width = image.width + image_data.height = image.height + image_data.aspect_ratio = float(image.width) / image.height + + if image.mode != 'RGB': + image = image.convert('RGB') # to reduce convert in next steps + image_data.image_vector = self._transformers_service.get_image_vector(image) + if not skip_ocr and config.ocr_search.enable: + image_data.ocr_text = self._ocr_service.ocr_interface(image) + if image_data.ocr_text != "": + image_data.text_contain_vector = self._transformers_service.get_bert_vector(image_data.ocr_text) + else: + image_data.ocr_text = None + + async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False): + self._prepare_image(image, image_data, skip_ocr) + await self._db_context.insertItems([image_data]) + + async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], skip_ocr=False): + for i in range(len(image)): + self._prepare_image(image[i], image_data[i], skip_ocr) + await self._db_context.insertItems(image_data) diff --git a/app/Services/provider.py b/app/Services/provider.py new file mode 100644 index 0000000..12604d7 --- /dev/null +++ b/app/Services/provider.py @@ -0,0 +1,31 @@ +from .index_service import IndexService +from .transformers_service import TransformersService +from .vector_db_context import VectorDbContext +from ..config import config, environment + +transformers_service = TransformersService() +db_context = VectorDbContext() +ocr_service = None + +if environment.local_indexing: + match config.ocr_search.ocr_module: + case "easyocr": + from .ocr_services import EasyOCRService + + ocr_service = EasyOCRService() + case "easypaddleocr": + from .ocr_services import EasyPaddleOCRService + + ocr_service = EasyPaddleOCRService() + case "paddleocr": + from .ocr_services import PaddleOCRService + + ocr_service = PaddleOCRService() + case _: + raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.") +else: + from .ocr_services import DisabledOCRService + + ocr_service = DisabledOCRService() + +index_service = IndexService(ocr_service, transformers_service, db_context) diff --git a/app/Services/transformers_service.py b/app/Services/transformers_service.py index edd104a..e745832 100644 --- a/app/Services/transformers_service.py +++ b/app/Services/transformers_service.py @@ -18,12 +18,12 @@ def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Using device: {}; CLIP Model: {}, BERT Model: {}", self.device, config.clip.model, config.ocr_search.bert_model) - self.clip_model = CLIPModel.from_pretrained(config.clip.model).to(self.device) - self.clip_processor = CLIPProcessor.from_pretrained(config.clip.model) + self._clip_model = CLIPModel.from_pretrained(config.clip.model).to(self.device) + self._clip_processor = CLIPProcessor.from_pretrained(config.clip.model) logger.success("CLIP Model loaded successfully") if config.ocr_search.enable: - self.bert_model = BertModel.from_pretrained(config.ocr_search.bert_model).to(self.device) - self.bert_tokenizer = BertTokenizer.from_pretrained(config.ocr_search.bert_model) + self._bert_model = BertModel.from_pretrained(config.ocr_search.bert_model).to(self.device) + self._bert_tokenizer = BertTokenizer.from_pretrained(config.ocr_search.bert_model) logger.success("BERT Model loaded successfully") else: logger.info("OCR search is disabled. Skipping OCR and BERT model loading.") @@ -34,11 +34,10 @@ def get_image_vector(self, image: Image.Image) -> ndarray: image = image.convert("RGB") logger.info("Processing image...") start_time = time() - inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device) + inputs = self._clip_processor(images=image, return_tensors="pt").to(self.device) logger.success("Image processed, now inferencing with CLIP model...") - outputs: FloatTensor = self.clip_model.get_image_features(**inputs) + outputs: FloatTensor = self._clip_model.get_image_features(**inputs) logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time) - logger.info("Norm: {}", outputs.norm(dim=-1).item()) outputs /= outputs.norm(dim=-1, keepdim=True) return outputs.numpy(force=True).reshape(-1) @@ -46,11 +45,10 @@ def get_image_vector(self, image: Image.Image) -> ndarray: def get_text_vector(self, text: str) -> ndarray: logger.info("Processing text...") start_time = time() - inputs = self.clip_processor(text=text, return_tensors="pt").to(self.device) + inputs = self._clip_processor(text=text, return_tensors="pt").to(self.device) logger.success("Text processed, now inferencing with CLIP model...") - outputs: FloatTensor = self.clip_model.get_text_features(**inputs) + outputs: FloatTensor = self._clip_model.get_text_features(**inputs) logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time) - logger.info("Norm: {}", outputs.norm(dim=-1).item()) outputs /= outputs.norm(dim=-1, keepdim=True) return outputs.numpy(force=True).reshape(-1) @@ -58,8 +56,8 @@ def get_text_vector(self, text: str) -> ndarray: def get_bert_vector(self, text: str) -> ndarray: start_time = time() logger.info("Inferencing with BERT model...") - inputs = self.bert_tokenizer(text.strip().lower(), return_tensors="pt").to(self.device) - outputs = self.bert_model(**inputs) + inputs = self._bert_tokenizer(text.strip().lower(), return_tensors="pt").to(self.device) + outputs = self._bert_model(**inputs) vector = outputs.last_hidden_state.mean(dim=1).squeeze() logger.success("BERT inference done. Time elapsed: {:.2f}s", time() - start_time) return vector.cpu().numpy() diff --git a/requirements.dev.txt b/requirements.dev.txt new file mode 100644 index 0000000..04bb894 --- /dev/null +++ b/requirements.dev.txt @@ -0,0 +1,4 @@ +# Requirements for development and testing + +pytest +pylint \ No newline at end of file diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/db_migrations.py b/scripts/db_migrations.py index 86805a9..919a823 100644 --- a/scripts/db_migrations.py +++ b/scripts/db_migrations.py @@ -1,6 +1,6 @@ from loguru import logger -from app.Services import db_context, transformers_service +from app.Services.provider import db_context, transformers_service CURRENT_VERSION = 2 diff --git a/scripts/local_create_thumbnail.py b/scripts/local_create_thumbnail.py index f9f2a42..04c768f 100644 --- a/scripts/local_create_thumbnail.py +++ b/scripts/local_create_thumbnail.py @@ -4,8 +4,9 @@ from PIL import Image from loguru import logger -from app.Services import db_context +from app.Services.provider import db_context from app.config import config +from .local_utility import gather_valid_files async def main(): @@ -14,7 +15,7 @@ async def main(): if not static_thumb_path.exists(): static_thumb_path.mkdir() count = 0 - for item in static_path.glob('*.*'): + for item in gather_valid_files(static_path, '*.*'): count += 1 logger.info("[{}] Processing {}", str(count), str(item.relative_to(static_path))) size = item.stat().st_size @@ -22,9 +23,6 @@ async def main(): logger.warning("File size too small: {}. Skip...", size) continue try: - if item.suffix not in ['.jpg', '.png', '.jpeg']: - logger.warning("Unsupported file type: {}. Skip...", item.suffix) - continue if (static_thumb_path / f'{item.stem}.webp').exists(): logger.warning("Thumbnail for {} already exists. Skip...", item.stem) continue diff --git a/scripts/local_indexing.py b/scripts/local_indexing.py index 343fd1e..74e0920 100644 --- a/scripts/local_indexing.py +++ b/scripts/local_indexing.py @@ -1,4 +1,3 @@ -import argparse from datetime import datetime from pathlib import Path from shutil import copy2 @@ -9,77 +8,40 @@ from loguru import logger from app.Models.img_data import ImageData -from app.Services import transformers_service, db_context, ocr_service +from app.Services.provider import index_service from app.config import config +from .local_utility import gather_valid_files -def parse_args(): - parser = argparse.ArgumentParser(description='Create Qdrant collection') - parser.add_argument('--copy-from', dest="local_index_target_dir", type=str, required=True, - help="Copy from this directory") - return parser.parse_args() - - -def copy_and_index(file_path: Path) -> ImageData | None: +async def copy_and_index(file_path: Path): try: img = Image.open(file_path) except PIL.UnidentifiedImageError as e: logger.error("Error when opening image {}: {}", file_path, e) - return None + return image_id = uuid4() img_ext = file_path.suffix - image_ocr_result = None - text_contain_vector = None - [width, height] = img.size - try: - image_vector = transformers_service.get_image_vector(img) - if config.ocr_search.enable: - image_ocr_result = ocr_service.ocr_interface(img) # This will modify img if you use preprocess! - if image_ocr_result != "": - text_contain_vector = transformers_service.get_bert_vector(image_ocr_result) - else: - image_ocr_result = None - except Exception as e: - logger.error("Error when processing image {}: {}", file_path, e) - return None imgdata = ImageData(id=image_id, url=f'/static/{image_id}{img_ext}', - image_vector=image_vector, - text_contain_vector=text_contain_vector, index_date=datetime.now(), - width=width, - height=height, - aspect_ratio=float(width) / height, - ocr_text=image_ocr_result, local=True) - + try: + await index_service.index_image(img, imgdata) + except Exception as e: + logger.error("Error when processing image {}: {}", file_path, e) + return # copy to static copy2(file_path, Path(config.static_file.path) / f'{image_id}{img_ext}') - return imgdata @logger.catch() async def main(args): root = Path(args.local_index_target_dir) static_path = Path(config.static_file.path) - if not static_path.exists(): - static_path.mkdir() - buffer = [] + static_path.mkdir(exist_ok=True) counter = 0 - for item in root.glob('**/*.*'): + for item in gather_valid_files(root): counter += 1 logger.info("[{}] Indexing {}", str(counter), str(item.relative_to(root))) - if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']: - imgdata = copy_and_index(item) - if imgdata is not None: - buffer.append(imgdata) - if len(buffer) >= 20: - logger.info("Upload {} element to database", len(buffer)) - await db_context.insertItems(buffer) - buffer.clear() - else: - logger.warning("Unsupported file type: {}. Skip...", item.suffix) - if len(buffer) > 0: - logger.info("Upload {} element to database", len(buffer)) - await db_context.insertItems(buffer) - logger.success("Indexing completed! {} images indexed", counter) + await copy_and_index(item) + logger.success("Indexing completed! {} images indexed", counter) diff --git a/scripts/local_utility.py b/scripts/local_utility.py new file mode 100644 index 0000000..6272e08 --- /dev/null +++ b/scripts/local_utility.py @@ -0,0 +1,11 @@ +from pathlib import Path + +from loguru import logger + + +def gather_valid_files(root: Path, pattern: str = '**/*.*'): + for item in root.glob(pattern): + if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']: + yield item + else: + logger.warning("Unsupported file type: {}. Skip...", item.suffix)