Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor indexing to a standalone service and decouple of local directories #10

Merged
merged 9 commits into from
Dec 29, 2023
Merged
2 changes: 1 addition & 1 deletion app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from app.Models.admin_api_model import ImageOptUpdateModel
from app.Models.api_response.base import NekoProtocol
from app.Services import db_context
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import db_context
from app.Services.vector_db_context import PointNotFoundError
from app.config import config
from app.util import directories
Expand Down
3 changes: 1 addition & 2 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from app.Models.api_response.search_api_response import SearchApiResponse
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
from app.Services import db_context
from app.Services import transformers_service
from app.Services.authentication import force_access_token_verify
from app.Services.provider import db_context, transformers_service
from app.config import config
from app.util.calculate_vectors_cosine import calculate_vectors_cosine

Expand Down
2 changes: 2 additions & 0 deletions app/Models/img_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ImageData(BaseModel):
height: Optional[int] = None
aspect_ratio: Optional[float] = None
starred: Optional[bool] = False
categories: Optional[list[str]] = []
local: Optional[bool] = False

@computed_field()
Expand All @@ -27,6 +28,7 @@ def ocr_text_lower(self) -> str | None:
return None
return self.ocr_text.lower()


@property
def payload(self):
result = self.model_dump(exclude={'id', 'index_date'})
Expand Down
28 changes: 0 additions & 28 deletions app/Services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +0,0 @@
from .transformers_service import TransformersService
from .vector_db_context import VectorDbContext
from ..config import config, environment

transformers_service = TransformersService()
db_context = VectorDbContext()
ocr_service = None

if environment.local_indexing:
match config.ocr_search.ocr_module:
case "easyocr":
from .ocr_services import EasyOCRService

ocr_service = EasyOCRService()
case "easypaddleocr":
from .ocr_services import EasyPaddleOCRService

ocr_service = EasyPaddleOCRService()
case "paddleocr":
from .ocr_services import PaddleOCRService

ocr_service = PaddleOCRService()
case _:
raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.")
else:
from .ocr_services import DisabledOCRService

ocr_service = DisabledOCRService()
4 changes: 1 addition & 3 deletions app/Services/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@


def verify_access_token(token: str | None) -> bool:
if not config.access_protected:
return True
return token is not None and token == config.access_token
return (not config.access_protected) or (token is not None and token == config.access_token)


def permissive_access_token_verify(
Expand Down
38 changes: 38 additions & 0 deletions app/Services/index_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from PIL import Image

from app.Models.img_data import ImageData
from app.Services.ocr_services import OCRService
from app.Services.transformers_service import TransformersService
from app.Services.vector_db_context import VectorDbContext
from app.config import config


class IndexService:
def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext):
self._ocr_service = ocr_service
self._transformers_service = transformers_service
self._db_context = db_context

def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
image_data.width = image.width
image_data.height = image.height
image_data.aspect_ratio = float(image.width) / image.height

if image.mode != 'RGB':
image = image.convert('RGB') # to reduce convert in next steps
image_data.image_vector = self._transformers_service.get_image_vector(image)
if not skip_ocr and config.ocr_search.enable:
image_data.ocr_text = self._ocr_service.ocr_interface(image)
if image_data.ocr_text != "":
image_data.text_contain_vector = self._transformers_service.get_bert_vector(image_data.ocr_text)
else:
image_data.ocr_text = None

async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
self._prepare_image(image, image_data, skip_ocr)
await self._db_context.insertItems([image_data])

async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], skip_ocr=False):
for i in range(len(image)):
self._prepare_image(image[i], image_data[i], skip_ocr)
await self._db_context.insertItems(image_data)
31 changes: 31 additions & 0 deletions app/Services/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from .index_service import IndexService
from .transformers_service import TransformersService
from .vector_db_context import VectorDbContext
from ..config import config, environment

transformers_service = TransformersService()
db_context = VectorDbContext()
ocr_service = None

if environment.local_indexing:
match config.ocr_search.ocr_module:
case "easyocr":
from .ocr_services import EasyOCRService

ocr_service = EasyOCRService()
case "easypaddleocr":
from .ocr_services import EasyPaddleOCRService

ocr_service = EasyPaddleOCRService()
case "paddleocr":
from .ocr_services import PaddleOCRService

ocr_service = PaddleOCRService()
case _:
raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.")
else:
from .ocr_services import DisabledOCRService

ocr_service = DisabledOCRService()

index_service = IndexService(ocr_service, transformers_service, db_context)
22 changes: 10 additions & 12 deletions app/Services/transformers_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Using device: {}; CLIP Model: {}, BERT Model: {}",
self.device, config.clip.model, config.ocr_search.bert_model)
self.clip_model = CLIPModel.from_pretrained(config.clip.model).to(self.device)
self.clip_processor = CLIPProcessor.from_pretrained(config.clip.model)
self._clip_model = CLIPModel.from_pretrained(config.clip.model).to(self.device)
self._clip_processor = CLIPProcessor.from_pretrained(config.clip.model)
logger.success("CLIP Model loaded successfully")
if config.ocr_search.enable:
self.bert_model = BertModel.from_pretrained(config.ocr_search.bert_model).to(self.device)
self.bert_tokenizer = BertTokenizer.from_pretrained(config.ocr_search.bert_model)
self._bert_model = BertModel.from_pretrained(config.ocr_search.bert_model).to(self.device)
self._bert_tokenizer = BertTokenizer.from_pretrained(config.ocr_search.bert_model)
logger.success("BERT Model loaded successfully")
else:
logger.info("OCR search is disabled. Skipping OCR and BERT model loading.")
Expand All @@ -34,32 +34,30 @@ def get_image_vector(self, image: Image.Image) -> ndarray:
image = image.convert("RGB")
logger.info("Processing image...")
start_time = time()
inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
inputs = self._clip_processor(images=image, return_tensors="pt").to(self.device)
logger.success("Image processed, now inferencing with CLIP model...")
outputs: FloatTensor = self.clip_model.get_image_features(**inputs)
outputs: FloatTensor = self._clip_model.get_image_features(**inputs)
logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time)
logger.info("Norm: {}", outputs.norm(dim=-1).item())
outputs /= outputs.norm(dim=-1, keepdim=True)
return outputs.numpy(force=True).reshape(-1)

@no_grad()
def get_text_vector(self, text: str) -> ndarray:
logger.info("Processing text...")
start_time = time()
inputs = self.clip_processor(text=text, return_tensors="pt").to(self.device)
inputs = self._clip_processor(text=text, return_tensors="pt").to(self.device)
logger.success("Text processed, now inferencing with CLIP model...")
outputs: FloatTensor = self.clip_model.get_text_features(**inputs)
outputs: FloatTensor = self._clip_model.get_text_features(**inputs)
logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time)
logger.info("Norm: {}", outputs.norm(dim=-1).item())
outputs /= outputs.norm(dim=-1, keepdim=True)
return outputs.numpy(force=True).reshape(-1)

@no_grad()
def get_bert_vector(self, text: str) -> ndarray:
start_time = time()
logger.info("Inferencing with BERT model...")
inputs = self.bert_tokenizer(text.strip().lower(), return_tensors="pt").to(self.device)
outputs = self.bert_model(**inputs)
inputs = self._bert_tokenizer(text.strip().lower(), return_tensors="pt").to(self.device)
outputs = self._bert_model(**inputs)
vector = outputs.last_hidden_state.mean(dim=1).squeeze()
logger.success("BERT inference done. Time elapsed: {:.2f}s", time() - start_time)
return vector.cpu().numpy()
Expand Down
4 changes: 4 additions & 0 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Requirements for development and testing

pytest
pylint
Empty file added scripts/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion scripts/db_migrations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from loguru import logger

from app.Services import db_context, transformers_service
from app.Services.provider import db_context, transformers_service

CURRENT_VERSION = 2

Expand Down
8 changes: 3 additions & 5 deletions scripts/local_create_thumbnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from PIL import Image
from loguru import logger

from app.Services import db_context
from app.Services.provider import db_context
from app.config import config
from .local_utility import gather_valid_files


async def main():
Expand All @@ -14,17 +15,14 @@ async def main():
if not static_thumb_path.exists():
static_thumb_path.mkdir()
count = 0
for item in static_path.glob('*.*'):
for item in gather_valid_files(static_path, '*.*'):
count += 1
logger.info("[{}] Processing {}", str(count), str(item.relative_to(static_path)))
size = item.stat().st_size
if size < 1024 * 500:
logger.warning("File size too small: {}. Skip...", size)
continue
try:
if item.suffix not in ['.jpg', '.png', '.jpeg']:
logger.warning("Unsupported file type: {}. Skip...", item.suffix)
continue
if (static_thumb_path / f'{item.stem}.webp').exists():
logger.warning("Thumbnail for {} already exists. Skip...", item.stem)
continue
Expand Down
64 changes: 13 additions & 51 deletions scripts/local_indexing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
from datetime import datetime
from pathlib import Path
from shutil import copy2
Expand All @@ -9,77 +8,40 @@
from loguru import logger

from app.Models.img_data import ImageData
from app.Services import transformers_service, db_context, ocr_service
from app.Services.provider import index_service
from app.config import config
from .local_utility import gather_valid_files


def parse_args():
parser = argparse.ArgumentParser(description='Create Qdrant collection')
parser.add_argument('--copy-from', dest="local_index_target_dir", type=str, required=True,
help="Copy from this directory")
return parser.parse_args()


def copy_and_index(file_path: Path) -> ImageData | None:
async def copy_and_index(file_path: Path):
try:
img = Image.open(file_path)
except PIL.UnidentifiedImageError as e:
logger.error("Error when opening image {}: {}", file_path, e)
return None
return
image_id = uuid4()
img_ext = file_path.suffix
image_ocr_result = None
text_contain_vector = None
[width, height] = img.size
try:
image_vector = transformers_service.get_image_vector(img)
if config.ocr_search.enable:
image_ocr_result = ocr_service.ocr_interface(img) # This will modify img if you use preprocess!
if image_ocr_result != "":
text_contain_vector = transformers_service.get_bert_vector(image_ocr_result)
else:
image_ocr_result = None
except Exception as e:
logger.error("Error when processing image {}: {}", file_path, e)
return None
imgdata = ImageData(id=image_id,
url=f'/static/{image_id}{img_ext}',
image_vector=image_vector,
text_contain_vector=text_contain_vector,
index_date=datetime.now(),
width=width,
height=height,
aspect_ratio=float(width) / height,
ocr_text=image_ocr_result,
local=True)

try:
await index_service.index_image(img, imgdata)
except Exception as e:
logger.error("Error when processing image {}: {}", file_path, e)
return
# copy to static
copy2(file_path, Path(config.static_file.path) / f'{image_id}{img_ext}')
return imgdata


@logger.catch()
async def main(args):
root = Path(args.local_index_target_dir)
static_path = Path(config.static_file.path)
if not static_path.exists():
static_path.mkdir()
buffer = []
static_path.mkdir(exist_ok=True)
counter = 0
for item in root.glob('**/*.*'):
for item in gather_valid_files(root):
counter += 1
logger.info("[{}] Indexing {}", str(counter), str(item.relative_to(root)))
if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']:
imgdata = copy_and_index(item)
if imgdata is not None:
buffer.append(imgdata)
if len(buffer) >= 20:
logger.info("Upload {} element to database", len(buffer))
await db_context.insertItems(buffer)
buffer.clear()
else:
logger.warning("Unsupported file type: {}. Skip...", item.suffix)
if len(buffer) > 0:
logger.info("Upload {} element to database", len(buffer))
await db_context.insertItems(buffer)
logger.success("Indexing completed! {} images indexed", counter)
await copy_and_index(item)
logger.success("Indexing completed! {} images indexed", counter)
11 changes: 11 additions & 0 deletions scripts/local_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pathlib import Path

from loguru import logger


def gather_valid_files(root: Path, pattern: str = '**/*.*'):
for item in root.glob(pattern):
if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']:
yield item
else:
logger.warning("Unsupported file type: {}. Skip...", item.suffix)