From 9cd5b16b45e3e71d83eeb050fbdcb878be219136 Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Mon, 1 Jul 2024 18:28:57 +0800 Subject: [PATCH] Skip OCR model loading when OCR search is disabled by config. --- app/Services/provider.py | 2 +- app/Services/transformers_service.py | 2 +- tests/unit/test_transformers_service.py | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/app/Services/provider.py b/app/Services/provider.py index c85f36c..986ff40 100644 --- a/app/Services/provider.py +++ b/app/Services/provider.py @@ -14,7 +14,7 @@ def __init__(self): self.db_context = VectorDbContext() self.ocr_service = None - if environment.local_indexing or config.admin_api_enable: + if config.ocr_search.enable and (environment.local_indexing or config.admin_api_enable): match config.ocr_search.ocr_module: case "easyocr": from .ocr_services import EasyOCRService diff --git a/app/Services/transformers_service.py b/app/Services/transformers_service.py index 62c9a66..7c72ab1 100644 --- a/app/Services/transformers_service.py +++ b/app/Services/transformers_service.py @@ -26,7 +26,7 @@ def __init__(self): self._bert_tokenizer = BertTokenizer.from_pretrained(config.model.bert) logger.success("BERT Model loaded successfully") else: - logger.info("OCR search is disabled. Skipping OCR and BERT model loading.") + logger.info("OCR search is disabled. Skipping BERT model loading.") @no_grad() def get_image_vector(self, image: Image.Image) -> ndarray: diff --git a/tests/unit/test_transformers_service.py b/tests/unit/test_transformers_service.py index 3c755cd..fec9abe 100644 --- a/tests/unit/test_transformers_service.py +++ b/tests/unit/test_transformers_service.py @@ -1,20 +1,18 @@ -from pathlib import Path - from PIL import Image from app.Services.transformers_service import TransformersService from app.util.calculate_vectors_cosine import calculate_vectors_cosine +from ..assets import assets_path class TestTransformersService: def setup_class(self): self.transformers_service = TransformersService() - self.assets_root = Path(__file__).parent / '..' / 'assets' def test_get_image_vector(self): - vector1 = self.transformers_service.get_image_vector(Image.open(self.assets_root / 'test_images/cat_0.jpg')) - vector2 = self.transformers_service.get_image_vector(Image.open(self.assets_root / 'test_images/cat_1.jpg')) + vector1 = self.transformers_service.get_image_vector(Image.open(assets_path / 'test_images/cat_0.jpg')) + vector2 = self.transformers_service.get_image_vector(Image.open(assets_path / 'test_images/cat_1.jpg')) assert vector1.shape == (768,) assert vector2.shape == (768,) assert calculate_vectors_cosine(vector1, vector2) > 0.8 @@ -34,5 +32,7 @@ def test_get_bert_vector(self): assert calculate_vectors_cosine(vector1, vector2) > 0.8 def test_get_bert_vector_long_text(self): - vector = self.transformers_service.get_bert_vector('我可以吞下玻璃而不伤身体' * 100) - assert vector.shape == (768,) + vector1 = self.transformers_service.get_bert_vector('The quick brown fox jumps over the lazy dog ' * 100) + vector2 = self.transformers_service.get_bert_vector('我可以吞下玻璃而不伤身体' * 100) + assert vector1.shape == (768,) + assert vector2.shape == (768,)