Skip to content

Commit

Permalink
Skip OCR model loading when OCR search is disabled by config.
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed Jul 1, 2024
1 parent a4fdb86 commit 9cd5b16
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion app/Services/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self):
self.db_context = VectorDbContext()
self.ocr_service = None

if environment.local_indexing or config.admin_api_enable:
if config.ocr_search.enable and (environment.local_indexing or config.admin_api_enable):
match config.ocr_search.ocr_module:
case "easyocr":
from .ocr_services import EasyOCRService
Expand Down
2 changes: 1 addition & 1 deletion app/Services/transformers_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self):
self._bert_tokenizer = BertTokenizer.from_pretrained(config.model.bert)
logger.success("BERT Model loaded successfully")
else:
logger.info("OCR search is disabled. Skipping OCR and BERT model loading.")
logger.info("OCR search is disabled. Skipping BERT model loading.")

@no_grad()
def get_image_vector(self, image: Image.Image) -> ndarray:
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/test_transformers_service.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from pathlib import Path

from PIL import Image

from app.Services.transformers_service import TransformersService
from app.util.calculate_vectors_cosine import calculate_vectors_cosine
from ..assets import assets_path


class TestTransformersService:

def setup_class(self):
self.transformers_service = TransformersService()
self.assets_root = Path(__file__).parent / '..' / 'assets'

def test_get_image_vector(self):
vector1 = self.transformers_service.get_image_vector(Image.open(self.assets_root / 'test_images/cat_0.jpg'))
vector2 = self.transformers_service.get_image_vector(Image.open(self.assets_root / 'test_images/cat_1.jpg'))
vector1 = self.transformers_service.get_image_vector(Image.open(assets_path / 'test_images/cat_0.jpg'))
vector2 = self.transformers_service.get_image_vector(Image.open(assets_path / 'test_images/cat_1.jpg'))
assert vector1.shape == (768,)
assert vector2.shape == (768,)
assert calculate_vectors_cosine(vector1, vector2) > 0.8
Expand All @@ -34,5 +32,7 @@ def test_get_bert_vector(self):
assert calculate_vectors_cosine(vector1, vector2) > 0.8

def test_get_bert_vector_long_text(self):
vector = self.transformers_service.get_bert_vector('我可以吞下玻璃而不伤身体' * 100)
assert vector.shape == (768,)
vector1 = self.transformers_service.get_bert_vector('The quick brown fox jumps over the lazy dog ' * 100)
vector2 = self.transformers_service.get_bert_vector('我可以吞下玻璃而不伤身体' * 100)
assert vector1.shape == (768,)
assert vector2.shape == (768,)

0 comments on commit 9cd5b16

Please sign in to comment.